aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/llm_agent.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-15 20:53:57 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-20 21:12:57 +0000
commit7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch)
tree2ccbc85132a170d046837de6bdd8be3317f94060 /pkg/aflow/llm_agent.go
parent2494e18d5ced59fc7f0522749041e499d3082a9e (diff)
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow. We can have some more complex tasks during patch generation with the most elaborate model, but also some simpler ones with less elaborate models.
Diffstat (limited to 'pkg/aflow/llm_agent.go')
-rw-r--r--pkg/aflow/llm_agent.go19
1 files changed, 16 insertions, 3 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 02d3bca85..3c416b37c 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -18,6 +18,9 @@ import (
type LLMAgent struct {
// For logging/debugging.
Name string
+ // The default Gemini model name to execute this workflow.
+ // Use the consts defined below.
+ Model string
// Name of the state variable to store the final reply of the agent.
// These names can be used in subsequent action instructions/prompts,
// and as final workflow outputs.
@@ -43,6 +46,13 @@ type LLMAgent struct {
Tools []Tool
}
+// Consts to use for LLMAgent.Model.
+// See https://ai.google.dev/gemini-api/docs/models
+const (
+ BestExpensiveModel = "gemini-3-pro-preview"
+ GoodBalancedModel = "gemini-3-flash-preview"
+)
+
// Tool represents a custom tool an LLMAgent can invoke.
// Use NewFuncTool to create function-based tools.
type Tool interface {
@@ -134,6 +144,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
Name: a.Name,
Instruction: instruction,
Prompt: formatTemplate(a.Prompt, ctx.state),
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(span); err != nil {
return "", nil, err
@@ -152,8 +163,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
for {
reqSpan := &trajectory.Span{
- Type: trajectory.SpanLLM,
- Name: a.Name,
+ Type: trajectory.SpanLLM,
+ Name: a.Name,
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(reqSpan); err != nil {
return "", nil, err
@@ -278,7 +290,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
req []*genai.Content) (*genai.GenerateContentResponse, error) {
backoff := time.Second
for try := 0; ; try++ {
- resp, err := ctx.generateContent(cfg, req)
+ resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
@@ -292,6 +304,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
func (a *LLMAgent) verify(vctx *verifyContext) {
vctx.requireNotEmpty(a.Name, "Name", a.Name)
+ vctx.requireNotEmpty(a.Name, "Model", a.Model)
vctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
a.Temperature = float32(temp)