From c9820ab0fe2dce0914ff01bcaf3829ca82150eb2 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Tue, 20 Jan 2026 15:03:18 +0100 Subject: pkg/aflow: cache LLM requests Using cached replies is faster, cheaper, and more reliable. Espcially handy during development when the same workflows are retried lots of times with some changes. --- pkg/aflow/execute.go | 3 +++ pkg/aflow/flow_test.go | 60 ++++++++++++++++++++------------------------- pkg/aflow/func_tool_test.go | 3 ++- pkg/aflow/llm_agent.go | 43 +++++++++++++++++++++++--------- 4 files changed, 64 insertions(+), 45 deletions(-) (limited to 'pkg') diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 19f8d3aec..3e1a6a112 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -178,6 +178,9 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } if thinking { + // Don't alter the original object (that may affect request caching). + cfgCopy := *cfg + cfg = &cfgCopy cfg.ThinkingConfig = &genai.ThinkingConfig{ // We capture them in the trajectory for analysis. IncludeThoughts: true, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index ff1fe9dfa..5294bc9ea 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -225,7 +225,7 @@ func TestWorkflow(t *testing.T) { ID: "id1", Name: "tool2", Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, }, }, @@ -265,7 +265,7 @@ func TestWorkflow(t *testing.T) { ID: "id2", Name: "set-results", Args: map[string]any{ - "AgentFoo": 42, + "AgentFoo": 42.0, "AgentBar": "agent-bar", }, }, @@ -294,8 +294,6 @@ func TestWorkflow(t *testing.T) { }, }} - // dupl considers makeSwarmReply/makeSwarmResp duplicates - // nolint:dupl makeSwarmReply := func(index int) *genai.Content { return &genai.Content{ Role: string(genai.RoleModel), @@ -305,14 +303,13 @@ func TestWorkflow(t *testing.T) { ID: fmt.Sprintf("id%v", index), Name: "set-results", Args: map[string]any{ - "SwarmInt": index, + "SwarmInt": float64(index), "SwarmStr": fmt.Sprintf("swarm%v", index), }, }, }, }} } - // nolint:dupl // dupl considers makeSwarmReply/makeSwarmResp duplicates makeSwarmResp := func(index int) *genai.Content { return &genai.Content{ Role: string(genai.RoleUser), @@ -410,7 +407,7 @@ func TestWorkflow(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) // nolint: dupl expected := []*trajectory.Span{ @@ -502,7 +499,7 @@ func TestWorkflow(t *testing.T) { Name: "tool2", Started: startTime.Add(9 * time.Second), Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, }, { @@ -513,7 +510,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(9 * time.Second), Finished: startTime.Add(10 * time.Second), Args: map[string]any{ - "ArgBaz": 101, + "ArgBaz": 101.0, }, Results: map[string]any{ "ResBaz": 300, @@ -545,7 +542,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(13 * time.Second), Args: map[string]any{ "AgentBar": "agent-bar", - "AgentFoo": 42, + "AgentFoo": 42.0, }, }, { @@ -557,7 +554,7 @@ func TestWorkflow(t *testing.T) { Finished: startTime.Add(14 * time.Second), Args: map[string]any{ "AgentBar": "agent-bar", - "AgentFoo": 42, + "AgentFoo": 42.0, }, Results: map[string]any{ "AgentBar": "agent-bar", @@ -656,7 +653,7 @@ func TestWorkflow(t *testing.T) { Name: "set-results", Started: startTime.Add(24 * time.Second), Args: map[string]any{ - "SwarmInt": 1, + "SwarmInt": 1.0, "SwarmStr": "swarm1", }, }, @@ -668,7 +665,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(24 * time.Second), Finished: startTime.Add(25 * time.Second), Args: map[string]any{ - "SwarmInt": 1, + "SwarmInt": 1.0, "SwarmStr": "swarm1", }, Results: map[string]any{ @@ -743,7 +740,7 @@ func TestWorkflow(t *testing.T) { Name: "set-results", Started: startTime.Add(32 * time.Second), Args: map[string]any{ - "SwarmInt": 2, + "SwarmInt": 2.0, "SwarmStr": "swarm2", }, }, @@ -755,7 +752,7 @@ func TestWorkflow(t *testing.T) { Started: startTime.Add(32 * time.Second), Finished: startTime.Add(33 * time.Second), Args: map[string]any{ - "SwarmInt": 2, + "SwarmInt": 2.0, "SwarmStr": "swarm2", }, Results: map[string]any{ @@ -1023,7 +1020,6 @@ func TestToolMisbehavior(t *testing.T) { FunctionCall: &genai.FunctionCall{ ID: "id3", Name: "tool2", - Args: map[string]any{}, }, }, // Excessive argument. @@ -1032,8 +1028,8 @@ func TestToolMisbehavior(t *testing.T) { ID: "id4", Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, }, }, @@ -1172,7 +1168,7 @@ func TestToolMisbehavior(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) expected := []*trajectory.Span{ { @@ -1247,14 +1243,12 @@ func TestToolMisbehavior(t *testing.T) { Nesting: 2, Type: trajectory.SpanTool, Name: "tool2", - Args: map[string]any{}, }, { Seq: 5, Nesting: 2, Type: trajectory.SpanTool, Name: "tool2", - Args: map[string]any{}, Error: "missing argument \"Tool2Arg\"", }, { @@ -1263,8 +1257,8 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, }, { @@ -1273,8 +1267,8 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool2", Args: map[string]any{ - "Tool2Arg": 0, - "Tool2Arg2": 100, + "Tool2Arg": 0.0, + "Tool2Arg2": 100.0, }, Results: map[string]any{ "Result": 42, @@ -1286,7 +1280,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool3", Args: map[string]any{ - "Arg": 0, + "Arg": 0.0, }, }, { @@ -1295,7 +1289,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "tool3", Args: map[string]any{ - "Arg": 0, + "Arg": 0.0, }, Error: "tool \"tool3\" does not exist, please correct the name", }, @@ -1305,7 +1299,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "WrongArg": 0, + "WrongArg": 0.0, }, }, { @@ -1314,7 +1308,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "WrongArg": 0, + "WrongArg": 0.0, }, Error: "missing argument \"AdditionalOutput\"", }, @@ -1352,7 +1346,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 1, + "AdditionalOutput": 1.0, }, }, { @@ -1361,7 +1355,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 1, + "AdditionalOutput": 1.0, }, Results: map[string]any{ "AdditionalOutput": 1, @@ -1373,7 +1367,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 2, + "AdditionalOutput": 2.0, }, }, { @@ -1382,7 +1376,7 @@ func TestToolMisbehavior(t *testing.T) { Type: trajectory.SpanTool, Name: "set-results", Args: map[string]any{ - "AdditionalOutput": 2, + "AdditionalOutput": 2.0, }, Results: map[string]any{ "AdditionalOutput": 2, diff --git a/pkg/aflow/func_tool_test.go b/pkg/aflow/func_tool_test.go index 429566dbe..2076e0bb9 100644 --- a/pkg/aflow/func_tool_test.go +++ b/pkg/aflow/func_tool_test.go @@ -8,6 +8,7 @@ import ( "errors" "path/filepath" "testing" + "time" "github.com/google/syzkaller/pkg/aflow/trajectory" "github.com/stretchr/testify/assert" @@ -103,7 +104,7 @@ func TestToolErrors(t *testing.T) { } ctx := context.WithValue(context.Background(), stubContextKey, stub) workdir := t.TempDir() - cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, time.Now) require.NoError(t, err) onEvent := func(span *trajectory.Span) error { return nil } _, err = flows["test"].Execute(ctx, "", workdir, nil, cache, onEvent) diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index e5391dfcf..406947e25 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -13,6 +13,7 @@ import ( "time" "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/google/syzkaller/pkg/hash" "google.golang.org/genai" ) @@ -109,7 +110,7 @@ type llmOutputs struct { func (a *LLMAgent) execute(ctx *Context) error { if a.Candidates <= 1 { - reply, outputs, err := a.executeOne(ctx) + reply, outputs, err := a.executeOne(ctx, 0) if err != nil { return err } @@ -132,7 +133,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error { var replies []string allOutputs := map[string]any{} for candidate := 0; candidate < a.Candidates; candidate++ { - reply, outputs, err := a.executeOne(ctx) + reply, outputs, err := a.executeOne(ctx, candidate) if err != nil { return err } @@ -146,7 +147,7 @@ func (a *LLMAgent) executeMany(ctx *Context) error { return nil } -func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { +func (a *LLMAgent) executeOne(ctx *Context, candidate int) (string, map[string]any, error) { cfg, instruction, tools := a.config(ctx) span := &trajectory.Span{ Type: trajectory.SpanAgent, @@ -158,7 +159,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { if err := ctx.startSpan(span); err != nil { return "", nil, err } - reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt) + reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt, candidate) if err == nil { span.Reply = reply span.Results = outputs @@ -166,8 +167,8 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { return reply, outputs, ctx.finishSpan(span, err) } -func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) ( - string, map[string]any, error) { +func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, + prompt string, candidate int) (string, map[string]any, error) { var outputs map[string]any req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)} for { @@ -179,7 +180,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma if err := ctx.startSpan(reqSpan); err != nil { return "", nil, err } - resp, err := a.generateContent(ctx, cfg, req) + resp, err := a.generateContent(ctx, cfg, req, candidate) if err != nil { return "", nil, ctx.finishSpan(reqSpan, err) } @@ -328,11 +329,10 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( } func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, - req []*genai.Content) (*genai.GenerateContentResponse, error) { + req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { backoff := time.Second - model := ctx.modelName(a.Model) for try := 0; ; try++ { - resp, err := ctx.generateContent(model, cfg, req) + resp, err := a.generateContentCached(ctx, cfg, req, candidate) var apiErr genai.APIError if err != nil && try < 100 && errors.As(err, &apiErr) && apiErr.Code == http.StatusServiceUnavailable { @@ -343,12 +343,33 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests && strings.Contains(apiErr.Message, "Quota exceeded for metric") && strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { - return resp, &modelQuotaError{model} + return resp, &modelQuotaError{ctx.modelName(a.Model)} } return resp, err } } +func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig, + req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { + type Cached struct { + Config *genai.GenerateContentConfig + Request []*genai.Content + Reply *genai.GenerateContentResponse + } + model := ctx.modelName(a.Model) + desc := fmt.Sprintf("model %v, config hash %v, request hash %v, candidate %v", + model, hash.String(cfg), hash.String(req), candidate) + cached, err := CacheObject(ctx, "llm", desc, func() (Cached, error) { + resp, err := ctx.generateContent(model, cfg, req) + return Cached{ + Config: cfg, + Request: req, + Reply: resp, + }, err + }) + return cached.Reply, err +} + func (a *LLMAgent) verify(vctx *verifyContext) { vctx.requireNotEmpty(a.Name, "Name", a.Name) vctx.requireNotEmpty(a.Name, "Model", a.Model) -- cgit mrf-deployment