diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-14 11:35:20 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-14 11:07:16 +0000 |
| commit | a9d6a79219801d2130df3b1a792c57f0e5428e9f (patch) | |
| tree | bc900771cf25374ed86011f4c0a85e7eb4647d2e | |
| parent | 1b03c2cc6e672ed19398ca4a9ce22da45299e68a (diff) | |
pkg/aflow: allow to specify model per-flow
We may want to use a weaker model for some workflows.
Allow to use different models for different workflows.
| -rw-r--r-- | dashboard/app/ai.go | 7 | ||||
| -rw-r--r-- | dashboard/app/ai_test.go | 36 | ||||
| -rw-r--r-- | dashboard/app/aidb/crud.go | 10 | ||||
| -rw-r--r-- | dashboard/dashapi/ai.go | 6 | ||||
| -rw-r--r-- | pkg/aflow/execute.go | 8 | ||||
| -rw-r--r-- | pkg/aflow/flow.go | 12 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/kcsan.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/moderation.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow/patching/patching.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 4 | ||||
| -rw-r--r-- | syz-agent/agent.go | 19 | ||||
| -rw-r--r-- | tools/syz-aflow/aflow.go | 49 |
12 files changed, 90 insertions, 64 deletions
diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go index 29bfc38d1..2912addea 100644 --- a/dashboard/app/ai.go +++ b/dashboard/app/ai.go @@ -207,9 +207,14 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan } func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) { - if len(req.Workflows) == 0 || req.CodeRevision == "" || req.LLMModel == "" { + if len(req.Workflows) == 0 || req.CodeRevision == "" { return nil, fmt.Errorf("invalid request") } + for _, flow := range req.Workflows { + if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" { + return nil, fmt.Errorf("invalid request") + } + } if err := aidb.UpdateWorkflows(ctx, req.Workflows); err != nil { return nil, fmt.Errorf("failed UpdateWorkflows: %w", err) } diff --git a/dashboard/app/ai_test.go b/dashboard/app/ai_test.go index bd3935b18..b775b2a89 100644 --- a/dashboard/app/ai_test.go +++ b/dashboard/app/ai_test.go @@ -63,11 +63,10 @@ func TestAIBugWorkflows(t *testing.T) { _, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching"}, - {Type: "patching", Name: "patching-foo"}, - {Type: "patching", Name: "patching-bar"}, + {Type: "patching", Name: "patching", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-foo", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, }, }) require.NoError(t, err) @@ -77,25 +76,23 @@ func TestAIBugWorkflows(t *testing.T) { _, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching"}, - {Type: "patching", Name: "patching-bar"}, - {Type: "patching", Name: "patching-baz"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan"}, + {Type: "patching", Name: "patching", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-baz", LLMModel: "smarty"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, }, }) require.NoError(t, err) _, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: "patching", Name: "patching"}, - {Type: "patching", Name: "patching-bar"}, - {Type: "patching", Name: "patching-qux"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan"}, - {Type: "assessment-kcsan", Name: "assessment-kcsan-foo"}, + {Type: "patching", Name: "patching", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-bar", LLMModel: "smarty"}, + {Type: "patching", Name: "patching-qux", LLMModel: "smarty"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"}, }, }) require.NoError(t, err) @@ -117,9 +114,8 @@ func TestAIJob(t *testing.T) { resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: "assessment-kcsan", Name: "assessment-kcsan"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, }, }) require.NoError(t, err) @@ -137,9 +133,8 @@ func TestAIJob(t *testing.T) { resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: "assessment-kcsan", Name: "assessment-kcsan"}, + {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"}, }, }) require.NoError(t, err2) @@ -214,9 +209,8 @@ func TestAIAssessmentKCSAN(t *testing.T) { resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{ CodeRevision: prog.GitRevision, - LLMModel: "smarty", Workflows: []dashapi.AIWorkflow{ - {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)}, + {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"}, }, }) require.NoError(t, err) diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go index a01c370b6..4f7e93f6a 100644 --- a/dashboard/app/aidb/crud.go +++ b/dashboard/app/aidb/crud.go @@ -128,11 +128,13 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) { } job = jobs[0] } - job.Started = spanner.NullTime{ - Time: TimeNow(ctx), - Valid: true, + job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true} + for _, flow := range req.Workflows { + if job.Workflow == flow.Name { + job.LLMModel = flow.LLMModel + break + } } - job.LLMModel = req.LLMModel job.CodeRevision = req.CodeRevision mut, err := spanner.InsertOrUpdateStruct("Jobs", job) if err != nil { diff --git a/dashboard/dashapi/ai.go b/dashboard/dashapi/ai.go index 893ed7215..8134e5744 100644 --- a/dashboard/dashapi/ai.go +++ b/dashboard/dashapi/ai.go @@ -9,14 +9,14 @@ import ( ) type AIJobPollReq struct { - LLMModel string // LLM model that will be used to execute jobs CodeRevision string // git commit of the syz-agent server Workflows []AIWorkflow } type AIWorkflow struct { - Type ai.WorkflowType - Name string + Type ai.WorkflowType + Name string + LLMModel string // LLM model that will be used to execute this workflow } type AIJobPollResp struct { diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 6e724988e..4133c01c9 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -18,9 +18,11 @@ import ( "google.golang.org/genai" ) -// https://ai.google.dev/gemini-api/docs/models -const DefaultModel = "gemini-3-pro-preview" - +// Execute executes the given AI workflow with provided inputs and returns workflow outputs. +// The model argument sets Gemini model name to execute the workflow. +// The workdir argument should point to a dir owned by aflow to store private data, +// it can be shared across parallel executions in the same process, and preferably +// preserved across process restarts for caching purposes. func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any, cache *Cache, onEvent onEvent) (map[string]any, error) { if err := flow.checkInputs(inputs); err != nil { diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go index 6325b2fd2..b4cbe2201 100644 --- a/pkg/aflow/flow.go +++ b/pkg/aflow/flow.go @@ -22,8 +22,9 @@ import ( // Actions are nodes of the graph, and they consume/produce some named values // (input/output fields, and intermediate values consumed by other actions). type Flow struct { - Name string // Empty for the main workflow for the workflow type. - Root Action + Name string // Empty for the main workflow for the workflow type. + Model string // The default Gemini model name to execute this workflow. + Root Action *FlowType } @@ -35,6 +36,12 @@ type FlowType struct { extractOutputs func(map[string]any) map[string]any } +// See https://ai.google.dev/gemini-api/docs/models +const ( + BestExpensiveModel = "gemini-3-pro-preview" + GoodBalancedModel = "gemini-3-flash-preview" +) + var Flows = make(map[string]*Flow) // Register a workflow type (characterized by Inputs and Outputs), @@ -88,6 +95,7 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error { actions: make(map[string]bool), state: make(map[string]*varState), } + ctx.requireNotEmpty(flow.Name, "Model", flow.Model) provideOutputs[Inputs](ctx, "flow inputs") flow.Root.verify(ctx) requireInputs[Outputs](ctx, "flow outputs") diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go index 92033aff7..67d695eb9 100644 --- a/pkg/aflow/flow/assessment/kcsan.go +++ b/pkg/aflow/flow/assessment/kcsan.go @@ -23,6 +23,7 @@ func init() { ai.WorkflowAssessmentKCSAN, "assess if a KCSAN report is about a benign race that only needs annotations or not", &aflow.Flow{ + Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ kernel.Checkout, diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go index 4b78f901c..8d9ac4a0b 100644 --- a/pkg/aflow/flow/assessment/moderation.go +++ b/pkg/aflow/flow/assessment/moderation.go @@ -33,6 +33,7 @@ func init() { ai.WorkflowModeration, "assess if a bug report is consistent and actionable or not", &aflow.Flow{ + Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ aflow.NewFuncAction("extract-crash-type", extractCrashType), diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go index ef10f1a2e..766cf089f 100644 --- a/pkg/aflow/flow/patching/patching.go +++ b/pkg/aflow/flow/patching/patching.go @@ -43,6 +43,7 @@ func init() { ai.WorkflowPatching, "generate a kernel patch fixing a provided bug reproducer", &aflow.Flow{ + Model: aflow.BestExpensiveModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ baseCommitPicker, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 8ab8016f3..19abce6b1 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -76,7 +76,8 @@ func TestWorkflow(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Name: "flow", + Name: "flow", + Model: "model", Root: NewPipeline( NewFuncAction("func-action", func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { @@ -530,6 +531,7 @@ func TestNoInputs(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { + Model: "model", Root: NewFuncAction("func-action", func(ctx *Context, args flowInputs) (flowOutputs, error) { return flowOutputs{}, nil diff --git a/syz-agent/agent.go b/syz-agent/agent.go index 1d2f96509..c5aad2470 100644 --- a/syz-agent/agent.go +++ b/syz-agent/agent.go @@ -47,7 +47,7 @@ type Config struct { CacheSize uint64 `json:"cache_size"` // Use fixed base commit for patching jobs (for testing). FixedBaseCommit string `json:"fixed_base_commit"` - // Use this LLM model (for testing, if empty use a default model). + // Use this LLM model (for testing, if empty use workflow-default model). Model string `json:"model"` } @@ -70,7 +70,6 @@ func run(configFile string, exitOnUpgrade, autoUpdate bool) error { SyzkallerRepo: "https://github.com/google/syzkaller.git", SyzkallerBranch: "master", CacheSize: 1 << 40, // 1TB should be enough for everyone! - Model: aflow.DefaultModel, } if err := config.LoadFile(configFile, cfg); err != nil { return fmt.Errorf("failed to load config: %w", err) @@ -169,13 +168,17 @@ type Server struct { func (s *Server) poll(ctx context.Context) ( bool, error) { req := &dashapi.AIJobPollReq{ - LLMModel: s.cfg.Model, CodeRevision: prog.GitRevision, } for _, flow := range aflow.Flows { + model := flow.Model + if s.cfg.Model != "" { + model = s.cfg.Model + } req.Workflows = append(req.Workflows, dashapi.AIWorkflow{ - Type: flow.Type, - Name: flow.Name, + Type: flow.Type, + Name: flow.Name, + LLMModel: model, }) } resp, err := s.dash.AIJobPoll(req) @@ -207,6 +210,10 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma if flow == nil { return nil, fmt.Errorf("unsupported flow %q", req.Workflow) } + model := flow.Model + if s.cfg.Model != "" { + model = s.cfg.Model + } inputs := map[string]any{ "Syzkaller": osutil.Abs(filepath.FromSlash("syzkaller/current")), "CodesearchToolBin": s.cfg.CodesearchToolBin, @@ -223,5 +230,5 @@ func (s *Server) executeJob(ctx context.Context, req *dashapi.AIJobPollResp) (ma Span: span, }) } - return flow.Execute(ctx, s.cfg.Model, s.workdir, inputs, s.cache, onEvent) + return flow.Execute(ctx, model, s.workdir, inputs, s.cache, onEvent) } diff --git a/tools/syz-aflow/aflow.go b/tools/syz-aflow/aflow.go index b9b59e90e..d915c2061 100644 --- a/tools/syz-aflow/aflow.go +++ b/tools/syz-aflow/aflow.go @@ -33,7 +33,7 @@ func main() { flagFlow = flag.String("workflow", "", "workflow to execute") flagInput = flag.String("input", "", "input json file with workflow arguments") flagWorkdir = flag.String("workdir", "", "directory for kernel checkout, kernel builds, etc") - flagModel = flag.String("model", aflow.DefaultModel, "use this LLM model") + flagModel = flag.String("model", "", "use this LLM model, if empty use the workflow default model") flagCacheSize = flag.String("cache-size", "10GB", "max cache size (e.g. 100MB, 5GB, 1TB)") flagDownloadBug = flag.String("download-bug", "", "extid of a bug to download from the dashboard"+ " and save into -input file") @@ -73,33 +73,14 @@ func main() { } } -func parseSize(s string) (uint64, error) { - var size uint64 - var suffix string - if _, err := fmt.Sscanf(s, "%d%s", &size, &suffix); err != nil { - return 0, fmt.Errorf("failed to parse cache size %q: %w", s, err) - } - switch suffix { - case "KB": - size <<= 10 - case "MB": - size <<= 20 - case "GB": - size <<= 30 - case "TB": - size <<= 40 - case "": - default: - return 0, fmt.Errorf("unknown size suffix %q", suffix) - } - return size, nil -} - func run(ctx context.Context, model, flowName, inputFile, workdir string, cacheSize uint64) error { flow := aflow.Flows[flowName] if flow == nil { return fmt.Errorf("workflow %q is not found", flowName) } + if model == "" { + model = flow.Model + } inputData, err := os.ReadFile(inputFile) if err != nil { return fmt.Errorf("failed to open -input file: %w", err) @@ -194,3 +175,25 @@ func getAccessToken() (string, error) { return token.AccessToken, nil } + +func parseSize(s string) (uint64, error) { + var size uint64 + var suffix string + if _, err := fmt.Sscanf(s, "%d%s", &size, &suffix); err != nil { + return 0, fmt.Errorf("failed to parse cache size %q: %w", s, err) + } + switch suffix { + case "KB": + size <<= 10 + case "MB": + size <<= 20 + case "GB": + size <<= 30 + case "TB": + size <<= 40 + case "": + default: + return 0, fmt.Errorf("unknown size suffix %q", suffix) + } + return size, nil +} |
