From dd7d1ced5000d600f31d0be2ff901c5827a349c6 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Mon, 26 Jan 2026 15:07:38 +0100 Subject: pkg/aflow: fix Temperature handling If LLMAgent.Temperature is assigned an untyped float const (0.5) it will be typed as float64 rather than float32. So recast them. Cap Temperature at model's supported MaxTemperature. --- pkg/aflow/execute.go | 29 ++++++++++++++++++++----- pkg/aflow/flow_test.go | 2 +- pkg/aflow/llm_agent.go | 10 ++++----- pkg/aflow/testdata/TestToolMisbehavior.llm.json | 10 ++++----- 4 files changed, 35 insertions(+), 16 deletions(-) (limited to 'pkg') diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go index 405498800..ee83b541f 100644 --- a/pkg/aflow/execute.go +++ b/pkg/aflow/execute.go @@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map state: maps.Clone(inputs), onEvent: onEvent, } + defer c.close() if s := ctx.Value(stubContextKey); s != nil { c.stubContext = *s.(*stubContext) @@ -143,10 +144,17 @@ var ( createClientOnce sync.Once createClientErr error client *genai.Client - modelList = make(map[string]bool) + modelList = make(map[string]*modelInfo) stubContextKey = contextKeyType(1) ) +type modelInfo struct { + Thinking bool + MaxTemperature float32 + InputTokenLimit int + OutputTokenLimit int +} + func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { const modelPrefix = "models/" @@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte createClientErr = err return } - modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking + if !slices.Contains(m.SupportedActions, "generateContent") || + strings.Contains(m.Name, "-image") || + strings.Contains(m.Name, "-audio") { + continue + } + modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{ + Thinking: m.Thinking, + MaxTemperature: m.MaxTemperature, + InputTokenLimit: int(m.InputTokenLimit), + OutputTokenLimit: int(m.OutputTokenLimit), + } } }) if createClientErr != nil { return nil, createClientErr } - thinking, ok := modelList[model] - if !ok { + info := modelList[model] + if info == nil { models := slices.Collect(maps.Keys(modelList)) slices.Sort(models) return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) } - if thinking { + *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature) + if info.Thinking { // Don't alter the original object (that may affect request caching). cfgCopy := *cfg cfg = &cfgCopy diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 67b51dd79..203ff798c 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -321,7 +321,7 @@ func TestToolMisbehavior(t *testing.T) { &LLMAgent{ Name: "smarty", Model: "model", - Temperature: 1, + Temperature: 0.5, Reply: "Reply", Outputs: LLMOutputs[struct { diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index 05e875f28..6cb84b5a3 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -35,7 +35,7 @@ type LLMAgent struct { // Value that controls the degree of randomness in token selection. // Lower temperatures are good for prompts that require a less open-ended or creative response, // while higher temperatures can lead to more diverse or creative results. - // Must be assigned a float32 value in the range [0, 2]. + // Must be assigned a number in the range [0, 2]. Temperature any // If set, the agent will generate that many candidates and the outputs will be arrays // instead of scalars. @@ -245,7 +245,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m } return &genai.GenerateContentConfig{ ResponseModalities: []string{"TEXT"}, - Temperature: genai.Ptr(a.Temperature.(float32)), + Temperature: genai.Ptr(float32(a.Temperature.(float64))), SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser), Tools: tools, }, instruction, toolMap @@ -399,10 +399,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) { ctx.requireNotEmpty(a.Name, "Model", a.Model) ctx.requireNotEmpty(a.Name, "Reply", a.Reply) if temp, ok := a.Temperature.(int); ok { - a.Temperature = float32(temp) + a.Temperature = float64(temp) } - if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 { - ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]") + if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 { + ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]") } if a.Candidates < 0 || a.Candidates > 100 { ctx.errorf(a.Name, "Candidates must be in the range [0, 100]") diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json index 370954d98..6c6478020 100644 --- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json +++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json @@ -10,7 +10,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -132,7 +132,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -367,7 +367,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -618,7 +618,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ @@ -915,7 +915,7 @@ ], "role": "user" }, - "temperature": 1, + "temperature": 0.5, "tools": [ { "functionDeclarations": [ -- cgit mrf-deployment