aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/func_tool.go13
-rw-r--r--pkg/aflow/func_tool_test.go111
-rw-r--r--pkg/aflow/llm_agent.go21
-rw-r--r--pkg/aflow/schema.go8
-rw-r--r--pkg/aflow/tool/codesearcher/codesearcher.go39
-rw-r--r--pkg/codesearch/codesearch.go67
-rw-r--r--pkg/codesearch/codesearch_test.go3
-rw-r--r--pkg/codesearch/testdata/query-def-source-missing2
-rw-r--r--pkg/codesearch/testdata/query-dir-index-escaping2
-rw-r--r--pkg/codesearch/testdata/query-dir-index-file2
-rw-r--r--pkg/codesearch/testdata/query-dir-index-missing2
-rw-r--r--pkg/codesearch/testdata/query-file-index-missing2
-rw-r--r--pkg/codesearch/testdata/query-read-file-dir2
-rw-r--r--pkg/codesearch/testdata/query-read-file-escaping2
-rw-r--r--pkg/codesearch/testdata/query-read-file-missing2
15 files changed, 195 insertions, 83 deletions
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go
index 48b47b1e5..dde359485 100644
--- a/pkg/aflow/func_tool.go
+++ b/pkg/aflow/func_tool.go
@@ -4,6 +4,8 @@
package aflow
import (
+ "errors"
+
"github.com/google/syzkaller/pkg/aflow/trajectory"
"google.golang.org/genai"
)
@@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State,
}
}
+// BadCallError creates an error that means that LLM made a bad tool call,
+// the provided message will be returned to the LLM as an error,
+// instead of failing the whole workflow.
+func BadCallError(message string) error {
+ return &badCallError{errors.New(message)}
+}
+
+type badCallError struct {
+ error
+}
+
type funcTool[State, Args, Results any] struct {
Name string
Description string
diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go
new file mode 100644
index 000000000..429566dbe
--- /dev/null
+++ b/pkg/aflow/func_tool_test.go
@@ -0,0 +1,111 @@
+// Copyright 2026 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "context"
+ "errors"
+ "path/filepath"
+ "testing"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/genai"
+)
+
+func TestToolErrors(t *testing.T) {
+ type flowOutputs struct {
+ Reply string
+ }
+ type toolArgs struct {
+ CallError bool `jsonschema:"call error"`
+ }
+ flows := make(map[string]*Flow)
+ err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Root: &LLMAgent{
+ Name: "smarty",
+ Model: "model",
+ Reply: "Reply",
+ Temperature: 0,
+ Instruction: "Do something!",
+ Prompt: "Prompt",
+ Tools: []Tool{
+ NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
+ if args.CallError {
+ return struct{}{}, BadCallError("you are wrong")
+ }
+ return struct{}{}, errors.New("hard error")
+ }, "tool 1 description"),
+ },
+ },
+ },
+ })
+ require.NoError(t, err)
+ replySeq := 0
+ stub := &stubContext{
+ // nolint:dupl
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ replySeq++
+ switch replySeq {
+ case 1:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{
+ Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id0",
+ Name: "faulty",
+ Args: map[string]any{
+ "CallError": true,
+ },
+ },
+ },
+ }}}}}, nil
+ case 2:
+ assert.Equal(t, req[2], &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id0",
+ Name: "faulty",
+ Response: map[string]any{
+ "error": "you are wrong",
+ },
+ },
+ }}})
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{
+ Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id0",
+ Name: "faulty",
+ Args: map[string]any{
+ "CallError": false,
+ },
+ },
+ },
+ }}}}}, nil
+ default:
+ t.Fatal("unexpected LLM calls")
+ return nil, nil
+ }
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ onEvent := func(span *trajectory.Span) error { return nil }
+ _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
+ require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]")
+}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index b473c9e7a..5934bf9bd 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
},
})
}
+ appendError := func(message string) {
+ appendPart(map[string]any{"error": message})
+ }
tool := tools[call.Name]
if tool == nil {
- appendPart(map[string]any{
- "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name),
- })
+ appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name))
continue
}
results, err := tool.execute(ctx, call.Args)
if err != nil {
- if argsErr := new(toolArgsError); errors.As(err, &argsErr) {
- // LLM provided wrong arguments to the tool,
- // return the error back to the LLM instead of failing.
- appendPart(map[string]any{
- "error": err.Error(),
- })
+ // LLM provided wrong arguments to the tool,
+ // or the tool returned error message to the LLM.
+ // Return the error back to the LLM instead of failing.
+ if callErr := new(badCallError); errors.As(err, &callErr) {
+ appendError(err.Error())
continue
}
- return nil, nil, err
+ return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v",
+ call.Name, err, call.Args)
}
appendPart(results)
if a.Outputs != nil && tool == a.Outputs.tool {
diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go
index 2b2d77f76..0b0eb8994 100644
--- a/pkg/aflow/schema.go
+++ b/pkg/aflow/schema.go
@@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
f, ok := m[name]
if !ok {
if tool {
- return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
+ return val, BadCallError(fmt.Sprintf("missing argument %q", name))
} else {
return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
}
@@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
field.Set(reflect.ValueOf(f))
} else {
if tool {
- return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
- name, f, field.Type().Name())}
+ return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name()))
} else {
return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
name, f, field.Type().Name())
@@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
return val, nil
}
-type toolArgsError struct{ error }
-
// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
return func(yield func(string, reflect.Value) bool) {
diff --git a/pkg/aflow/tool/codesearcher/codesearcher.go b/pkg/aflow/tool/codesearcher/codesearcher.go
index 34db81b80..922e32569 100644
--- a/pkg/aflow/tool/codesearcher/codesearcher.go
+++ b/pkg/aflow/tool/codesearcher/codesearcher.go
@@ -68,7 +68,6 @@ type dirIndexArgs struct {
}
type dirIndexResult struct {
- Missing bool `jsonschema:"Set to true if the requested directory does not exist."`
Subdirs []string `jsonschema:"List of direct subdirectories."`
Files []string `jsonschema:"List of source files."`
}
@@ -78,7 +77,6 @@ type readFileArgs struct {
}
type readFileResult struct {
- Missing bool `jsonschema:"Set to true if the requested file does not exist."`
Contents string `jsonschema:"File contents."`
}
@@ -87,7 +85,6 @@ type fileIndexArgs struct {
}
type fileIndexResult struct {
- Missing bool `jsonschema:"Set to true if the file with the given name does not exist."`
Entities []indexEntity `jsonschema:"List of entites defined in the file."`
}
@@ -103,7 +100,6 @@ type defCommentArgs struct {
}
type defCommentResult struct {
- Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
Kind string `jsonschema:"Kind of the entity: function, struct, variable."`
Comment string `jsonschema:"Source comment for the entity."`
}
@@ -117,7 +113,6 @@ type defSourceArgs struct {
// nolint: lll
type defSourceResult struct {
- Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
SourceFile string `jsonschema:"Source file path where the entity is defined."`
SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."`
}
@@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) {
}
func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) {
- ok, subdirs, files, err := state.Index.DirIndex(args.Dir)
- res := dirIndexResult{
- Missing: !ok,
+ subdirs, files, err := state.Index.DirIndex(args.Dir)
+ return dirIndexResult{
Subdirs: subdirs,
Files: files,
- }
- return res, err
+ }, err
}
func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) {
- ok, contents, err := state.Index.ReadFile(args.File)
- res := readFileResult{
- Missing: !ok,
+ contents, err := state.Index.ReadFile(args.File)
+ return readFileResult{
Contents: contents,
- }
- return res, err
+ }, err
}
func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) {
- ok, entities, err := state.Index.FileIndex(args.SourceFile)
- res := fileIndexResult{
- Missing: !ok,
- }
+ entities, err := state.Index.FileIndex(args.SourceFile)
+ res := fileIndexResult{}
for _, ent := range entities {
res.Entities = append(res.Entities, indexEntity{
Kind: ent.Kind,
@@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil
func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) {
info, err := state.Index.DefinitionComment(args.SourceFile, args.Name)
- if err != nil || info == nil {
- return defCommentResult{
- Missing: info == nil,
- }, err
+ if err != nil {
+ return defCommentResult{}, err
}
return defCommentResult{
Kind: info.Kind,
@@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA
func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) {
info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines)
- if err != nil || info == nil {
- return defSourceResult{
- Missing: info == nil,
- }, err
+ if err != nil {
+ return defSourceResult{}, err
}
return defSourceResult{
SourceFile: info.File,
diff --git a/pkg/codesearch/codesearch.go b/pkg/codesearch/codesearch.go
index 96ee1c696..6eab749db 100644
--- a/pkg/codesearch/codesearch.go
+++ b/pkg/codesearch/codesearch.go
@@ -13,6 +13,7 @@ import (
"strings"
"syscall"
+ "github.com/google/syzkaller/pkg/aflow"
"github.com/google/syzkaller/pkg/osutil"
)
@@ -30,9 +31,9 @@ type Command struct {
// Commands are used to run unit tests and for the syz-codesearch tool.
var Commands = []Command{
{"dir-index", 1, func(index *Index, args []string) (string, error) {
- ok, subdirs, files, err := index.DirIndex(args[0])
- if err != nil || !ok {
- return notFound, err
+ subdirs, files, err := index.DirIndex(args[0])
+ if err != nil {
+ return "", err
}
b := new(strings.Builder)
fmt.Fprintf(b, "directory %v subdirs:\n", args[0])
@@ -46,16 +47,12 @@ var Commands = []Command{
return b.String(), nil
}},
{"read-file", 1, func(index *Index, args []string) (string, error) {
- ok, contents, err := index.ReadFile(args[0])
- if err != nil || !ok {
- return notFound, err
- }
- return contents, nil
+ return index.ReadFile(args[0])
}},
{"file-index", 1, func(index *Index, args []string) (string, error) {
- ok, entities, err := index.FileIndex(args[0])
- if err != nil || !ok {
- return notFound, err
+ entities, err := index.FileIndex(args[0])
+ if err != nil {
+ return "", err
}
b := new(strings.Builder)
fmt.Fprintf(b, "file %v defines the following entities:\n\n", args[0])
@@ -66,8 +63,8 @@ var Commands = []Command{
}},
{"def-comment", 2, func(index *Index, args []string) (string, error) {
info, err := index.DefinitionComment(args[0], args[1])
- if err != nil || info == nil {
- return notFound, err
+ if err != nil {
+ return "", err
}
if info.Body == "" {
return fmt.Sprintf("%v %v is defined in %v and is not commented\n",
@@ -78,8 +75,8 @@ var Commands = []Command{
}},
{"def-source", 3, func(index *Index, args []string) (string, error) {
info, err := index.DefinitionSource(args[0], args[1], args[2] == "yes")
- if err != nil || info == nil {
- return notFound, err
+ if err != nil {
+ return "", err
}
return fmt.Sprintf("%v %v is defined in %v:\n\n%v", info.Kind, args[1], info.File, info.Body), nil
}},
@@ -87,8 +84,6 @@ var Commands = []Command{
var SourceExtensions = map[string]bool{".c": true, ".h": true, ".S": true, ".rs": true}
-const notFound = "not found\n"
-
func NewIndex(databaseFile string, srcDirs []string) (*Index, error) {
db, err := osutil.ReadJSON[*Database](databaseFile)
if err != nil {
@@ -118,16 +113,16 @@ type Entity struct {
Name string
}
-func (index *Index) DirIndex(dir string) (bool, []string, []string, error) {
+func (index *Index) DirIndex(dir string) ([]string, []string, error) {
if err := escaping(dir); err != nil {
- return false, nil, nil, nil
+ return nil, nil, err
}
exists := false
var subdirs, files []string
for _, root := range index.srcDirs {
exists1, subdirs1, files1, err := dirIndex(root, dir)
if err != nil {
- return false, nil, nil, err
+ return nil, nil, err
}
if exists1 {
exists = true
@@ -135,18 +130,21 @@ func (index *Index) DirIndex(dir string) (bool, []string, []string, error) {
subdirs = append(subdirs, subdirs1...)
files = append(files, files1...)
}
+ if !exists {
+ return nil, nil, aflow.BadCallError("the directory does not exist")
+ }
slices.Sort(subdirs)
slices.Sort(files)
// Dedup dirs across src/build trees,
// also dedup files, but hopefully there are no duplicates.
subdirs = slices.Compact(subdirs)
files = slices.Compact(files)
- return exists, subdirs, files, nil
+ return subdirs, files, nil
}
-func (index *Index) ReadFile(file string) (bool, string, error) {
+func (index *Index) ReadFile(file string) (string, error) {
if err := escaping(file); err != nil {
- return false, "", nil
+ return "", err
}
for _, dir := range index.srcDirs {
data, err := os.ReadFile(filepath.Join(dir, file))
@@ -156,16 +154,21 @@ func (index *Index) ReadFile(file string) (bool, string, error) {
}
var errno syscall.Errno
if errors.As(err, &errno) && errno == syscall.EISDIR {
- return false, "", nil
+ return "", aflow.BadCallError("the file is a directory")
}
- return false, "", err
+ return "", err
}
- return true, string(data), nil
+ return string(data), nil
}
- return false, "", nil
+ return "", aflow.BadCallError("the file does not exist")
}
-func (index *Index) FileIndex(file string) (bool, []Entity, error) {
+func (index *Index) FileIndex(file string) ([]Entity, error) {
+ file = filepath.Clean(file)
+ // This allows to distinguish missing files from files that don't define anything.
+ if _, err := index.ReadFile(file); err != nil {
+ return nil, err
+ }
var entities []Entity
for _, def := range index.db.Definitions {
if def.Body.File == file {
@@ -175,7 +178,7 @@ func (index *Index) FileIndex(file string) (bool, []Entity, error) {
})
}
}
- return len(entities) != 0, entities, nil
+ return entities, nil
}
type EntityInfo struct {
@@ -195,7 +198,7 @@ func (index *Index) DefinitionSource(contextFile, name string, includeLines bool
func (index *Index) definitionSource(contextFile, name string, comment, includeLines bool) (*EntityInfo, error) {
def := index.findDefinition(contextFile, name)
if def == nil {
- return nil, nil
+ return nil, aflow.BadCallError("requested entity does not exist")
}
lineRange := def.Body
if comment {
@@ -266,7 +269,7 @@ func formatSourceFile(file string, start, end int, includeLines bool) (string, e
func escaping(path string) error {
if strings.Contains(filepath.Clean(path), "..") {
- return errors.New("path is outside of the source tree")
+ return aflow.BadCallError("path is outside of the source tree")
}
return nil
}
@@ -280,7 +283,7 @@ func dirIndex(root, subdir string) (bool, []string, []string, error) {
}
var errno syscall.Errno
if errors.As(err, &errno) && errno == syscall.ENOTDIR {
- err = nil
+ err = aflow.BadCallError("the path is not a directory")
}
return false, nil, nil, err
}
diff --git a/pkg/codesearch/codesearch_test.go b/pkg/codesearch/codesearch_test.go
index 7af509294..1f353c804 100644
--- a/pkg/codesearch/codesearch_test.go
+++ b/pkg/codesearch/codesearch_test.go
@@ -53,7 +53,8 @@ func testCommand(t *testing.T, index *Index, covered map[string]bool, file strin
}
result, err := index.Command(args[0], args[1:])
if err != nil {
- t.Fatal(err)
+ // This is supposed to test aflow.BadCallError messages.
+ result = err.Error() + "\n"
}
got := append([]byte(strings.Join(args, " ")+"\n\n"), result...)
tooltest.CompareGoldenData(t, file, got)
diff --git a/pkg/codesearch/testdata/query-def-source-missing b/pkg/codesearch/testdata/query-def-source-missing
index 0b60003c7..dcb8c4d75 100644
--- a/pkg/codesearch/testdata/query-def-source-missing
+++ b/pkg/codesearch/testdata/query-def-source-missing
@@ -1,3 +1,3 @@
def-source source0.c some_non_existent_function no
-not found
+requested entity does not exist
diff --git a/pkg/codesearch/testdata/query-dir-index-escaping b/pkg/codesearch/testdata/query-dir-index-escaping
index fd7b55ff0..f57f0eb82 100644
--- a/pkg/codesearch/testdata/query-dir-index-escaping
+++ b/pkg/codesearch/testdata/query-dir-index-escaping
@@ -1,3 +1,3 @@
dir-index mm/../../
-not found
+path is outside of the source tree
diff --git a/pkg/codesearch/testdata/query-dir-index-file b/pkg/codesearch/testdata/query-dir-index-file
index eecd67d67..a519fad18 100644
--- a/pkg/codesearch/testdata/query-dir-index-file
+++ b/pkg/codesearch/testdata/query-dir-index-file
@@ -1,3 +1,3 @@
dir-index source0.c
-not found
+the path is not a directory
diff --git a/pkg/codesearch/testdata/query-dir-index-missing b/pkg/codesearch/testdata/query-dir-index-missing
index e028d1be1..102ddbdab 100644
--- a/pkg/codesearch/testdata/query-dir-index-missing
+++ b/pkg/codesearch/testdata/query-dir-index-missing
@@ -1,3 +1,3 @@
dir-index mm/foobar
-not found
+the directory does not exist
diff --git a/pkg/codesearch/testdata/query-file-index-missing b/pkg/codesearch/testdata/query-file-index-missing
index 1be486378..803a66377 100644
--- a/pkg/codesearch/testdata/query-file-index-missing
+++ b/pkg/codesearch/testdata/query-file-index-missing
@@ -1,3 +1,3 @@
file-index some-non-existent-file.c
-not found
+the file does not exist
diff --git a/pkg/codesearch/testdata/query-read-file-dir b/pkg/codesearch/testdata/query-read-file-dir
index 210a326cd..b5008e010 100644
--- a/pkg/codesearch/testdata/query-read-file-dir
+++ b/pkg/codesearch/testdata/query-read-file-dir
@@ -1,3 +1,3 @@
read-file mm
-not found
+the file is a directory
diff --git a/pkg/codesearch/testdata/query-read-file-escaping b/pkg/codesearch/testdata/query-read-file-escaping
index fca2abf6a..ba5d544c2 100644
--- a/pkg/codesearch/testdata/query-read-file-escaping
+++ b/pkg/codesearch/testdata/query-read-file-escaping
@@ -1,3 +1,3 @@
read-file mm/../../codesearch.go
-not found
+path is outside of the source tree
diff --git a/pkg/codesearch/testdata/query-read-file-missing b/pkg/codesearch/testdata/query-read-file-missing
index ac7bead8d..adaff1616 100644
--- a/pkg/codesearch/testdata/query-read-file-missing
+++ b/pkg/codesearch/testdata/query-read-file-missing
@@ -1,3 +1,3 @@
read-file file-that-does-not-exist.c
-not found
+the file does not exist