aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/execute.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-15 20:53:57 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-20 21:12:57 +0000
commit7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch)
tree2ccbc85132a170d046837de6bdd8be3317f94060 /pkg/aflow/execute.go
parent2494e18d5ced59fc7f0522749041e499d3082a9e (diff)
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow. We can have some more complex tasks during patch generation with the most elaborate model, but also some simpler ones with less elaborate models.
Diffstat (limited to 'pkg/aflow/execute.go')
-rw-r--r--pkg/aflow/execute.go66
1 files changed, 35 insertions, 31 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 482a58fb4..96f15c13b 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -20,7 +20,8 @@ import (
)
// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
-// The model argument sets Gemini model name to execute the workflow.
+// The model argument overrides Gemini models used to execute LLM agents,
+// if not set, then default models for each agent are used.
// The workdir argument should point to a dir owned by aflow to store private data,
// it can be shared across parallel executions in the same process, and preferably
// preserved across process restarts for caching purposes.
@@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
return nil, fmt.Errorf("flow inputs are missing: %w", err)
}
ctx := &Context{
- Context: c,
- Workdir: osutil.Abs(workdir),
- cache: cache,
- state: maps.Clone(inputs),
- onEvent: onEvent,
+ Context: c,
+ Workdir: osutil.Abs(workdir),
+ llmModel: model,
+ cache: cache,
+ state: maps.Clone(inputs),
+ onEvent: onEvent,
}
defer ctx.close()
if s := c.Value(stubContextKey); s != nil {
@@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
ctx.timeNow = time.Now
}
if ctx.generateContent == nil {
- var err error
- ctx.generateContent, err = contentGenerator(c, model)
- if err != nil {
- return nil, err
- }
+ ctx.generateContent = ctx.generateContentGemini
}
span := &trajectory.Span{
Type: trajectory.SpanFlow,
@@ -91,9 +89,7 @@ type flowError struct {
}
type (
- onEvent func(*trajectory.Span) error
- generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) (
- *genai.GenerateContentResponse, error)
+ onEvent func(*trajectory.Span) error
contextKeyType int
)
@@ -105,7 +101,8 @@ var (
stubContextKey = contextKeyType(1)
)
-func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) {
+func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
+ req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
createClientOnce.Do(func() {
if os.Getenv("GOOGLE_API_KEY") == "" {
@@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
" (see https://ai.google.dev/gemini-api/docs/api-key)")
return
}
- client, createClientErr = genai.NewClient(ctx, nil)
+ client, createClientErr = genai.NewClient(ctx.Context, nil)
if createClientErr != nil {
return
}
- for m, err := range client.Models.All(ctx) {
+ for m, err := range client.Models.All(ctx.Context) {
if err != nil {
createClientErr = err
return
@@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) {
- if thinking {
- cfg.ThinkingConfig = &genai.ThinkingConfig{
- // We capture them in the trajectory for analysis.
- IncludeThoughts: true,
- // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
- // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
- // However, thoughts output also consumes total output token budget.
- // We may consider adjusting ThinkingLevel parameter.
- ThinkingBudget: genai.Ptr[int32](-1),
- }
+ if thinking {
+ cfg.ThinkingConfig = &genai.ThinkingConfig{
+ // We capture them in the trajectory for analysis.
+ IncludeThoughts: true,
+ // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
+ // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
+ // However, thoughts output also consumes total output token budget.
+ // We may consider adjusting ThinkingLevel parameter.
+ ThinkingBudget: genai.Ptr[int32](-1),
}
- return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg)
- }, nil
+ }
+ return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg)
}
type Context struct {
Context context.Context
Workdir string
+ llmModel string
cache *Cache
cachedDirs []string
state map[string]any
@@ -164,7 +160,15 @@ type Context struct {
type stubContext struct {
timeNow func() time.Time
- generateContent generateContentFunc
+ generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) (
+ *genai.GenerateContentResponse, error)
+}
+
+func (ctx *Context) modelName(model string) string {
+ if ctx.llmModel != "" {
+ return ctx.llmModel
+ }
+ return model
}
func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) {