aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-15 20:53:57 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-20 21:12:57 +0000
commit7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch)
tree2ccbc85132a170d046837de6bdd8be3317f94060
parent2494e18d5ced59fc7f0522749041e499d3082a9e (diff)
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow. We can have some more complex tasks during patch generation with the most elaborate model, but also some simpler ones with less elaborate models.
-rw-r--r--dashboard/app/ai.go6
-rw-r--r--dashboard/app/ai_test.go30
-rw-r--r--dashboard/app/aidb/crud.go7
-rw-r--r--dashboard/app/aidb/entities.go2
-rw-r--r--dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql1
-rw-r--r--dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql1
-rw-r--r--dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql1
-rw-r--r--dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql1
-rw-r--r--dashboard/app/templates/ai_job.html3
-rw-r--r--dashboard/app/templates/templates.html2
-rw-r--r--dashboard/dashapi/ai.go5
-rw-r--r--pkg/aflow/execute.go66
-rw-r--r--pkg/aflow/flow.go12
-rw-r--r--pkg/aflow/flow/assessment/kcsan.go2
-rw-r--r--pkg/aflow/flow/assessment/moderation.go2
-rw-r--r--pkg/aflow/flow/patching/patching.go4
-rw-r--r--pkg/aflow/flow_test.go42
-rw-r--r--pkg/aflow/llm_agent.go19
-rw-r--r--pkg/aflow/trajectory/trajectory.go1
-rw-r--r--syz-agent/agent.go15
-rw-r--r--tools/syz-aflow/aflow.go5
21 files changed, 127 insertions, 100 deletions
diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go
index 8d8767832..39f1479a3 100644
--- a/dashboard/app/ai.go
+++ b/dashboard/app/ai.go
@@ -48,7 +48,6 @@ type uiAIJob struct {
Created time.Time
Started time.Time
Finished time.Time
- LLMModel string
CodeRevision string
CodeRevisionLink string
Error string
@@ -68,6 +67,7 @@ type uiAITrajectorySpan struct {
Nesting int64
Type string
Name string
+ Model string
Duration time.Duration
Error string
Args string
@@ -198,7 +198,6 @@ func makeUIAIJob(job *aidb.Job) *uiAIJob {
Created: job.Created,
Started: nullTime(job.Started),
Finished: nullTime(job.Finished),
- LLMModel: job.LLMModel,
CodeRevision: job.CodeRevision,
CodeRevisionLink: vcs.LogLink(vcs.SyzkallerRepo, job.CodeRevision),
Error: job.Error,
@@ -220,6 +219,7 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
Nesting: span.Nesting,
Type: span.Type,
Name: span.Name,
+ Model: span.Model,
Duration: duration,
Error: nullString(span.Error),
Args: nullJSON(span.Args),
@@ -238,7 +238,7 @@ func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) {
return nil, fmt.Errorf("invalid request")
}
for _, flow := range req.Workflows {
- if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" {
+ if flow.Type == "" || flow.Name == "" {
return nil, fmt.Errorf("invalid request")
}
}
diff --git a/dashboard/app/ai_test.go b/dashboard/app/ai_test.go
index b775b2a89..addf71f5a 100644
--- a/dashboard/app/ai_test.go
+++ b/dashboard/app/ai_test.go
@@ -64,9 +64,9 @@ func TestAIBugWorkflows(t *testing.T) {
_, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-foo", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching"},
+ {Type: "patching", Name: "patching-foo"},
+ {Type: "patching", Name: "patching-bar"},
},
})
require.NoError(t, err)
@@ -77,10 +77,10 @@ func TestAIBugWorkflows(t *testing.T) {
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-baz", LLMModel: "smarty"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching"},
+ {Type: "patching", Name: "patching-bar"},
+ {Type: "patching", Name: "patching-baz"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan"},
},
})
require.NoError(t, err)
@@ -88,11 +88,11 @@ func TestAIBugWorkflows(t *testing.T) {
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
- {Type: "patching", Name: "patching-qux", LLMModel: "smarty"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching"},
+ {Type: "patching", Name: "patching-bar"},
+ {Type: "patching", Name: "patching-qux"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan-foo"},
},
})
require.NoError(t, err)
@@ -115,7 +115,7 @@ func TestAIJob(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan"},
},
})
require.NoError(t, err)
@@ -134,7 +134,7 @@ func TestAIJob(t *testing.T) {
resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan"},
},
})
require.NoError(t, err2)
@@ -210,7 +210,7 @@ func TestAIAssessmentKCSAN(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
Workflows: []dashapi.AIWorkflow{
- {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"},
+ {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)},
},
})
require.NoError(t, err)
diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go
index 4f7e93f6a..872a70ace 100644
--- a/dashboard/app/aidb/crud.go
+++ b/dashboard/app/aidb/crud.go
@@ -129,12 +129,6 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
job = jobs[0]
}
job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true}
- for _, flow := range req.Workflows {
- if job.Workflow == flow.Name {
- job.LLMModel = flow.LLMModel
- break
- }
- }
job.CodeRevision = req.CodeRevision
mut, err := spanner.InsertOrUpdateStruct("Jobs", job)
if err != nil {
@@ -184,6 +178,7 @@ func StoreTrajectorySpan(ctx context.Context, jobID string, span *trajectory.Spa
Nesting: int64(span.Nesting),
Type: string(span.Type),
Name: span.Name,
+ Model: span.Model,
Started: span.Started,
Finished: toNullTime(span.Finished),
Error: toNullString(span.Error),
diff --git a/dashboard/app/aidb/entities.go b/dashboard/app/aidb/entities.go
index 23df884df..0a3e7b164 100644
--- a/dashboard/app/aidb/entities.go
+++ b/dashboard/app/aidb/entities.go
@@ -28,7 +28,6 @@ type Job struct {
Created time.Time
Started spanner.NullTime
Finished spanner.NullTime
- LLMModel string // LLM model used to execute the job, filled when the job is started
CodeRevision string // syzkaller revision, filled when the job is started
Error string // for finished jobs
Args spanner.NullJSON
@@ -43,6 +42,7 @@ type TrajectorySpan struct {
Nesting int64
Type string
Name string
+ Model string
Started time.Time
Finished spanner.NullTime
Error spanner.NullString
diff --git a/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql b/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql
new file mode 100644
index 000000000..9c8ee7020
--- /dev/null
+++ b/dashboard/app/aidb/migrations/3_add_trajectory_model.down.sql
@@ -0,0 +1 @@
+ALTER TABLE TrajectorySpans DROP COLUMN Model;
diff --git a/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql b/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql
new file mode 100644
index 000000000..c5cd8821d
--- /dev/null
+++ b/dashboard/app/aidb/migrations/3_add_trajectory_model.up.sql
@@ -0,0 +1 @@
+ALTER TABLE TrajectorySpans ADD COLUMN Model STRING(1000);
diff --git a/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql b/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql
new file mode 100644
index 000000000..1d9885cb8
--- /dev/null
+++ b/dashboard/app/aidb/migrations/4_remove_jobs_model.down.sql
@@ -0,0 +1 @@
+ALTER TABLE Jobs ADD COLUMN LLMModel STRING(1000);
diff --git a/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql b/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql
new file mode 100644
index 000000000..85b4c74e4
--- /dev/null
+++ b/dashboard/app/aidb/migrations/4_remove_jobs_model.up.sql
@@ -0,0 +1 @@
+ALTER TABLE Jobs DROP COLUMN LLMModel;
diff --git a/dashboard/app/templates/ai_job.html b/dashboard/app/templates/ai_job.html
index 8f2526c63..f8f2b82bd 100644
--- a/dashboard/app/templates/ai_job.html
+++ b/dashboard/app/templates/ai_job.html
@@ -65,6 +65,9 @@ Detailed info on a single AI job execution.
<td>
<details>
<summary>{{formatDuration $span.Duration}}</summary>
+ {{if $span.Model}}
+ <b>Model:</b> <div id="ai_details_div"><pre>{{$span.Model}}</pre></div><br>
+ {{end}}
{{if $span.Error}}
<b>Error:</b> <div id="ai_details_div"><pre>{{$span.Error}}</pre></div><br>
{{end}}
diff --git a/dashboard/app/templates/templates.html b/dashboard/app/templates/templates.html
index dee2d1300..d20727a73 100644
--- a/dashboard/app/templates/templates.html
+++ b/dashboard/app/templates/templates.html
@@ -689,7 +689,6 @@ Use of this source code is governed by Apache 2 LICENSE that can be found in the
<th><a onclick="return sortTable(this, 'Created', textSort)" href="#">Created</a></th>
<th><a onclick="return sortTable(this, 'Started', textSort)" href="#">Started</a></th>
<th><a onclick="return sortTable(this, 'Finished', textSort)" href="#">Finished</a></th>
- <th><a onclick="return sortTable(this, 'Model', textSort)" href="#">Model</a></th>
<th><a onclick="return sortTable(this, 'Revision', textSort)" href="#">Revision</a></th>
<th><a onclick="return sortTable(this, 'Error', textSort)" href="#">Error</a></th>
</tr></thead>
@@ -710,7 +709,6 @@ Use of this source code is governed by Apache 2 LICENSE that can be found in the
<td>{{formatTime $job.Created}}</td>
<td>{{formatTime $job.Started}}</td>
<td>{{formatTime $job.Finished}}</td>
- <td>{{$job.LLMModel}}</td>
<td class="tag">{{link $job.CodeRevisionLink $job.CodeRevision}}</td>
<td>{{$job.Error}}</td>
</tr>
diff --git a/dashboard/dashapi/ai.go b/dashboard/dashapi/ai.go
index 8134e5744..dfa410402 100644
--- a/dashboard/dashapi/ai.go
+++ b/dashboard/dashapi/ai.go
@@ -14,9 +14,8 @@ type AIJobPollReq struct {
}
type AIWorkflow struct {
- Type ai.WorkflowType
- Name string
- LLMModel string // LLM model that will be used to execute this workflow
+ Type ai.WorkflowType
+ Name string
}
type AIJobPollResp struct {
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 482a58fb4..96f15c13b 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -20,7 +20,8 @@ import (
)
// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
-// The model argument sets Gemini model name to execute the workflow.
+// The model argument overrides Gemini models used to execute LLM agents,
+// if not set, then default models for each agent are used.
// The workdir argument should point to a dir owned by aflow to store private data,
// it can be shared across parallel executions in the same process, and preferably
// preserved across process restarts for caching purposes.
@@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
return nil, fmt.Errorf("flow inputs are missing: %w", err)
}
ctx := &Context{
- Context: c,
- Workdir: osutil.Abs(workdir),
- cache: cache,
- state: maps.Clone(inputs),
- onEvent: onEvent,
+ Context: c,
+ Workdir: osutil.Abs(workdir),
+ llmModel: model,
+ cache: cache,
+ state: maps.Clone(inputs),
+ onEvent: onEvent,
}
defer ctx.close()
if s := c.Value(stubContextKey); s != nil {
@@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
ctx.timeNow = time.Now
}
if ctx.generateContent == nil {
- var err error
- ctx.generateContent, err = contentGenerator(c, model)
- if err != nil {
- return nil, err
- }
+ ctx.generateContent = ctx.generateContentGemini
}
span := &trajectory.Span{
Type: trajectory.SpanFlow,
@@ -91,9 +89,7 @@ type flowError struct {
}
type (
- onEvent func(*trajectory.Span) error
- generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) (
- *genai.GenerateContentResponse, error)
+ onEvent func(*trajectory.Span) error
contextKeyType int
)
@@ -105,7 +101,8 @@ var (
stubContextKey = contextKeyType(1)
)
-func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) {
+func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
+ req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
createClientOnce.Do(func() {
if os.Getenv("GOOGLE_API_KEY") == "" {
@@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
" (see https://ai.google.dev/gemini-api/docs/api-key)")
return
}
- client, createClientErr = genai.NewClient(ctx, nil)
+ client, createClientErr = genai.NewClient(ctx.Context, nil)
if createClientErr != nil {
return
}
- for m, err := range client.Models.All(ctx) {
+ for m, err := range client.Models.All(ctx.Context) {
if err != nil {
createClientErr = err
return
@@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) {
- if thinking {
- cfg.ThinkingConfig = &genai.ThinkingConfig{
- // We capture them in the trajectory for analysis.
- IncludeThoughts: true,
- // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
- // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
- // However, thoughts output also consumes total output token budget.
- // We may consider adjusting ThinkingLevel parameter.
- ThinkingBudget: genai.Ptr[int32](-1),
- }
+ if thinking {
+ cfg.ThinkingConfig = &genai.ThinkingConfig{
+ // We capture them in the trajectory for analysis.
+ IncludeThoughts: true,
+ // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
+ // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
+ // However, thoughts output also consumes total output token budget.
+ // We may consider adjusting ThinkingLevel parameter.
+ ThinkingBudget: genai.Ptr[int32](-1),
}
- return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg)
- }, nil
+ }
+ return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg)
}
type Context struct {
Context context.Context
Workdir string
+ llmModel string
cache *Cache
cachedDirs []string
state map[string]any
@@ -164,7 +160,15 @@ type Context struct {
type stubContext struct {
timeNow func() time.Time
- generateContent generateContentFunc
+ generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) (
+ *genai.GenerateContentResponse, error)
+}
+
+func (ctx *Context) modelName(model string) string {
+ if ctx.llmModel != "" {
+ return ctx.llmModel
+ }
+ return model
}
func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) {
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
index b4cbe2201..6325b2fd2 100644
--- a/pkg/aflow/flow.go
+++ b/pkg/aflow/flow.go
@@ -22,9 +22,8 @@ import (
// Actions are nodes of the graph, and they consume/produce some named values
// (input/output fields, and intermediate values consumed by other actions).
type Flow struct {
- Name string // Empty for the main workflow for the workflow type.
- Model string // The default Gemini model name to execute this workflow.
- Root Action
+ Name string // Empty for the main workflow for the workflow type.
+ Root Action
*FlowType
}
@@ -36,12 +35,6 @@ type FlowType struct {
extractOutputs func(map[string]any) map[string]any
}
-// See https://ai.google.dev/gemini-api/docs/models
-const (
- BestExpensiveModel = "gemini-3-pro-preview"
- GoodBalancedModel = "gemini-3-flash-preview"
-)
-
var Flows = make(map[string]*Flow)
// Register a workflow type (characterized by Inputs and Outputs),
@@ -95,7 +88,6 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
actions: make(map[string]bool),
state: make(map[string]*varState),
}
- ctx.requireNotEmpty(flow.Name, "Model", flow.Model)
provideOutputs[Inputs](ctx, "flow inputs")
flow.Root.verify(ctx)
requireInputs[Outputs](ctx, "flow outputs")
diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go
index 67d695eb9..6bfc7bb12 100644
--- a/pkg/aflow/flow/assessment/kcsan.go
+++ b/pkg/aflow/flow/assessment/kcsan.go
@@ -23,7 +23,6 @@ func init() {
ai.WorkflowAssessmentKCSAN,
"assess if a KCSAN report is about a benign race that only needs annotations or not",
&aflow.Flow{
- Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
kernel.Checkout,
@@ -31,6 +30,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "expert",
+ Model: aflow.GoodBalancedModel,
Reply: "Explanation",
Outputs: aflow.LLMOutputs[struct {
Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."`
diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go
index 8d9ac4a0b..b13ee1e7d 100644
--- a/pkg/aflow/flow/assessment/moderation.go
+++ b/pkg/aflow/flow/assessment/moderation.go
@@ -33,7 +33,6 @@ func init() {
ai.WorkflowModeration,
"assess if a bug report is consistent and actionable or not",
&aflow.Flow{
- Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
aflow.NewFuncAction("extract-crash-type", extractCrashType),
@@ -42,6 +41,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "expert",
+ Model: aflow.GoodBalancedModel,
Reply: "Explanation",
Outputs: aflow.LLMOutputs[struct {
Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."`
diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go
index 766cf089f..856962e6c 100644
--- a/pkg/aflow/flow/patching/patching.go
+++ b/pkg/aflow/flow/patching/patching.go
@@ -43,7 +43,6 @@ func init() {
ai.WorkflowPatching,
"generate a kernel patch fixing a provided bug reproducer",
&aflow.Flow{
- Model: aflow.BestExpensiveModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
baseCommitPicker,
@@ -54,6 +53,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "debugger",
+ Model: aflow.BestExpensiveModel,
Reply: "BugExplanation",
Temperature: 1,
Instruction: debuggingInstruction,
@@ -62,6 +62,7 @@ func init() {
},
&aflow.LLMAgent{
Name: "diff-generator",
+ Model: aflow.BestExpensiveModel,
Reply: "PatchDiff",
Temperature: 1,
Instruction: diffInstruction,
@@ -70,6 +71,7 @@ func init() {
},
&aflow.LLMAgent{
Name: "description-generator",
+ Model: aflow.BestExpensiveModel,
Reply: "PatchDescription",
Temperature: 1,
Instruction: descriptionInstruction,
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 5c038d7c6..97e2dd93a 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Name: "flow",
- Model: "model",
+ Name: "flow",
Root: NewPipeline(
NewFuncAction("func-action",
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "smarty",
+ Model: "model1",
Reply: "OutFoo",
Outputs: LLMOutputs[agentOutputs](),
Temperature: 0,
@@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "swarm",
+ Model: "model2",
Reply: "OutSwarm",
Candidates: 2,
Outputs: LLMOutputs[swarmOutputs](),
@@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) {
},
&LLMAgent{
Name: "aggregator",
+ Model: "model3",
Reply: "OutAggregator",
Temperature: 0,
Instruction: "Aggregate!",
@@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) {
stubTime = stubTime.Add(time.Second)
return stubTime
},
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
replySeq++
if replySeq < 4 {
+ assert.Equal(t, model, "model1")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
@@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) {
assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
} else if replySeq < 8 {
+ assert.Equal(t, model, "model2")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 1)
assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results")
} else {
+ assert.Equal(t, model, "model3")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 0)
}
@@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz func-output",
@@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
},
{
@@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
Finished: startTime.Add(6 * time.Second),
Thoughts: "I am thinking I need to call some tools",
@@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
},
{
@@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
Finished: startTime.Add(12 * time.Second),
Thoughts: "Completly blank.Whatever.",
@@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
},
{
@@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
Finished: startTime.Add(16 * time.Second),
},
@@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Finished: startTime.Add(17 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
@@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
},
{
@@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
Finished: startTime.Add(23 * time.Second),
},
@@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
},
{
@@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
Finished: startTime.Add(27 * time.Second),
},
@@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Finished: startTime.Add(28 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
},
{
@@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
Finished: startTime.Add(31 * time.Second),
},
@@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
},
{
@@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
Finished: startTime.Add(35 * time.Second),
},
@@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Finished: startTime.Add(36 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Instruction: "Aggregate!",
Prompt: `Prompt: baz
@@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
},
{
@@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
Finished: startTime.Add(40 * time.Second),
},
@@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Finished: startTime.Add(41 * time.Second),
Instruction: "Aggregate!",
@@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) {
expected = expected[1:]
return nil
}
- res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.NoError(t, err)
require.Equal(t, replySeq, 8)
require.Equal(t, res, expectedOutputs)
@@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Model: "model",
Root: NewFuncAction("func-action",
func(ctx *Context, args flowInputs) (flowOutputs, error) {
return flowOutputs{}, nil
@@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) {
})
require.NoError(t, err)
stub := &stubContext{
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
return nil, nil
},
@@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) {
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
require.NoError(t, err)
onEvent := func(span *trajectory.Span) error { return nil }
- _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.Equal(t, err.Error(), "flow inputs are missing:"+
" field InBar is not present when converting map to aflow.flowInputs")
}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 02d3bca85..3c416b37c 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -18,6 +18,9 @@ import (
type LLMAgent struct {
// For logging/debugging.
Name string
+ // The default Gemini model name to execute this workflow.
+ // Use the consts defined below.
+ Model string
// Name of the state variable to store the final reply of the agent.
// These names can be used in subsequent action instructions/prompts,
// and as final workflow outputs.
@@ -43,6 +46,13 @@ type LLMAgent struct {
Tools []Tool
}
+// Consts to use for LLMAgent.Model.
+// See https://ai.google.dev/gemini-api/docs/models
+const (
+ BestExpensiveModel = "gemini-3-pro-preview"
+ GoodBalancedModel = "gemini-3-flash-preview"
+)
+
// Tool represents a custom tool an LLMAgent can invoke.
// Use NewFuncTool to create function-based tools.
type Tool interface {
@@ -134,6 +144,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
Name: a.Name,
Instruction: instruction,
Prompt: formatTemplate(a.Prompt, ctx.state),
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(span); err != nil {
return "", nil, err
@@ -152,8 +163,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
for {
reqSpan := &trajectory.Span{
- Type: trajectory.SpanLLM,
- Name: a.Name,
+ Type: trajectory.SpanLLM,
+ Name: a.Name,
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(reqSpan); err != nil {
return "", nil, err
@@ -278,7 +290,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
req []*genai.Content) (*genai.GenerateContentResponse, error) {
backoff := time.Second
for try := 0; ; try++ {
- resp, err := ctx.generateContent(cfg, req)
+ resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
@@ -292,6 +304,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
func (a *LLMAgent) verify(vctx *verifyContext) {
vctx.requireNotEmpty(a.Name, "Name", a.Name)
+ vctx.requireNotEmpty(a.Name, "Model", a.Model)
vctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
a.Temperature = float32(temp)
diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go
index 49e36933b..aa9558708 100644
--- a/pkg/aflow/trajectory/trajectory.go
+++ b/pkg/aflow/trajectory/trajectory.go
@@ -20,6 +20,7 @@ type Span struct {
Nesting int
Type SpanType
Name string // flow/action/tool name
+ Model string // LLM model name for agent/LLM spans
Started time.Time
Finished time.Time
Error string // relevant if Finished is set
diff --git a/syz-agent/agent.go b/syz-agent/agent.go
index 54d6a67c6..d070144db 100644
--- a/syz-agent/agent.go
+++ b/syz-agent/agent.go
@@ -171,14 +171,9 @@ func (s *Server) poll(ctx context.Context) (
CodeRevision: prog.GitRevision,
}
for _, flow := range aflow.Flows {
- model := flow.Model
- if s.cfg.Model != "" {
- model = s.cfg.Model
- }
req.Workflows = append(req.Workflows, dashapi.AIWorkflow{
- Type: flow.Type,
- Name: flow.Name,
- LLMModel: model,
+ Type: flow.Type,
+ Name: flow.Name,
})
}
resp, err := s.dash.AIJobPoll(req)
@@ -210,10 +205,6 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma
if flow == nil {
return nil, fmt.Errorf("unsupported flow %q", req.Workflow)
}
- model := flow.Model
- if s.cfg.Model != "" {
- model = s.cfg.Model
- }
inputs := map[string]any{
"Syzkaller": osutil.Abs(filepath.FromSlash("syzkaller/current")),
"CodesearchToolBin": s.cfg.CodesearchToolBin,
@@ -230,5 +221,5 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma
Span: span,
})
}
- return flow.Execute(ctx, model, s.workdir, inputs, s.cache, onEvent)
+ return flow.Execute(ctx, s.cfg.Model, s.workdir, inputs, s.cache, onEvent)
}
diff --git a/tools/syz-aflow/aflow.go b/tools/syz-aflow/aflow.go
index d915c2061..160d4541c 100644
--- a/tools/syz-aflow/aflow.go
+++ b/tools/syz-aflow/aflow.go
@@ -33,7 +33,7 @@ func main() {
flagFlow = flag.String("workflow", "", "workflow to execute")
flagInput = flag.String("input", "", "input json file with workflow arguments")
flagWorkdir = flag.String("workdir", "", "directory for kernel checkout, kernel builds, etc")
- flagModel = flag.String("model", "", "use this LLM model, if empty use the workflow default model")
+ flagModel = flag.String("model", "", "use this LLM model, if empty use default models")
flagCacheSize = flag.String("cache-size", "10GB", "max cache size (e.g. 100MB, 5GB, 1TB)")
flagDownloadBug = flag.String("download-bug", "", "extid of a bug to download from the dashboard"+
" and save into -input file")
@@ -78,9 +78,6 @@ func run(ctx context.Context, model, flowName, inputFile, workdir string, cacheS
if flow == nil {
return fmt.Errorf("workflow %q is not found", flowName)
}
- if model == "" {
- model = flow.Model
- }
inputData, err := os.ReadFile(inputFile)
if err != nil {
return fmt.Errorf("failed to open -input file: %w", err)