aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/flow_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/flow_test.go')
-rw-r--r--pkg/aflow/flow_test.go42
1 files changed, 35 insertions, 7 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 5c038d7c6..97e2dd93a 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Name: "flow",
- Model: "model",
+ Name: "flow",
Root: NewPipeline(
NewFuncAction("func-action",
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "smarty",
+ Model: "model1",
Reply: "OutFoo",
Outputs: LLMOutputs[agentOutputs](),
Temperature: 0,
@@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "swarm",
+ Model: "model2",
Reply: "OutSwarm",
Candidates: 2,
Outputs: LLMOutputs[swarmOutputs](),
@@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) {
},
&LLMAgent{
Name: "aggregator",
+ Model: "model3",
Reply: "OutAggregator",
Temperature: 0,
Instruction: "Aggregate!",
@@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) {
stubTime = stubTime.Add(time.Second)
return stubTime
},
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
replySeq++
if replySeq < 4 {
+ assert.Equal(t, model, "model1")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
@@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) {
assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
} else if replySeq < 8 {
+ assert.Equal(t, model, "model2")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 1)
assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results")
} else {
+ assert.Equal(t, model, "model3")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 0)
}
@@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz func-output",
@@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
},
{
@@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
Finished: startTime.Add(6 * time.Second),
Thoughts: "I am thinking I need to call some tools",
@@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
},
{
@@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
Finished: startTime.Add(12 * time.Second),
Thoughts: "Completly blank.Whatever.",
@@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
},
{
@@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
Finished: startTime.Add(16 * time.Second),
},
@@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Finished: startTime.Add(17 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
@@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
},
{
@@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
Finished: startTime.Add(23 * time.Second),
},
@@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
},
{
@@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
Finished: startTime.Add(27 * time.Second),
},
@@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Finished: startTime.Add(28 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
},
{
@@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
Finished: startTime.Add(31 * time.Second),
},
@@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
},
{
@@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
Finished: startTime.Add(35 * time.Second),
},
@@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Finished: startTime.Add(36 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Instruction: "Aggregate!",
Prompt: `Prompt: baz
@@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
},
{
@@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
Finished: startTime.Add(40 * time.Second),
},
@@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Finished: startTime.Add(41 * time.Second),
Instruction: "Aggregate!",
@@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) {
expected = expected[1:]
return nil
}
- res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.NoError(t, err)
require.Equal(t, replySeq, 8)
require.Equal(t, res, expectedOutputs)
@@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Model: "model",
Root: NewFuncAction("func-action",
func(ctx *Context, args flowInputs) (flowOutputs, error) {
return flowOutputs{}, nil
@@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) {
})
require.NoError(t, err)
stub := &stubContext{
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
return nil, nil
},
@@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) {
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
require.NoError(t, err)
onEvent := func(span *trajectory.Span) error { return nil }
- _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.Equal(t, err.Error(), "flow inputs are missing:"+
" field InBar is not present when converting map to aflow.flowInputs")
}