aboutsummaryrefslogtreecommitdiffstats
path: root/dashboard/app
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-14 11:35:20 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-14 11:07:16 +0000
commita9d6a79219801d2130df3b1a792c57f0e5428e9f (patch)
treebc900771cf25374ed86011f4c0a85e7eb4647d2e /dashboard/app
parent1b03c2cc6e672ed19398ca4a9ce22da45299e68a (diff)
pkg/aflow: allow to specify model per-flow
We may want to use a weaker model for some workflows. Allow to use different models for different workflows.
Diffstat (limited to 'dashboard/app')
-rw-r--r--dashboard/app/ai.go7
-rw-r--r--dashboard/app/ai_test.go36
-rw-r--r--dashboard/app/aidb/crud.go10
3 files changed, 27 insertions, 26 deletions
diff --git a/dashboard/app/ai.go b/dashboard/app/ai.go
index 29bfc38d1..2912addea 100644
--- a/dashboard/app/ai.go
+++ b/dashboard/app/ai.go
@@ -207,9 +207,14 @@ func makeUIAITrajectory(trajetory []*aidb.TrajectorySpan) []*uiAITrajectorySpan
}
func apiAIJobPoll(ctx context.Context, req *dashapi.AIJobPollReq) (any, error) {
- if len(req.Workflows) == 0 || req.CodeRevision == "" || req.LLMModel == "" {
+ if len(req.Workflows) == 0 || req.CodeRevision == "" {
return nil, fmt.Errorf("invalid request")
}
+ for _, flow := range req.Workflows {
+ if flow.Type == "" || flow.Name == "" || flow.LLMModel == "" {
+ return nil, fmt.Errorf("invalid request")
+ }
+ }
if err := aidb.UpdateWorkflows(ctx, req.Workflows); err != nil {
return nil, fmt.Errorf("failed UpdateWorkflows: %w", err)
}
diff --git a/dashboard/app/ai_test.go b/dashboard/app/ai_test.go
index bd3935b18..b775b2a89 100644
--- a/dashboard/app/ai_test.go
+++ b/dashboard/app/ai_test.go
@@ -63,11 +63,10 @@ func TestAIBugWorkflows(t *testing.T) {
_, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-foo"},
- {Type: "patching", Name: "patching-bar"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-foo", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -77,25 +76,23 @@ func TestAIBugWorkflows(t *testing.T) {
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-bar"},
- {Type: "patching", Name: "patching-baz"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-baz", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err)
_, err = c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "patching", Name: "patching"},
- {Type: "patching", Name: "patching-bar"},
- {Type: "patching", Name: "patching-qux"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
- {Type: "assessment-kcsan", Name: "assessment-kcsan-foo"},
+ {Type: "patching", Name: "patching", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-bar", LLMModel: "smarty"},
+ {Type: "patching", Name: "patching-qux", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan-foo", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -117,9 +114,8 @@ func TestAIJob(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err)
@@ -137,9 +133,8 @@ func TestAIJob(t *testing.T) {
resp2, err2 := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: "assessment-kcsan", Name: "assessment-kcsan"},
+ {Type: "assessment-kcsan", Name: "assessment-kcsan", LLMModel: "smarty"},
},
})
require.NoError(t, err2)
@@ -214,9 +209,8 @@ func TestAIAssessmentKCSAN(t *testing.T) {
resp, err := c.aiClient.AIJobPoll(&dashapi.AIJobPollReq{
CodeRevision: prog.GitRevision,
- LLMModel: "smarty",
Workflows: []dashapi.AIWorkflow{
- {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN)},
+ {Type: ai.WorkflowAssessmentKCSAN, Name: string(ai.WorkflowAssessmentKCSAN), LLMModel: "smarty"},
},
})
require.NoError(t, err)
diff --git a/dashboard/app/aidb/crud.go b/dashboard/app/aidb/crud.go
index a01c370b6..4f7e93f6a 100644
--- a/dashboard/app/aidb/crud.go
+++ b/dashboard/app/aidb/crud.go
@@ -128,11 +128,13 @@ func StartJob(ctx context.Context, req *dashapi.AIJobPollReq) (*Job, error) {
}
job = jobs[0]
}
- job.Started = spanner.NullTime{
- Time: TimeNow(ctx),
- Valid: true,
+ job.Started = spanner.NullTime{Time: TimeNow(ctx), Valid: true}
+ for _, flow := range req.Workflows {
+ if job.Workflow == flow.Name {
+ job.LLMModel = flow.LLMModel
+ break
+ }
}
- job.LLMModel = req.LLMModel
job.CodeRevision = req.CodeRevision
mut, err := spanner.InsertOrUpdateStruct("Jobs", job)
if err != nil {