diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/execute.go | 8 | ||||
| -rw-r--r-- | pkg/aflow/flow.go | 12 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/kcsan.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/moderation.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow/patching/patching.go | 1 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 4 |
6 files changed, 21 insertions, 6 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 6e724988e..4133c01c9 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -18,9 +18,11 @@ import ( "google.golang.org/genai" ) -// https://ai.google.dev/gemini-api/docs/models -const DefaultModel = "gemini-3-pro-preview" - +// Execute executes the given AI workflow with provided inputs and returns workflow outputs. +// The model argument sets Gemini model name to execute the workflow. +// The workdir argument should point to a dir owned by aflow to store private data, +// it can be shared across parallel executions in the same process, and preferably +// preserved across process restarts for caching purposes. func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any, cache *Cache, onEvent onEvent) (map[string]any, error) { if err := flow.checkInputs(inputs); err != nil { diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go index 6325b2fd2..b4cbe2201 100644 --- a/pkg/aflow/flow.go +++ b/pkg/aflow/flow.go @@ -22,8 +22,9 @@ import ( // Actions are nodes of the graph, and they consume/produce some named values // (input/output fields, and intermediate values consumed by other actions). type Flow struct { - Name string // Empty for the main workflow for the workflow type. - Root Action + Name string // Empty for the main workflow for the workflow type. + Model string // The default Gemini model name to execute this workflow. + Root Action *FlowType } @@ -35,6 +36,12 @@ type FlowType struct { extractOutputs func(map[string]any) map[string]any } +// See https://ai.google.dev/gemini-api/docs/models +const ( + BestExpensiveModel = "gemini-3-pro-preview" + GoodBalancedModel = "gemini-3-flash-preview" +) + var Flows = make(map[string]*Flow) // Register a workflow type (characterized by Inputs and Outputs), @@ -88,6 +95,7 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error { actions: make(map[string]bool), state: make(map[string]*varState), } + ctx.requireNotEmpty(flow.Name, "Model", flow.Model) provideOutputs[Inputs](ctx, "flow inputs") flow.Root.verify(ctx) requireInputs[Outputs](ctx, "flow outputs") diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go index 92033aff7..67d695eb9 100644 --- a/pkg/aflow/flow/assessment/kcsan.go +++ b/pkg/aflow/flow/assessment/kcsan.go @@ -23,6 +23,7 @@ func init() { ai.WorkflowAssessmentKCSAN, "assess if a KCSAN report is about a benign race that only needs annotations or not", &aflow.Flow{ + Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ kernel.Checkout, diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go index 4b78f901c..8d9ac4a0b 100644 --- a/pkg/aflow/flow/assessment/moderation.go +++ b/pkg/aflow/flow/assessment/moderation.go @@ -33,6 +33,7 @@ func init() { ai.WorkflowModeration, "assess if a bug report is consistent and actionable or not", &aflow.Flow{ + Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ aflow.NewFuncAction("extract-crash-type", extractCrashType), diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go index ef10f1a2e..766cf089f 100644 --- a/pkg/aflow/flow/patching/patching.go +++ b/pkg/aflow/flow/patching/patching.go @@ -43,6 +43,7 @@ func init() { ai.WorkflowPatching, "generate a kernel patch fixing a provided bug reproducer", &aflow.Flow{ + Model: aflow.BestExpensiveModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ baseCommitPicker, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 8ab8016f3..19abce6b1 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -76,7 +76,8 @@ func TestWorkflow(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Name: "flow", + Name: "flow", + Model: "model", Root: NewPipeline( NewFuncAction("func-action", func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { @@ -530,6 +531,7 @@ func TestNoInputs(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { + Model: "model", Root: NewFuncAction("func-action", func(ctx *Context, args flowInputs) (flowOutputs, error) { return flowOutputs{}, nil |
