diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-15 20:53:57 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-20 21:12:57 +0000 |
| commit | 7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch) | |
| tree | 2ccbc85132a170d046837de6bdd8be3317f94060 /pkg/aflow/execute.go | |
| parent | 2494e18d5ced59fc7f0522749041e499d3082a9e (diff) | |
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow.
We can have some more complex tasks during patch generation with the most elaborate model,
but also some simpler ones with less elaborate models.
Diffstat (limited to 'pkg/aflow/execute.go')
| -rw-r--r-- | pkg/aflow/execute.go | 66 |
1 files changed, 35 insertions, 31 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 482a58fb4..96f15c13b 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -20,7 +20,8 @@ import ( ) // Execute executes the given AI workflow with provided inputs and returns workflow outputs. -// The model argument sets Gemini model name to execute the workflow. +// The model argument overrides Gemini models used to execute LLM agents, +// if not set, then default models for each agent are used. // The workdir argument should point to a dir owned by aflow to store private data, // it can be shared across parallel executions in the same process, and preferably // preserved across process restarts for caching purposes. @@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s return nil, fmt.Errorf("flow inputs are missing: %w", err) } ctx := &Context{ - Context: c, - Workdir: osutil.Abs(workdir), - cache: cache, - state: maps.Clone(inputs), - onEvent: onEvent, + Context: c, + Workdir: osutil.Abs(workdir), + llmModel: model, + cache: cache, + state: maps.Clone(inputs), + onEvent: onEvent, } defer ctx.close() if s := c.Value(stubContextKey); s != nil { @@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s ctx.timeNow = time.Now } if ctx.generateContent == nil { - var err error - ctx.generateContent, err = contentGenerator(c, model) - if err != nil { - return nil, err - } + ctx.generateContent = ctx.generateContentGemini } span := &trajectory.Span{ Type: trajectory.SpanFlow, @@ -91,9 +89,7 @@ type flowError struct { } type ( - onEvent func(*trajectory.Span) error - generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) ( - *genai.GenerateContentResponse, error) + onEvent func(*trajectory.Span) error contextKeyType int ) @@ -105,7 +101,8 @@ var ( stubContextKey = contextKeyType(1) ) -func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) { +func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, + req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" createClientOnce.Do(func() { if os.Getenv("GOOGLE_API_KEY") == "" { @@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e " (see https://ai.google.dev/gemini-api/docs/api-key)") return } - client, createClientErr = genai.NewClient(ctx, nil) + client, createClientErr = genai.NewClient(ctx.Context, nil) if createClientErr != nil { return } - for m, err := range client.Models.All(ctx) { + for m, err := range client.Models.All(ctx.Context) { if err != nil { createClientErr = err return @@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { - if thinking { - cfg.ThinkingConfig = &genai.ThinkingConfig{ - // We capture them in the trajectory for analysis. - IncludeThoughts: true, - // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). - // See https://ai.google.dev/gemini-api/docs/thinking#set-budget - // However, thoughts output also consumes total output token budget. - // We may consider adjusting ThinkingLevel parameter. - ThinkingBudget: genai.Ptr[int32](-1), - } + if thinking { + cfg.ThinkingConfig = &genai.ThinkingConfig{ + // We capture them in the trajectory for analysis. + IncludeThoughts: true, + // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). + // See https://ai.google.dev/gemini-api/docs/thinking#set-budget + // However, thoughts output also consumes total output token budget. + // We may consider adjusting ThinkingLevel parameter. + ThinkingBudget: genai.Ptr[int32](-1), } - return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg) - }, nil + } + return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg) } type Context struct { Context context.Context Workdir string + llmModel string cache *Cache cachedDirs []string state map[string]any @@ -164,7 +160,15 @@ type Context struct { type stubContext struct { timeNow func() time.Time - generateContent generateContentFunc + generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) ( + *genai.GenerateContentResponse, error) +} + +func (ctx *Context) modelName(model string) string { + if ctx.llmModel != "" { + return ctx.llmModel + } + return model } func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) { |
