aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/execute.go8
-rw-r--r--pkg/aflow/flow.go12
-rw-r--r--pkg/aflow/flow/assessment/kcsan.go1
-rw-r--r--pkg/aflow/flow/assessment/moderation.go1
-rw-r--r--pkg/aflow/flow/patching/patching.go1
-rw-r--r--pkg/aflow/flow_test.go4
6 files changed, 21 insertions, 6 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 6e724988e..4133c01c9 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -18,9 +18,11 @@ import (
"google.golang.org/genai"
)
-// https://ai.google.dev/gemini-api/docs/models
-const DefaultModel = "gemini-3-pro-preview"
-
+// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
+// The model argument sets Gemini model name to execute the workflow.
+// The workdir argument should point to a dir owned by aflow to store private data,
+// it can be shared across parallel executions in the same process, and preferably
+// preserved across process restarts for caching purposes.
func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any,
cache *Cache, onEvent onEvent) (map[string]any, error) {
if err := flow.checkInputs(inputs); err != nil {
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
index 6325b2fd2..b4cbe2201 100644
--- a/pkg/aflow/flow.go
+++ b/pkg/aflow/flow.go
@@ -22,8 +22,9 @@ import (
// Actions are nodes of the graph, and they consume/produce some named values
// (input/output fields, and intermediate values consumed by other actions).
type Flow struct {
- Name string // Empty for the main workflow for the workflow type.
- Root Action
+ Name string // Empty for the main workflow for the workflow type.
+ Model string // The default Gemini model name to execute this workflow.
+ Root Action
*FlowType
}
@@ -35,6 +36,12 @@ type FlowType struct {
extractOutputs func(map[string]any) map[string]any
}
+// See https://ai.google.dev/gemini-api/docs/models
+const (
+ BestExpensiveModel = "gemini-3-pro-preview"
+ GoodBalancedModel = "gemini-3-flash-preview"
+)
+
var Flows = make(map[string]*Flow)
// Register a workflow type (characterized by Inputs and Outputs),
@@ -88,6 +95,7 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
actions: make(map[string]bool),
state: make(map[string]*varState),
}
+ ctx.requireNotEmpty(flow.Name, "Model", flow.Model)
provideOutputs[Inputs](ctx, "flow inputs")
flow.Root.verify(ctx)
requireInputs[Outputs](ctx, "flow outputs")
diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go
index 92033aff7..67d695eb9 100644
--- a/pkg/aflow/flow/assessment/kcsan.go
+++ b/pkg/aflow/flow/assessment/kcsan.go
@@ -23,6 +23,7 @@ func init() {
ai.WorkflowAssessmentKCSAN,
"assess if a KCSAN report is about a benign race that only needs annotations or not",
&aflow.Flow{
+ Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
kernel.Checkout,
diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go
index 4b78f901c..8d9ac4a0b 100644
--- a/pkg/aflow/flow/assessment/moderation.go
+++ b/pkg/aflow/flow/assessment/moderation.go
@@ -33,6 +33,7 @@ func init() {
ai.WorkflowModeration,
"assess if a bug report is consistent and actionable or not",
&aflow.Flow{
+ Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
aflow.NewFuncAction("extract-crash-type", extractCrashType),
diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go
index ef10f1a2e..766cf089f 100644
--- a/pkg/aflow/flow/patching/patching.go
+++ b/pkg/aflow/flow/patching/patching.go
@@ -43,6 +43,7 @@ func init() {
ai.WorkflowPatching,
"generate a kernel patch fixing a provided bug reproducer",
&aflow.Flow{
+ Model: aflow.BestExpensiveModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
baseCommitPicker,
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 8ab8016f3..19abce6b1 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -76,7 +76,8 @@ func TestWorkflow(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Name: "flow",
+ Name: "flow",
+ Model: "model",
Root: NewPipeline(
NewFuncAction("func-action",
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -530,6 +531,7 @@ func TestNoInputs(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
+ Model: "model",
Root: NewFuncAction("func-action",
func(ctx *Context, args flowInputs) (flowOutputs, error) {
return flowOutputs{}, nil