diff options
| -rw-r--r-- | pkg/aflow/func_tool.go | 13 | ||||
| -rw-r--r-- | pkg/aflow/func_tool_test.go | 111 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 21 | ||||
| -rw-r--r-- | pkg/aflow/schema.go | 8 | ||||
| -rw-r--r-- | pkg/aflow/tool/codesearcher/codesearcher.go | 39 | ||||
| -rw-r--r-- | pkg/codesearch/codesearch.go | 67 | ||||
| -rw-r--r-- | pkg/codesearch/codesearch_test.go | 3 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-def-source-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-escaping | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-file | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-dir-index-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-file-index-missing | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-dir | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-escaping | 2 | ||||
| -rw-r--r-- | pkg/codesearch/testdata/query-read-file-missing | 2 |
15 files changed, 195 insertions, 83 deletions
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go index 48b47b1e5..dde359485 100644 --- a/pkg/aflow/func_tool.go +++ b/pkg/aflow/func_tool.go @@ -4,6 +4,8 @@ package aflow import ( + "errors" + "github.com/google/syzkaller/pkg/aflow/trajectory" "google.golang.org/genai" ) @@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, } } +// BadCallError creates an error that means that LLM made a bad tool call, +// the provided message will be returned to the LLM as an error, +// instead of failing the whole workflow. +func BadCallError(message string) error { + return &badCallError{errors.New(message)} +} + +type badCallError struct { + error +} + type funcTool[State, Args, Results any] struct { Name string Description string diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go new file mode 100644 index 000000000..429566dbe --- /dev/null +++ b/pkg/aflow/func_tool_test.go @@ -0,0 +1,111 @@ +// Copyright 2026 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestToolErrors(t *testing.T) { + type flowOutputs struct { + Reply string + } + type toolArgs struct { + CallError bool `jsonschema:"call error"` + } + flows := make(map[string]*Flow) + err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{ + { + Root: &LLMAgent{ + Name: "smarty", + Model: "model", + Reply: "Reply", + Temperature: 0, + Instruction: "Do something!", + Prompt: "Prompt", + Tools: []Tool{ + NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) { + if args.CallError { + return struct{}{}, BadCallError("you are wrong") + } + return struct{}{}, errors.New("hard error") + }, "tool 1 description"), + }, + }, + }, + }) + require.NoError(t, err) + replySeq := 0 + stub := &stubContext{ + // nolint:dupl + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + replySeq++ + switch replySeq { + case 1: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": true, + }, + }, + }, + }}}}}, nil + case 2: + assert.Equal(t, req[2], &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id0", + Name: "faulty", + Response: map[string]any{ + "error": "you are wrong", + }, + }, + }}}) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "faulty", + Args: map[string]any{ + "CallError": false, + }, + }, + }, + }}}}}, nil + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent) + require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]") +} diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index b473c9e7a..5934bf9bd 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai }, }) } + appendError := func(message string) { + appendPart(map[string]any{"error": message}) + } tool := tools[call.Name] if tool == nil { - appendPart(map[string]any{ - "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name), - }) + appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name)) continue } results, err := tool.execute(ctx, call.Args) if err != nil { - if argsErr := new(toolArgsError); errors.As(err, &argsErr) { - // LLM provided wrong arguments to the tool, - // return the error back to the LLM instead of failing. - appendPart(map[string]any{ - "error": err.Error(), - }) + // LLM provided wrong arguments to the tool, + // or the tool returned error message to the LLM. + // Return the error back to the LLM instead of failing. + if callErr := new(badCallError); errors.As(err, &callErr) { + appendError(err.Error()) continue } - return nil, nil, err + return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v", + call.Name, err, call.Args) } appendPart(results) if a.Outputs != nil && tool == a.Outputs.tool { diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go index 2b2d77f76..0b0eb8994 100644 --- a/pkg/aflow/schema.go +++ b/pkg/aflow/schema.go @@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { f, ok := m[name] if !ok { if tool { - return val, &toolArgsError{fmt.Errorf("missing argument %q", name)} + return val, BadCallError(fmt.Sprintf("missing argument %q", name)) } else { return val, fmt.Errorf("field %q is not present when converting map to %T", name, val) } @@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { field.Set(reflect.ValueOf(f)) } else { if tool { - return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v", - name, f, field.Type().Name())} + return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v", + name, f, field.Type().Name())) } else { return val, fmt.Errorf("field %q has wrong type: got %T, want %v", name, f, field.Type().Name()) @@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) { return val, nil } -type toolArgsError struct{ error } - // foreachField iterates over all public fields of the struct provided in data. func foreachField(data any) iter.Seq2[string, reflect.Value] { return func(yield func(string, reflect.Value) bool) { diff --git a/pkg/aflow/tool/codesearcher/codesearcher.go b/pkg/aflow/tool/codesearcher/codesearcher.go index 34db81b80..922e32569 100644 --- a/pkg/aflow/tool/codesearcher/codesearcher.go +++ b/pkg/aflow/tool/codesearcher/codesearcher.go @@ -68,7 +68,6 @@ type dirIndexArgs struct { } type dirIndexResult struct { - Missing bool `jsonschema:"Set to true if the requested directory does not exist."` Subdirs []string `jsonschema:"List of direct subdirectories."` Files []string `jsonschema:"List of source files."` } @@ -78,7 +77,6 @@ type readFileArgs struct { } type readFileResult struct { - Missing bool `jsonschema:"Set to true if the requested file does not exist."` Contents string `jsonschema:"File contents."` } @@ -87,7 +85,6 @@ type fileIndexArgs struct { } type fileIndexResult struct { - Missing bool `jsonschema:"Set to true if the file with the given name does not exist."` Entities []indexEntity `jsonschema:"List of entites defined in the file."` } @@ -103,7 +100,6 @@ type defCommentArgs struct { } type defCommentResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` Kind string `jsonschema:"Kind of the entity: function, struct, variable."` Comment string `jsonschema:"Source comment for the entity."` } @@ -117,7 +113,6 @@ type defSourceArgs struct { // nolint: lll type defSourceResult struct { - Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."` SourceFile string `jsonschema:"Source file path where the entity is defined."` SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."` } @@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) { } func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) { - ok, subdirs, files, err := state.Index.DirIndex(args.Dir) - res := dirIndexResult{ - Missing: !ok, + subdirs, files, err := state.Index.DirIndex(args.Dir) + return dirIndexResult{ Subdirs: subdirs, Files: files, - } - return res, err + }, err } func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) { - ok, contents, err := state.Index.ReadFile(args.File) - res := readFileResult{ - Missing: !ok, + contents, err := state.Index.ReadFile(args.File) + return readFileResult{ Contents: contents, - } - return res, err + }, err } func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) { - ok, entities, err := state.Index.FileIndex(args.SourceFile) - res := fileIndexResult{ - Missing: !ok, - } + entities, err := state.Index.FileIndex(args.SourceFile) + res := fileIndexResult{} for _, ent := range entities { res.Entities = append(res.Entities, indexEntity{ Kind: ent.Kind, @@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) { info, err := state.Index.DefinitionComment(args.SourceFile, args.Name) - if err != nil || info == nil { - return defCommentResult{ - Missing: info == nil, - }, err + if err != nil { + return defCommentResult{}, err } return defCommentResult{ Kind: info.Kind, @@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) { info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines) - if err != nil || info == nil { - return defSourceResult{ - Missing: info == nil, - }, err + if err != nil { + return defSourceResult{}, err } return defSourceResult{ SourceFile: info.File, diff --git a/pkg/codesearch/codesearch.go b/pkg/codesearch/codesearch.go index 96ee1c696..6eab749db 100644 --- a/pkg/codesearch/codesearch.go +++ b/pkg/codesearch/codesearch.go @@ -13,6 +13,7 @@ import ( "strings" "syscall" + "github.com/google/syzkaller/pkg/aflow" "github.com/google/syzkaller/pkg/osutil" ) @@ -30,9 +31,9 @@ type Command struct { // Commands are used to run unit tests and for the syz-codesearch tool. var Commands = []Command{ {"dir-index", 1, func(index *Index, args []string) (string, error) { - ok, subdirs, files, err := index.DirIndex(args[0]) - if err != nil || !ok { - return notFound, err + subdirs, files, err := index.DirIndex(args[0]) + if err != nil { + return "", err } b := new(strings.Builder) fmt.Fprintf(b, "directory %v subdirs:\n", args[0]) @@ -46,16 +47,12 @@ var Commands = []Command{ return b.String(), nil }}, {"read-file", 1, func(index *Index, args []string) (string, error) { - ok, contents, err := index.ReadFile(args[0]) - if err != nil || !ok { - return notFound, err - } - return contents, nil + return index.ReadFile(args[0]) }}, {"file-index", 1, func(index *Index, args []string) (string, error) { - ok, entities, err := index.FileIndex(args[0]) - if err != nil || !ok { - return notFound, err + entities, err := index.FileIndex(args[0]) + if err != nil { + return "", err } b := new(strings.Builder) fmt.Fprintf(b, "file %v defines the following entities:\n\n", args[0]) @@ -66,8 +63,8 @@ var Commands = []Command{ }}, {"def-comment", 2, func(index *Index, args []string) (string, error) { info, err := index.DefinitionComment(args[0], args[1]) - if err != nil || info == nil { - return notFound, err + if err != nil { + return "", err } if info.Body == "" { return fmt.Sprintf("%v %v is defined in %v and is not commented\n", @@ -78,8 +75,8 @@ var Commands = []Command{ }}, {"def-source", 3, func(index *Index, args []string) (string, error) { info, err := index.DefinitionSource(args[0], args[1], args[2] == "yes") - if err != nil || info == nil { - return notFound, err + if err != nil { + return "", err } return fmt.Sprintf("%v %v is defined in %v:\n\n%v", info.Kind, args[1], info.File, info.Body), nil }}, @@ -87,8 +84,6 @@ var Commands = []Command{ var SourceExtensions = map[string]bool{".c": true, ".h": true, ".S": true, ".rs": true} -const notFound = "not found\n" - func NewIndex(databaseFile string, srcDirs []string) (*Index, error) { db, err := osutil.ReadJSON[*Database](databaseFile) if err != nil { @@ -118,16 +113,16 @@ type Entity struct { Name string } -func (index *Index) DirIndex(dir string) (bool, []string, []string, error) { +func (index *Index) DirIndex(dir string) ([]string, []string, error) { if err := escaping(dir); err != nil { - return false, nil, nil, nil + return nil, nil, err } exists := false var subdirs, files []string for _, root := range index.srcDirs { exists1, subdirs1, files1, err := dirIndex(root, dir) if err != nil { - return false, nil, nil, err + return nil, nil, err } if exists1 { exists = true @@ -135,18 +130,21 @@ func (index *Index) DirIndex(dir string) (bool, []string, []string, error) { subdirs = append(subdirs, subdirs1...) files = append(files, files1...) } + if !exists { + return nil, nil, aflow.BadCallError("the directory does not exist") + } slices.Sort(subdirs) slices.Sort(files) // Dedup dirs across src/build trees, // also dedup files, but hopefully there are no duplicates. subdirs = slices.Compact(subdirs) files = slices.Compact(files) - return exists, subdirs, files, nil + return subdirs, files, nil } -func (index *Index) ReadFile(file string) (bool, string, error) { +func (index *Index) ReadFile(file string) (string, error) { if err := escaping(file); err != nil { - return false, "", nil + return "", err } for _, dir := range index.srcDirs { data, err := os.ReadFile(filepath.Join(dir, file)) @@ -156,16 +154,21 @@ func (index *Index) ReadFile(file string) (bool, string, error) { } var errno syscall.Errno if errors.As(err, &errno) && errno == syscall.EISDIR { - return false, "", nil + return "", aflow.BadCallError("the file is a directory") } - return false, "", err + return "", err } - return true, string(data), nil + return string(data), nil } - return false, "", nil + return "", aflow.BadCallError("the file does not exist") } -func (index *Index) FileIndex(file string) (bool, []Entity, error) { +func (index *Index) FileIndex(file string) ([]Entity, error) { + file = filepath.Clean(file) + // This allows to distinguish missing files from files that don't define anything. + if _, err := index.ReadFile(file); err != nil { + return nil, err + } var entities []Entity for _, def := range index.db.Definitions { if def.Body.File == file { @@ -175,7 +178,7 @@ func (index *Index) FileIndex(file string) (bool, []Entity, error) { }) } } - return len(entities) != 0, entities, nil + return entities, nil } type EntityInfo struct { @@ -195,7 +198,7 @@ func (index *Index) DefinitionSource(contextFile, name string, includeLines bool func (index *Index) definitionSource(contextFile, name string, comment, includeLines bool) (*EntityInfo, error) { def := index.findDefinition(contextFile, name) if def == nil { - return nil, nil + return nil, aflow.BadCallError("requested entity does not exist") } lineRange := def.Body if comment { @@ -266,7 +269,7 @@ func formatSourceFile(file string, start, end int, includeLines bool) (string, e func escaping(path string) error { if strings.Contains(filepath.Clean(path), "..") { - return errors.New("path is outside of the source tree") + return aflow.BadCallError("path is outside of the source tree") } return nil } @@ -280,7 +283,7 @@ func dirIndex(root, subdir string) (bool, []string, []string, error) { } var errno syscall.Errno if errors.As(err, &errno) && errno == syscall.ENOTDIR { - err = nil + err = aflow.BadCallError("the path is not a directory") } return false, nil, nil, err } diff --git a/pkg/codesearch/codesearch_test.go b/pkg/codesearch/codesearch_test.go index 7af509294..1f353c804 100644 --- a/pkg/codesearch/codesearch_test.go +++ b/pkg/codesearch/codesearch_test.go @@ -53,7 +53,8 @@ func testCommand(t *testing.T, index *Index, covered map[string]bool, file strin } result, err := index.Command(args[0], args[1:]) if err != nil { - t.Fatal(err) + // This is supposed to test aflow.BadCallError messages. + result = err.Error() + "\n" } got := append([]byte(strings.Join(args, " ")+"\n\n"), result...) tooltest.CompareGoldenData(t, file, got) diff --git a/pkg/codesearch/testdata/query-def-source-missing b/pkg/codesearch/testdata/query-def-source-missing index 0b60003c7..dcb8c4d75 100644 --- a/pkg/codesearch/testdata/query-def-source-missing +++ b/pkg/codesearch/testdata/query-def-source-missing @@ -1,3 +1,3 @@ def-source source0.c some_non_existent_function no -not found +requested entity does not exist diff --git a/pkg/codesearch/testdata/query-dir-index-escaping b/pkg/codesearch/testdata/query-dir-index-escaping index fd7b55ff0..f57f0eb82 100644 --- a/pkg/codesearch/testdata/query-dir-index-escaping +++ b/pkg/codesearch/testdata/query-dir-index-escaping @@ -1,3 +1,3 @@ dir-index mm/../../ -not found +path is outside of the source tree diff --git a/pkg/codesearch/testdata/query-dir-index-file b/pkg/codesearch/testdata/query-dir-index-file index eecd67d67..a519fad18 100644 --- a/pkg/codesearch/testdata/query-dir-index-file +++ b/pkg/codesearch/testdata/query-dir-index-file @@ -1,3 +1,3 @@ dir-index source0.c -not found +the path is not a directory diff --git a/pkg/codesearch/testdata/query-dir-index-missing b/pkg/codesearch/testdata/query-dir-index-missing index e028d1be1..102ddbdab 100644 --- a/pkg/codesearch/testdata/query-dir-index-missing +++ b/pkg/codesearch/testdata/query-dir-index-missing @@ -1,3 +1,3 @@ dir-index mm/foobar -not found +the directory does not exist diff --git a/pkg/codesearch/testdata/query-file-index-missing b/pkg/codesearch/testdata/query-file-index-missing index 1be486378..803a66377 100644 --- a/pkg/codesearch/testdata/query-file-index-missing +++ b/pkg/codesearch/testdata/query-file-index-missing @@ -1,3 +1,3 @@ file-index some-non-existent-file.c -not found +the file does not exist diff --git a/pkg/codesearch/testdata/query-read-file-dir b/pkg/codesearch/testdata/query-read-file-dir index 210a326cd..b5008e010 100644 --- a/pkg/codesearch/testdata/query-read-file-dir +++ b/pkg/codesearch/testdata/query-read-file-dir @@ -1,3 +1,3 @@ read-file mm -not found +the file is a directory diff --git a/pkg/codesearch/testdata/query-read-file-escaping b/pkg/codesearch/testdata/query-read-file-escaping index fca2abf6a..ba5d544c2 100644 --- a/pkg/codesearch/testdata/query-read-file-escaping +++ b/pkg/codesearch/testdata/query-read-file-escaping @@ -1,3 +1,3 @@ read-file mm/../../codesearch.go -not found +path is outside of the source tree diff --git a/pkg/codesearch/testdata/query-read-file-missing b/pkg/codesearch/testdata/query-read-file-missing index ac7bead8d..adaff1616 100644 --- a/pkg/codesearch/testdata/query-read-file-missing +++ b/pkg/codesearch/testdata/query-read-file-missing @@ -1,3 +1,3 @@ read-file file-that-does-not-exist.c -not found +the file does not exist |
