diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-19 15:15:18 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 21:12:57 +0000 |
| commit | 5b6bebdcb7da46d1471b3aeacb28b54ba905b3b2 (patch) | |
| tree | d60fdce83c9b47fb39327f0bdc40b734a7213985 | |
| parent | 8088ac4199a6e947c38db669c11d4441a9d59581 (diff) | |
pkg/aflow: add BadCallError
The error allows tools to communicate that an error is not an infrastructure error
that must fail the whole workflow, but rather a bad tool invocation by an LLM
(e.g. asking for a non-existent file contents).
Previously in the codesearcher tool we used a separate Missing bool
to communicate that. With the error everything just becomes cleaner and nicer.
The errors also allows all other tools to communicate any errors to the LLM
when the normal results cannot be provided and don't make sense.
| -rw-r--r-- | pkg/aflow/func_tool.go | 13 | ||||
| -rw-r--r-- | pkg/aflow/func_tool_test.go | 111 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 21 | ||||
| -rw-r--r-- | pkg/aflow/schema.go | 8 | ||||
| -rw-r--r-- | pkg/aflow/tool/codesearcher/codesearcher.go | 39 | ||||
| -rw-r--r-- | pkg/codesearch/codesearch.go | 67 | ||||
| -rw-r--r-- | pkg/codesearch/codesearch_test.go | 3 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-def-source-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-escaping | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-file | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-file-index-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-dir | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-escaping | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-missing | 2 |
15 files changed, 195 insertions, 83 deletions
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go index 48b47b1e5..dde359485 100644 --- a/pkg/aflow/func_tool.go +++ b/pkg/aflow/func_tool.go @@ -4,6 +4,8 @@ package aflow import ( + "errors" + "github.com/google/syzkaller/pkg/aflow/trajectory" "google.golang.org/genai" ) @@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, } } +// BadCallError creates an error that means that LLM made a bad tool call, +// the provided message will be returned to the LLM as an error, +// instead of failing the whole workflow. +func BadCallError(message string) error { + return &badCallError{errors.New(message)} +} + +type badCallError struct { + error +} + type funcTool[State, Args, Results any] struct { Name string Description string diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go new file mode 100644 index 000000000..429566dbe --- /dev/null +++ b/pkg/aflow/func_tool_test.go @@ -0,0 +1,111 @@ +// Copyright 2026 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestToolErrors(t *testing.T) { + type flowOutputs struct { + Reply string + } + type toolArgs struct { + CallError bool `jsonschema:"call error"` + } + flows := make(map[string]*Flow) + err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{ + { + Root: &LLMAgent{ + Name: "smarty", + Model: "model", + Reply: "Reply", + Temperature: 0, + Instruction: "Do something!", + Prompt: "Prompt", + Tools: []Tool{ + NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) { + if args.CallError { + return struct{}{}, BadCallError("you are wrong") + } + return struct{}{}, errors.New("hard error") + }, "tool 1 description"), + }, + }, + }, + }) + require.NoError(t, err) + replySeq := 0 + stub := &stubContext{ + // nolint:dupl + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + replySeq++ + switch replySeq { + case 1: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": true, + }, + }, + }, + }}}}}, nil + case 2: + assert.Equal(t, req[2], &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id0", + Name: "faulty", + Response: map[string]any{ + "error": "you are wrong", + }, + }, + }}}) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": false, + }, + }, + }, + }}}}}, nil + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent) + require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]") +} diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index b473c9e7a..5934bf9bd 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai }, }) } + appendError := func(message string) { + appendPart(map[string]any{"error": message}) + } tool := tools[call.Name] if tool == nil { - appendPart(map[string]any{ - "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name), - }) + appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name)) continue } results, err := tool.execute(ctx, call.Args) if err != nil { - if argsErr := new(toolArgsError); errors.As(err, &argsErr) { - // LLM provided wrong arguments to the tool, - // return the error back to the LLM instead of failing. - appendPart(map[string]any{ - "error": err.Error(), - }) + // LLM provided wrong arguments to the tool, + // or the tool returned error message to the LLM. + // Return the error back to the LLM instead of failing. + if callErr := new(badCallError); errors.As(err, &callErr) { + appendError(err.Error()) continue } - return nil, nil, err + return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v", + call.Name, err, call.Args) } appendPart(results) if a.Outputs != nil && tool == a.Outputs.tool { diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go index 2b2d77f76..0b0eb8994 100644 --- a/pkg/aflow/schema.go +++ b/pkg/aflow/schema.go @@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { f, ok := m[name] if !ok { if tool { - return val, &toolArgsError{fmt.Errorf("missing argument %q", name)} + return val, BadCallError(fmt.Sprintf("missing argument %q", name)) } else { return val, fmt.Errorf("field %q is not present when converting map to %T", name, val) } @@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { field.Set(reflect.ValueOf(f)) } else { if tool { - return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v", - name, f, field.Type().Name())} + return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v", + name, f, field.Type().Name())) } else { return val, fmt.Errorf("field %q has wrong type: got %T, want %v", name, f, field.Type().Name()) @@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { return val, nil } -type toolArgsError struct{ error } - // foreachField iterates over all public fields of the struct provided in data. func foreachField(data any) iter.Seq2[string, reflect.Value] { return func(yield func(string, reflect.Value) bool) { diff --git a/pkg/aflow/tool/codesearcher/codesearcher.go b/pkg/aflow/tool/codesearcher/codesearcher.go index 34db81b80..922e32569 100644 --- a/pkg/aflow/tool/codesearcher/codesearcher.go +++ b/pkg/aflow/tool/codesearcher/codesearcher.go @@ -68,7 +68,6 @@ type dirIndexArgs struct { } type dirIndexResult struct { - Missing bool `jsonschema:"Set to true if the requested directory does not exist."` Subdirs []string `jsonschema:"List of direct subdirectories."` Files []string `jsonschema:"List of source files."` } @@ -78,7 +77,6 @@ type readFileArgs struct { } type readFileResult struct { - Missing bool `jsonschema:"Set to true if the requested file does not exist."` Contents string `jsonschema:"File contents."` } @@ -87,7 +85,6 @@ type fileIndexArgs struct { } type fileIndexResult struct { - Missing bool `jsonschema:"Set to true if the file with the given name does not exist."` Entities []indexEntity `jsonschema:"List of entites defined in the file."` } @@ -103,7 +100,6 @@ type defCommentArgs struct { } type defCommentResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` Kind string `jsonschema:"Kind of the entity: function, struct, variable."` Comment string `jsonschema:"Source comment for the entity."` } @@ -117,7 +113,6 @@ type defSourceArgs struct { // nolint: lll type defSourceResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` SourceFile string `jsonschema:"Source file path where the entity is defined."` SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."` } @@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) { } func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) { - ok, subdirs, files, err := state.Index.DirIndex(args.Dir) - res := dirIndexResult{ - Missing: !ok, + subdirs, files, err := state.Index.DirIndex(args.Dir) + return dirIndexResult{ Subdirs: subdirs, Files: files, - } - return res, err + }, err } func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) { - ok, contents, err := state.Index.ReadFile(args.File) - res := readFileResult{ - Missing: !ok, + contents, err := state.Index.ReadFile(args.File) + return readFileResult{ Contents: contents, - } - return res, err + }, err } func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) { - ok, entities, err := state.Index.FileIndex(args.SourceFile) - res := fileIndexResult{ - Missing: !ok, - } + entities, err := state.Index.FileIndex(args.SourceFile) + res := fileIndexResult{} for _, ent := range entities { res.Entities = append(res.Entities, indexEntity{ Kind: ent.Kind, @@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) { info, err := state.Index.DefinitionComment(args.SourceFile, args.Name) - if err != nil || info == nil { - return defCommentResult{ - Missing: info == nil, - }, err + if err != nil { + return defCommentResult{}, err } return defCommentResult{ Kind: info.Kind, @@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) { info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines) - if err != nil || info == nil { - return defSourceResult{ - Missing: info == nil, - }, err + if err != nil { + return defSourceResult{}, err } return defSourceResult{ SourceFile: info.File, diff --git a/pkg/codesearch/codesearch.go b/pkg/codesearch/codesearch.go index 96ee1c696..6eab749db 100644 --- a/pkg/codesearch/codesearch.go +++ b/pkg/codesearch/codesearch.go @@ -13,6 +13,7 @@ import ( "strings" "syscall" + "github.com/google/syzkaller/pkg/aflow" "github.com/google/syzkaller/pkg/osutil" ) @@ -30,9 +31,9 @@ type Command struct { // Commands are used to run unit tests and for the syz-codesearch tool. var Commands = []Command{ {"dir-index", 1, func(index *Index, args []string) (string, error) { - ok, subdirs, files, err := index.DirIndex(args[0]) - if err != nil || !ok { - return notFound, err + subdirs, files, err := index.DirIndex(args[0]) + if err != nil { + return "", err } b := new(strings.Builder) fmt.Fprintf(b, "directory %v subdirs:\n", args[0]) @@ -46,16 +47,12 @@ var Commands = []Command{ return b.String(), nil }}, {"read-file", 1, func(index *Index, args []string) (string, error) { - ok, contents, err := index.ReadFile(args[0]) - if err != nil || !ok { - return notFound, err - } - return contents, nil + return index.ReadFile(args[0]) }}, {"file-index", 1, func(index *Index, args []string) (string, error) { - ok, entities, err := index.FileIndex(args[0]) - if err != nil || !ok { - return notFound, err + entities, err := index.FileIndex(args[0]) + if err != nil { + return "", err } b := new(strings.Builder) fmt.Fprintf(b, "file %v defines the following entities:\n\n", args[0]) @@ -66,8 +63,8 @@ var Commands = []Command{ }}, {"def-comment", 2, func(index *Index, args []string) (string, error) { info, err := index.DefinitionComment(args[0], args[1]) - if err != nil || info == nil { - return notFound, err + if err != nil { + return "", err } if info.Body == "" { return fmt.Sprintf("%v %v is defined in %v and is not commented\n", @@ -78,8 +75,8 @@ var Commands = []Command{ }}, {"def-source", 3, func(index *Index, args []string) (string, error) { info, err := index.DefinitionSource(args[0], args[1], args[2] == "yes") - if err != nil || info == nil { - return notFound, err + if err != nil { + return "", err } return fmt.Sprintf("%v %v is defined in %v:\n\n%v", info.Kind, args[1], info.File, info.Body), nil }}, @@ -87,8 +84,6 @@ var Commands = []Command{ var SourceExtensions = map[string]bool{".c": true, ".h": true, ".S": true, ".rs": true} -const notFound = "not found\n" - func NewIndex(databaseFile string, srcDirs []string) (*Index, error) { db, err := osutil.ReadJSON[*Database](databaseFile) if err != nil { @@ -118,16 +113,16 @@ type Entity struct { Name string } -func (index *Index) DirIndex(dir string) (bool, []string, []string, error) { +func (index *Index) DirIndex(dir string) ([]string, []string, error) { if err := escaping(dir); err != nil { - return false, nil, nil, nil + return nil, nil, err } exists := false var subdirs, files []string for _, root := range index.srcDirs { exists1, subdirs1, files1, err := dirIndex(root, dir) if err != nil { - return false, nil, nil, err + return nil, nil, err } if exists1 { exists = true @@ -135,18 +130,21 @@ func (index *Index) DirIndex(dir string) (bool, []string, []string, error) { subdirs = append(subdirs, subdirs1...) files = append(files, files1...) } + if !exists { + return nil, nil, aflow.BadCallError("the directory does not exist") + } slices.Sort(subdirs) slices.Sort(files) // Dedup dirs across src/build trees, // also dedup files, but hopefully there are no duplicates. subdirs = slices.Compact(subdirs) files = slices.Compact(files) - return exists, subdirs, files, nil + return subdirs, files, nil } -func (index *Index) ReadFile(file string) (bool, string, error) { +func (index *Index) ReadFile(file string) (string, error) { if err := escaping(file); err != nil { - return false, "", nil + return "", err } for _, dir := range index.srcDirs { data, err := os.ReadFile(filepath.Join(dir, file)) @@ -156,16 +154,21 @@ func (index *Index) ReadFile(file string) (bool, string, error) { } var errno syscall.Errno if errors.As(err, &errno) && errno == syscall.EISDIR { - return false, "", nil + return "", aflow.BadCallError("the file is a directory") } - return false, "", err + return "", err } - return true, string(data), nil + return string(data), nil } - return false, "", nil + return "", aflow.BadCallError("the file does not exist") } -func (index *Index) FileIndex(file string) (bool, []Entity, error) { +func (index *Index) FileIndex(file string) ([]Entity, error) { + file = filepath.Clean(file) + // This allows to distinguish missing files from files that don't define anything. + if _, err := index.ReadFile(file); err != nil { + return nil, err + } var entities []Entity for _, def := range index.db.Definitions { if def.Body.File == file { @@ -175,7 +178,7 @@ func (index *Index) FileIndex(file string) (bool, []Entity, error) { }) } } - return len(entities) != 0, entities, nil + return entities, nil } type EntityInfo struct { @@ -195,7 +198,7 @@ func (index *Index) DefinitionSource(contextFile, name string, includeLines bool func (index *Index) definitionSource(contextFile, name string, comment, includeLines bool) (*EntityInfo, error) { def := index.findDefinition(contextFile, name) if def == nil { - return nil, nil + return nil, aflow.BadCallError("requested entity does not exist") } lineRange := def.Body if comment { @@ -266,7 +269,7 @@ func formatSourceFile(file string, start, end int, includeLines bool) (string, e func escaping(path string) error { if strings.Contains(filepath.Clean(path), "..") { - return errors.New("path is outside of the source tree") + return aflow.BadCallError("path is outside of the source tree") } return nil } @@ -280,7 +283,7 @@ func dirIndex(root, subdir string) (bool, []string, []string, error) { } var errno syscall.Errno if errors.As(err, &errno) && errno == syscall.ENOTDIR { - err = nil + err = aflow.BadCallError("the path is not a directory") } return false, nil, nil, err } diff --git a/pkg/codesearch/codesearch_test.go b/pkg/codesearch/codesearch_test.go index 7af509294..1f353c804 100644 --- a/pkg/codesearch/codesearch_test.go +++ b/pkg/codesearch/codesearch_test.go @@ -53,7 +53,8 @@ func testCommand(t *testing.T, index *Index, covered map[string]bool, file strin } result, err := index.Command(args[0], args[1:]) if err != nil { - t.Fatal(err) + // This is supposed to test aflow.BadCallError messages. + result = err.Error() + "\n" } got := append([]byte(strings.Join(args, " ")+"\n\n"), result...) tooltest.CompareGoldenData(t, file, got) diff --git a/pkg/codesearch/testdata/query-def-source-missing b/pkg/codesearch/testdata/query-def-source-missing index 0b60003c7..dcb8c4d75 100644 --- a/pkg/codesearch/testdata/query-def-source-missing +++ b/pkg/codesearch/testdata/query-def-source-missing @@ -1,3 +1,3 @@ def-source source0.c some_non_existent_function no -not found +requested entity does not exist diff --git a/pkg/codesearch/testdata/query-dir-index-escaping b/pkg/codesearch/testdata/query-dir-index-escaping index fd7b55ff0..f57f0eb82 100644 --- a/pkg/codesearch/testdata/query-dir-index-escaping +++ b/pkg/codesearch/testdata/query-dir-index-escaping @@ -1,3 +1,3 @@ dir-index mm/../../ -not found +path is outside of the source tree diff --git a/pkg/codesearch/testdata/query-dir-index-file b/pkg/codesearch/testdata/query-dir-index-file index eecd67d67..a519fad18 100644 --- a/pkg/codesearch/testdata/query-dir-index-file +++ b/pkg/codesearch/testdata/query-dir-index-file @@ -1,3 +1,3 @@ dir-index source0.c -not found +the path is not a directory diff --git a/pkg/codesearch/testdata/query-dir-index-missing b/pkg/codesearch/testdata/query-dir-index-missing index e028d1be1..102ddbdab 100644 --- a/pkg/codesearch/testdata/query-dir-index-missing +++ b/pkg/codesearch/testdata/query-dir-index-missing @@ -1,3 +1,3 @@ dir-index mm/foobar -not found +the directory does not exist diff --git a/pkg/codesearch/testdata/query-file-index-missing b/pkg/codesearch/testdata/query-file-index-missing index 1be486378..803a66377 100644 --- a/pkg/codesearch/testdata/query-file-index-missing +++ b/pkg/codesearch/testdata/query-file-index-missing @@ -1,3 +1,3 @@ file-index some-non-existent-file.c -not found +the file does not exist diff --git a/pkg/codesearch/testdata/query-read-file-dir b/pkg/codesearch/testdata/query-read-file-dir index 210a326cd..b5008e010 100644 --- a/pkg/codesearch/testdata/query-read-file-dir +++ b/pkg/codesearch/testdata/query-read-file-dir @@ -1,3 +1,3 @@ read-file mm -not found +the file is a directory diff --git a/pkg/codesearch/testdata/query-read-file-escaping b/pkg/codesearch/testdata/query-read-file-escaping index fca2abf6a..ba5d544c2 100644 --- a/pkg/codesearch/testdata/query-read-file-escaping +++ b/pkg/codesearch/testdata/query-read-file-escaping @@ -1,3 +1,3 @@ read-file mm/../../codesearch.go -not found +path is outside of the source tree diff --git a/pkg/codesearch/testdata/query-read-file-missing b/pkg/codesearch/testdata/query-read-file-missing index ac7bead8d..adaff1616 100644 --- a/pkg/codesearch/testdata/query-read-file-missing +++ b/pkg/codesearch/testdata/query-read-file-missing @@ -1,3 +1,3 @@ read-file file-that-does-not-exist.c -not found +the file does not exist |
