aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/flow_test.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-15 11:37:02 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-19 09:21:15 +0000
commit1276f83b46b38cc241614ebc4401720f5f1fc4ab (patch)
treeedf8e8d9c9ac313d9457cebf678aea9334804f05 /pkg/aflow/flow_test.go
parenta9fc52269b8aab60248b6e4c5366216bc2191101 (diff)
pkg/aflow: add ability to generate several candidate replies for LLM agents
Add LLMAgent.Candidates parameter. If set to a value N>1, then the agent is invoked N times, and all outputs become slices. The results can be later aggregated by another agent, as shown in the test.
Diffstat (limited to 'pkg/aflow/flow_test.go')
-rw-r--r--pkg/aflow/flow_test.go450
1 files changed, 393 insertions, 57 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 19abce6b1..5c038d7c6 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -5,6 +5,7 @@ package aflow
import (
"context"
+ "fmt"
"path/filepath"
"testing"
"time"
@@ -22,10 +23,13 @@ func TestWorkflow(t *testing.T) {
InBaz string
}
type flowOutputs struct {
- OutFoo string
- OutBar int
- OutBaz string
- AgentFoo int
+ OutFoo string
+ OutBar int
+ OutBaz string
+ AgentFoo int
+ OutSwarm []string
+ SwarmInt []int
+ OutAggregator string
}
type firstFuncInputs struct {
InFoo int
@@ -68,11 +72,24 @@ func TestWorkflow(t *testing.T) {
type tool2Results struct {
ResBaz int `jsonschema:"baz"`
}
+ type swarmOutputs struct {
+ SwarmInt int `jsonschema:"swarm-int"`
+ SwarmStr string `jsonschema:"swarm-str"`
+ }
inputs := map[string]any{
"InFoo": 10,
"InBar": "bar",
"InBaz": "baz",
}
+ expectedOutputs := map[string]any{
+ "AgentFoo": 42,
+ "OutBar": 142,
+ "OutBaz": "baz",
+ "OutFoo": "hello, world!",
+ "OutSwarm": []string{"swarm candidate 1", "swarm candidate 2"},
+ "SwarmInt": []int{1, 2},
+ "OutAggregator": "aggregated",
+ }
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
@@ -124,6 +141,29 @@ func TestWorkflow(t *testing.T) {
OutBaz: "baz",
}, nil
}),
+ &LLMAgent{
+ Name: "swarm",
+ Reply: "OutSwarm",
+ Candidates: 2,
+ Outputs: LLMOutputs[swarmOutputs](),
+ Temperature: 0,
+ Instruction: "Do something. {{.InBaz}}",
+ Prompt: "Prompt: {{.InBaz}}",
+ },
+ &LLMAgent{
+ Name: "aggregator",
+ Reply: "OutAggregator",
+ Temperature: 0,
+ Instruction: "Aggregate!",
+ Prompt: `Prompt: {{.InBaz}}
+{{range $i, $v := .OutSwarm}}#{{$i}}: {{$v}}
+{{end}}
+{{range $i, $v := .SwarmInt}}#{{$i}}: {{$v}}
+{{end}}
+{{range $i, $v := .SwarmStr}}#{{$i}}: {{$v}}
+{{end}}
+`,
+ },
),
},
})
@@ -138,19 +178,26 @@ func TestWorkflow(t *testing.T) {
},
generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
- assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText(`You are smarty. baz
-
-Use set-results tool to provide results of the analysis.
-It must be called exactly once before the final reply.
-Ignore results of this tool.
-`, genai.RoleUser))
- assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
- assert.Equal(t, len(cfg.Tools), 3)
- assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1")
- assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description")
- assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2")
- assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
- assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
+ replySeq++
+ if replySeq < 4 {
+ assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+
+ llmOutputsInstruction, genai.RoleUser))
+ assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
+ assert.Equal(t, len(cfg.Tools), 3)
+ assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1")
+ assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description")
+ assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2")
+ assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
+ assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
+ } else if replySeq < 8 {
+ assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+
+ llmOutputsInstruction, genai.RoleUser))
+ assert.Equal(t, len(cfg.Tools), 1)
+ assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results")
+ } else {
+ assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser))
+ assert.Equal(t, len(cfg.Tools), 0)
+ }
reply1 := &genai.Content{
Role: string(genai.RoleModel),
@@ -239,7 +286,42 @@ Ignore results of this tool.
},
}}
- replySeq++
+ // dupl considers makeSwarmReply/makeSwarmResp duplicates
+ // nolint:dupl
+ makeSwarmReply := func(index int) *genai.Content {
+ return &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: fmt.Sprintf("id%v", index),
+ Name: "set-results",
+ Args: map[string]any{
+ "SwarmInt": index,
+ "SwarmStr": fmt.Sprintf("swarm%v", index),
+ },
+ },
+ },
+ }}
+ }
+ // nolint:dupl // dupl considers makeSwarmReply/makeSwarmResp duplicates
+ makeSwarmResp := func(index int) *genai.Content {
+ return &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: fmt.Sprintf("id%v", index),
+ Name: "set-results",
+ Response: map[string]any{
+ "SwarmInt": index,
+ "SwarmStr": fmt.Sprintf("swarm%v", index),
+ },
+ },
+ },
+ }}
+ }
+
switch replySeq {
case 1:
assert.Equal(t, req, []*genai.Content{
@@ -270,7 +352,48 @@ Ignore results of this tool.
Parts: []*genai.Part{
genai.NewPartFromText("hello, world!")},
}}}}, nil
+ case 4, 6:
+ index := (replySeq - 2) / 2
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText("Prompt: baz", genai.RoleUser),
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: makeSwarmReply(index)}}}, nil
+ case 5, 7:
+ index := (replySeq - 3) / 2
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText("Prompt: baz", genai.RoleUser),
+ makeSwarmReply(index),
+ makeSwarmResp(index),
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText(fmt.Sprintf("swarm candidate %v", index))},
+ }}}}, nil
+ case 8:
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText(`Prompt: baz
+#0: swarm candidate 1
+#1: swarm candidate 2
+
+#0: 1
+#1: 2
+#0: swarm1
+#1: swarm2
+
+`, genai.RoleUser),
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("aggregated")},
+ }}}}, nil
default:
t.Fatal("unexpected LLM calls")
return nil, nil
@@ -310,18 +433,13 @@ Ignore results of this tool.
},
},
{
- Seq: 2,
- Nesting: 1,
- Type: trajectory.SpanAgent,
- Name: "smarty",
- Started: startTime.Add(4 * time.Second),
- Instruction: `You are smarty. baz
-
-Use set-results tool to provide results of the analysis.
-It must be called exactly once before the final reply.
-Ignore results of this tool.
-`,
- Prompt: "Prompt: baz func-output",
+ Seq: 2,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "smarty",
+ Started: startTime.Add(4 * time.Second),
+ Instruction: "You are smarty. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz func-output",
},
{
Seq: 3,
@@ -449,20 +567,15 @@ Ignore results of this tool.
Finished: startTime.Add(16 * time.Second),
},
{
- Seq: 2,
- Nesting: 1,
- Type: trajectory.SpanAgent,
- Name: "smarty",
- Started: startTime.Add(4 * time.Second),
- Finished: startTime.Add(17 * time.Second),
- Instruction: `You are smarty. baz
-
-Use set-results tool to provide results of the analysis.
-It must be called exactly once before the final reply.
-Ignore results of this tool.
-`,
- Prompt: "Prompt: baz func-output",
- Reply: "hello, world!",
+ Seq: 2,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "smarty",
+ Started: startTime.Add(4 * time.Second),
+ Finished: startTime.Add(17 * time.Second),
+ Instruction: "You are smarty. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz func-output",
+ Reply: "hello, world!",
Results: map[string]any{
"AgentBar": "agent-bar",
"AgentFoo": 42,
@@ -487,18 +600,245 @@ Ignore results of this tool.
},
},
{
+ Seq: 10,
+ Nesting: 1,
+ Type: trajectory.SpanAgentCandidates,
+ Name: "swarm",
+ Started: startTime.Add(20 * time.Second),
+ },
+ {
+ Seq: 11,
+ Nesting: 2,
+ Type: trajectory.SpanAgent,
+ Name: "swarm",
+ Started: startTime.Add(21 * time.Second),
+ Instruction: "Do something. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz",
+ },
+ {
+ Seq: 12,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(22 * time.Second),
+ },
+ {
+ Seq: 12,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(22 * time.Second),
+ Finished: startTime.Add(23 * time.Second),
+ },
+ {
+ Seq: 13,
+ Nesting: 3,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(24 * time.Second),
+ Args: map[string]any{
+ "SwarmInt": 1,
+ "SwarmStr": "swarm1",
+ },
+ },
+ {
+ Seq: 13,
+ Nesting: 3,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(24 * time.Second),
+ Finished: startTime.Add(25 * time.Second),
+ Args: map[string]any{
+ "SwarmInt": 1,
+ "SwarmStr": "swarm1",
+ },
+ Results: map[string]any{
+ "SwarmInt": 1,
+ "SwarmStr": "swarm1",
+ },
+ },
+ {
+ Seq: 14,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(26 * time.Second),
+ },
+ {
+ Seq: 14,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(26 * time.Second),
+ Finished: startTime.Add(27 * time.Second),
+ },
+ {
+ Seq: 11,
+ Nesting: 2,
+ Type: trajectory.SpanAgent,
+ Name: "swarm",
+ Started: startTime.Add(21 * time.Second),
+ Finished: startTime.Add(28 * time.Second),
+ Instruction: "Do something. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz",
+ Reply: "swarm candidate 1",
+ Results: map[string]any{
+ "SwarmInt": 1,
+ "SwarmStr": "swarm1",
+ },
+ },
+ {
+ Seq: 15,
+ Nesting: 2,
+ Type: trajectory.SpanAgent,
+ Name: "swarm",
+ Started: startTime.Add(29 * time.Second),
+ Instruction: "Do something. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz",
+ },
+ {
+ Seq: 16,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(30 * time.Second),
+ },
+ {
+ Seq: 16,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(30 * time.Second),
+ Finished: startTime.Add(31 * time.Second),
+ },
+ {
+ Seq: 17,
+ Nesting: 3,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(32 * time.Second),
+ Args: map[string]any{
+ "SwarmInt": 2,
+ "SwarmStr": "swarm2",
+ },
+ },
+ {
+ Seq: 17,
+ Nesting: 3,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(32 * time.Second),
+ Finished: startTime.Add(33 * time.Second),
+ Args: map[string]any{
+ "SwarmInt": 2,
+ "SwarmStr": "swarm2",
+ },
+ Results: map[string]any{
+ "SwarmInt": 2,
+ "SwarmStr": "swarm2",
+ },
+ },
+ {
+ Seq: 18,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(34 * time.Second),
+ },
+ {
+ Seq: 18,
+ Nesting: 3,
+ Type: trajectory.SpanLLM,
+ Name: "swarm",
+ Started: startTime.Add(34 * time.Second),
+ Finished: startTime.Add(35 * time.Second),
+ },
+ {
+ Seq: 15,
+ Nesting: 2,
+ Type: trajectory.SpanAgent,
+ Name: "swarm",
+ Started: startTime.Add(29 * time.Second),
+ Finished: startTime.Add(36 * time.Second),
+ Instruction: "Do something. baz" + llmOutputsInstruction,
+ Prompt: "Prompt: baz",
+ Reply: "swarm candidate 2",
+ Results: map[string]any{
+ "SwarmInt": 2,
+ "SwarmStr": "swarm2",
+ },
+ },
+ {
+ Seq: 10,
+ Nesting: 1,
+ Type: trajectory.SpanAgentCandidates,
+ Name: "swarm",
+ Started: startTime.Add(20 * time.Second),
+ Finished: startTime.Add(37 * time.Second),
+ },
+ {
+ Seq: 19,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "aggregator",
+ Started: startTime.Add(38 * time.Second),
+ Instruction: "Aggregate!",
+ Prompt: `Prompt: baz
+#0: swarm candidate 1
+#1: swarm candidate 2
+
+#0: 1
+#1: 2
+
+#0: swarm1
+#1: swarm2
+
+`,
+ },
+ {
+ Seq: 20,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "aggregator",
+ Started: startTime.Add(39 * time.Second),
+ },
+ {
+ Seq: 20,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "aggregator",
+ Started: startTime.Add(39 * time.Second),
+ Finished: startTime.Add(40 * time.Second),
+ },
+ {
+ Seq: 19,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "aggregator",
+ Started: startTime.Add(38 * time.Second),
+ Finished: startTime.Add(41 * time.Second),
+ Instruction: "Aggregate!",
+ Prompt: `Prompt: baz
+#0: swarm candidate 1
+#1: swarm candidate 2
+
+#0: 1
+#1: 2
+
+#0: swarm1
+#1: swarm2
+
+`,
+ Reply: "aggregated",
+ },
+ {
Seq: 0,
Nesting: 0,
Type: trajectory.SpanFlow,
Name: "test-flow",
Started: startTime.Add(1 * time.Second),
- Finished: startTime.Add(20 * time.Second),
- Results: map[string]any{
- "AgentFoo": 42,
- "OutBar": 142,
- "OutBaz": "baz",
- "OutFoo": "hello, world!",
- },
+ Finished: startTime.Add(42 * time.Second),
+ Results: expectedOutputs,
},
}
onEvent := func(span *trajectory.Span) error {
@@ -509,12 +849,8 @@ Ignore results of this tool.
}
res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
require.NoError(t, err)
- require.Equal(t, res, map[string]any{
- "OutFoo": "hello, world!",
- "OutBar": 142,
- "OutBaz": "baz",
- "AgentFoo": 42,
- })
+ require.Equal(t, replySeq, 8)
+ require.Equal(t, res, expectedOutputs)
require.Empty(t, expected)
}