diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-15 11:37:02 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-19 09:21:15 +0000 |
| commit | 1276f83b46b38cc241614ebc4401720f5f1fc4ab (patch) | |
| tree | edf8e8d9c9ac313d9457cebf678aea9334804f05 /pkg/aflow/llm_agent.go | |
| parent | a9fc52269b8aab60248b6e4c5366216bc2191101 (diff) | |
pkg/aflow: add ability to generate several candidate replies for LLM agents
Add LLMAgent.Candidates parameter.
If set to a value N>1, then the agent is invoked N times,
and all outputs become slices.
The results can be later aggregated by another agent,
as shown in the test.
Diffstat (limited to 'pkg/aflow/llm_agent.go')
| -rw-r--r-- | pkg/aflow/llm_agent.go | 96 |
1 files changed, 79 insertions, 17 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index b897643c7..c30143425 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -27,6 +27,9 @@ type LLMAgent struct { // while higher temperatures can lead to more diverse or creative results. // Must be assigned a float32 value in the range [0, 2]. Temperature any + // If set, the agent will generate that many candidates and the outputs will be arrays + // instead of scalars. + Candidates int // Instructions for the agent. // Formatted as text/template, can use "{{.Variable}}" as placeholders for dynamic content. // Variables can come from the workflow inputs, or from preceding actions outputs. @@ -51,25 +54,77 @@ func LLMOutputs[Args any]() *llmOutputs { tool: NewFuncTool("set-results", func(ctx *Context, state struct{}, args Args) (Args, error) { return args, nil }, "Use this tool to provide results of the analysis."), - provideOutputs: func(ctx *verifyContext, who string) { - provideOutputs[Args](ctx, who) + provideOutputs: func(ctx *verifyContext, who string, many bool) { + if many { + provideArrayOutputs[Args](ctx, who) + } else { + provideOutputs[Args](ctx, who) + } + }, + append: func(to, from map[string]any) { + for name, typ := range foreachFieldOf[Args]() { + if to[name] == nil { + to[name] = reflect.Zero(reflect.SliceOf(typ)).Interface() + } + to[name] = reflect.Append(reflect.ValueOf(to[name]), reflect.ValueOf(from[name])).Interface() + } }, - instruction: ` + } +} + +const llmOutputsInstruction = ` Use set-results tool to provide results of the analysis. It must be called exactly once before the final reply. Ignore results of this tool. -`, - } -} +` type llmOutputs struct { tool Tool - provideOutputs func(*verifyContext, string) - instruction string + provideOutputs func(*verifyContext, string, bool) + append func(map[string]any, map[string]any) } func (a *LLMAgent) execute(ctx *Context) error { + if a.Candidates <= 1 { + reply, outputs, err := a.executeOne(ctx) + if err != nil { + return err + } + ctx.state[a.Reply] = reply + maps.Insert(ctx.state, maps.All(outputs)) + return nil + } + span := &trajectory.Span{ + Type: trajectory.SpanAgentCandidates, + Name: a.Name, + } + if err := ctx.startSpan(span); err != nil { + return err + } + err := a.executeMany(ctx) + return ctx.finishSpan(span, err) +} + +func (a *LLMAgent) executeMany(ctx *Context) error { + var replies []string + allOutputs := map[string]any{} + for candidate := 0; candidate < a.Candidates; candidate++ { + reply, outputs, err := a.executeOne(ctx) + if err != nil { + return err + } + replies = append(replies, reply) + if a.Outputs != nil { + a.Outputs.append(allOutputs, outputs) + } + } + ctx.state[a.Reply] = replies + maps.Insert(ctx.state, maps.All(allOutputs)) + return nil +} + +func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { cfg, instruction, tools := a.config(ctx) span := &trajectory.Span{ Type: trajectory.SpanAgent, @@ -78,12 +133,14 @@ func (a *LLMAgent) execute(ctx *Context) error { Prompt: formatTemplate(a.Prompt, ctx.state), } if err := ctx.startSpan(span); err != nil { - return err + return "", nil, err } reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt) - span.Reply = reply - span.Results = outputs - return ctx.finishSpan(span, err) + if err == nil { + span.Reply = reply + span.Results = outputs + } + return reply, outputs, ctx.finishSpan(span, err) } func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) ( @@ -112,8 +169,6 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if a.Outputs != nil && outputs == nil { return "", nil, fmt.Errorf("LLM did not call tool to set outputs") } - ctx.state[a.Reply] = reply - maps.Insert(ctx.state, maps.All(outputs)) return reply, outputs, nil } // This is not the final reply, LLM asked to execute some tools. @@ -134,7 +189,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m instruction := formatTemplate(a.Instruction, ctx.state) toolList := a.Tools if a.Outputs != nil { - instruction += a.Outputs.instruction + instruction += llmOutputsInstruction toolList = append(toolList, a.Outputs.tool) } toolMap := make(map[string]Tool) @@ -225,6 +280,9 @@ func (a *LLMAgent) verify(vctx *verifyContext) { if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 { vctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]") } + if a.Candidates < 0 || a.Candidates > 100 { + vctx.errorf(a.Name, "Candidates must be in the range [0, 100]") + } // Verify dataflow. All dynamic variables must be provided by inputs, // or preceding actions. a.verifyTemplate(vctx, "Instruction", a.Instruction) @@ -232,9 +290,13 @@ func (a *LLMAgent) verify(vctx *verifyContext) { for _, tool := range a.Tools { tool.verify(vctx) } - vctx.provideOutput(a.Name, a.Reply, reflect.TypeFor[string](), true) + replyType := reflect.TypeFor[string]() + if a.Candidates > 1 { + replyType = reflect.TypeFor[[]string]() + } + vctx.provideOutput(a.Name, a.Reply, replyType, true) if a.Outputs != nil { - a.Outputs.provideOutputs(vctx, a.Name) + a.Outputs.provideOutputs(vctx, a.Name, a.Candidates > 1) } } |
