aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-20 15:03:18 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-21 13:38:45 +0000
commitc9820ab0fe2dce0914ff01bcaf3829ca82150eb2 (patch)
tree350f597291bec42753754dc1948acf8166c5ba68 /pkg
parentf1d5c3ecdec0b86db1df926ccd3553157988690d (diff)
pkg/aflow: cache LLM requests
Using cached replies is faster, cheaper, and more reliable. Espcially handy during development when the same workflows are retried lots of times with some changes.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/execute.go3
-rw-r--r--pkg/aflow/flow_test.go60
-rw-r--r--pkg/aflow/func_tool_test.go3
-rw-r--r--pkg/aflow/llm_agent.go43
4 files changed, 64 insertions, 45 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 19f8d3aec..3e1a6a112 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -178,6 +178,9 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
if thinking {
+ // Don't alter the original object (that may affect request caching).
+ cfgCopy := *cfg
+ cfg = &cfgCopy
cfg.ThinkingConfig = &genai.ThinkingConfig{
// We capture them in the trajectory for analysis.
IncludeThoughts: true,
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index ff1fe9dfa..5294bc9ea 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -225,7 +225,7 @@ func TestWorkflow(t *testing.T) {
ID: "id1",
Name: "tool2",
Args: map[string]any{
- "ArgBaz": 101,
+ "ArgBaz": 101.0,
},
},
},
@@ -265,7 +265,7 @@ func TestWorkflow(t *testing.T) {
ID: "id2",
Name: "set-results",
Args: map[string]any{
- "AgentFoo": 42,
+ "AgentFoo": 42.0,
"AgentBar": "agent-bar",
},
},
@@ -294,8 +294,6 @@ func TestWorkflow(t *testing.T) {
},
}}
- // dupl considers makeSwarmReply/makeSwarmResp duplicates
- // nolint:dupl
makeSwarmReply := func(index int) *genai.Content {
return &genai.Content{
Role: string(genai.RoleModel),
@@ -305,14 +303,13 @@ func TestWorkflow(t *testing.T) {
ID: fmt.Sprintf("id%v", index),
Name: "set-results",
Args: map[string]any{
- "SwarmInt": index,
+ "SwarmInt": float64(index),
"SwarmStr": fmt.Sprintf("swarm%v", index),
},
},
},
}}
}
- // nolint:dupl // dupl considers makeSwarmReply/makeSwarmResp duplicates
makeSwarmResp := func(index int) *genai.Content {
return &genai.Content{
Role: string(genai.RoleUser),
@@ -410,7 +407,7 @@ func TestWorkflow(t *testing.T) {
}
ctx := context.WithValue(context.Background(), stubContextKey, stub)
workdir := t.TempDir()
- cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
require.NoError(t, err)
// nolint: dupl
expected := []*trajectory.Span{
@@ -502,7 +499,7 @@ func TestWorkflow(t *testing.T) {
Name: "tool2",
Started: startTime.Add(9 * time.Second),
Args: map[string]any{
- "ArgBaz": 101,
+ "ArgBaz": 101.0,
},
},
{
@@ -513,7 +510,7 @@ func TestWorkflow(t *testing.T) {
Started: startTime.Add(9 * time.Second),
Finished: startTime.Add(10 * time.Second),
Args: map[string]any{
- "ArgBaz": 101,
+ "ArgBaz": 101.0,
},
Results: map[string]any{
"ResBaz": 300,
@@ -545,7 +542,7 @@ func TestWorkflow(t *testing.T) {
Started: startTime.Add(13 * time.Second),
Args: map[string]any{
"AgentBar": "agent-bar",
- "AgentFoo": 42,
+ "AgentFoo": 42.0,
},
},
{
@@ -557,7 +554,7 @@ func TestWorkflow(t *testing.T) {
Finished: startTime.Add(14 * time.Second),
Args: map[string]any{
"AgentBar": "agent-bar",
- "AgentFoo": 42,
+ "AgentFoo": 42.0,
},
Results: map[string]any{
"AgentBar": "agent-bar",
@@ -656,7 +653,7 @@ func TestWorkflow(t *testing.T) {
Name: "set-results",
Started: startTime.Add(24 * time.Second),
Args: map[string]any{
- "SwarmInt": 1,
+ "SwarmInt": 1.0,
"SwarmStr": "swarm1",
},
},
@@ -668,7 +665,7 @@ func TestWorkflow(t *testing.T) {
Started: startTime.Add(24 * time.Second),
Finished: startTime.Add(25 * time.Second),
Args: map[string]any{
- "SwarmInt": 1,
+ "SwarmInt": 1.0,
"SwarmStr": "swarm1",
},
Results: map[string]any{
@@ -743,7 +740,7 @@ func TestWorkflow(t *testing.T) {
Name: "set-results",
Started: startTime.Add(32 * time.Second),
Args: map[string]any{
- "SwarmInt": 2,
+ "SwarmInt": 2.0,
"SwarmStr": "swarm2",
},
},
@@ -755,7 +752,7 @@ func TestWorkflow(t *testing.T) {
Started: startTime.Add(32 * time.Second),
Finished: startTime.Add(33 * time.Second),
Args: map[string]any{
- "SwarmInt": 2,
+ "SwarmInt": 2.0,
"SwarmStr": "swarm2",
},
Results: map[string]any{
@@ -1023,7 +1020,6 @@ func TestToolMisbehavior(t *testing.T) {
FunctionCall: &genai.FunctionCall{
ID: "id3",
Name: "tool2",
- Args: map[string]any{},
},
},
// Excessive argument.
@@ -1032,8 +1028,8 @@ func TestToolMisbehavior(t *testing.T) {
ID: "id4",
Name: "tool2",
Args: map[string]any{
- "Tool2Arg": 0,
- "Tool2Arg2": 100,
+ "Tool2Arg": 0.0,
+ "Tool2Arg2": 100.0,
},
},
},
@@ -1172,7 +1168,7 @@ func TestToolMisbehavior(t *testing.T) {
}
ctx := context.WithValue(context.Background(), stubContextKey, stub)
workdir := t.TempDir()
- cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
require.NoError(t, err)
expected := []*trajectory.Span{
{
@@ -1247,14 +1243,12 @@ func TestToolMisbehavior(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanTool,
Name: "tool2",
- Args: map[string]any{},
},
{
Seq: 5,
Nesting: 2,
Type: trajectory.SpanTool,
Name: "tool2",
- Args: map[string]any{},
Error: "missing argument \"Tool2Arg\"",
},
{
@@ -1263,8 +1257,8 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "tool2",
Args: map[string]any{
- "Tool2Arg": 0,
- "Tool2Arg2": 100,
+ "Tool2Arg": 0.0,
+ "Tool2Arg2": 100.0,
},
},
{
@@ -1273,8 +1267,8 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "tool2",
Args: map[string]any{
- "Tool2Arg": 0,
- "Tool2Arg2": 100,
+ "Tool2Arg": 0.0,
+ "Tool2Arg2": 100.0,
},
Results: map[string]any{
"Result": 42,
@@ -1286,7 +1280,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "tool3",
Args: map[string]any{
- "Arg": 0,
+ "Arg": 0.0,
},
},
{
@@ -1295,7 +1289,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "tool3",
Args: map[string]any{
- "Arg": 0,
+ "Arg": 0.0,
},
Error: "tool \"tool3\" does not exist, please correct the name",
},
@@ -1305,7 +1299,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "WrongArg": 0,
+ "WrongArg": 0.0,
},
},
{
@@ -1314,7 +1308,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "WrongArg": 0,
+ "WrongArg": 0.0,
},
Error: "missing argument \"AdditionalOutput\"",
},
@@ -1352,7 +1346,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "AdditionalOutput": 1,
+ "AdditionalOutput": 1.0,
},
},
{
@@ -1361,7 +1355,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "AdditionalOutput": 1,
+ "AdditionalOutput": 1.0,
},
Results: map[string]any{
"AdditionalOutput": 1,
@@ -1373,7 +1367,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "AdditionalOutput": 2,
+ "AdditionalOutput": 2.0,
},
},
{
@@ -1382,7 +1376,7 @@ func TestToolMisbehavior(t *testing.T) {
Type: trajectory.SpanTool,
Name: "set-results",
Args: map[string]any{
- "AdditionalOutput": 2,
+ "AdditionalOutput": 2.0,
},
Results: map[string]any{
"AdditionalOutput": 2,
diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go
index 429566dbe..2076e0bb9 100644
--- a/pkg/aflow/func_tool_test.go
+++ b/pkg/aflow/func_tool_test.go
@@ -8,6 +8,7 @@ import (
"errors"
"path/filepath"
"testing"
+ "time"
"github.com/google/syzkaller/pkg/aflow/trajectory"
"github.com/stretchr/testify/assert"
@@ -103,7 +104,7 @@ func TestToolErrors(t *testing.T) {
}
ctx := context.WithValue(context.Background(), stubContextKey, stub)
workdir := t.TempDir()
- cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now)
require.NoError(t, err)
onEvent := func(span *trajectory.Span) error { return nil }
_, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent)
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index e5391dfcf..406947e25 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/google/syzkaller/pkg/aflow/trajectory"
+ "github.com/google/syzkaller/pkg/hash"
"google.golang.org/genai"
)
@@ -109,7 +110,7 @@ type llmOutputs struct {
func (a *LLMAgent) execute(ctx *Context) error {
if a.Candidates <= 1 {
- reply, outputs, err := a.executeOne(ctx)
+ reply, outputs, err := a.executeOne(ctx, 0)
if err != nil {
return err
}
@@ -132,7 +133,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error {
var replies []string
allOutputs := map[string]any{}
for candidate := 0; candidate < a.Candidates; candidate++ {
- reply, outputs, err := a.executeOne(ctx)
+ reply, outputs, err := a.executeOne(ctx, candidate)
if err != nil {
return err
}
@@ -146,7 +147,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error {
return nil
}
-func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
+func (a *LLMAgent) executeOne(ctx *Context, candidate int) (string, map[string]any, error) {
cfg, instruction, tools := a.config(ctx)
span := &trajectory.Span{
Type: trajectory.SpanAgent,
@@ -158,7 +159,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
if err := ctx.startSpan(span); err != nil {
return "", nil, err
}
- reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt)
+ reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt, candidate)
if err == nil {
span.Reply = reply
span.Results = outputs
@@ -166,8 +167,8 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
return reply, outputs, ctx.finishSpan(span, err)
}
-func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) (
- string, map[string]any, error) {
+func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool,
+ prompt string, candidate int) (string, map[string]any, error) {
var outputs map[string]any
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
for {
@@ -179,7 +180,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err := ctx.startSpan(reqSpan); err != nil {
return "", nil, err
}
- resp, err := a.generateContent(ctx, cfg, req)
+ resp, err := a.generateContent(ctx, cfg, req, candidate)
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
}
@@ -328,11 +329,10 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
}
func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig,
- req []*genai.Content) (*genai.GenerateContentResponse, error) {
+ req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
backoff := time.Second
- model := ctx.modelName(a.Model)
for try := 0; ; try++ {
- resp, err := ctx.generateContent(model, cfg, req)
+ resp, err := a.generateContentCached(ctx, cfg, req, candidate)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
@@ -343,12 +343,33 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
strings.Contains(apiErr.Message, "Quota exceeded for metric") &&
strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
- return resp, &modelQuotaError{model}
+ return resp, &modelQuotaError{ctx.modelName(a.Model)}
}
return resp, err
}
}
+func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig,
+ req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
+ type Cached struct {
+ Config *genai.GenerateContentConfig
+ Request []*genai.Content
+ Reply *genai.GenerateContentResponse
+ }
+ model := ctx.modelName(a.Model)
+ desc := fmt.Sprintf("model %v, config hash %v, request hash %v, candidate %v",
+ model, hash.String(cfg), hash.String(req), candidate)
+ cached, err := CacheObject(ctx, "llm", desc, func() (Cached, error) {
+ resp, err := ctx.generateContent(model, cfg, req)
+ return Cached{
+ Config: cfg,
+ Request: req,
+ Reply: resp,
+ }, err
+ })
+ return cached.Reply, err
+}
+
func (a *LLMAgent) verify(vctx *verifyContext) {
vctx.requireNotEmpty(a.Name, "Name", a.Name)
vctx.requireNotEmpty(a.Name, "Model", a.Model)