diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 15:03:18 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-21 13:38:45 +0000 |
| commit | c9820ab0fe2dce0914ff01bcaf3829ca82150eb2 (patch) | |
| tree | 350f597291bec42753754dc1948acf8166c5ba68 /pkg | |
| parent | f1d5c3ecdec0b86db1df926ccd3553157988690d (diff) | |
pkg/aflow: cache LLM requests
Using cached replies is faster, cheaper, and more reliable.
Espcially handy during development when the same workflows
are retried lots of times with some changes.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/execute.go | 3 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 60 | ||||
| -rw-r--r-- | pkg/aflow/func_tool_test.go | 3 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 43 |
4 files changed, 64 insertions, 45 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 19f8d3aec..3e1a6a112 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -178,6 +178,9 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } if thinking { + // Don't alter the original object (that may affect request caching). + cfgCopy := *cfg + cfg = &cfgCopy cfg.ThinkingConfig = &genai.ThinkingConfig{ // We capture them in the trajectory for analysis. IncludeThoughts: true, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index ff1fe9dfa..5294bc9ea 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -225,7 +225,7 @@ func TestWorkflow(t *testing.T) { ID: "id1", Name: "tool2", Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, }, }, @@ -265,7 +265,7 @@ func TestWorkflow(t *testing.T) { ID: "id2", Name: "set-results", Args: map[string]any{ - "AgentFoo": 42, + "AgentFoo": 42.0, "AgentBar": "agent-bar", }, }, @@ -294,8 +294,6 @@ func TestWorkflow(t *testing.T) { }, }} - // dupl considers makeSwarmReply/makeSwarmResp duplicates - // nolint:dupl makeSwarmReply := func(index int) *genai.Content { return &genai.Content{ Role: string(genai.RoleModel), @@ -305,14 +303,13 @@ func TestWorkflow(t *testing.T) { ID: fmt.Sprintf("id%v", index), Name: "set-results", Args: map[string]any{ - "SwarmInt": index, + "SwarmInt": float64(index), "SwarmStr": fmt.Sprintf("swarm%v", index), }, }, }, }} } - // nolint:dupl // dupl considers makeSwarmReply/makeSwarmResp duplicates makeSwarmResp := func(index int) *genai.Content { return &genai.Content{ Role: string(genai.RoleUser), @@ -410,7 +407,7 @@ func TestWorkflow(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) // nolint: dupl expected := []*trajectory.Span{ @@ -502,7 +499,7 @@ func TestWorkflow(t *testing.T) { Name: "tool2", Started: startTime.Add(9 * time.Second), Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, }, { @@ -513,7 +510,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(9 * time.Second), Finished: startTime.Add(10 * time.Second), Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, Results: map[string]any{ "ResBaz": 300, @@ -545,7 +542,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(13 * time.Second), Args: map[string]any{ "AgentBar": "agent-bar", - "AgentFoo": 42, + "AgentFoo": 42.0, }, }, { @@ -557,7 +554,7 @@ func TestWorkflow(t *testing.T) { Finished: startTime.Add(14 * time.Second), Args: map[string]any{ "AgentBar": "agent-bar", - "AgentFoo": 42, + "AgentFoo": 42.0, }, Results: map[string]any{ "AgentBar": "agent-bar", @@ -656,7 +653,7 @@ func TestWorkflow(t *testing.T) { Name: "set-results", Started: startTime.Add(24 * time.Second), Args: map[string]any{ - "SwarmInt": 1, + "SwarmInt": 1.0, "SwarmStr": "swarm1", }, }, @@ -668,7 +665,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(24 * time.Second), Finished: startTime.Add(25 * time.Second), Args: map[string]any{ - "SwarmInt": 1, + "SwarmInt": 1.0, "SwarmStr": "swarm1", }, Results: map[string]any{ @@ -743,7 +740,7 @@ func TestWorkflow(t *testing.T) { Name: "set-results", Started: startTime.Add(32 * time.Second), Args: map[string]any{ - "SwarmInt": 2, + "SwarmInt": 2.0, "SwarmStr": "swarm2", }, }, @@ -755,7 +752,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(32 * time.Second), Finished: startTime.Add(33 * time.Second), Args: map[string]any{ - "SwarmInt": 2, + "SwarmInt": 2.0, "SwarmStr": "swarm2", }, Results: map[string]any{ @@ -1023,7 +1020,6 @@ func TestToolMisbehavior(t *testing.T) { FunctionCall: &genai.FunctionCall{ ID: "id3", Name: "tool2", - Args: map[string]any{}, }, }, // Excessive argument. @@ -1032,8 +1028,8 @@ func TestToolMisbehavior(t *testing.T) { ID: "id4", Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, }, }, @@ -1172,7 +1168,7 @@ func TestToolMisbehavior(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) expected := []*trajectory.Span{ { @@ -1247,14 +1243,12 @@ func TestToolMisbehavior(t *testing.T) { Nesting: 2, Type: trajectory.SpanTool, Name: "tool2", - Args: map[string]any{}, }, { Seq: 5, Nesting: 2, Type: trajectory.SpanTool, Name: "tool2", - Args: map[string]any{}, Error: "missing argument \"Tool2Arg\"", }, { @@ -1263,8 +1257,8 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, }, { @@ -1273,8 +1267,8 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, Results: map[string]any{ "Result": 42, @@ -1286,7 +1280,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool3", Args: map[string]any{ - "Arg": 0, + "Arg": 0.0, }, }, { @@ -1295,7 +1289,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool3", Args: map[string]any{ - "Arg": 0, + "Arg": 0.0, }, Error: "tool \"tool3\" does not exist, please correct the name", }, @@ -1305,7 +1299,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "WrongArg": 0, + "WrongArg": 0.0, }, }, { @@ -1314,7 +1308,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "WrongArg": 0, + "WrongArg": 0.0, }, Error: "missing argument \"AdditionalOutput\"", }, @@ -1352,7 +1346,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 1, + "AdditionalOutput": 1.0, }, }, { @@ -1361,7 +1355,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 1, + "AdditionalOutput": 1.0, }, Results: map[string]any{ "AdditionalOutput": 1, @@ -1373,7 +1367,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 2, + "AdditionalOutput": 2.0, }, }, { @@ -1382,7 +1376,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 2, + "AdditionalOutput": 2.0, }, Results: map[string]any{ "AdditionalOutput": 2, diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go index 429566dbe..2076e0bb9 100644 --- a/pkg/aflow/func_tool_test.go +++ b/pkg/aflow/func_tool_test.go @@ -8,6 +8,7 @@ import ( "errors" "path/filepath" "testing" + "time" "github.com/google/syzkaller/pkg/aflow/trajectory" "github.com/stretchr/testify/assert" @@ -103,7 +104,7 @@ func TestToolErrors(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) onEvent := func(span *trajectory.Span) error { return nil } _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent) diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index e5391dfcf..406947e25 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -13,6 +13,7 @@ import ( "time" "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/google/syzkaller/pkg/hash" "google.golang.org/genai" ) @@ -109,7 +110,7 @@ type llmOutputs struct { func (a *LLMAgent) execute(ctx *Context) error { if a.Candidates <= 1 { - reply, outputs, err := a.executeOne(ctx) + reply, outputs, err := a.executeOne(ctx, 0) if err != nil { return err } @@ -132,7 +133,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error { var replies []string allOutputs := map[string]any{} for candidate := 0; candidate < a.Candidates; candidate++ { - reply, outputs, err := a.executeOne(ctx) + reply, outputs, err := a.executeOne(ctx, candidate) if err != nil { return err } @@ -146,7 +147,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error { return nil } -func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { +func (a *LLMAgent) executeOne(ctx *Context, candidate int) (string, map[string]any, error) { cfg, instruction, tools := a.config(ctx) span := &trajectory.Span{ Type: trajectory.SpanAgent, @@ -158,7 +159,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { if err := ctx.startSpan(span); err != nil { return "", nil, err } - reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt) + reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt, candidate) if err == nil { span.Reply = reply span.Results = outputs @@ -166,8 +167,8 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { return reply, outputs, ctx.finishSpan(span, err) } -func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) ( - string, map[string]any, error) { +func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, + prompt string, candidate int) (string, map[string]any, error) { var outputs map[string]any req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)} for { @@ -179,7 +180,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if err := ctx.startSpan(reqSpan); err != nil { return "", nil, err } - resp, err := a.generateContent(ctx, cfg, req) + resp, err := a.generateContent(ctx, cfg, req, candidate) if err != nil { return "", nil, ctx.finishSpan(reqSpan, err) } @@ -328,11 +329,10 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( } func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, - req []*genai.Content) (*genai.GenerateContentResponse, error) { + req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { backoff := time.Second - model := ctx.modelName(a.Model) for try := 0; ; try++ { - resp, err := ctx.generateContent(model, cfg, req) + resp, err := a.generateContentCached(ctx, cfg, req, candidate) var apiErr genai.APIError if err != nil && try < 100 && errors.As(err, &apiErr) && apiErr.Code == http.StatusServiceUnavailable { @@ -343,12 +343,33 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests && strings.Contains(apiErr.Message, "Quota exceeded for metric") && strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { - return resp, &modelQuotaError{model} + return resp, &modelQuotaError{ctx.modelName(a.Model)} } return resp, err } } +func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig, + req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { + type Cached struct { + Config *genai.GenerateContentConfig + Request []*genai.Content + Reply *genai.GenerateContentResponse + } + model := ctx.modelName(a.Model) + desc := fmt.Sprintf("model %v, config hash %v, request hash %v, candidate %v", + model, hash.String(cfg), hash.String(req), candidate) + cached, err := CacheObject(ctx, "llm", desc, func() (Cached, error) { + resp, err := ctx.generateContent(model, cfg, req) + return Cached{ + Config: cfg, + Request: req, + Reply: resp, + }, err + }) + return cached.Reply, err +} + func (a *LLMAgent) verify(vctx *verifyContext) { vctx.requireNotEmpty(a.Name, "Name", a.Name) vctx.requireNotEmpty(a.Name, "Model", a.Model) |
