aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/func_tool_test.go
blob: 429566dbef03c66c084ed9b227a08985b1bcaef9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// Copyright 2026 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package aflow

import (
	"context"
	"errors"
	"path/filepath"
	"testing"

	"github.com/google/syzkaller/pkg/aflow/trajectory"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"google.golang.org/genai"
)

func TestToolErrors(t *testing.T) {
	type flowOutputs struct {
		Reply string
	}
	type toolArgs struct {
		CallError bool `jsonschema:"call error"`
	}
	flows := make(map[string]*Flow)
	err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
		{
			Root: &LLMAgent{
				Name:        "smarty",
				Model:       "model",
				Reply:       "Reply",
				Temperature: 0,
				Instruction: "Do something!",
				Prompt:      "Prompt",
				Tools: []Tool{
					NewFuncTool("faulty", func(ctx *Context, state struct{}, args toolArgs) (struct{}, error) {
						if args.CallError {
							return struct{}{}, BadCallError("you are wrong")
						}
						return struct{}{}, errors.New("hard error")
					}, "tool 1 description"),
				},
			},
		},
	})
	require.NoError(t, err)
	replySeq := 0
	stub := &stubContext{
		// nolint:dupl
		generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
			*genai.GenerateContentResponse, error) {
			replySeq++
			switch replySeq {
			case 1:
				return &genai.GenerateContentResponse{
					Candidates: []*genai.Candidate{{
						Content: &genai.Content{
							Role: string(genai.RoleModel),
							Parts: []*genai.Part{
								{
									FunctionCall: &genai.FunctionCall{
										ID:   "id0",
										Name: "faulty",
										Args: map[string]any{
											"CallError": true,
										},
									},
								},
							}}}}}, nil
			case 2:
				assert.Equal(t, req[2], &genai.Content{
					Role: string(genai.RoleUser),
					Parts: []*genai.Part{
						{
							FunctionResponse: &genai.FunctionResponse{
								ID:   "id0",
								Name: "faulty",
								Response: map[string]any{
									"error": "you are wrong",
								},
							},
						}}})
				return &genai.GenerateContentResponse{
					Candidates: []*genai.Candidate{{
						Content: &genai.Content{
							Role: string(genai.RoleModel),
							Parts: []*genai.Part{
								{
									FunctionCall: &genai.FunctionCall{
										ID:   "id0",
										Name: "faulty",
										Args: map[string]any{
											"CallError": false,
										},
									},
								},
							}}}}}, nil
			default:
				t.Fatal("unexpected LLM calls")
				return nil, nil
			}
		},
	}
	ctx := context.WithValue(context.Background(), stubContextKey, stub)
	workdir := t.TempDir()
	cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
	require.NoError(t, err)
	onEvent := func(span *trajectory.Span) error { return nil }
	_, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
	require.Equal(t, err.Error(), "tool faulty failed: error: hard error\nargs: map[CallError:false]")
}