diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-16 20:48:47 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 21:12:57 +0000 |
| commit | 4dc35ec28780d6a78e8afcf2650d4ada4fcd245c (patch) | |
| tree | 2d230546858e301914fc8f3d92fb83935ba7a796 /pkg | |
| parent | 91e26ec437abcd42a8255aa88e31b45da059529e (diff) | |
pkg/aflow: handle common LLM mis-behaviors wrt tool calling
Gracefully handle (reply to LLM with error):
- incorrect tool name
- incorrect tool arg type
- missing tool arg
Silently handle:
- more than one call to set-results
- excessive tool args
Fixes #6604
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/flow.go | 4 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 238 | ||||
| -rw-r--r-- | pkg/aflow/func_action.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/func_tool.go | 10 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 45 | ||||
| -rw-r--r-- | pkg/aflow/schema.go | 22 |
6 files changed, 302 insertions, 19 deletions
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go index a391f5a01..d1bbe4c69 100644 --- a/pkg/aflow/flow.go +++ b/pkg/aflow/flow.go @@ -57,12 +57,12 @@ func register[Inputs, Outputs any](typ ai.WorkflowType, description string, Type: typ, Description: description, checkInputs: func(inputs map[string]any) error { - _, err := convertFromMap[Inputs](inputs, false) + _, err := convertFromMap[Inputs](inputs, false, false) return err }, extractOutputs: func(state map[string]any) map[string]any { // Ensure that we actually have all outputs. - tmp, err := convertFromMap[Outputs](state, false) + tmp, err := convertFromMap[Outputs](state, false, false) if err != nil { panic(err) } diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 3ee3f545b..795b52d6e 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -916,7 +916,7 @@ func TestNoInputs(t *testing.T) { onEvent := func(span *trajectory.Span) error { return nil } _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.Equal(t, err.Error(), "flow inputs are missing:"+ - " field InBar is not present when converting map to aflow.flowInputs") + " field \"InBar\" is not present when converting map to aflow.flowInputs") } func TestQuotaResetTime(t *testing.T) { @@ -940,3 +940,239 @@ func TestQuotaResetTime(t *testing.T) { assert.Equal(t, test.reset, got, "when: %v", test.when) } } + +func TestToolMisbehavior(t *testing.T) { + type flowOutputs struct { + Reply string + AdditionalOutput int + } + type tool1Args struct { + Tool1Arg string `jsonschema:"arg"` + } + type tool2Args struct { + Tool2Arg int `jsonschema:"arg"` + } + type tool2Results struct { + Result int `jsonschema:"arg"` + } + flows := make(map[string]*Flow) + err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{ + { + Name: "flow", + Root: NewPipeline( + &LLMAgent{ + Name: "smarty", + Model: "model1", + Temperature: 1, + Reply: "Reply", + + Outputs: LLMOutputs[struct { + AdditionalOutput int `jsonschema:"arg"` + }](), + Instruction: "Do something!", + Prompt: "Data", + Tools: []Tool{ + NewFuncTool("tool1", func(ctx *Context, state struct{}, args tool1Args) (struct{}, error) { + return struct{}{}, nil + }, "tool description"), + NewFuncTool("tool2", func(ctx *Context, state struct{}, args tool2Args) (tool2Results, error) { + return tool2Results{42}, nil + }, "tool description"), + }, + }, + ), + }, + }) + require.NoError(t, err) + replySeq := 0 + stub := &stubContext{ + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + replySeq++ + switch replySeq { + case 1: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + // This tool call is OK, and the tool must be called. + { + FunctionCall: &genai.FunctionCall{ + ID: "id1", + Name: "tool1", + Args: map[string]any{ + "Tool1Arg": "string", + }, + }, + }, + // Incorrect argument type. + { + FunctionCall: &genai.FunctionCall{ + ID: "id2", + Name: "tool2", + Args: map[string]any{ + "Tool2Arg": "string-instead-of-int", + }, + }, + }, + // Missing argument. + { + FunctionCall: &genai.FunctionCall{ + ID: "id3", + Name: "tool2", + Args: map[string]any{}, + }, + }, + // Excessive argument. + { + FunctionCall: &genai.FunctionCall{ + ID: "id4", + Name: "tool2", + Args: map[string]any{ + "Tool2Arg": 0, + "Tool2Arg2": 100, + }, + }, + }, + // Tool that does not exist. + { + FunctionCall: &genai.FunctionCall{ + ID: "id5", + Name: "tool3", + Args: map[string]any{ + "Arg": 0, + }, + }, + }, + // Wrong arg for set-results (should not count as it was called). + { + FunctionCall: &genai.FunctionCall{ + ID: "id6", + Name: "set-results", + Args: map[string]any{ + "WrongArg": 0, + }, + }, + }, + }}}}}, nil + case 2: + require.Equal(t, len(req), 3) + assert.Equal(t, req[2], &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id1", + Name: "tool1", + Response: map[string]any{}, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id2", + Name: "tool2", + Response: map[string]any{ + "error": "argument \"Tool2Arg\" has wrong type: got string, want int", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id3", + Name: "tool2", + Response: map[string]any{ + "error": "missing argument \"Tool2Arg\"", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id4", + Name: "tool2", + Response: map[string]any{ + "Result": 42, + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id5", + Name: "tool3", + Response: map[string]any{ + "error": "tool \"tool3\" does not exist, please correct the name", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id6", + Name: "set-results", + Response: map[string]any{ + "error": "missing argument \"AdditionalOutput\"", + }, + }, + }, + }}) + // Now it tries to provide the final result w/o calling set-results (successfully). + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("I am done")}, + }}}}, nil + case 3: + // Reply that set-results wasn't called. + require.Equal(t, len(req), 5) + assert.Equal(t, req[4], genai.NewContentFromText(llmMissingOutputs, genai.RoleUser)) + // Now call it twice. + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id1", + Name: "set-results", + Args: map[string]any{ + "AdditionalOutput": 1, + }, + }, + }, + { + FunctionCall: &genai.FunctionCall{ + ID: "id2", + Name: "set-results", + Args: map[string]any{ + "AdditionalOutput": 2, + }, + }, + }, + }}}}}, nil + case 4: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("Finally done")}, + }}}}, nil + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + res, err := flows["test-flow"].Execute(ctx, "", workdir, map[string]any{}, cache, onEvent) + require.NoError(t, err) + require.Equal(t, replySeq, 4) + require.Equal(t, res, map[string]any{ + "Reply": "Finally done", + "AdditionalOutput": 2, + }) +} diff --git a/pkg/aflow/func_action.go b/pkg/aflow/func_action.go index a54579320..b2906500f 100644 --- a/pkg/aflow/func_action.go +++ b/pkg/aflow/func_action.go @@ -22,7 +22,7 @@ type funcAction[Args, Results any] struct { } func (a *funcAction[Args, Results]) execute(ctx *Context) error { - args, err := convertFromMap[Args](ctx.state, false) + args, err := convertFromMap[Args](ctx.state, false, false) if err != nil { return err } diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go index cd069db84..48b47b1e5 100644 --- a/pkg/aflow/func_tool.go +++ b/pkg/aflow/func_tool.go @@ -40,11 +40,17 @@ func (t *funcTool[State, Args, Results]) declaration() *genai.FunctionDeclaratio } func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]any) (map[string]any, error) { - state, err := convertFromMap[State](ctx.state, false) + state, err := convertFromMap[State](ctx.state, false, false) if err != nil { return nil, err } - a, err := convertFromMap[Args](args, true) + // We parse args in non-strict mode too. + // LLM shouldn't provide excessive args, but they are known to mess up things + // in all possible ways occasionally. Generally we want to handle such cases + // in some way, rather than fail the whole workflow. We could reply to it + // with an error about this, but it's unclear if the additional round-trip + // worth it, it already provided all the actual arguments. + a, err := convertFromMap[Args](args, false, true) if err != nil { return nil, err } diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index d5e4d6d4d..b473c9e7a 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -93,6 +93,10 @@ It must be called exactly once before the final reply. Ignore results of this tool. ` +const llmMissingOutputs = `You did not call set-results tool. +Please call set-results tool to provide results of the analysis. +` + type llmOutputs struct { tool Tool provideOutputs func(*verifyContext, string, bool) @@ -180,12 +184,15 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if err := ctx.finishSpan(reqSpan, respErr); err != nil { return "", nil, err } + req = append(req, resp.Candidates[0].Content) if len(calls) == 0 { // This is the final reply. - if a.Outputs != nil && outputs == nil { - return "", nil, fmt.Errorf("LLM did not call tool to set outputs") + if a.Outputs == nil || outputs != nil { + return reply, outputs, nil } - return reply, outputs, nil + // LLM did not call set-results. + req = append(req, genai.NewContentFromText(llmMissingOutputs, genai.RoleUser)) + continue } // This is not the final reply, LLM asked to execute some tools. // Append the current reply, and tool responses to the next request. @@ -193,11 +200,10 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if err != nil { return "", nil, err } - if outputs != nil && outputs1 != nil { - return "", nil, fmt.Errorf("LLM called outputs tool twice") - } + // Overwrite previous outputs, if LLM calls the tool more than once. + // It shouldn't, but this seems to be the easiest way to handle it gracefully. outputs = outputs1 - req = append(req, resp.Candidates[0].Content, responses) + req = append(req, responses) } } @@ -231,16 +237,35 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai } var outputs map[string]any for _, call := range calls { + appendPart := func(results map[string]any) { + responses.Parts = append(responses.Parts, &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + ID: call.ID, + Name: call.Name, + Response: results, + }, + }) + } tool := tools[call.Name] if tool == nil { - return nil, nil, fmt.Errorf("no tool %q", call.Name) + appendPart(map[string]any{ + "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name), + }) + continue } results, err := tool.execute(ctx, call.Args) if err != nil { + if argsErr := new(toolArgsError); errors.As(err, &argsErr) { + // LLM provided wrong arguments to the tool, + // return the error back to the LLM instead of failing. + appendPart(map[string]any{ + "error": err.Error(), + }) + continue + } return nil, nil, err } - responses.Parts = append(responses.Parts, genai.NewPartFromFunctionResponse(call.Name, results)) - responses.Parts[len(responses.Parts)-1].FunctionResponse.ID = call.ID + appendPart(results) if a.Outputs != nil && tool == a.Outputs.tool { outputs = results } diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go index e34d465ea..2b2d77f76 100644 --- a/pkg/aflow/schema.go +++ b/pkg/aflow/schema.go @@ -54,13 +54,19 @@ func convertToMap[T any](val T) map[string]any { // convertFromMap converts an untyped map to a struct. // It always ensures that all struct fields are present in the map. // In the strict mode it also checks that the map does not contain any other unused elements. -func convertFromMap[T any](m map[string]any, strict bool) (T, error) { +// If tool is set, return errors in the form suitable to return back to LLM +// during tool arguments conversion. +func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { m = maps.Clone(m) var val T for name, field := range foreachField(&val) { f, ok := m[name] if !ok { - return val, fmt.Errorf("field %v is not present when converting map to %T", name, val) + if tool { + return val, &toolArgsError{fmt.Errorf("missing argument %q", name)} + } else { + return val, fmt.Errorf("field %q is not present when converting map to %T", name, val) + } } delete(m, name) if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() { @@ -69,8 +75,16 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) { return val, err } field.Set(reflect.ValueOf(json.RawMessage(raw))) - } else { + } else if field.Type() == reflect.TypeOf(f) { field.Set(reflect.ValueOf(f)) + } else { + if tool { + return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v", + name, f, field.Type().Name())} + } else { + return val, fmt.Errorf("field %q has wrong type: got %T, want %v", + name, f, field.Type().Name()) + } } } if strict && len(m) != 0 { @@ -79,6 +93,8 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) { return val, nil } +type toolArgsError struct{ error } + // foreachField iterates over all public fields of the struct provided in data. func foreachField(data any) iter.Seq2[string, reflect.Value] { return func(yield func(string, reflect.Value) bool) { |
