aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow')
-rw-r--r--pkg/aflow/func_tool.go13
-rw-r--r--pkg/aflow/func_tool_test.go111
-rw-r--r--pkg/aflow/llm_agent.go21
-rw-r--r--pkg/aflow/schema.go8
-rw-r--r--pkg/aflow/tool/codesearcher/codesearcher.go39
5 files changed, 150 insertions, 42 deletions
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go
index 48b47b1e5..dde359485 100644
--- a/pkg/aflow/func_tool.go
+++ b/pkg/aflow/func_tool.go
@@ -4,6 +4,8 @@
package aflow
import (
+ "errors"
+
"github.com/google/syzkaller/pkg/aflow/trajectory"
"google.golang.org/genai"
)
@@ -24,6 +26,17 @@ func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State,
}
}
+// BadCallError creates an error that means that LLM made a bad tool call,
+// the provided message will be returned to the LLM as an error,
+// instead of failing the whole workflow.
+func BadCallError(message string) error {
+ return &badCallError{errors.New(message)}
+}
+
+type badCallError struct {
+ error
+}
+
type funcTool[State, Args, Results any] struct {
Name string
Description string
diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go
new file mode 100644
index 000000000..429566dbe
--- /dev/null
+++ b/pkg/aflow/func_tool_test.go
@@ -0,0 +1,111 @@
+// Copyright 2026 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "context"
+ "errors"
+ "path/filepath"
+ "testing"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/genai"
+)
+
+func TestToolErrors(t *testing.T) {
+ type flowOutputs struct {
+ Reply string
+ }
+ type toolArgs struct {
+ CallError bool `jsonschema:"call error"`
+ }
+ flows := make(map[string]*Flow)
+ err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Root: &LLMAgent{
+ Name: "smarty",
+ Model: "model",
+ Reply: "Reply",
+ Temperature: 0,
+ Instruction: "Do something!",
+ Prompt: "Prompt",
+ Tools: []Tool{
+ NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
+ if args.CallError {
+ return struct{}{}, BadCallError("you are wrong")
+ }
+ return struct{}{}, errors.New("hard error")
+ }, "tool 1 description"),
+ },
+ },
+ },
+ })
+ require.NoError(t, err)
+ replySeq := 0
+ stub := &stubContext{
+ // nolint:dupl
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ replySeq++
+ switch replySeq {
+ case 1:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{
+ Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id0",
+ Name: "faulty",
+ Args: map[string]any{
+ "CallError": true,
+ },
+ },
+ },
+ }}}}}, nil
+ case 2:
+ assert.Equal(t, req[2], &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id0",
+ Name: "faulty",
+ Response: map[string]any{
+ "error": "you are wrong",
+ },
+ },
+ }}})
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{
+ Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id0",
+ Name: "faulty",
+ Args: map[string]any{
+ "CallError": false,
+ },
+ },
+ },
+ }}}}}, nil
+ default:
+ t.Fatal("unexpected LLM calls")
+ return nil, nil
+ }
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ onEvent := func(span *trajectory.Span) error { return nil }
+ _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
+ require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]")
+}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index b473c9e7a..5934bf9bd 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -246,24 +246,25 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
},
})
}
+ appendError := func(message string) {
+ appendPart(map[string]any{"error": message})
+ }
tool := tools[call.Name]
if tool == nil {
- appendPart(map[string]any{
- "error": fmt.Sprintf("tool %q does not exist, please correct the name", call.Name),
- })
+ appendError(fmt.Sprintf("tool %q does not exist, please correct the name", call.Name))
continue
}
results, err := tool.execute(ctx, call.Args)
if err != nil {
- if argsErr := new(toolArgsError); errors.As(err, &argsErr) {
- // LLM provided wrong arguments to the tool,
- // return the error back to the LLM instead of failing.
- appendPart(map[string]any{
- "error": err.Error(),
- })
+ // LLM provided wrong arguments to the tool,
+ // or the tool returned error message to the LLM.
+ // Return the error back to the LLM instead of failing.
+ if callErr := new(badCallError); errors.As(err, &callErr) {
+ appendError(err.Error())
continue
}
- return nil, nil, err
+ return nil, nil, fmt.Errorf("tool %v failed: error: %w\nargs: %+v",
+ call.Name, err, call.Args)
}
appendPart(results)
if a.Outputs != nil && tool == a.Outputs.tool {
diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go
index 2b2d77f76..0b0eb8994 100644
--- a/pkg/aflow/schema.go
+++ b/pkg/aflow/schema.go
@@ -63,7 +63,7 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
f, ok := m[name]
if !ok {
if tool {
- return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
+ return val, BadCallError(fmt.Sprintf("missing argument %q", name))
} else {
return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
}
@@ -79,8 +79,8 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
field.Set(reflect.ValueOf(f))
} else {
if tool {
- return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
- name, f, field.Type().Name())}
+ return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name()))
} else {
return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
name, f, field.Type().Name())
@@ -93,8 +93,6 @@ func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
return val, nil
}
-type toolArgsError struct{ error }
-
// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
return func(yield func(string, reflect.Value) bool) {
diff --git a/pkg/aflow/tool/codesearcher/codesearcher.go b/pkg/aflow/tool/codesearcher/codesearcher.go
index 34db81b80..922e32569 100644
--- a/pkg/aflow/tool/codesearcher/codesearcher.go
+++ b/pkg/aflow/tool/codesearcher/codesearcher.go
@@ -68,7 +68,6 @@ type dirIndexArgs struct {
}
type dirIndexResult struct {
- Missing bool `jsonschema:"Set to true if the requested directory does not exist."`
Subdirs []string `jsonschema:"List of direct subdirectories."`
Files []string `jsonschema:"List of source files."`
}
@@ -78,7 +77,6 @@ type readFileArgs struct {
}
type readFileResult struct {
- Missing bool `jsonschema:"Set to true if the requested file does not exist."`
Contents string `jsonschema:"File contents."`
}
@@ -87,7 +85,6 @@ type fileIndexArgs struct {
}
type fileIndexResult struct {
- Missing bool `jsonschema:"Set to true if the file with the given name does not exist."`
Entities []indexEntity `jsonschema:"List of entites defined in the file."`
}
@@ -103,7 +100,6 @@ type defCommentArgs struct {
}
type defCommentResult struct {
- Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
Kind string `jsonschema:"Kind of the entity: function, struct, variable."`
Comment string `jsonschema:"Source comment for the entity."`
}
@@ -117,7 +113,6 @@ type defSourceArgs struct {
// nolint: lll
type defSourceResult struct {
- Missing bool `jsonschema:"Set to true if the entity with the given name does not exist."`
SourceFile string `jsonschema:"Source file path where the entity is defined."`
SourceCode string `jsonschema:"Source code of the entity definition. It is prefixed with line numbers, so that they can be referenced in other tool invocations."`
}
@@ -159,29 +154,23 @@ func prepare(ctx *aflow.Context, args prepareArgs) (prepareResult, error) {
}
func dirIndex(ctx *aflow.Context, state prepareResult, args dirIndexArgs) (dirIndexResult, error) {
- ok, subdirs, files, err := state.Index.DirIndex(args.Dir)
- res := dirIndexResult{
- Missing: !ok,
+ subdirs, files, err := state.Index.DirIndex(args.Dir)
+ return dirIndexResult{
Subdirs: subdirs,
Files: files,
- }
- return res, err
+ }, err
}
func readFile(ctx *aflow.Context, state prepareResult, args readFileArgs) (readFileResult, error) {
- ok, contents, err := state.Index.ReadFile(args.File)
- res := readFileResult{
- Missing: !ok,
+ contents, err := state.Index.ReadFile(args.File)
+ return readFileResult{
Contents: contents,
- }
- return res, err
+ }, err
}
func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fileIndexResult, error) {
- ok, entities, err := state.Index.FileIndex(args.SourceFile)
- res := fileIndexResult{
- Missing: !ok,
- }
+ entities, err := state.Index.FileIndex(args.SourceFile)
+ res := fileIndexResult{}
for _, ent := range entities {
res.Entities = append(res.Entities, indexEntity{
Kind: ent.Kind,
@@ -193,10 +182,8 @@ func fileIndex(ctx *aflow.Context, state prepareResult, args fileIndexArgs) (fil
func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentArgs) (defCommentResult, error) {
info, err := state.Index.DefinitionComment(args.SourceFile, args.Name)
- if err != nil || info == nil {
- return defCommentResult{
- Missing: info == nil,
- }, err
+ if err != nil {
+ return defCommentResult{}, err
}
return defCommentResult{
Kind: info.Kind,
@@ -206,10 +193,8 @@ func definitionComment(ctx *aflow.Context, state prepareResult, args defCommentA
func definitionSource(ctx *aflow.Context, state prepareResult, args defSourceArgs) (defSourceResult, error) {
info, err := state.Index.DefinitionSource(args.SourceFile, args.Name, args.IncludeLines)
- if err != nil || info == nil {
- return defSourceResult{
- Missing: info == nil,
- }, err
+ if err != nil {
+ return defSourceResult{}, err
}
return defSourceResult{
SourceFile: info.File,