aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/llm_agent.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/llm_agent.go')
-rw-r--r--pkg/aflow/llm_agent.go45
1 files changed, 35 insertions, 10 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index d5e4d6d4d..b473c9e7a 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -93,6 +93,10 @@ It must be called exactly once before the final reply.
Ignore results of this tool.
`
+const llmMissingOutputs = `You did not call set-results tool.
+Please call set-results tool to provide results of the analysis.
+`
+
type llmOutputs struct {
tool Tool
provideOutputs func(*verifyContext, string, bool)
@@ -180,12 +184,15 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
return "", nil, err
}
+ req = append(req, resp.Candidates[0].Content)
if len(calls) == 0 {
// This is the final reply.
- if a.Outputs != nil && outputs == nil {
- return "", nil, fmt.Errorf("LLM did not call tool to set outputs")
+ if a.Outputs == nil || outputs != nil {
+ return reply, outputs, nil
}
- return reply, outputs, nil
+ // LLM did not call set-results.
+ req = append(req, genai.NewContentFromText(llmMissingOutputs, genai.RoleUser))
+ continue
}
// This is not the final reply, LLM asked to execute some tools.
// Append the current reply, and tool responses to the next request.
@@ -193,11 +200,10 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err != nil {
return "", nil, err
}
- if outputs != nil && outputs1 != nil {
- return "", nil, fmt.Errorf("LLM called outputs tool twice")
- }
+ // Overwrite previous outputs, if LLM calls the tool more than once.
+ // It shouldn't, but this seems to be the easiest way to handle it gracefully.
outputs = outputs1
- req = append(req, resp.Candidates[0].Content, responses)
+ req = append(req, responses)
}
}
@@ -231,16 +237,35 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
}
var outputs map[string]any
for _, call := range calls {
+ appendPart := func(results map[string]any) {
+ responses.Parts = append(responses.Parts, &genai.Part{
+ FunctionResponse: &genai.FunctionResponse{
+ ID: call.ID,
+ Name: call.Name,
+ Response: results,
+ },
+ })
+ }
tool := tools[call.Name]
if tool == nil {
- return nil, nil, fmt.Errorf("no tool %q", call.Name)
+ appendPart(map[string]any{
+ "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name),
+ })
+ continue
}
results, err := tool.execute(ctx, call.Args)
if err != nil {
+ if argsErr := new(toolArgsError); errors.As(err, &argsErr) {
+ // LLM provided wrong arguments to the tool,
+ // return the error back to the LLM instead of failing.
+ appendPart(map[string]any{
+ "error": err.Error(),
+ })
+ continue
+ }
return nil, nil, err
}
- responses.Parts = append(responses.Parts, genai.NewPartFromFunctionResponse(call.Name, results))
- responses.Parts[len(responses.Parts)-1].FunctionResponse.ID = call.ID
+ appendPart(results)
if a.Outputs != nil && tool == a.Outputs.tool {
outputs = results
}