1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
// Copyright 2026 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package aflow
import (
"context"
"encoding/json"
"flag"
"os"
"path/filepath"
"reflect"
"slices"
"testing"
"time"
"github.com/google/syzkaller/pkg/aflow/trajectory"
"github.com/google/syzkaller/pkg/osutil"
"github.com/stretchr/testify/require"
"google.golang.org/genai"
)
var flagUpdate = flag.Bool("update", false, "update golden test files to match the actual execution")
// testFlow executes the provided test workflow by returning LLM replies from llmReplies.
// The result can be either a map[string]any with Outputs fields, or an error,
// if an error is expected as the result of the execution.
// llmReplies objects can be either *genai.Part, []*genai.Part, or an error.
// Requests sent to LLM are compared against "testdata/TestName.llm.json" file.
// Resulting trajectory is compared against "testdata/TestName.trajectory.json" file.
// If -update flag is provided, the golden testdata files are updated to match the actual execution.
func testFlow[Inputs, Outputs any](t *testing.T, inputs map[string]any, result any, root Action,
llmReplies []any, consts map[string]any) {
flows := make(map[string]*Flow)
err := register[Inputs, Outputs]("test", "description", flows, []*Flow{{
Consts: consts,
Root: root,
}})
require.NoError(t, err)
type llmRequest struct {
Model string
Config *genai.GenerateContentConfig `json:",omitempty"`
Request []*genai.Content
}
var requests []llmRequest
var stubTime time.Time
var lastConfig genai.GenerateContentConfig
generateContentStub := false
stub := &stubContext{
timeNow: func() time.Time {
stubTime = stubTime.Add(time.Second)
return stubTime
},
generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
// Copy config and req slices, so that future changes to these objects
// don't affect our stored requests.
var storeCfg *genai.GenerateContentConfig
if !reflect.DeepEqual(*cfg, lastConfig) {
// Memorize config only if it has changed from the previous request.
// Most of the time it's repeated for the same agent.
lastConfig = *cfg
cfgCopy := *cfg
storeCfg = &cfgCopy
}
requests = append(requests, llmRequest{model, storeCfg, slices.Clone(req)})
require.NotEmpty(t, llmReplies, "unexpected LLM call")
reply := llmReplies[0]
if cb, ok := reply.(func(string, *genai.GenerateContentConfig, []*genai.Content) (
*genai.GenerateContentResponse, error)); ok {
generateContentStub = true
return cb(model, cfg, req)
}
llmReplies = llmReplies[1:]
switch reply := reply.(type) {
case error:
return nil, reply
case *genai.Part:
return &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{{Content: &genai.Content{
Role: string(genai.RoleUser),
Parts: []*genai.Part{reply},
}}}}, nil
case []*genai.Part:
return &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{{Content: &genai.Content{
Role: string(genai.RoleUser),
Parts: reply,
}}}}, nil
default:
t.Fatalf("bad LLM reply type %T", reply)
return nil, nil
}
},
}
var spans []trajectory.Span
onEvent := func(span *trajectory.Span) error {
spans = append(spans, *span)
return nil
}
ctx := context.WithValue(context.Background(), stubContextKey, stub)
workdir := t.TempDir()
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
require.NoError(t, err)
if inputs == nil {
inputs = map[string]any{}
}
got, err := flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
switch result := result.(type) {
case map[string]any:
require.NoError(t, err)
require.Equal(t, got, result)
case string:
require.Error(t, err)
require.Equal(t, err.Error(), result)
default:
t.Fatalf("bad result type %T", result)
}
// We need to pass spans/requests via double marshal/unmarshal round-trip
// b/c some values change during the first round-trip (int->float64, jsonschema).
spansData, err := json.Marshal(spans)
require.NoError(t, err)
spans = nil
require.NoError(t, json.Unmarshal(spansData, &spans))
requestsData, err := json.Marshal(requests)
require.NoError(t, err)
requests = nil
require.NoError(t, json.Unmarshal(requestsData, &requests))
trajectoryFile := filepath.Join("testdata", t.Name()+".trajectory.json")
requestsFile := filepath.Join("testdata", t.Name()+".llm.json")
if *flagUpdate {
require.NoError(t, osutil.WriteJSON(trajectoryFile, spans))
if requests != nil {
require.NoError(t, osutil.WriteJSON(requestsFile, requests))
} else {
os.Remove(requestsFile)
}
}
wantSpans, err := osutil.ReadJSON[[]trajectory.Span](trajectoryFile)
require.NoError(t, err)
require.Equal(t, spans, wantSpans)
if requests != nil {
wantRequests, err := osutil.ReadJSON[[]llmRequest](requestsFile)
require.NoError(t, err)
require.Equal(t, requests, wantRequests)
} else {
require.False(t, osutil.IsExist(requestsFile))
}
require.True(t, len(llmReplies) == 0 || generateContentStub)
}
func testRegistrationError[Inputs, Outputs any](t *testing.T, expected string, flow *Flow) {
flows := map[string]*Flow{}
err := register[Inputs, Outputs]("test", "description", flows, []*Flow{flow})
require.Error(t, err)
require.Equal(t, expected, err.Error())
}
|