aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-14 11:35:20 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-14 11:07:16 +0000
commita9d6a79219801d2130df3b1a792c57f0e5428e9f (patch)
treebc900771cf25374ed86011f4c0a85e7eb4647d2e
parent1b03c2cc6e672ed19398ca4a9ce22da45299e68a (diff)
pkg/aflow: allow to specify model per-flow
We may want to use a weaker model for some workflows. Allow to use different models for different workflows.
-rw-r--r--dashboard/app/ai.go7
-rw-r--r--dashboard/app/ai_test.go36
-rw-r--r--dashboard/app/aidb/crud.go10
-rw-r--r--dashboard/dashapi/ai.go6
-rw-r--r--pkg/aflow/execute.go8
-rw-r--r--pkg/aflow/flow.go12
-rw-r--r--pkg/aflow/flow/assessment/kcsan.go1
-rw-r--r--pkg/aflow/flow/assessment/moderation.go1
-rw-r--r--pkg/aflow/flow/patching/patching.go1
-rw-r--r--pkg/aflow/flow_test.go4
-rw-r--r--syz-agent/agent.go19
-rw-r--r--tools/syz-aflow/aflow.go49
12 files changed, 90 insertions, 64 deletions
diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go
index 29bfc38d1..2912addea 100644
--- a/dashboard/app/ai.go
+++ b/dashboard/app/ai.go
@@ -207,9 +207,14 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
}
func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) {
- if len(req.Workflows) == 0 || req.CodeRevision == "" || req.LLMModel == "" {
+ if len(req.Workflows) == 0 || req.CodeRevision == "" {
return nil, fmt.Errorf("invalid request")
}
+ for _, flow := range req.Workflows {
+ if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" {
+ return nil, fmt.Errorf("invalid request")
+ }
+ }
if err := aidb.UpdateWorkflows(ctx, req.Workflows); err != nil {
return nil, fmt.Errorf("failed UpdateWorkflows: %w", err)
}
diff --git a/dashboard/app/ai_test.go b/dashboard/app/ai_test.go
index bd3935b18..b775b2a89 100644
--- a/dashboard/app/ai_test.go
+++ b/dashboard/app/ai_test.go
@@ -63,11 +63,10 @@ func TestAIBugWorkflows(t *testing.T) {
_, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-foo"},
- {Type: "patching", Name: "patching-bar"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-foo", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -77,25 +76,23 @@ func TestAIBugWorkflows(t *testing.T) {
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-bar"},
- {Type: "patching", Name: "patching-baz"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-baz", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err)
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-bar"},
- {Type: "patching", Name: "patching-qux"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan-foo"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-qux", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -117,9 +114,8 @@ func TestAIJob(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -137,9 +133,8 @@ func TestAIJob(t *testing.T) {
resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err2)
@@ -214,9 +209,8 @@ func TestAIAssessmentKCSAN(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)},
+ {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"},
},
})
require.NoError(t, err)
diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go
index a01c370b6..4f7e93f6a 100644
--- a/dashboard/app/aidb/crud.go
+++ b/dashboard/app/aidb/crud.go
@@ -128,11 +128,13 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
}
job = jobs[0]
}
- job.Started = spanner.NullTime{
- Time: TimeNow(ctx),
- Valid: true,
+ job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true}
+ for _, flow := range req.Workflows {
+ if job.Workflow == flow.Name {
+ job.LLMModel = flow.LLMModel
+ break
+ }
}
- job.LLMModel = req.LLMModel
job.CodeRevision = req.CodeRevision
mut, err := spanner.InsertOrUpdateStruct("Jobs", job)
if err != nil {
diff --git a/dashboard/dashapi/ai.go b/dashboard/dashapi/ai.go
index 893ed7215..8134e5744 100644
--- a/dashboard/dashapi/ai.go
+++ b/dashboard/dashapi/ai.go
@@ -9,14 +9,14 @@ import (
)
type AIJobPollReq struct {
- LLMModel string // LLM model that will be used to execute jobs
CodeRevision string // git commit of the syz-agent server
Workflows []AIWorkflow
}
type AIWorkflow struct {
- Type ai.WorkflowType
- Name string
+ Type ai.WorkflowType
+ Name string
+ LLMModel string // LLM model that will be used to execute this workflow
}
type AIJobPollResp struct {
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 6e724988e..4133c01c9 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -18,9 +18,11 @@ import (
"google.golang.org/genai"
)
-// https://ai.google.dev/gemini-api/docs/models
-const DefaultModel = "gemini-3-pro-preview"
-
+// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
+// The model argument sets Gemini model name to execute the workflow.
+// The workdir argument should point to a dir owned by aflow to store private data,
+// it can be shared across parallel executions in the same process, and preferably
+// preserved across process restarts for caching purposes.
func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any,
cache *Cache, onEvent onEvent) (map[string]any, error) {
if err := flow.checkInputs(inputs); err != nil {
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
index 6325b2fd2..b4cbe2201 100644
--- a/pkg/aflow/flow.go
+++ b/pkg/aflow/flow.go
@@ -22,8 +22,9 @@ import (
// Actions are nodes of the graph, and they consume/produce some named values
// (input/output fields, and intermediate values consumed by other actions).
type Flow struct {
- Name string // Empty for the main workflow for the workflow type.
- Root Action
+ Name string // Empty for the main workflow for the workflow type.
+ Model string // The default Gemini model name to execute this workflow.
+ Root Action
*FlowType
}
@@ -35,6 +36,12 @@ type FlowType struct {
extractOutputs func(map[string]any) map[string]any
}
+// See https://ai.google.dev/gemini-api/docs/models
+const (
+ BestExpensiveModel = "gemini-3-pro-preview"
+ GoodBalancedModel = "gemini-3-flash-preview"
+)
+
var Flows = make(map[string]*Flow)
// Register a workflow type (characterized by Inputs and Outputs),
@@ -88,6 +95,7 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
actions: make(map[string]bool),
state: make(map[string]*varState),
}
+ ctx.requireNotEmpty(flow.Name, "Model", flow.Model)
provideOutputs[Inputs](ctx, "flow inputs")
flow.Root.verify(ctx)
requireInputs[Outputs](ctx, "flow outputs")
diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go
index 92033aff7..67d695eb9 100644
--- a/pkg/aflow/flow/assessment/kcsan.go
+++ b/pkg/aflow/flow/assessment/kcsan.go
@@ -23,6 +23,7 @@ func init() {
ai.WorkflowAssessmentKCSAN,
"assess if a KCSAN report is about a benign race that only needs annotations or not",
&aflow.Flow{
+ Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
kernel.Checkout,
diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go
index 4b78f901c..8d9ac4a0b 100644
--- a/pkg/aflow/flow/assessment/moderation.go
+++ b/pkg/aflow/flow/assessment/moderation.go
@@ -33,6 +33,7 @@ func init() {
ai.WorkflowModeration,
"assess if a bug report is consistent and actionable or not",
&aflow.Flow{
+ Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
aflow.NewFuncAction("extract-crash-type", extractCrashType),
diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go
index ef10f1a2e..766cf089f 100644
--- a/pkg/aflow/flow/patching/patching.go
+++ b/pkg/aflow/flow/patching/patching.go
@@ -43,6 +43,7 @@ func init() {
ai.WorkflowPatching,
"generate a kernel patch fixing a provided bug reproducer",
&aflow.Flow{
+ Model: aflow.BestExpensiveModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
baseCommitPicker,
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 8ab8016f3..19abce6b1 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -76,7 +76,8 @@ func TestWorkflow(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Name: "flow",
+ Name: "flow",
+ Model: "model",
Root: NewPipeline(
NewFuncAction("func-action",
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -530,6 +531,7 @@ func TestNoInputs(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
+ Model: "model",
Root: NewFuncAction("func-action",
func(ctx *Context, args flowInputs) (flowOutputs, error) {
return flowOutputs{}, nil
diff --git a/syz-agent/agent.go b/syz-agent/agent.go
index 1d2f96509..c5aad2470 100644
--- a/syz-agent/agent.go
+++ b/syz-agent/agent.go
@@ -47,7 +47,7 @@ type Config struct {
CacheSize uint64 `json:"cache_size"`
// Use fixed base commit for patching jobs (for testing).
FixedBaseCommit string `json:"fixed_base_commit"`
- // Use this LLM model (for testing, if empty use a default model).
+ // Use this LLM model (for testing, if empty use workflow-default model).
Model string `json:"model"`
}
@@ -70,7 +70,6 @@ func run(configFile string, exitOnUpgrade, autoUpdate bool) error {
SyzkallerRepo: "https://github.com/google/syzkaller.git",
SyzkallerBranch: "master",
CacheSize: 1 << 40, // 1TB should be enough for everyone!
- Model: aflow.DefaultModel,
}
if err := config.LoadFile(configFile, cfg); err != nil {
return fmt.Errorf("failed to load config: %w", err)
@@ -169,13 +168,17 @@ type Server struct {
func (s *Server) poll(ctx context.Context) (
bool, error) {
req := &dashapi.AIJobPollReq{
- LLMModel: s.cfg.Model,
CodeRevision: prog.GitRevision,
}
for _, flow := range aflow.Flows {
+ model := flow.Model
+ if s.cfg.Model != "" {
+ model = s.cfg.Model
+ }
req.Workflows = append(req.Workflows, dashapi.AIWorkflow{
- Type: flow.Type,
- Name: flow.Name,
+ Type: flow.Type,
+ Name: flow.Name,
+ LLMModel: model,
})
}
resp, err := s.dash.AIJobPoll(req)
@@ -207,6 +210,10 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma
if flow == nil {
return nil, fmt.Errorf("unsupported flow %q", req.Workflow)
}
+ model := flow.Model
+ if s.cfg.Model != "" {
+ model = s.cfg.Model
+ }
inputs := map[string]any{
"Syzkaller": osutil.Abs(filepath.FromSlash("syzkaller/current")),
"CodesearchToolBin": s.cfg.CodesearchToolBin,
@@ -223,5 +230,5 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma
Span: span,
})
}
- return flow.Execute(ctx, s.cfg.Model, s.workdir, inputs, s.cache, onEvent)
+ return flow.Execute(ctx, model, s.workdir, inputs, s.cache, onEvent)
}
diff --git a/tools/syz-aflow/aflow.go b/tools/syz-aflow/aflow.go
index b9b59e90e..d915c2061 100644
--- a/tools/syz-aflow/aflow.go
+++ b/tools/syz-aflow/aflow.go
@@ -33,7 +33,7 @@ func main() {
flagFlow = flag.String("workflow", "", "workflow to execute")
flagInput = flag.String("input", "", "input json file with workflow arguments")
flagWorkdir = flag.String("workdir", "", "directory for kernel checkout, kernel builds, etc")
- flagModel = flag.String("model", aflow.DefaultModel, "use this LLM model")
+ flagModel = flag.String("model", "", "use this LLM model, if empty use the workflow default model")
flagCacheSize = flag.String("cache-size", "10GB", "max cache size (e.g. 100MB, 5GB, 1TB)")
flagDownloadBug = flag.String("download-bug", "", "extid of a bug to download from the dashboard"+
" and save into -input file")
@@ -73,33 +73,14 @@ func main() {
}
}
-func parseSize(s string) (uint64, error) {
- var size uint64
- var suffix string
- if _, err := fmt.Sscanf(s, "%d%s", &size, &suffix); err != nil {
- return 0, fmt.Errorf("failed to parse cache size %q: %w", s, err)
- }
- switch suffix {
- case "KB":
- size <<= 10
- case "MB":
- size <<= 20
- case "GB":
- size <<= 30
- case "TB":
- size <<= 40
- case "":
- default:
- return 0, fmt.Errorf("unknown size suffix %q", suffix)
- }
- return size, nil
-}
-
func run(ctx context.Context, model, flowName, inputFile, workdir string, cacheSize uint64) error {
flow := aflow.Flows[flowName]
if flow == nil {
return fmt.Errorf("workflow %q is not found", flowName)
}
+ if model == "" {
+ model = flow.Model
+ }
inputData, err := os.ReadFile(inputFile)
if err != nil {
return fmt.Errorf("failed to open -input file: %w", err)
@@ -194,3 +175,25 @@ func getAccessToken() (string, error) {
return token.AccessToken, nil
}
+
+func parseSize(s string) (uint64, error) {
+ var size uint64
+ var suffix string
+ if _, err := fmt.Sscanf(s, "%d%s", &size, &suffix); err != nil {
+ return 0, fmt.Errorf("failed to parse cache size %q: %w", s, err)
+ }
+ switch suffix {
+ case "KB":
+ size <<= 10
+ case "MB":
+ size <<= 20
+ case "GB":
+ size <<= 30
+ case "TB":
+ size <<= 40
+ case "":
+ default:
+ return 0, fmt.Errorf("unknown size suffix %q", suffix)
+ }
+ return size, nil
+}