aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-16 20:48:47 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-20 21:12:57 +0000
commit4dc35ec28780d6a78e8afcf2650d4ada4fcd245c (patch)
tree2d230546858e301914fc8f3d92fb83935ba7a796
parent91e26ec437abcd42a8255aa88e31b45da059529e (diff)
pkg/aflow: handle common LLM mis-behaviors wrt tool calling
Gracefully handle (reply to LLM with error): - incorrect tool name - incorrect tool arg type - missing tool arg Silently handle: - more than one call to set-results - excessive tool args Fixes #6604
-rw-r--r--pkg/aflow/flow.go4
-rw-r--r--pkg/aflow/flow_test.go238
-rw-r--r--pkg/aflow/func_action.go2
-rw-r--r--pkg/aflow/func_tool.go10
-rw-r--r--pkg/aflow/llm_agent.go45
-rw-r--r--pkg/aflow/schema.go22
6 files changed, 302 insertions, 19 deletions
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
index a391f5a01..d1bbe4c69 100644
--- a/pkg/aflow/flow.go
+++ b/pkg/aflow/flow.go
@@ -57,12 +57,12 @@ func register[Inputs, Outputs any](typ ai.WorkflowType, description string,
Type: typ,
Description: description,
checkInputs: func(inputs map[string]any) error {
- _, err := convertFromMap[Inputs](inputs, false)
+ _, err := convertFromMap[Inputs](inputs, false, false)
return err
},
extractOutputs: func(state map[string]any) map[string]any {
// Ensure that we actually have all outputs.
- tmp, err := convertFromMap[Outputs](state, false)
+ tmp, err := convertFromMap[Outputs](state, false, false)
if err != nil {
panic(err)
}
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 3ee3f545b..795b52d6e 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -916,7 +916,7 @@ func TestNoInputs(t *testing.T) {
onEvent := func(span *trajectory.Span) error { return nil }
_, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.Equal(t, err.Error(), "flow inputs are missing:"+
- " field InBar is not present when converting map to aflow.flowInputs")
+ " field \"InBar\" is not present when converting map to aflow.flowInputs")
}
func TestQuotaResetTime(t *testing.T) {
@@ -940,3 +940,239 @@ func TestQuotaResetTime(t *testing.T) {
assert.Equal(t, test.reset, got, "when: %v", test.when)
}
}
+
+func TestToolMisbehavior(t *testing.T) {
+ type flowOutputs struct {
+ Reply string
+ AdditionalOutput int
+ }
+ type tool1Args struct {
+ Tool1Arg string `jsonschema:"arg"`
+ }
+ type tool2Args struct {
+ Tool2Arg int `jsonschema:"arg"`
+ }
+ type tool2Results struct {
+ Result int `jsonschema:"arg"`
+ }
+ flows := make(map[string]*Flow)
+ err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Name: "flow",
+ Root: NewPipeline(
+ &LLMAgent{
+ Name: "smarty",
+ Model: "model1",
+ Temperature: 1,
+ Reply: "Reply",
+
+ Outputs: LLMOutputs[struct {
+ AdditionalOutput int `jsonschema:"arg"`
+ }](),
+ Instruction: "Do something!",
+ Prompt: "Data",
+ Tools: []Tool{
+ NewFuncTool("tool1", func(ctx *Context, state struct{}, args tool1Args) (struct{}, error) {
+ return struct{}{}, nil
+ }, "tool description"),
+ NewFuncTool("tool2", func(ctx *Context, state struct{}, args tool2Args) (tool2Results, error) {
+ return tool2Results{42}, nil
+ }, "tool description"),
+ },
+ },
+ ),
+ },
+ })
+ require.NoError(t, err)
+ replySeq := 0
+ stub := &stubContext{
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ replySeq++
+ switch replySeq {
+ case 1:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ // This tool call is OK, and the tool must be called.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id1",
+ Name: "tool1",
+ Args: map[string]any{
+ "Tool1Arg": "string",
+ },
+ },
+ },
+ // Incorrect argument type.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id2",
+ Name: "tool2",
+ Args: map[string]any{
+ "Tool2Arg": "string-instead-of-int",
+ },
+ },
+ },
+ // Missing argument.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id3",
+ Name: "tool2",
+ Args: map[string]any{},
+ },
+ },
+ // Excessive argument.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id4",
+ Name: "tool2",
+ Args: map[string]any{
+ "Tool2Arg": 0,
+ "Tool2Arg2": 100,
+ },
+ },
+ },
+ // Tool that does not exist.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id5",
+ Name: "tool3",
+ Args: map[string]any{
+ "Arg": 0,
+ },
+ },
+ },
+ // Wrong arg for set-results (should not count as it was called).
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id6",
+ Name: "set-results",
+ Args: map[string]any{
+ "WrongArg": 0,
+ },
+ },
+ },
+ }}}}}, nil
+ case 2:
+ require.Equal(t, len(req), 3)
+ assert.Equal(t, req[2], &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id1",
+ Name: "tool1",
+ Response: map[string]any{},
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id2",
+ Name: "tool2",
+ Response: map[string]any{
+ "error": "argument \"Tool2Arg\" has wrong type: got string, want int",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id3",
+ Name: "tool2",
+ Response: map[string]any{
+ "error": "missing argument \"Tool2Arg\"",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id4",
+ Name: "tool2",
+ Response: map[string]any{
+ "Result": 42,
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id5",
+ Name: "tool3",
+ Response: map[string]any{
+ "error": "tool \"tool3\" does not exist, please correct the name",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id6",
+ Name: "set-results",
+ Response: map[string]any{
+ "error": "missing argument \"AdditionalOutput\"",
+ },
+ },
+ },
+ }})
+ // Now it tries to provide the final result w/o calling set-results (successfully).
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("I am done")},
+ }}}}, nil
+ case 3:
+ // Reply that set-results wasn't called.
+ require.Equal(t, len(req), 5)
+ assert.Equal(t, req[4], genai.NewContentFromText(llmMissingOutputs, genai.RoleUser))
+ // Now call it twice.
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id1",
+ Name: "set-results",
+ Args: map[string]any{
+ "AdditionalOutput": 1,
+ },
+ },
+ },
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id2",
+ Name: "set-results",
+ Args: map[string]any{
+ "AdditionalOutput": 2,
+ },
+ },
+ },
+ }}}}}, nil
+ case 4:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("Finally done")},
+ }}}}, nil
+ default:
+ t.Fatal("unexpected LLM calls")
+ return nil, nil
+ }
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ onEvent := func(span *trajectory.Span) error { return nil }
+ res, err := flows["test-flow"].Execute(ctx, "", workdir, map[string]any{}, cache, onEvent)
+ require.NoError(t, err)
+ require.Equal(t, replySeq, 4)
+ require.Equal(t, res, map[string]any{
+ "Reply": "Finally done",
+ "AdditionalOutput": 2,
+ })
+}
diff --git a/pkg/aflow/func_action.go b/pkg/aflow/func_action.go
index a54579320..b2906500f 100644
--- a/pkg/aflow/func_action.go
+++ b/pkg/aflow/func_action.go
@@ -22,7 +22,7 @@ type funcAction[Args, Results any] struct {
}
func (a *funcAction[Args, Results]) execute(ctx *Context) error {
- args, err := convertFromMap[Args](ctx.state, false)
+ args, err := convertFromMap[Args](ctx.state, false, false)
if err != nil {
return err
}
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go
index cd069db84..48b47b1e5 100644
--- a/pkg/aflow/func_tool.go
+++ b/pkg/aflow/func_tool.go
@@ -40,11 +40,17 @@ func (t *funcTool[State, Args, Results]) declaration() *genai.FunctionDeclaratio
}
func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]any) (map[string]any, error) {
- state, err := convertFromMap[State](ctx.state, false)
+ state, err := convertFromMap[State](ctx.state, false, false)
if err != nil {
return nil, err
}
- a, err := convertFromMap[Args](args, true)
+ // We parse args in non-strict mode too.
+ // LLM shouldn't provide excessive args, but they are known to mess up things
+ // in all possible ways occasionally. Generally we want to handle such cases
+ // in some way, rather than fail the whole workflow. We could reply to it
+ // with an error about this, but it's unclear if the additional round-trip
+ // worth it, it already provided all the actual arguments.
+ a, err := convertFromMap[Args](args, false, true)
if err != nil {
return nil, err
}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index d5e4d6d4d..b473c9e7a 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -93,6 +93,10 @@ It must be called exactly once before the final reply.
Ignore results of this tool.
`
+const llmMissingOutputs = `You did not call set-results tool.
+Please call set-results tool to provide results of the analysis.
+`
+
type llmOutputs struct {
tool Tool
provideOutputs func(*verifyContext, string, bool)
@@ -180,12 +184,15 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
return "", nil, err
}
+ req = append(req, resp.Candidates[0].Content)
if len(calls) == 0 {
// This is the final reply.
- if a.Outputs != nil && outputs == nil {
- return "", nil, fmt.Errorf("LLM did not call tool to set outputs")
+ if a.Outputs == nil || outputs != nil {
+ return reply, outputs, nil
}
- return reply, outputs, nil
+ // LLM did not call set-results.
+ req = append(req, genai.NewContentFromText(llmMissingOutputs, genai.RoleUser))
+ continue
}
// This is not the final reply, LLM asked to execute some tools.
// Append the current reply, and tool responses to the next request.
@@ -193,11 +200,10 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err != nil {
return "", nil, err
}
- if outputs != nil && outputs1 != nil {
- return "", nil, fmt.Errorf("LLM called outputs tool twice")
- }
+ // Overwrite previous outputs, if LLM calls the tool more than once.
+ // It shouldn't, but this seems to be the easiest way to handle it gracefully.
outputs = outputs1
- req = append(req, resp.Candidates[0].Content, responses)
+ req = append(req, responses)
}
}
@@ -231,16 +237,35 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
}
var outputs map[string]any
for _, call := range calls {
+ appendPart := func(results map[string]any) {
+ responses.Parts = append(responses.Parts, &genai.Part{
+ FunctionResponse: &genai.FunctionResponse{
+ ID: call.ID,
+ Name: call.Name,
+ Response: results,
+ },
+ })
+ }
tool := tools[call.Name]
if tool == nil {
- return nil, nil, fmt.Errorf("no tool %q", call.Name)
+ appendPart(map[string]any{
+ "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name),
+ })
+ continue
}
results, err := tool.execute(ctx, call.Args)
if err != nil {
+ if argsErr := new(toolArgsError); errors.As(err, &argsErr) {
+ // LLM provided wrong arguments to the tool,
+ // return the error back to the LLM instead of failing.
+ appendPart(map[string]any{
+ "error": err.Error(),
+ })
+ continue
+ }
return nil, nil, err
}
- responses.Parts = append(responses.Parts, genai.NewPartFromFunctionResponse(call.Name, results))
- responses.Parts[len(responses.Parts)-1].FunctionResponse.ID = call.ID
+ appendPart(results)
if a.Outputs != nil && tool == a.Outputs.tool {
outputs = results
}
diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go
index e34d465ea..2b2d77f76 100644
--- a/pkg/aflow/schema.go
+++ b/pkg/aflow/schema.go
@@ -54,13 +54,19 @@ func convertToMap[T any](val T) map[string]any {
// convertFromMap converts an untyped map to a struct.
// It always ensures that all struct fields are present in the map.
// In the strict mode it also checks that the map does not contain any other unused elements.
-func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
+// If tool is set, return errors in the form suitable to return back to LLM
+// during tool arguments conversion.
+func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
m = maps.Clone(m)
var val T
for name, field := range foreachField(&val) {
f, ok := m[name]
if !ok {
- return val, fmt.Errorf("field %v is not present when converting map to %T", name, val)
+ if tool {
+ return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
+ } else {
+ return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
+ }
}
delete(m, name)
if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() {
@@ -69,8 +75,16 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
return val, err
}
field.Set(reflect.ValueOf(json.RawMessage(raw)))
- } else {
+ } else if field.Type() == reflect.TypeOf(f) {
field.Set(reflect.ValueOf(f))
+ } else {
+ if tool {
+ return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name())}
+ } else {
+ return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name())
+ }
}
}
if strict && len(m) != 0 {
@@ -79,6 +93,8 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
return val, nil
}
+type toolArgsError struct{ error }
+
// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
return func(yield func(string, reflect.Value) bool) {