diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-26 15:07:38 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-26 15:57:53 +0000 |
| commit | dd7d1ced5000d600f31d0be2ff901c5827a349c6 (patch) | |
| tree | a2de2d8bc949f0167eaa7ef91357d6159cd007f4 /pkg | |
| parent | 0a7bbd79ddce993add61e99ffe0e9983dd56257d (diff) | |
pkg/aflow: fix Temperature handling
If LLMAgent.Temperature is assigned an untyped float const (0.5)
it will be typed as float64 rather than float32. So recast them.
Cap Temperature at model's supported MaxTemperature.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/execute.go | 29 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 2 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 10 | ||||
| -rw-r--r-- | pkg/aflow/testdata/TestToolMisbehavior.llm.json | 10 |
4 files changed, 35 insertions, 16 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 405498800..ee83b541f 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map state: maps.Clone(inputs), onEvent: onEvent, } + defer c.close() if s := ctx.Value(stubContextKey); s != nil { c.stubContext = *s.(*stubContext) @@ -143,10 +144,17 @@ var ( createClientOnce sync.Once createClientErr error client *genai.Client - modelList = make(map[string]bool) + modelList = make(map[string]*modelInfo) stubContextKey = contextKeyType(1) ) +type modelInfo struct { + Thinking bool + MaxTemperature float32 + InputTokenLimit int + OutputTokenLimit int +} + func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" @@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte createClientErr = err return } - modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking + if !slices.Contains(m.SupportedActions, "generateContent") || + strings.Contains(m.Name, "-image") || + strings.Contains(m.Name, "-audio") { + continue + } + modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{ + Thinking: m.Thinking, + MaxTemperature: m.MaxTemperature, + InputTokenLimit: int(m.InputTokenLimit), + OutputTokenLimit: int(m.OutputTokenLimit), + } } }) if createClientErr != nil { return nil, createClientErr } - thinking, ok := modelList[model] - if !ok { + info := modelList[model] + if info == nil { models := slices.Collect(maps.Keys(modelList)) slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - if thinking { + *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature) + if info.Thinking { // Don't alter the original object (that may affect request caching). cfgCopy := *cfg cfg = &cfgCopy diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 67b51dd79..203ff798c 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -321,7 +321,7 @@ func TestToolMisbehavior(t *testing.T) { &LLMAgent{ Name: "smarty", Model: "model", - Temperature: 1, + Temperature: 0.5, Reply: "Reply", Outputs: LLMOutputs[struct { diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index 05e875f28..6cb84b5a3 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -35,7 +35,7 @@ type LLMAgent struct { // Value that controls the degree of randomness in token selection. // Lower temperatures are good for prompts that require a less open-ended or creative response, // while higher temperatures can lead to more diverse or creative results. - // Must be assigned a float32 value in the range [0, 2]. + // Must be assigned a number in the range [0, 2]. Temperature any // If set, the agent will generate that many candidates and the outputs will be arrays // instead of scalars. @@ -245,7 +245,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m } return &genai.GenerateContentConfig{ ResponseModalities: []string{"TEXT"}, - Temperature: genai.Ptr(a.Temperature.(float32)), + Temperature: genai.Ptr(float32(a.Temperature.(float64))), SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser), Tools: tools, }, instruction, toolMap @@ -399,10 +399,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) { ctx.requireNotEmpty(a.Name, "Model", a.Model) ctx.requireNotEmpty(a.Name, "Reply", a.Reply) if temp, ok := a.Temperature.(int); ok { - a.Temperature = float32(temp) + a.Temperature = float64(temp) } - if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 { - ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]") + if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 { + ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]") } if a.Candidates < 0 || a.Candidates > 100 { ctx.errorf(a.Name, "Candidates must be in the range [0, 100]") diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json index 370954d98..6c6478020 100644 --- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json +++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json @@ -10,7 +10,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -132,7 +132,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -367,7 +367,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -618,7 +618,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -915,7 +915,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ |
