diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 16:07:06 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-21 13:38:45 +0000 |
| commit | 42b96d8b5320697eba2adb686c3206f1d50e14a2 (patch) | |
| tree | aa5dc57eb3cac1ff21d2ddb1f8965ecfa5fbbd4b /pkg | |
| parent | cb6c392c6d56f142bdfe88cd3fa4cdc8f036b460 (diff) | |
pkg/aflow: inject tool errors into trajectory
Currently we handle several errors in LLMAgent (wrong tool name,
wrong tool arguments), and return the error to LLM,
but nothing is injected into the trajectory wrt what happened.
This makes trajectory incomplete and confusing,
one just sees repeated LLM calls w/o understanding what caused them.
Inject these tool failures into the trace, so that it's clear
what happened.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/flow_test.go | 94 | ||||
| -rw-r--r-- | pkg/aflow/func_tool.go | 13 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 53 |
3 files changed, 114 insertions, 46 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 7b78d0a2d..bab3f429c 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -1226,12 +1226,46 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool2", Args: map[string]any{ + "Tool2Arg": "string-instead-of-int", + }, + }, + { + Seq: 4, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Args: map[string]any{ + "Tool2Arg": "string-instead-of-int", + }, + Error: "argument \"Tool2Arg\" has wrong type: got string, want int", + }, + { + Seq: 5, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Args: map[string]any{}, + }, + { + Seq: 5, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Args: map[string]any{}, + Error: "missing argument \"Tool2Arg\"", + }, + { + Seq: 6, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Args: map[string]any{ "Tool2Arg": 0, "Tool2Arg2": 100, }, }, { - Seq: 4, + Seq: 6, Nesting: 2, Type: trajectory.SpanTool, Name: "tool2", @@ -1244,35 +1278,73 @@ func TestToolMisbehavior(t *testing.T) { }, }, { - Seq: 5, + Seq: 7, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool3", + Args: map[string]any{ + "Arg": 0, + }, + }, + { + Seq: 7, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool3", + Args: map[string]any{ + "Arg": 0, + }, + Error: "tool \"tool3\" does not exist, please correct the name", + }, + { + Seq: 8, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "set-results", + Args: map[string]any{ + "WrongArg": 0, + }, + }, + { + Seq: 8, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "set-results", + Args: map[string]any{ + "WrongArg": 0, + }, + Error: "missing argument \"AdditionalOutput\"", + }, + { + Seq: 9, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", Model: "model", }, { - Seq: 5, + Seq: 9, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", Model: "model", }, { - Seq: 6, + Seq: 10, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", Model: "model", }, { - Seq: 6, + Seq: 10, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", Model: "model", }, { - Seq: 7, + Seq: 11, Nesting: 2, Type: trajectory.SpanTool, Name: "set-results", @@ -1281,7 +1353,7 @@ func TestToolMisbehavior(t *testing.T) { }, }, { - Seq: 7, + Seq: 11, Nesting: 2, Type: trajectory.SpanTool, Name: "set-results", @@ -1293,7 +1365,7 @@ func TestToolMisbehavior(t *testing.T) { }, }, { - Seq: 8, + Seq: 12, Nesting: 2, Type: trajectory.SpanTool, Name: "set-results", @@ -1302,7 +1374,7 @@ func TestToolMisbehavior(t *testing.T) { }, }, { - Seq: 8, + Seq: 12, Nesting: 2, Type: trajectory.SpanTool, Name: "set-results", @@ -1314,14 +1386,14 @@ func TestToolMisbehavior(t *testing.T) { }, }, { - Seq: 9, + Seq: 13, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", Model: "model", }, { - Seq: 9, + Seq: 13, Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go index dde359485..013e6a917 100644 --- a/pkg/aflow/func_tool.go +++ b/pkg/aflow/func_tool.go @@ -6,7 +6,6 @@ package aflow import ( "errors" - "github.com/google/syzkaller/pkg/aflow/trajectory" "google.golang.org/genai" ) @@ -67,18 +66,8 @@ func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]a if err != nil { return nil, err } - span := &trajectory.Span{ - Type: trajectory.SpanTool, - Name: t.Name, - Args: args, - } - if err := ctx.startSpan(span); err != nil { - return nil, err - } res, err := t.Func(ctx, state, a) - span.Results = convertToMap(res) - err = ctx.finishSpan(span, err) - return span.Results, err + return convertToMap(res), err } func (t *funcTool[State, Args, Results]) verify(ctx *verifyContext) { diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index e5b753607..e5391dfcf 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -244,38 +244,45 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai } var outputs map[string]any for _, call := range calls { - appendPart := func(results map[string]any) { - responses.Parts = append(responses.Parts, &genai.Part{ - FunctionResponse: &genai.FunctionResponse{ - ID: call.ID, - Name: call.Name, - Response: results, - }, - }) + span := &trajectory.Span{ + Type: trajectory.SpanTool, + Name: call.Name, + Args: call.Args, } - appendError := func(message string) { - appendPart(map[string]any{"error": message}) + if err := ctx.startSpan(span); err != nil { + return nil, nil, err } + toolErr := BadCallError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name)) tool := tools[call.Name] - if tool == nil { - appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name)) - continue + if tool != nil { + span.Results, toolErr = tool.execute(ctx, call.Args) } - results, err := tool.execute(ctx, call.Args) - if err != nil { + if toolErr != nil { + span.Error = toolErr.Error() + } + if err := ctx.finishSpan(span, nil); err != nil { + return nil, nil, err + } + if toolErr != nil { // LLM provided wrong arguments to the tool, // or the tool returned error message to the LLM. // Return the error back to the LLM instead of failing. - if callErr := new(badCallError); errors.As(err, &callErr) { - appendError(err.Error()) - continue + if callErr := new(badCallError); errors.As(toolErr, &callErr) { + span.Results = map[string]any{"error": toolErr.Error()} + } else { + return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v", + call.Name, toolErr, call.Args) } - return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v", - call.Name, err, call.Args) } - appendPart(results) - if a.Outputs != nil && tool == a.Outputs.tool { - outputs = results + responses.Parts = append(responses.Parts, &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + ID: call.ID, + Name: call.Name, + Response: span.Results, + }, + }) + if toolErr == nil && a.Outputs != nil && tool == a.Outputs.tool { + outputs = span.Results } } return responses, outputs, nil |
