aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/llm_agent.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/llm_agent.go')
-rw-r--r--pkg/aflow/llm_agent.go96
1 files changed, 79 insertions, 17 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index b897643c7..c30143425 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -27,6 +27,9 @@ type LLMAgent struct {
// while higher temperatures can lead to more diverse or creative results.
// Must be assigned a float32 value in the range [0, 2].
Temperature any
+ // If set, the agent will generate that many candidates and the outputs will be arrays
+ // instead of scalars.
+ Candidates int
// Instructions for the agent.
// Formatted as text/template, can use "{{.Variable}}" as placeholders for dynamic content.
// Variables can come from the workflow inputs, or from preceding actions outputs.
@@ -51,25 +54,77 @@ func LLMOutputs[Args any]() *llmOutputs {
tool: NewFuncTool("set-results", func(ctx *Context, state struct{}, args Args) (Args, error) {
return args, nil
}, "Use this tool to provide results of the analysis."),
- provideOutputs: func(ctx *verifyContext, who string) {
- provideOutputs[Args](ctx, who)
+ provideOutputs: func(ctx *verifyContext, who string, many bool) {
+ if many {
+ provideArrayOutputs[Args](ctx, who)
+ } else {
+ provideOutputs[Args](ctx, who)
+ }
+ },
+ append: func(to, from map[string]any) {
+ for name, typ := range foreachFieldOf[Args]() {
+ if to[name] == nil {
+ to[name] = reflect.Zero(reflect.SliceOf(typ)).Interface()
+ }
+ to[name] = reflect.Append(reflect.ValueOf(to[name]), reflect.ValueOf(from[name])).Interface()
+ }
},
- instruction: `
+ }
+}
+
+const llmOutputsInstruction = `
Use set-results tool to provide results of the analysis.
It must be called exactly once before the final reply.
Ignore results of this tool.
-`,
- }
-}
+`
type llmOutputs struct {
tool Tool
- provideOutputs func(*verifyContext, string)
- instruction string
+ provideOutputs func(*verifyContext, string, bool)
+ append func(map[string]any, map[string]any)
}
func (a *LLMAgent) execute(ctx *Context) error {
+ if a.Candidates <= 1 {
+ reply, outputs, err := a.executeOne(ctx)
+ if err != nil {
+ return err
+ }
+ ctx.state[a.Reply] = reply
+ maps.Insert(ctx.state, maps.All(outputs))
+ return nil
+ }
+ span := &trajectory.Span{
+ Type: trajectory.SpanAgentCandidates,
+ Name: a.Name,
+ }
+ if err := ctx.startSpan(span); err != nil {
+ return err
+ }
+ err := a.executeMany(ctx)
+ return ctx.finishSpan(span, err)
+}
+
+func (a *LLMAgent) executeMany(ctx *Context) error {
+ var replies []string
+ allOutputs := map[string]any{}
+ for candidate := 0; candidate < a.Candidates; candidate++ {
+ reply, outputs, err := a.executeOne(ctx)
+ if err != nil {
+ return err
+ }
+ replies = append(replies, reply)
+ if a.Outputs != nil {
+ a.Outputs.append(allOutputs, outputs)
+ }
+ }
+ ctx.state[a.Reply] = replies
+ maps.Insert(ctx.state, maps.All(allOutputs))
+ return nil
+}
+
+func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
cfg, instruction, tools := a.config(ctx)
span := &trajectory.Span{
Type: trajectory.SpanAgent,
@@ -78,12 +133,14 @@ func (a *LLMAgent) execute(ctx *Context) error {
Prompt: formatTemplate(a.Prompt, ctx.state),
}
if err := ctx.startSpan(span); err != nil {
- return err
+ return "", nil, err
}
reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt)
- span.Reply = reply
- span.Results = outputs
- return ctx.finishSpan(span, err)
+ if err == nil {
+ span.Reply = reply
+ span.Results = outputs
+ }
+ return reply, outputs, ctx.finishSpan(span, err)
}
func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) (
@@ -112,8 +169,6 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if a.Outputs != nil && outputs == nil {
return "", nil, fmt.Errorf("LLM did not call tool to set outputs")
}
- ctx.state[a.Reply] = reply
- maps.Insert(ctx.state, maps.All(outputs))
return reply, outputs, nil
}
// This is not the final reply, LLM asked to execute some tools.
@@ -134,7 +189,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
instruction := formatTemplate(a.Instruction, ctx.state)
toolList := a.Tools
if a.Outputs != nil {
- instruction += a.Outputs.instruction
+ instruction += llmOutputsInstruction
toolList = append(toolList, a.Outputs.tool)
}
toolMap := make(map[string]Tool)
@@ -225,6 +280,9 @@ func (a *LLMAgent) verify(vctx *verifyContext) {
if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
vctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
}
+ if a.Candidates < 0 || a.Candidates > 100 {
+ vctx.errorf(a.Name, "Candidates must be in the range [0, 100]")
+ }
// Verify dataflow. All dynamic variables must be provided by inputs,
// or preceding actions.
a.verifyTemplate(vctx, "Instruction", a.Instruction)
@@ -232,9 +290,13 @@ func (a *LLMAgent) verify(vctx *verifyContext) {
for _, tool := range a.Tools {
tool.verify(vctx)
}
- vctx.provideOutput(a.Name, a.Reply, reflect.TypeFor[string](), true)
+ replyType := reflect.TypeFor[string]()
+ if a.Candidates > 1 {
+ replyType = reflect.TypeFor[[]string]()
+ }
+ vctx.provideOutput(a.Name, a.Reply, replyType, true)
if a.Outputs != nil {
- a.Outputs.provideOutputs(vctx, a.Name)
+ a.Outputs.provideOutputs(vctx, a.Name, a.Candidates > 1)
}
}