diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/execute.go | 66 | ||||
| -rw-r--r-- | pkg/aflow/flow.go | 12 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/kcsan.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/flow/assessment/moderation.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/flow/patching/patching.go | 4 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 42 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 19 | ||||
| -rw-r--r-- | pkg/aflow/trajectory/trajectory.go | 1 |
8 files changed, 94 insertions, 54 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 482a58fb4..96f15c13b 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -20,7 +20,8 @@ import ( ) // Execute executes the given AI workflow with provided inputs and returns workflow outputs. -// The model argument sets Gemini model name to execute the workflow. +// The model argument overrides Gemini models used to execute LLM agents, +// if not set, then default models for each agent are used. // The workdir argument should point to a dir owned by aflow to store private data, // it can be shared across parallel executions in the same process, and preferably // preserved across process restarts for caching purposes. @@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s return nil, fmt.Errorf("flow inputs are missing: %w", err) } ctx := &Context{ - Context: c, - Workdir: osutil.Abs(workdir), - cache: cache, - state: maps.Clone(inputs), - onEvent: onEvent, + Context: c, + Workdir: osutil.Abs(workdir), + llmModel: model, + cache: cache, + state: maps.Clone(inputs), + onEvent: onEvent, } defer ctx.close() if s := c.Value(stubContextKey); s != nil { @@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s ctx.timeNow = time.Now } if ctx.generateContent == nil { - var err error - ctx.generateContent, err = contentGenerator(c, model) - if err != nil { - return nil, err - } + ctx.generateContent = ctx.generateContentGemini } span := &trajectory.Span{ Type: trajectory.SpanFlow, @@ -91,9 +89,7 @@ type flowError struct { } type ( - onEvent func(*trajectory.Span) error - generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) ( - *genai.GenerateContentResponse, error) + onEvent func(*trajectory.Span) error contextKeyType int ) @@ -105,7 +101,8 @@ var ( stubContextKey = contextKeyType(1) ) -func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) { +func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, + req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" createClientOnce.Do(func() { if os.Getenv("GOOGLE_API_KEY") == "" { @@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e " (see https://ai.google.dev/gemini-api/docs/api-key)") return } - client, createClientErr = genai.NewClient(ctx, nil) + client, createClientErr = genai.NewClient(ctx.Context, nil) if createClientErr != nil { return } - for m, err := range client.Models.All(ctx) { + for m, err := range client.Models.All(ctx.Context) { if err != nil { createClientErr = err return @@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { - if thinking { - cfg.ThinkingConfig = &genai.ThinkingConfig{ - // We capture them in the trajectory for analysis. - IncludeThoughts: true, - // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). - // See https://ai.google.dev/gemini-api/docs/thinking#set-budget - // However, thoughts output also consumes total output token budget. - // We may consider adjusting ThinkingLevel parameter. - ThinkingBudget: genai.Ptr[int32](-1), - } + if thinking { + cfg.ThinkingConfig = &genai.ThinkingConfig{ + // We capture them in the trajectory for analysis. + IncludeThoughts: true, + // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). + // See https://ai.google.dev/gemini-api/docs/thinking#set-budget + // However, thoughts output also consumes total output token budget. + // We may consider adjusting ThinkingLevel parameter. + ThinkingBudget: genai.Ptr[int32](-1), } - return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg) - }, nil + } + return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg) } type Context struct { Context context.Context Workdir string + llmModel string cache *Cache cachedDirs []string state map[string]any @@ -164,7 +160,15 @@ type Context struct { type stubContext struct { timeNow func() time.Time - generateContent generateContentFunc + generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) ( + *genai.GenerateContentResponse, error) +} + +func (ctx *Context) modelName(model string) string { + if ctx.llmModel != "" { + return ctx.llmModel + } + return model } func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) { diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go index b4cbe2201..6325b2fd2 100644 --- a/pkg/aflow/flow.go +++ b/pkg/aflow/flow.go @@ -22,9 +22,8 @@ import ( // Actions are nodes of the graph, and they consume/produce some named values // (input/output fields, and intermediate values consumed by other actions). type Flow struct { - Name string // Empty for the main workflow for the workflow type. - Model string // The default Gemini model name to execute this workflow. - Root Action + Name string // Empty for the main workflow for the workflow type. + Root Action *FlowType } @@ -36,12 +35,6 @@ type FlowType struct { extractOutputs func(map[string]any) map[string]any } -// See https://ai.google.dev/gemini-api/docs/models -const ( - BestExpensiveModel = "gemini-3-pro-preview" - GoodBalancedModel = "gemini-3-flash-preview" -) - var Flows = make(map[string]*Flow) // Register a workflow type (characterized by Inputs and Outputs), @@ -95,7 +88,6 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error { actions: make(map[string]bool), state: make(map[string]*varState), } - ctx.requireNotEmpty(flow.Name, "Model", flow.Model) provideOutputs[Inputs](ctx, "flow inputs") flow.Root.verify(ctx) requireInputs[Outputs](ctx, "flow outputs") diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go index 67d695eb9..6bfc7bb12 100644 --- a/pkg/aflow/flow/assessment/kcsan.go +++ b/pkg/aflow/flow/assessment/kcsan.go @@ -23,7 +23,6 @@ func init() { ai.WorkflowAssessmentKCSAN, "assess if a KCSAN report is about a benign race that only needs annotations or not", &aflow.Flow{ - Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ kernel.Checkout, @@ -31,6 +30,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "expert", + Model: aflow.GoodBalancedModel, Reply: "Explanation", Outputs: aflow.LLMOutputs[struct { Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."` diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go index 8d9ac4a0b..b13ee1e7d 100644 --- a/pkg/aflow/flow/assessment/moderation.go +++ b/pkg/aflow/flow/assessment/moderation.go @@ -33,7 +33,6 @@ func init() { ai.WorkflowModeration, "assess if a bug report is consistent and actionable or not", &aflow.Flow{ - Model: aflow.GoodBalancedModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ aflow.NewFuncAction("extract-crash-type", extractCrashType), @@ -42,6 +41,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "expert", + Model: aflow.GoodBalancedModel, Reply: "Explanation", Outputs: aflow.LLMOutputs[struct { Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."` diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go index 766cf089f..856962e6c 100644 --- a/pkg/aflow/flow/patching/patching.go +++ b/pkg/aflow/flow/patching/patching.go @@ -43,7 +43,6 @@ func init() { ai.WorkflowPatching, "generate a kernel patch fixing a provided bug reproducer", &aflow.Flow{ - Model: aflow.BestExpensiveModel, Root: &aflow.Pipeline{ Actions: []aflow.Action{ baseCommitPicker, @@ -54,6 +53,7 @@ func init() { codesearcher.PrepareIndex, &aflow.LLMAgent{ Name: "debugger", + Model: aflow.BestExpensiveModel, Reply: "BugExplanation", Temperature: 1, Instruction: debuggingInstruction, @@ -62,6 +62,7 @@ func init() { }, &aflow.LLMAgent{ Name: "diff-generator", + Model: aflow.BestExpensiveModel, Reply: "PatchDiff", Temperature: 1, Instruction: diffInstruction, @@ -70,6 +71,7 @@ func init() { }, &aflow.LLMAgent{ Name: "description-generator", + Model: aflow.BestExpensiveModel, Reply: "PatchDescription", Temperature: 1, Instruction: descriptionInstruction, diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 5c038d7c6..97e2dd93a 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Name: "flow", - Model: "model", + Name: "flow", Root: NewPipeline( NewFuncAction("func-action", func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { @@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "smarty", + Model: "model1", Reply: "OutFoo", Outputs: LLMOutputs[agentOutputs](), Temperature: 0, @@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) { }), &LLMAgent{ Name: "swarm", + Model: "model2", Reply: "OutSwarm", Candidates: 2, Outputs: LLMOutputs[swarmOutputs](), @@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) { }, &LLMAgent{ Name: "aggregator", + Model: "model3", Reply: "OutAggregator", Temperature: 0, Instruction: "Aggregate!", @@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) { stubTime = stubTime.Add(time.Second) return stubTime }, - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { replySeq++ if replySeq < 4 { + assert.Equal(t, model, "model1") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) @@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) { assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") } else if replySeq < 8 { + assert.Equal(t, model, "model2") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+ llmOutputsInstruction, genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 1) assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results") } else { + assert.Equal(t, model, "model3") assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser)) assert.Equal(t, len(cfg.Tools), 0) } @@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, Prompt: "Prompt: baz func-output", @@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), }, { @@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(5 * time.Second), Finished: startTime.Add(6 * time.Second), Thoughts: "I am thinking I need to call some tools", @@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), }, { @@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(11 * time.Second), Finished: startTime.Add(12 * time.Second), Thoughts: "Completly blank.Whatever.", @@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), }, { @@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "smarty", + Model: "model1", Started: startTime.Add(15 * time.Second), Finished: startTime.Add(16 * time.Second), }, @@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "smarty", + Model: "model1", Started: startTime.Add(4 * time.Second), Finished: startTime.Add(17 * time.Second), Instruction: "You are smarty. baz" + llmOutputsInstruction, @@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), }, { @@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(22 * time.Second), Finished: startTime.Add(23 * time.Second), }, @@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), }, { @@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(26 * time.Second), Finished: startTime.Add(27 * time.Second), }, @@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(21 * time.Second), Finished: startTime.Add(28 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, Prompt: "Prompt: baz", @@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), }, { @@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(30 * time.Second), Finished: startTime.Add(31 * time.Second), }, @@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), }, { @@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) { Nesting: 3, Type: trajectory.SpanLLM, Name: "swarm", + Model: "model2", Started: startTime.Add(34 * time.Second), Finished: startTime.Add(35 * time.Second), }, @@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanAgent, Name: "swarm", + Model: "model2", Started: startTime.Add(29 * time.Second), Finished: startTime.Add(36 * time.Second), Instruction: "Do something. baz" + llmOutputsInstruction, @@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Instruction: "Aggregate!", Prompt: `Prompt: baz @@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), }, { @@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) { Nesting: 2, Type: trajectory.SpanLLM, Name: "aggregator", + Model: "model3", Started: startTime.Add(39 * time.Second), Finished: startTime.Add(40 * time.Second), }, @@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) { Nesting: 1, Type: trajectory.SpanAgent, Name: "aggregator", + Model: "model3", Started: startTime.Add(38 * time.Second), Finished: startTime.Add(41 * time.Second), Instruction: "Aggregate!", @@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) { expected = expected[1:] return nil } - res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.NoError(t, err) require.Equal(t, replySeq, 8) require.Equal(t, res, expectedOutputs) @@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) { flows := make(map[string]*Flow) err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ { - Model: "model", Root: NewFuncAction("func-action", func(ctx *Context, args flowInputs) (flowOutputs, error) { return flowOutputs{}, nil @@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) { }) require.NoError(t, err) stub := &stubContext{ - generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( *genai.GenerateContentResponse, error) { return nil, nil }, @@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) { cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) require.NoError(t, err) onEvent := func(span *trajectory.Span) error { return nil } - _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.Equal(t, err.Error(), "flow inputs are missing:"+ " field InBar is not present when converting map to aflow.flowInputs") } diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index 02d3bca85..3c416b37c 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -18,6 +18,9 @@ import ( type LLMAgent struct { // For logging/debugging. Name string + // The default Gemini model name to execute this workflow. + // Use the consts defined below. + Model string // Name of the state variable to store the final reply of the agent. // These names can be used in subsequent action instructions/prompts, // and as final workflow outputs. @@ -43,6 +46,13 @@ type LLMAgent struct { Tools []Tool } +// Consts to use for LLMAgent.Model. +// See https://ai.google.dev/gemini-api/docs/models +const ( + BestExpensiveModel = "gemini-3-pro-preview" + GoodBalancedModel = "gemini-3-flash-preview" +) + // Tool represents a custom tool an LLMAgent can invoke. // Use NewFuncTool to create function-based tools. type Tool interface { @@ -134,6 +144,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) { Name: a.Name, Instruction: instruction, Prompt: formatTemplate(a.Prompt, ctx.state), + Model: ctx.modelName(a.Model), } if err := ctx.startSpan(span); err != nil { return "", nil, err @@ -152,8 +163,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)} for { reqSpan := &trajectory.Span{ - Type: trajectory.SpanLLM, - Name: a.Name, + Type: trajectory.SpanLLM, + Name: a.Name, + Model: ctx.modelName(a.Model), } if err := ctx.startSpan(reqSpan); err != nil { return "", nil, err @@ -278,7 +290,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi req []*genai.Content) (*genai.GenerateContentResponse, error) { backoff := time.Second for try := 0; ; try++ { - resp, err := ctx.generateContent(cfg, req) + resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req) var apiErr genai.APIError if err != nil && try < 100 && errors.As(err, &apiErr) && apiErr.Code == http.StatusServiceUnavailable { @@ -292,6 +304,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi func (a *LLMAgent) verify(vctx *verifyContext) { vctx.requireNotEmpty(a.Name, "Name", a.Name) + vctx.requireNotEmpty(a.Name, "Model", a.Model) vctx.requireNotEmpty(a.Name, "Reply", a.Reply) if temp, ok := a.Temperature.(int); ok { a.Temperature = float32(temp) diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go index 49e36933b..aa9558708 100644 --- a/pkg/aflow/trajectory/trajectory.go +++ b/pkg/aflow/trajectory/trajectory.go @@ -20,6 +20,7 @@ type Span struct { Nesting int Type SpanType Name string // flow/action/tool name + Model string // LLM model name for agent/LLM spans Started time.Time Finished time.Time Error string // relevant if Finished is set |
