diff options
Diffstat (limited to 'pkg/aflow/flow_test.go')
| -rw-r--r-- | pkg/aflow/flow_test.go | 42 |
1 files changed, 35 insertions, 7 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 5c038d7c6..97e2dd93a 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Name: "flow", - Model: "model", + Name: "flow", Root: NewPipeline( NewFuncAction("func-action", func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { @@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "smarty", + Model: "model1", Reply: "OutFoo", Outputs: LLMOutputs[agentOutputs](), Temperature: 0, @@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "swarm", + Model: "model2", Reply: "OutSwarm", Candidates: 2, Outputs: LLMOutputs[swarmOutputs](), @@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) { }, &LLMAgent{ Name: "aggregator", + Model: "model3", Reply: "OutAggregator", Temperature: 0, Instruction: "Aggregate!", @@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) { stubTime = stubTime.Add(time.Second) return stubTime }, - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { replySeq++ if replySeq < 4 { + assert.Equal(t, model, "model1") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) @@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) { assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") } else if replySeq < 8 { + assert.Equal(t, model, "model2") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 1) assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results") } else { + assert.Equal(t, model, "model3") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 0) } @@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, Prompt: "Prompt: baz func-output", @@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), }, { @@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), Finished: startTime.Add(6 * time.Second), Thoughts: "I am thinking I need to call some tools", @@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), }, { @@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), Finished: startTime.Add(12 * time.Second), Thoughts: "Completly blank.Whatever.", @@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), }, { @@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), Finished: startTime.Add(16 * time.Second), }, @@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Finished: startTime.Add(17 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, @@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), }, { @@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), Finished: startTime.Add(23 * time.Second), }, @@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), }, { @@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), Finished: startTime.Add(27 * time.Second), }, @@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Finished: startTime.Add(28 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), }, { @@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), Finished: startTime.Add(31 * time.Second), }, @@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), }, { @@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), Finished: startTime.Add(35 * time.Second), }, @@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Finished: startTime.Add(36 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Instruction: "Aggregate!", Prompt: `Prompt: baz @@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), }, { @@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), Finished: startTime.Add(40 * time.Second), }, @@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Finished: startTime.Add(41 * time.Second), Instruction: "Aggregate!", @@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) { expected = expected[1:] return nil } - res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.NoError(t, err) require.Equal(t, replySeq, 8) require.Equal(t, res, expectedOutputs) @@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Model: "model", Root: NewFuncAction("func-action", func(ctx *Context, args flowInputs) (flowOutputs, error) { return flowOutputs{}, nil @@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) { }) require.NoError(t, err) stub := &stubContext{ - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { return nil, nil }, @@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) { cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) require.NoError(t, err) onEvent := func(span *trajectory.Span) error { return nil } - _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.Equal(t, err.Error(), "flow inputs are missing:"+ " field InBar is not present when converting map to aflow.flowInputs") } |
