aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-15 20:53:57 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-20 21:12:57 +0000
commit7f5908e77ae0e7fef4b7901341b8c2c4bbb74b28 (patch)
tree2ccbc85132a170d046837de6bdd8be3317f94060 /pkg
parent2494e18d5ced59fc7f0522749041e499d3082a9e (diff)
pkg/aflow: make LLM model per-agent rather than per-flow
Having LLM model per-agent is even more flexible than per-flow. We can have some more complex tasks during patch generation with the most elaborate model, but also some simpler ones with less elaborate models.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/execute.go66
-rw-r--r--pkg/aflow/flow.go12
-rw-r--r--pkg/aflow/flow/assessment/kcsan.go2
-rw-r--r--pkg/aflow/flow/assessment/moderation.go2
-rw-r--r--pkg/aflow/flow/patching/patching.go4
-rw-r--r--pkg/aflow/flow_test.go42
-rw-r--r--pkg/aflow/llm_agent.go19
-rw-r--r--pkg/aflow/trajectory/trajectory.go1
8 files changed, 94 insertions, 54 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 482a58fb4..96f15c13b 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -20,7 +20,8 @@ import (
)
// Execute executes the given AI workflow with provided inputs and returns workflow outputs.
-// The model argument sets Gemini model name to execute the workflow.
+// The model argument overrides Gemini models used to execute LLM agents,
+// if not set, then default models for each agent are used.
// The workdir argument should point to a dir owned by aflow to store private data,
// it can be shared across parallel executions in the same process, and preferably
// preserved across process restarts for caching purposes.
@@ -30,11 +31,12 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
return nil, fmt.Errorf("flow inputs are missing: %w", err)
}
ctx := &Context{
- Context: c,
- Workdir: osutil.Abs(workdir),
- cache: cache,
- state: maps.Clone(inputs),
- onEvent: onEvent,
+ Context: c,
+ Workdir: osutil.Abs(workdir),
+ llmModel: model,
+ cache: cache,
+ state: maps.Clone(inputs),
+ onEvent: onEvent,
}
defer ctx.close()
if s := c.Value(stubContextKey); s != nil {
@@ -44,11 +46,7 @@ func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[s
ctx.timeNow = time.Now
}
if ctx.generateContent == nil {
- var err error
- ctx.generateContent, err = contentGenerator(c, model)
- if err != nil {
- return nil, err
- }
+ ctx.generateContent = ctx.generateContentGemini
}
span := &trajectory.Span{
Type: trajectory.SpanFlow,
@@ -91,9 +89,7 @@ type flowError struct {
}
type (
- onEvent func(*trajectory.Span) error
- generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) (
- *genai.GenerateContentResponse, error)
+ onEvent func(*trajectory.Span) error
contextKeyType int
)
@@ -105,7 +101,8 @@ var (
stubContextKey = contextKeyType(1)
)
-func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) {
+func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
+ req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
createClientOnce.Do(func() {
if os.Getenv("GOOGLE_API_KEY") == "" {
@@ -113,11 +110,11 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
" (see https://ai.google.dev/gemini-api/docs/api-key)")
return
}
- client, createClientErr = genai.NewClient(ctx, nil)
+ client, createClientErr = genai.NewClient(ctx.Context, nil)
if createClientErr != nil {
return
}
- for m, err := range client.Models.All(ctx) {
+ for m, err := range client.Models.All(ctx.Context) {
if err != nil {
createClientErr = err
return
@@ -134,25 +131,24 @@ func contentGenerator(ctx context.Context, model string) (generateContentFunc, e
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) {
- if thinking {
- cfg.ThinkingConfig = &genai.ThinkingConfig{
- // We capture them in the trajectory for analysis.
- IncludeThoughts: true,
- // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
- // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
- // However, thoughts output also consumes total output token budget.
- // We may consider adjusting ThinkingLevel parameter.
- ThinkingBudget: genai.Ptr[int32](-1),
- }
+ if thinking {
+ cfg.ThinkingConfig = &genai.ThinkingConfig{
+ // We capture them in the trajectory for analysis.
+ IncludeThoughts: true,
+ // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
+ // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
+ // However, thoughts output also consumes total output token budget.
+ // We may consider adjusting ThinkingLevel parameter.
+ ThinkingBudget: genai.Ptr[int32](-1),
}
- return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg)
- }, nil
+ }
+ return client.Models.GenerateContent(ctx.Context, modelPrefix+model, req, cfg)
}
type Context struct {
Context context.Context
Workdir string
+ llmModel string
cache *Cache
cachedDirs []string
state map[string]any
@@ -164,7 +160,15 @@ type Context struct {
type stubContext struct {
timeNow func() time.Time
- generateContent generateContentFunc
+ generateContent func(string, *genai.GenerateContentConfig, []*genai.Content) (
+ *genai.GenerateContentResponse, error)
+}
+
+func (ctx *Context) modelName(model string) string {
+ if ctx.llmModel != "" {
+ return ctx.llmModel
+ }
+ return model
}
func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) {
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
index b4cbe2201..6325b2fd2 100644
--- a/pkg/aflow/flow.go
+++ b/pkg/aflow/flow.go
@@ -22,9 +22,8 @@ import (
// Actions are nodes of the graph, and they consume/produce some named values
// (input/output fields, and intermediate values consumed by other actions).
type Flow struct {
- Name string // Empty for the main workflow for the workflow type.
- Model string // The default Gemini model name to execute this workflow.
- Root Action
+ Name string // Empty for the main workflow for the workflow type.
+ Root Action
*FlowType
}
@@ -36,12 +35,6 @@ type FlowType struct {
extractOutputs func(map[string]any) map[string]any
}
-// See https://ai.google.dev/gemini-api/docs/models
-const (
- BestExpensiveModel = "gemini-3-pro-preview"
- GoodBalancedModel = "gemini-3-flash-preview"
-)
-
var Flows = make(map[string]*Flow)
// Register a workflow type (characterized by Inputs and Outputs),
@@ -95,7 +88,6 @@ func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
actions: make(map[string]bool),
state: make(map[string]*varState),
}
- ctx.requireNotEmpty(flow.Name, "Model", flow.Model)
provideOutputs[Inputs](ctx, "flow inputs")
flow.Root.verify(ctx)
requireInputs[Outputs](ctx, "flow outputs")
diff --git a/pkg/aflow/flow/assessment/kcsan.go b/pkg/aflow/flow/assessment/kcsan.go
index 67d695eb9..6bfc7bb12 100644
--- a/pkg/aflow/flow/assessment/kcsan.go
+++ b/pkg/aflow/flow/assessment/kcsan.go
@@ -23,7 +23,6 @@ func init() {
ai.WorkflowAssessmentKCSAN,
"assess if a KCSAN report is about a benign race that only needs annotations or not",
&aflow.Flow{
- Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
kernel.Checkout,
@@ -31,6 +30,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "expert",
+ Model: aflow.GoodBalancedModel,
Reply: "Explanation",
Outputs: aflow.LLMOutputs[struct {
Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."`
diff --git a/pkg/aflow/flow/assessment/moderation.go b/pkg/aflow/flow/assessment/moderation.go
index 8d9ac4a0b..b13ee1e7d 100644
--- a/pkg/aflow/flow/assessment/moderation.go
+++ b/pkg/aflow/flow/assessment/moderation.go
@@ -33,7 +33,6 @@ func init() {
ai.WorkflowModeration,
"assess if a bug report is consistent and actionable or not",
&aflow.Flow{
- Model: aflow.GoodBalancedModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
aflow.NewFuncAction("extract-crash-type", extractCrashType),
@@ -42,6 +41,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "expert",
+ Model: aflow.GoodBalancedModel,
Reply: "Explanation",
Outputs: aflow.LLMOutputs[struct {
Confident bool `jsonschema:"If you are confident in the verdict of the analysis or not."`
diff --git a/pkg/aflow/flow/patching/patching.go b/pkg/aflow/flow/patching/patching.go
index 766cf089f..856962e6c 100644
--- a/pkg/aflow/flow/patching/patching.go
+++ b/pkg/aflow/flow/patching/patching.go
@@ -43,7 +43,6 @@ func init() {
ai.WorkflowPatching,
"generate a kernel patch fixing a provided bug reproducer",
&aflow.Flow{
- Model: aflow.BestExpensiveModel,
Root: &aflow.Pipeline{
Actions: []aflow.Action{
baseCommitPicker,
@@ -54,6 +53,7 @@ func init() {
codesearcher.PrepareIndex,
&aflow.LLMAgent{
Name: "debugger",
+ Model: aflow.BestExpensiveModel,
Reply: "BugExplanation",
Temperature: 1,
Instruction: debuggingInstruction,
@@ -62,6 +62,7 @@ func init() {
},
&aflow.LLMAgent{
Name: "diff-generator",
+ Model: aflow.BestExpensiveModel,
Reply: "PatchDiff",
Temperature: 1,
Instruction: diffInstruction,
@@ -70,6 +71,7 @@ func init() {
},
&aflow.LLMAgent{
Name: "description-generator",
+ Model: aflow.BestExpensiveModel,
Reply: "PatchDescription",
Temperature: 1,
Instruction: descriptionInstruction,
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 5c038d7c6..97e2dd93a 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -93,8 +93,7 @@ func TestWorkflow(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Name: "flow",
- Model: "model",
+ Name: "flow",
Root: NewPipeline(
NewFuncAction("func-action",
func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
@@ -107,6 +106,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "smarty",
+ Model: "model1",
Reply: "OutFoo",
Outputs: LLMOutputs[agentOutputs](),
Temperature: 0,
@@ -143,6 +143,7 @@ func TestWorkflow(t *testing.T) {
}),
&LLMAgent{
Name: "swarm",
+ Model: "model2",
Reply: "OutSwarm",
Candidates: 2,
Outputs: LLMOutputs[swarmOutputs](),
@@ -152,6 +153,7 @@ func TestWorkflow(t *testing.T) {
},
&LLMAgent{
Name: "aggregator",
+ Model: "model3",
Reply: "OutAggregator",
Temperature: 0,
Instruction: "Aggregate!",
@@ -176,10 +178,11 @@ func TestWorkflow(t *testing.T) {
stubTime = stubTime.Add(time.Second)
return stubTime
},
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
replySeq++
if replySeq < 4 {
+ assert.Equal(t, model, "model1")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
@@ -190,11 +193,13 @@ func TestWorkflow(t *testing.T) {
assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
} else if replySeq < 8 {
+ assert.Equal(t, model, "model2")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Do something. baz"+
llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 1)
assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "set-results")
} else {
+ assert.Equal(t, model, "model3")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("Aggregate!", genai.RoleUser))
assert.Equal(t, len(cfg.Tools), 0)
}
@@ -437,6 +442,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz func-output",
@@ -446,6 +452,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
},
{
@@ -453,6 +460,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(5 * time.Second),
Finished: startTime.Add(6 * time.Second),
Thoughts: "I am thinking I need to call some tools",
@@ -513,6 +521,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
},
{
@@ -520,6 +529,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(11 * time.Second),
Finished: startTime.Add(12 * time.Second),
Thoughts: "Completly blank.Whatever.",
@@ -556,6 +566,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
},
{
@@ -563,6 +574,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(15 * time.Second),
Finished: startTime.Add(16 * time.Second),
},
@@ -571,6 +583,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "smarty",
+ Model: "model1",
Started: startTime.Add(4 * time.Second),
Finished: startTime.Add(17 * time.Second),
Instruction: "You are smarty. baz" + llmOutputsInstruction,
@@ -611,6 +624,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -620,6 +634,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
},
{
@@ -627,6 +642,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(22 * time.Second),
Finished: startTime.Add(23 * time.Second),
},
@@ -662,6 +678,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
},
{
@@ -669,6 +686,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(26 * time.Second),
Finished: startTime.Add(27 * time.Second),
},
@@ -677,6 +695,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(21 * time.Second),
Finished: startTime.Add(28 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -692,6 +711,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
Prompt: "Prompt: baz",
@@ -701,6 +721,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
},
{
@@ -708,6 +729,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(30 * time.Second),
Finished: startTime.Add(31 * time.Second),
},
@@ -743,6 +765,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
},
{
@@ -750,6 +773,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 3,
Type: trajectory.SpanLLM,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(34 * time.Second),
Finished: startTime.Add(35 * time.Second),
},
@@ -758,6 +782,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanAgent,
Name: "swarm",
+ Model: "model2",
Started: startTime.Add(29 * time.Second),
Finished: startTime.Add(36 * time.Second),
Instruction: "Do something. baz" + llmOutputsInstruction,
@@ -781,6 +806,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Instruction: "Aggregate!",
Prompt: `Prompt: baz
@@ -800,6 +826,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
},
{
@@ -807,6 +834,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 2,
Type: trajectory.SpanLLM,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(39 * time.Second),
Finished: startTime.Add(40 * time.Second),
},
@@ -815,6 +843,7 @@ func TestWorkflow(t *testing.T) {
Nesting: 1,
Type: trajectory.SpanAgent,
Name: "aggregator",
+ Model: "model3",
Started: startTime.Add(38 * time.Second),
Finished: startTime.Add(41 * time.Second),
Instruction: "Aggregate!",
@@ -847,7 +876,7 @@ func TestWorkflow(t *testing.T) {
expected = expected[1:]
return nil
}
- res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ res, err := flows["test-flow"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.NoError(t, err)
require.Equal(t, replySeq, 8)
require.Equal(t, res, expectedOutputs)
@@ -867,7 +896,6 @@ func TestNoInputs(t *testing.T) {
flows := make(map[string]*Flow)
err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
{
- Model: "model",
Root: NewFuncAction("func-action",
func(ctx *Context, args flowInputs) (flowOutputs, error) {
return flowOutputs{}, nil
@@ -876,7 +904,7 @@ func TestNoInputs(t *testing.T) {
})
require.NoError(t, err)
stub := &stubContext{
- generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
*genai.GenerateContentResponse, error) {
return nil, nil
},
@@ -886,7 +914,7 @@ func TestNoInputs(t *testing.T) {
cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
require.NoError(t, err)
onEvent := func(span *trajectory.Span) error { return nil }
- _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.Equal(t, err.Error(), "flow inputs are missing:"+
" field InBar is not present when converting map to aflow.flowInputs")
}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 02d3bca85..3c416b37c 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -18,6 +18,9 @@ import (
type LLMAgent struct {
// For logging/debugging.
Name string
+ // The default Gemini model name to execute this workflow.
+ // Use the consts defined below.
+ Model string
// Name of the state variable to store the final reply of the agent.
// These names can be used in subsequent action instructions/prompts,
// and as final workflow outputs.
@@ -43,6 +46,13 @@ type LLMAgent struct {
Tools []Tool
}
+// Consts to use for LLMAgent.Model.
+// See https://ai.google.dev/gemini-api/docs/models
+const (
+ BestExpensiveModel = "gemini-3-pro-preview"
+ GoodBalancedModel = "gemini-3-flash-preview"
+)
+
// Tool represents a custom tool an LLMAgent can invoke.
// Use NewFuncTool to create function-based tools.
type Tool interface {
@@ -134,6 +144,7 @@ func (a *LLMAgent) executeOne(ctx *Context) (string, map[string]any, error) {
Name: a.Name,
Instruction: instruction,
Prompt: formatTemplate(a.Prompt, ctx.state),
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(span); err != nil {
return "", nil, err
@@ -152,8 +163,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
for {
reqSpan := &trajectory.Span{
- Type: trajectory.SpanLLM,
- Name: a.Name,
+ Type: trajectory.SpanLLM,
+ Name: a.Name,
+ Model: ctx.modelName(a.Model),
}
if err := ctx.startSpan(reqSpan); err != nil {
return "", nil, err
@@ -278,7 +290,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
req []*genai.Content) (*genai.GenerateContentResponse, error) {
backoff := time.Second
for try := 0; ; try++ {
- resp, err := ctx.generateContent(cfg, req)
+ resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
@@ -292,6 +304,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
func (a *LLMAgent) verify(vctx *verifyContext) {
vctx.requireNotEmpty(a.Name, "Name", a.Name)
+ vctx.requireNotEmpty(a.Name, "Model", a.Model)
vctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
a.Temperature = float32(temp)
diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go
index 49e36933b..aa9558708 100644
--- a/pkg/aflow/trajectory/trajectory.go
+++ b/pkg/aflow/trajectory/trajectory.go
@@ -20,6 +20,7 @@ type Span struct {
Nesting int
Type SpanType
Name string // flow/action/tool name
+ Model string // LLM model name for agent/LLM spans
Started time.Time
Finished time.Time
Error string // relevant if Finished is set