diff options
Diffstat (limited to 'pkg/aflow')
| -rw-r--r-- | pkg/aflow/func_tool.go | 13 | ||||
| -rw-r--r-- | pkg/aflow/func_tool_test.go | 111 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 21 | ||||
| -rw-r--r-- | pkg/aflow/schema.go | 8 | ||||
| -rw-r--r-- | pkg/aflow/tool/codesearcher/codesearcher.go | 39 |
5 files changed, 150 insertions, 42 deletions
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go index 48b47b1e5..dde359485 100644 --- a/pkg/aflow/func_tool.go +++ b/pkg/aflow/func_tool.go @@ -4,6 +4,8 @@ package aflow import ( + "errors" + "github.com/google/syzkaller/pkg/aflow/trajectory" "google.golang.org/genai" ) @@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, } } +// BadCallError creates an error that means that LLM made a bad tool call, +// the provided message will be returned to the LLM as an error, +// instead of failing the whole workflow. +func BadCallError(message string) error { + return &badCallError{errors.New(message)} +} + +type badCallError struct { + error +} + type funcTool[State, Args, Results any] struct { Name string Description string diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go new file mode 100644 index 000000000..429566dbe --- /dev/null +++ b/pkg/aflow/func_tool_test.go @@ -0,0 +1,111 @@ +// Copyright 2026 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestToolErrors(t *testing.T) { + type flowOutputs struct { + Reply string + } + type toolArgs struct { + CallError bool `jsonschema:"call error"` + } + flows := make(map[string]*Flow) + err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{ + { + Root: &LLMAgent{ + Name: "smarty", + Model: "model", + Reply: "Reply", + Temperature: 0, + Instruction: "Do something!", + Prompt: "Prompt", + Tools: []Tool{ + NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) { + if args.CallError { + return struct{}{}, BadCallError("you are wrong") + } + return struct{}{}, errors.New("hard error") + }, "tool 1 description"), + }, + }, + }, + }) + require.NoError(t, err) + replySeq := 0 + stub := &stubContext{ + // nolint:dupl + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + replySeq++ + switch replySeq { + case 1: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": true, + }, + }, + }, + }}}}}, nil + case 2: + assert.Equal(t, req[2], &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id0", + Name: "faulty", + Response: map[string]any{ + "error": "you are wrong", + }, + }, + }}}) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": false, + }, + }, + }, + }}}}}, nil + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent) + require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]") +} diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index b473c9e7a..5934bf9bd 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai }, }) } + appendError := func(message string) { + appendPart(map[string]any{"error": message}) + } tool := tools[call.Name] if tool == nil { - appendPart(map[string]any{ - "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name), - }) + appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name)) continue } results, err := tool.execute(ctx, call.Args) if err != nil { - if argsErr := new(toolArgsError); errors.As(err, &argsErr) { - // LLM provided wrong arguments to the tool, - // return the error back to the LLM instead of failing. - appendPart(map[string]any{ - "error": err.Error(), - }) + // LLM provided wrong arguments to the tool, + // or the tool returned error message to the LLM. + // Return the error back to the LLM instead of failing. + if callErr := new(badCallError); errors.As(err, &callErr) { + appendError(err.Error()) continue } - return nil, nil, err + return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v", + call.Name, err, call.Args) } appendPart(results) if a.Outputs != nil && tool == a.Outputs.tool { diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go index 2b2d77f76..0b0eb8994 100644 --- a/pkg/aflow/schema.go +++ b/pkg/aflow/schema.go @@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { f, ok := m[name] if !ok { if tool { - return val, &toolArgsError{fmt.Errorf("missing argument %q", name)} + return val, BadCallError(fmt.Sprintf("missing argument %q", name)) } else { return val, fmt.Errorf("field %q is not present when converting map to %T", name, val) } @@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { field.Set(reflect.ValueOf(f)) } else { if tool { - return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v", - name, f, field.Type().Name())} + return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v", + name, f, field.Type().Name())) } else { return val, fmt.Errorf("field %q has wrong type: got %T, want %v", name, f, field.Type().Name()) @@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { return val, nil } -type toolArgsError struct{ error } - // foreachField iterates over all public fields of the struct provided in data. func foreachField(data any) iter.Seq2[string, reflect.Value] { return func(yield func(string, reflect.Value) bool) { diff --git a/pkg/aflow/tool/codesearcher/codesearcher.go b/pkg/aflow/tool/codesearcher/codesearcher.go index 34db81b80..922e32569 100644 --- a/pkg/aflow/tool/codesearcher/codesearcher.go +++ b/pkg/aflow/tool/codesearcher/codesearcher.go @@ -68,7 +68,6 @@ type dirIndexArgs struct { } type dirIndexResult struct { - Missing bool `jsonschema:"Set to true if the requested directory does not exist."` Subdirs []string `jsonschema:"List of direct subdirectories."` Files []string `jsonschema:"List of source files."` } @@ -78,7 +77,6 @@ type readFileArgs struct { } type readFileResult struct { - Missing bool `jsonschema:"Set to true if the requested file does not exist."` Contents string `jsonschema:"File contents."` } @@ -87,7 +85,6 @@ type fileIndexArgs struct { } type fileIndexResult struct { - Missing bool `jsonschema:"Set to true if the file with the given name does not exist."` Entities []indexEntity `jsonschema:"List of entites defined in the file."` } @@ -103,7 +100,6 @@ type defCommentArgs struct { } type defCommentResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` Kind string `jsonschema:"Kind of the entity: function, struct, variable."` Comment string `jsonschema:"Source comment for the entity."` } @@ -117,7 +113,6 @@ type defSourceArgs struct { // nolint: lll type defSourceResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` SourceFile string `jsonschema:"Source file path where the entity is defined."` SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."` } @@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) { } func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) { - ok, subdirs, files, err := state.Index.DirIndex(args.Dir) - res := dirIndexResult{ - Missing: !ok, + subdirs, files, err := state.Index.DirIndex(args.Dir) + return dirIndexResult{ Subdirs: subdirs, Files: files, - } - return res, err + }, err } func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) { - ok, contents, err := state.Index.ReadFile(args.File) - res := readFileResult{ - Missing: !ok, + contents, err := state.Index.ReadFile(args.File) + return readFileResult{ Contents: contents, - } - return res, err + }, err } func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) { - ok, entities, err := state.Index.FileIndex(args.SourceFile) - res := fileIndexResult{ - Missing: !ok, - } + entities, err := state.Index.FileIndex(args.SourceFile) + res := fileIndexResult{} for _, ent := range entities { res.Entities = append(res.Entities, indexEntity{ Kind: ent.Kind, @@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) { info, err := state.Index.DefinitionComment(args.SourceFile, args.Name) - if err != nil || info == nil { - return defCommentResult{ - Missing: info == nil, - }, err + if err != nil { + return defCommentResult{}, err } return defCommentResult{ Kind: info.Kind, @@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) { info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines) - if err != nil || info == nil { - return defSourceResult{ - Missing: info == nil, - }, err + if err != nil { + return defSourceResult{}, err } return defSourceResult{ SourceFile: info.File, |
