diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-15 20:53:57 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 21:12:57 +0000 |
| commit | 7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch) | |
| tree | 2ccbc85132a170d046837de6bdd8be3317f94060 | |
| parent | 2494e18d5ced59fc7f0522749041e499d3082a9e (diff) | |
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow.
We can have some more complex tasks during patch generation with the most elaborate model,
but also some simpler ones with less elaborate models.
| -rw-r--r-- | dashboard/app/ai.go | 6 | ||||
| -rw-r--r-- | dashboard/app/ai_test.go | 30 | ||||
| -rw-r--r-- | dashboard/app/aidb/crud.go | 7 | ||||
| -rw-r--r-- | dashboard/app/aidb/entities.go | 2 | ||||
| -rw-r--r-- | dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql | 1 | ||||
| -rw-r--r-- | dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql | 1 | ||||
| -rw-r--r-- | dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql | 1 | ||||
| -rw-r--r-- | dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql | 1 | ||||
| -rw-r--r-- | dashboard/app/templates/ai_job.html | 3 | ||||
| -rw-r--r-- | dashboard/app/templates/templates.html | 2 | ||||
| -rw-r--r-- | dashboard/dashapi/ai.go | 5 | ||||
| -rw-r--r-- | pkg/aflow/execute.go | 66 | ||||
| -rw-r--r-- | pkg/aflow/flow.go | 12 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/kcsan.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/moderation.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/flow/patching/patching.go | 4 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 42 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 19 | ||||
| -rw-r--r-- | pkg/aflow/trajectory/trajectory.go | 1 | ||||
| -rw-r--r-- | syz-agent/agent.go | 15 | ||||
| -rw-r--r-- | tools/syz-aflow/aflow.go | 5 |
21 files changed, 127 insertions, 100 deletions
diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go index 8d8767832..39f1479a3 100644 --- a/dashboard/app/ai.go +++ b/dashboard/app/ai.go @@ -48,7 +48,6 @@ type uiAIJob struct { Created time.Time Started time.Time Finished time.Time - LLMModel string CodeRevision string CodeRevisionLink string Error string @@ -68,6 +67,7 @@ type uiAITrajectorySpan struct { Nesting int64 Type string Name string + Model string Duration time.Duration Error string Args string @@ -198,7 +198,6 @@ func makeUIAIJob(job *aidb.Job) *uiAIJob { Created: job.Created, Started: nullTime(job.Started), Finished: nullTime(job.Finished), - LLMModel: job.LLMModel, CodeRevision: job.CodeRevision, CodeRevisionLink: vcs.LogLink(vcs.SyzkallerRepo, job.CodeRevision), Error: job.Error, @@ -220,6 +219,7 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan Nesting: span.Nesting, Type: span.Type, Name: span.Name, + Model: span.Model, Duration: duration, Error: nullString(span.Error), Args: nullJSON(span.Args), @@ -238,7 +238,7 @@ func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) { return nil, fmt.Errorf("invalid request") } for _, flow := range req.Workflows { - if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" { + if flow.Type == "" || flow.Name == "" { return nil, fmt.Errorf("invalid request") } } diff --git a/dashboard/app/ai_test.go b/dashboard/app/ai_test.go index b775b2a89..addf71f5a 100644 --- a/dashboard/app/ai_test.go +++ b/dashboard/app/ai_test.go @@ -64,9 +64,9 @@ func TestAIBugWorkflows(t *testing.T) { _, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-foo", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, + {Type: "patching", Name: "patching"}, + {Type: "patching", Name: "patching-foo"}, + {Type: "patching", Name: "patching-bar"}, }, }) require.NoError(t, err) @@ -77,10 +77,10 @@ func TestAIBugWorkflows(t *testing.T) { _, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-baz", LLMModel: "smarty"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, + {Type: "patching", Name: "patching"}, + {Type: "patching", Name: "patching-bar"}, + {Type: "patching", Name: "patching-baz"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan"}, }, }) require.NoError(t, err) @@ -88,11 +88,11 @@ func TestAIBugWorkflows(t *testing.T) { _, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, - {Type: "patching", Name: "patching-qux", LLMModel: "smarty"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"}, + {Type: "patching", Name: "patching"}, + {Type: "patching", Name: "patching-bar"}, + {Type: "patching", Name: "patching-qux"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan-foo"}, }, }) require.NoError(t, err) @@ -115,7 +115,7 @@ func TestAIJob(t *testing.T) { resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan"}, }, }) require.NoError(t, err) @@ -134,7 +134,7 @@ func TestAIJob(t *testing.T) { resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan"}, }, }) require.NoError(t, err2) @@ -210,7 +210,7 @@ func TestAIAssessmentKCSAN(t *testing.T) { resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, Workflows: []dashapi.AIWorkflow{ - {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"}, + {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)}, }, }) require.NoError(t, err) diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go index 4f7e93f6a..872a70ace 100644 --- a/dashboard/app/aidb/crud.go +++ b/dashboard/app/aidb/crud.go @@ -129,12 +129,6 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) { job = jobs[0] } job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true} - for _, flow := range req.Workflows { - if job.Workflow == flow.Name { - job.LLMModel = flow.LLMModel - break - } - } job.CodeRevision = req.CodeRevision mut, err := spanner.InsertOrUpdateStruct("Jobs", job) if err != nil { @@ -184,6 +178,7 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa Nesting: int64(span.Nesting), Type: string(span.Type), Name: span.Name, + Model: span.Model, Started: span.Started, Finished: toNullTime(span.Finished), Error: toNullString(span.Error), diff --git a/dashboard/app/aidb/entities.go b/dashboard/app/aidb/entities.go index 23df884df..0a3e7b164 100644 --- a/dashboard/app/aidb/entities.go +++ b/dashboard/app/aidb/entities.go @@ -28,7 +28,6 @@ type Job struct { Created time.Time Started spanner.NullTime Finished spanner.NullTime - LLMModel string // LLM model used to execute the job, filled when the job is started CodeRevision string // syzkaller revision, filled when the job is started Error string // for finished jobs Args spanner.NullJSON @@ -43,6 +42,7 @@ type TrajectorySpan struct { Nesting int64 Type string Name string + Model string Started time.Time Finished spanner.NullTime Error spanner.NullString diff --git a/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql b/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql new file mode 100644 index 000000000..9c8ee7020 --- /dev/null +++ b/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql @@ -0,0 +1 @@ +ALTER TABLE TrajectorySpans DROP COLUMN Model; diff --git a/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql b/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql new file mode 100644 index 000000000..c5cd8821d --- /dev/null +++ b/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql @@ -0,0 +1 @@ +ALTER TABLE TrajectorySpans ADD COLUMN Model STRING(1000); diff --git a/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql b/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql new file mode 100644 index 000000000..1d9885cb8 --- /dev/null +++ b/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql @@ -0,0 +1 @@ +ALTER TABLE Jobs ADD COLUMN LLMModel STRING(1000); diff --git a/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql b/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql new file mode 100644 index 000000000..85b4c74e4 --- /dev/null +++ b/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql @@ -0,0 +1 @@ +ALTER TABLE Jobs DROP COLUMN LLMModel; diff --git a/dashboard/app/templates/ai_job.html b/dashboard/app/templates/ai_job.html index 8f2526c63..f8f2b82bd 100644 --- a/dashboard/app/templates/ai_job.html +++ b/dashboard/app/templates/ai_job.html @@ -65,6 +65,9 @@ Detailed info on a single AI job execution. <td> <details> <summary>{{formatDuration $span.Duration}}</summary> + {{if $span.Model}} + <b>Model:</b> <div id="ai_details_div"><pre>{{$span.Model}}</pre></div><br> + {{end}} {{if $span.Error}} <b>Error:</b> <div id="ai_details_div"><pre>{{$span.Error}}</pre></div><br> {{end}} diff --git a/dashboard/app/templates/templates.html b/dashboard/app/templates/templates.html index dee2d1300..d20727a73 100644 --- a/dashboard/app/templates/templates.html +++ b/dashboard/app/templates/templates.html @@ -689,7 +689,6 @@ Use of this source code is governed by Apache 2 LICENSE that can be found in the <th><a onclick="return sortTable(this, 'Created', textSort)" href="#">Created</a></th> <th><a onclick="return sortTable(this, 'Started', textSort)" href="#">Started</a></th> <th><a onclick="return sortTable(this, 'Finished', textSort)" href="#">Finished</a></th> - <th><a onclick="return sortTable(this, 'Model', textSort)" href="#">Model</a></th> <th><a onclick="return sortTable(this, 'Revision', textSort)" href="#">Revision</a></th> <th><a onclick="return sortTable(this, 'Error', textSort)" href="#">Error</a></th> </tr></thead> @@ -710,7 +709,6 @@ Use of this source code is governed by Apache 2 LICENSE that can be found in the <td>{{formatTime $job.Created}}</td> <td>{{formatTime $job.Started}}</td> <td>{{formatTime $job.Finished}}</td> - <td>{{$job.LLMModel}}</td> <td class="tag">{{link $job.CodeRevisionLink $job.CodeRevision}}</td> <td>{{$job.Error}}</td> </tr> diff --git a/dashboard/dashapi/ai.go b/dashboard/dashapi/ai.go index 8134e5744..dfa410402 100644 --- a/dashboard/dashapi/ai.go +++ b/dashboard/dashapi/ai.go @@ -14,9 +14,8 @@ type AIJobPollReq struct { } type AIWorkflow struct { - Type ai.WorkflowType - Name string - LLMModel string // LLM model that will be used to execute this workflow + Type ai.WorkflowType + Name string } type AIJobPollResp struct { diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 482a58fb4..96f15c13b 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -20,7 +20,8 @@ import ( ) // Execute executes the given AI workflow with provided inputs and returns workflow outputs. -// The model argument sets Gemini model name to execute the workflow. +// The model argument overrides Gemini models used to execute LLM agents, +// if not set, then default models for each agent are used. // The workdir argument should point to a dir owned by aflow to store private data, // it can be shared across parallel executions in the same process, and preferably // preserved across process restarts for caching purposes. @@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s return nil, fmt.Errorf("flow inputs are missing: %w", err) } ctx := &Context{ - Context: c, - Workdir: osutil.Abs(workdir), - cache: cache, - state: maps.Clone(inputs), - onEvent: onEvent, + Context: c, + Workdir: osutil.Abs(workdir), + llmModel: model, + cache: cache, + state: maps.Clone(inputs), + onEvent: onEvent, } defer ctx.close() if s := c.Value(stubContextKey); s != nil { @@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s ctx.timeNow = time.Now } if ctx.generateContent == nil { - var err error - ctx.generateContent, err = contentGenerator(c, model) - if err != nil { - return nil, err - } + ctx.generateContent = ctx.generateContentGemini } span := &trajectory.Span{ Type: trajectory.SpanFlow, @@ -91,9 +89,7 @@ type flowError struct { } type ( - onEvent func(*trajectory.Span) error - generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) ( - *genai.GenerateContentResponse, error) + onEvent func(*trajectory.Span) error contextKeyType int ) @@ -105,7 +101,8 @@ var ( stubContextKey = contextKeyType(1) ) -func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) { +func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, + req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" createClientOnce.Do(func() { if os.Getenv("GOOGLE_API_KEY") == "" { @@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e " (see https://ai.google.dev/gemini-api/docs/api-key)") return } - client, createClientErr = genai.NewClient(ctx, nil) + client, createClientErr = genai.NewClient(ctx.Context, nil) if createClientErr != nil { return } - for m, err := range client.Models.All(ctx) { + for m, err := range client.Models.All(ctx.Context) { if err != nil { createClientErr = err return @@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { - if thinking { - cfg.ThinkingConfig = &genai.ThinkingConfig{ - // We capture them in the trajectory for analysis. - IncludeThoughts: true, - // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). - // See https://ai.google.dev/gemini-api/docs/thinking#set-budget - // However, thoughts output also consumes total output token budget. - // We may consider adjusting ThinkingLevel parameter. - ThinkingBudget: genai.Ptr[int32](-1), - } + if thinking { + cfg.ThinkingConfig = &genai.ThinkingConfig{ + // We capture them in the trajectory for analysis. + IncludeThoughts: true, + // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). + // See https://ai.google.dev/gemini-api/docs/thinking#set-budget + // However, thoughts output also consumes total output token budget. + // We may consider adjusting ThinkingLevel parameter. + ThinkingBudget: genai.Ptr[int32](-1), } - return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg) - }, nil + } + return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg) } type Context struct { Context context.Context Workdir string + llmModel string cache *Cache cachedDirs []string state map[string]any @@ -164,7 +160,15 @@ type Context struct { type stubContext struct { timeNow func() time.Time - generateContent generateContentFunc + generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) ( + *genai.GenerateContentResponse, error) +} + +func (ctx *Context) modelName(model string) string { + if ctx.llmModel != "" { + return ctx.llmModel + } + return model } func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) { diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go index b4cbe2201..6325b2fd2 100644 --- a/pkg/aflow/flow.go +++ b/pkg/aflow/flow.go @@ -22,9 +22,8 @@ import ( // Actions are nodes of the graph, and they consume/produce some named values // (input/output fields, and intermediate values consumed by other actions). type Flow struct { - Name string // Empty for the main workflow for the workflow type. - Model string // The default Gemini model name to execute this workflow. - Root Action + Name string // Empty for the main workflow for the workflow type. + Root Action *FlowType } @@ -36,12 +35,6 @@ type FlowType struct { extractOutputs func(map[string]any) map[string]any } -// See https://ai.google.dev/gemini-api/docs/models -const ( - BestExpensiveModel = "gemini-3-pro-preview" - GoodBalancedModel = "gemini-3-flash-preview" -) - var Flows = make(map[string]*Flow) // Register a workflow type (characterized by Inputs and Outputs), @@ -95,7 +88,6 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error { actions: make(map[string]bool), state: make(map[string]*varState), } - ctx.requireNotEmpty(flow.Name, "Model", flow.Model) provideOutputs[Inputs](ctx, "flow inputs") flow.Root.verify(ctx) requireInputs[Outputs](ctx, "flow outputs") diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go index 67d695eb9..6bfc7bb12 100644 --- a/pkg/aflow/flow/assessment/kcsan.go +++ b/pkg/aflow/flow/assessment/kcsan.go @@ -23,7 +23,6 @@ func init() { ai.WorkflowAssessmentKCSAN, "assess if a KCSAN report is about a benign race that only needs annotations or not", &aflow.Flow{ - Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ kernel.Checkout, @@ -31,6 +30,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "expert", + Model: aflow.GoodBalancedModel, Reply: "Explanation", Outputs: aflow.LLMOutputs[struct { Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."` diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go index 8d9ac4a0b..b13ee1e7d 100644 --- a/pkg/aflow/flow/assessment/moderation.go +++ b/pkg/aflow/flow/assessment/moderation.go @@ -33,7 +33,6 @@ func init() { ai.WorkflowModeration, "assess if a bug report is consistent and actionable or not", &aflow.Flow{ - Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ aflow.NewFuncAction("extract-crash-type", extractCrashType), @@ -42,6 +41,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "expert", + Model: aflow.GoodBalancedModel, Reply: "Explanation", Outputs: aflow.LLMOutputs[struct { Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."` diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go index 766cf089f..856962e6c 100644 --- a/pkg/aflow/flow/patching/patching.go +++ b/pkg/aflow/flow/patching/patching.go @@ -43,7 +43,6 @@ func init() { ai.WorkflowPatching, "generate a kernel patch fixing a provided bug reproducer", &aflow.Flow{ - Model: aflow.BestExpensiveModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ baseCommitPicker, @@ -54,6 +53,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "debugger", + Model: aflow.BestExpensiveModel, Reply: "BugExplanation", Temperature: 1, Instruction: debuggingInstruction, @@ -62,6 +62,7 @@ func init() { }, &aflow.LLMAgent{ Name: "diff-generator", + Model: aflow.BestExpensiveModel, Reply: "PatchDiff", Temperature: 1, Instruction: diffInstruction, @@ -70,6 +71,7 @@ func init() { }, &aflow.LLMAgent{ Name: "description-generator", + Model: aflow.BestExpensiveModel, Reply: "PatchDescription", Temperature: 1, Instruction: descriptionInstruction, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 5c038d7c6..97e2dd93a 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Name: "flow", - Model: "model", + Name: "flow", Root: NewPipeline( NewFuncAction("func-action", func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { @@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "smarty", + Model: "model1", Reply: "OutFoo", Outputs: LLMOutputs[agentOutputs](), Temperature: 0, @@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "swarm", + Model: "model2", Reply: "OutSwarm", Candidates: 2, Outputs: LLMOutputs[swarmOutputs](), @@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) { }, &LLMAgent{ Name: "aggregator", + Model: "model3", Reply: "OutAggregator", Temperature: 0, Instruction: "Aggregate!", @@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) { stubTime = stubTime.Add(time.Second) return stubTime }, - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { replySeq++ if replySeq < 4 { + assert.Equal(t, model, "model1") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) @@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) { assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") } else if replySeq < 8 { + assert.Equal(t, model, "model2") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 1) assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results") } else { + assert.Equal(t, model, "model3") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 0) } @@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, Prompt: "Prompt: baz func-output", @@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), }, { @@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), Finished: startTime.Add(6 * time.Second), Thoughts: "I am thinking I need to call some tools", @@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), }, { @@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), Finished: startTime.Add(12 * time.Second), Thoughts: "Completly blank.Whatever.", @@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), }, { @@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), Finished: startTime.Add(16 * time.Second), }, @@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Finished: startTime.Add(17 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, @@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), }, { @@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), Finished: startTime.Add(23 * time.Second), }, @@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), }, { @@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), Finished: startTime.Add(27 * time.Second), }, @@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Finished: startTime.Add(28 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), }, { @@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), Finished: startTime.Add(31 * time.Second), }, @@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), }, { @@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), Finished: startTime.Add(35 * time.Second), }, @@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Finished: startTime.Add(36 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Instruction: "Aggregate!", Prompt: `Prompt: baz @@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), }, { @@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), Finished: startTime.Add(40 * time.Second), }, @@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Finished: startTime.Add(41 * time.Second), Instruction: "Aggregate!", @@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) { expected = expected[1:] return nil } - res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.NoError(t, err) require.Equal(t, replySeq, 8) require.Equal(t, res, expectedOutputs) @@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Model: "model", Root: NewFuncAction("func-action", func(ctx *Context, args flowInputs) (flowOutputs, error) { return flowOutputs{}, nil @@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) { }) require.NoError(t, err) stub := &stubContext{ - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { return nil, nil }, @@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) { cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) require.NoError(t, err) onEvent := func(span *trajectory.Span) error { return nil } - _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.Equal(t, err.Error(), "flow inputs are missing:"+ " field InBar is not present when converting map to aflow.flowInputs") } diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index 02d3bca85..3c416b37c 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -18,6 +18,9 @@ import ( type LLMAgent struct { // For logging/debugging. Name string + // The default Gemini model name to execute this workflow. + // Use the consts defined below. + Model string // Name of the state variable to store the final reply of the agent. // These names can be used in subsequent action instructions/prompts, // and as final workflow outputs. @@ -43,6 +46,13 @@ type LLMAgent struct { Tools []Tool } +// Consts to use for LLMAgent.Model. +// See https://ai.google.dev/gemini-api/docs/models +const ( + BestExpensiveModel = "gemini-3-pro-preview" + GoodBalancedModel = "gemini-3-flash-preview" +) + // Tool represents a custom tool an LLMAgent can invoke. // Use NewFuncTool to create function-based tools. type Tool interface { @@ -134,6 +144,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { Name: a.Name, Instruction: instruction, Prompt: formatTemplate(a.Prompt, ctx.state), + Model: ctx.modelName(a.Model), } if err := ctx.startSpan(span); err != nil { return "", nil, err @@ -152,8 +163,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)} for { reqSpan := &trajectory.Span{ - Type: trajectory.SpanLLM, - Name: a.Name, + Type: trajectory.SpanLLM, + Name: a.Name, + Model: ctx.modelName(a.Model), } if err := ctx.startSpan(reqSpan); err != nil { return "", nil, err @@ -278,7 +290,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi req []*genai.Content) (*genai.GenerateContentResponse, error) { backoff := time.Second for try := 0; ; try++ { - resp, err := ctx.generateContent(cfg, req) + resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req) var apiErr genai.APIError if err != nil && try < 100 && errors.As(err, &apiErr) && apiErr.Code == http.StatusServiceUnavailable { @@ -292,6 +304,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi func (a *LLMAgent) verify(vctx *verifyContext) { vctx.requireNotEmpty(a.Name, "Name", a.Name) + vctx.requireNotEmpty(a.Name, "Model", a.Model) vctx.requireNotEmpty(a.Name, "Reply", a.Reply) if temp, ok := a.Temperature.(int); ok { a.Temperature = float32(temp) diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go index 49e36933b..aa9558708 100644 --- a/pkg/aflow/trajectory/trajectory.go +++ b/pkg/aflow/trajectory/trajectory.go @@ -20,6 +20,7 @@ type Span struct { Nesting int Type SpanType Name string // flow/action/tool name + Model string // LLM model name for agent/LLM spans Started time.Time Finished time.Time Error string // relevant if Finished is set diff --git a/syz-agent/agent.go b/syz-agent/agent.go index 54d6a67c6..d070144db 100644 --- a/syz-agent/agent.go +++ b/syz-agent/agent.go @@ -171,14 +171,9 @@ func (s *Server) poll(ctx context.Context) ( CodeRevision: prog.GitRevision, } for _, flow := range aflow.Flows { - model := flow.Model - if s.cfg.Model != "" { - model = s.cfg.Model - } req.Workflows = append(req.Workflows, dashapi.AIWorkflow{ - Type: flow.Type, - Name: flow.Name, - LLMModel: model, + Type: flow.Type, + Name: flow.Name, }) } resp, err := s.dash.AIJobPoll(req) @@ -210,10 +205,6 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma if flow == nil { return nil, fmt.Errorf("unsupported flow %q", req.Workflow) } - model := flow.Model - if s.cfg.Model != "" { - model = s.cfg.Model - } inputs := map[string]any{ "Syzkaller": osutil.Abs(filepath.FromSlash("syzkaller/current")), "CodesearchToolBin": s.cfg.CodesearchToolBin, @@ -230,5 +221,5 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma Span: span, }) } - return flow.Execute(ctx, model, s.workdir, inputs, s.cache, onEvent) + return flow.Execute(ctx, s.cfg.Model, s.workdir, inputs, s.cache, onEvent) } diff --git a/tools/syz-aflow/aflow.go b/tools/syz-aflow/aflow.go index d915c2061..160d4541c 100644 --- a/tools/syz-aflow/aflow.go +++ b/tools/syz-aflow/aflow.go @@ -33,7 +33,7 @@ func main() { flagFlow = flag.String("workflow", "", "workflow to execute") flagInput = flag.String("input", "", "input json file with workflow arguments") flagWorkdir = flag.String("workdir", "", "directory for kernel checkout, kernel builds, etc") - flagModel = flag.String("model", "", "use this LLM model, if empty use the workflow default model") + flagModel = flag.String("model", "", "use this LLM model, if empty use default models") flagCacheSize = flag.String("cache-size", "10GB", "max cache size (e.g. 100MB, 5GB, 1TB)") flagDownloadBug = flag.String("download-bug", "", "extid of a bug to download from the dashboard"+ " and save into -input file") @@ -78,9 +78,6 @@ func run(ctx context.Context, model, flowName, inputFile, workdir string, cacheS if flow == nil { return fmt.Errorf("workflow %q is not found", flowName) } - if model == "" { - model = flow.Model - } inputData, err := os.ReadFile(inputFile) if err != nil { return fmt.Errorf("failed to open -input file: %w", err) |
