diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-15 11:37:02 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-19 09:21:15 +0000 |
| commit | 1276f83b46b38cc241614ebc4401720f5f1fc4ab (patch) | |
| tree | edf8e8d9c9ac313d9457cebf678aea9334804f05 /pkg/aflow/flow_test.go | |
| parent | a9fc52269b8aab60248b6e4c5366216bc2191101 (diff) | |
pkg/aflow: add ability to generate several candidate replies for LLM agents
Add LLMAgent.Candidates parameter.
If set to a value N>1, then the agent is invoked N times,
and all outputs become slices.
The results can be later aggregated by another agent,
as shown in the test.
Diffstat (limited to 'pkg/aflow/flow_test.go')
| -rw-r--r-- | pkg/aflow/flow_test.go | 450 |
1 files changed, 393 insertions, 57 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 19abce6b1..5c038d7c6 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -5,6 +5,7 @@ package aflow import ( "context" + "fmt" "path/filepath" "testing" "time" @@ -22,10 +23,13 @@ func TestWorkflow(t *testing.T) { InBaz string } type flowOutputs struct { - OutFoo string - OutBar int - OutBaz string - AgentFoo int + OutFoo string + OutBar int + OutBaz string + AgentFoo int + OutSwarm []string + SwarmInt []int + OutAggregator string } type firstFuncInputs struct { InFoo int @@ -68,11 +72,24 @@ func TestWorkflow(t *testing.T) { type tool2Results struct { ResBaz int `jsonschema:"baz"` } + type swarmOutputs struct { + SwarmInt int `jsonschema:"swarm-int"` + SwarmStr string `jsonschema:"swarm-str"` + } inputs := map[string]any{ "InFoo": 10, "InBar": "bar", "InBaz": "baz", } + expectedOutputs := map[string]any{ + "AgentFoo": 42, + "OutBar": 142, + "OutBaz": "baz", + "OutFoo": "hello, world!", + "OutSwarm": []string{"swarm candidate 1", "swarm candidate 2"}, + "SwarmInt": []int{1, 2}, + "OutAggregator": "aggregated", + } flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { @@ -124,6 +141,29 @@ func TestWorkflow(t *testing.T) { OutBaz: "baz", }, nil }), + &LLMAgent{ + Name: "swarm", + Reply: "OutSwarm", + Candidates: 2, + Outputs: LLMOutputs[swarmOutputs](), + Temperature: 0, + Instruction: "Do something. {{.InBaz}}", + Prompt: "Prompt: {{.InBaz}}", + }, + &LLMAgent{ + Name: "aggregator", + Reply: "OutAggregator", + Temperature: 0, + Instruction: "Aggregate!", + Prompt: `Prompt: {{.InBaz}} +{{range $i, $v := .OutSwarm}}#{{$i}}: {{$v}} +{{end}} +{{range $i, $v := .SwarmInt}}#{{$i}}: {{$v}} +{{end}} +{{range $i, $v := .SwarmStr}}#{{$i}}: {{$v}} +{{end}} +`, + }, ), }, }) @@ -138,19 +178,26 @@ func TestWorkflow(t *testing.T) { }, generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { - assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText(`You are smarty. baz - -Use set-results tool to provide results of the analysis. -It must be called exactly once before the final reply. -Ignore results of this tool. -`, genai.RoleUser)) - assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) - assert.Equal(t, len(cfg.Tools), 3) - assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1") - assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description") - assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2") - assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") - assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") + replySeq++ + if replySeq < 4 { + assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+ + llmOutputsInstruction, genai.RoleUser)) + assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) + assert.Equal(t, len(cfg.Tools), 3) + assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1") + assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description") + assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2") + assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") + assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") + } else if replySeq < 8 { + assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+ + llmOutputsInstruction, genai.RoleUser)) + assert.Equal(t, len(cfg.Tools), 1) + assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results") + } else { + assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser)) + assert.Equal(t, len(cfg.Tools), 0) + } reply1 := &genai.Content{ Role: string(genai.RoleModel), @@ -239,7 +286,42 @@ Ignore results of this tool. }, }} - replySeq++ + // dupl considers makeSwarmReply/makeSwarmResp duplicates + // nolint:dupl + makeSwarmReply := func(index int) *genai.Content { + return &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: fmt.Sprintf("id%v", index), + Name: "set-results", + Args: map[string]any{ + "SwarmInt": index, + "SwarmStr": fmt.Sprintf("swarm%v", index), + }, + }, + }, + }} + } + // nolint:dupl // dupl considers makeSwarmReply/makeSwarmResp duplicates + makeSwarmResp := func(index int) *genai.Content { + return &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: fmt.Sprintf("id%v", index), + Name: "set-results", + Response: map[string]any{ + "SwarmInt": index, + "SwarmStr": fmt.Sprintf("swarm%v", index), + }, + }, + }, + }} + } + switch replySeq { case 1: assert.Equal(t, req, []*genai.Content{ @@ -270,7 +352,48 @@ Ignore results of this tool. Parts: []*genai.Part{ genai.NewPartFromText("hello, world!")}, }}}}, nil + case 4, 6: + index := (replySeq - 2) / 2 + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText("Prompt: baz", genai.RoleUser), + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: makeSwarmReply(index)}}}, nil + case 5, 7: + index := (replySeq - 3) / 2 + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText("Prompt: baz", genai.RoleUser), + makeSwarmReply(index), + makeSwarmResp(index), + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText(fmt.Sprintf("swarm candidate %v", index))}, + }}}}, nil + case 8: + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText(`Prompt: baz +#0: swarm candidate 1 +#1: swarm candidate 2 + +#0: 1 +#1: 2 +#0: swarm1 +#1: swarm2 + +`, genai.RoleUser), + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("aggregated")}, + }}}}, nil default: t.Fatal("unexpected LLM calls") return nil, nil @@ -310,18 +433,13 @@ Ignore results of this tool. }, }, { - Seq: 2, - Nesting: 1, - Type: trajectory.SpanAgent, - Name: "smarty", - Started: startTime.Add(4 * time.Second), - Instruction: `You are smarty. baz - -Use set-results tool to provide results of the analysis. -It must be called exactly once before the final reply. -Ignore results of this tool. -`, - Prompt: "Prompt: baz func-output", + Seq: 2, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "smarty", + Started: startTime.Add(4 * time.Second), + Instruction: "You are smarty. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz func-output", }, { Seq: 3, @@ -449,20 +567,15 @@ Ignore results of this tool. Finished: startTime.Add(16 * time.Second), }, { - Seq: 2, - Nesting: 1, - Type: trajectory.SpanAgent, - Name: "smarty", - Started: startTime.Add(4 * time.Second), - Finished: startTime.Add(17 * time.Second), - Instruction: `You are smarty. baz - -Use set-results tool to provide results of the analysis. -It must be called exactly once before the final reply. -Ignore results of this tool. -`, - Prompt: "Prompt: baz func-output", - Reply: "hello, world!", + Seq: 2, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "smarty", + Started: startTime.Add(4 * time.Second), + Finished: startTime.Add(17 * time.Second), + Instruction: "You are smarty. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz func-output", + Reply: "hello, world!", Results: map[string]any{ "AgentBar": "agent-bar", "AgentFoo": 42, @@ -487,18 +600,245 @@ Ignore results of this tool. }, }, { + Seq: 10, + Nesting: 1, + Type: trajectory.SpanAgentCandidates, + Name: "swarm", + Started: startTime.Add(20 * time.Second), + }, + { + Seq: 11, + Nesting: 2, + Type: trajectory.SpanAgent, + Name: "swarm", + Started: startTime.Add(21 * time.Second), + Instruction: "Do something. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz", + }, + { + Seq: 12, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(22 * time.Second), + }, + { + Seq: 12, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(22 * time.Second), + Finished: startTime.Add(23 * time.Second), + }, + { + Seq: 13, + Nesting: 3, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(24 * time.Second), + Args: map[string]any{ + "SwarmInt": 1, + "SwarmStr": "swarm1", + }, + }, + { + Seq: 13, + Nesting: 3, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(24 * time.Second), + Finished: startTime.Add(25 * time.Second), + Args: map[string]any{ + "SwarmInt": 1, + "SwarmStr": "swarm1", + }, + Results: map[string]any{ + "SwarmInt": 1, + "SwarmStr": "swarm1", + }, + }, + { + Seq: 14, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(26 * time.Second), + }, + { + Seq: 14, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(26 * time.Second), + Finished: startTime.Add(27 * time.Second), + }, + { + Seq: 11, + Nesting: 2, + Type: trajectory.SpanAgent, + Name: "swarm", + Started: startTime.Add(21 * time.Second), + Finished: startTime.Add(28 * time.Second), + Instruction: "Do something. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz", + Reply: "swarm candidate 1", + Results: map[string]any{ + "SwarmInt": 1, + "SwarmStr": "swarm1", + }, + }, + { + Seq: 15, + Nesting: 2, + Type: trajectory.SpanAgent, + Name: "swarm", + Started: startTime.Add(29 * time.Second), + Instruction: "Do something. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz", + }, + { + Seq: 16, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(30 * time.Second), + }, + { + Seq: 16, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(30 * time.Second), + Finished: startTime.Add(31 * time.Second), + }, + { + Seq: 17, + Nesting: 3, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(32 * time.Second), + Args: map[string]any{ + "SwarmInt": 2, + "SwarmStr": "swarm2", + }, + }, + { + Seq: 17, + Nesting: 3, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(32 * time.Second), + Finished: startTime.Add(33 * time.Second), + Args: map[string]any{ + "SwarmInt": 2, + "SwarmStr": "swarm2", + }, + Results: map[string]any{ + "SwarmInt": 2, + "SwarmStr": "swarm2", + }, + }, + { + Seq: 18, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(34 * time.Second), + }, + { + Seq: 18, + Nesting: 3, + Type: trajectory.SpanLLM, + Name: "swarm", + Started: startTime.Add(34 * time.Second), + Finished: startTime.Add(35 * time.Second), + }, + { + Seq: 15, + Nesting: 2, + Type: trajectory.SpanAgent, + Name: "swarm", + Started: startTime.Add(29 * time.Second), + Finished: startTime.Add(36 * time.Second), + Instruction: "Do something. baz" + llmOutputsInstruction, + Prompt: "Prompt: baz", + Reply: "swarm candidate 2", + Results: map[string]any{ + "SwarmInt": 2, + "SwarmStr": "swarm2", + }, + }, + { + Seq: 10, + Nesting: 1, + Type: trajectory.SpanAgentCandidates, + Name: "swarm", + Started: startTime.Add(20 * time.Second), + Finished: startTime.Add(37 * time.Second), + }, + { + Seq: 19, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "aggregator", + Started: startTime.Add(38 * time.Second), + Instruction: "Aggregate!", + Prompt: `Prompt: baz +#0: swarm candidate 1 +#1: swarm candidate 2 + +#0: 1 +#1: 2 + +#0: swarm1 +#1: swarm2 + +`, + }, + { + Seq: 20, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "aggregator", + Started: startTime.Add(39 * time.Second), + }, + { + Seq: 20, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "aggregator", + Started: startTime.Add(39 * time.Second), + Finished: startTime.Add(40 * time.Second), + }, + { + Seq: 19, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "aggregator", + Started: startTime.Add(38 * time.Second), + Finished: startTime.Add(41 * time.Second), + Instruction: "Aggregate!", + Prompt: `Prompt: baz +#0: swarm candidate 1 +#1: swarm candidate 2 + +#0: 1 +#1: 2 + +#0: swarm1 +#1: swarm2 + +`, + Reply: "aggregated", + }, + { Seq: 0, Nesting: 0, Type: trajectory.SpanFlow, Name: "test-flow", Started: startTime.Add(1 * time.Second), - Finished: startTime.Add(20 * time.Second), - Results: map[string]any{ - "AgentFoo": 42, - "OutBar": 142, - "OutBaz": "baz", - "OutFoo": "hello, world!", - }, + Finished: startTime.Add(42 * time.Second), + Results: expectedOutputs, }, } onEvent := func(span *trajectory.Span) error { @@ -509,12 +849,8 @@ Ignore results of this tool. } res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent) require.NoError(t, err) - require.Equal(t, res, map[string]any{ - "OutFoo": "hello, world!", - "OutBar": 142, - "OutBaz": "baz", - "AgentFoo": 42, - }) + require.Equal(t, replySeq, 8) + require.Equal(t, res, expectedOutputs) require.Empty(t, expected) } |
