aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-26 15:07:38 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-26 15:57:53 +0000
commitdd7d1ced5000d600f31d0be2ff901c5827a349c6 (patch)
treea2de2d8bc949f0167eaa7ef91357d6159cd007f4
parent0a7bbd79ddce993add61e99ffe0e9983dd56257d (diff)
pkg/aflow: fix Temperature handling
If LLMAgent.Temperature is assigned an untyped float const (0.5) it will be typed as float64 rather than float32. So recast them. Cap Temperature at model's supported MaxTemperature.
-rw-r--r--pkg/aflow/execute.go29
-rw-r--r--pkg/aflow/flow_test.go2
-rw-r--r--pkg/aflow/llm_agent.go10
-rw-r--r--pkg/aflow/testdata/TestToolMisbehavior.llm.json10
4 files changed, 35 insertions, 16 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 405498800..ee83b541f 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
state: maps.Clone(inputs),
onEvent: onEvent,
}
+
defer c.close()
if s := ctx.Value(stubContextKey); s != nil {
c.stubContext = *s.(*stubContext)
@@ -143,10 +144,17 @@ var (
createClientOnce sync.Once
createClientErr error
client *genai.Client
- modelList = make(map[string]bool)
+ modelList = make(map[string]*modelInfo)
stubContextKey = contextKeyType(1)
)
+type modelInfo struct {
+ Thinking bool
+ MaxTemperature float32
+ InputTokenLimit int
+ OutputTokenLimit int
+}
+
func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
@@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
createClientErr = err
return
}
- modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
+ if !slices.Contains(m.SupportedActions, "generateContent") ||
+ strings.Contains(m.Name, "-image") ||
+ strings.Contains(m.Name, "-audio") {
+ continue
+ }
+ modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
+ Thinking: m.Thinking,
+ MaxTemperature: m.MaxTemperature,
+ InputTokenLimit: int(m.InputTokenLimit),
+ OutputTokenLimit: int(m.OutputTokenLimit),
+ }
}
})
if createClientErr != nil {
return nil, createClientErr
}
- thinking, ok := modelList[model]
- if !ok {
+ info := modelList[model]
+ if info == nil {
models := slices.Collect(maps.Keys(modelList))
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- if thinking {
+ *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
+ if info.Thinking {
// Don't alter the original object (that may affect request caching).
cfgCopy := *cfg
cfg = &cfgCopy
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 67b51dd79..203ff798c 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -321,7 +321,7 @@ func TestToolMisbehavior(t *testing.T) {
&LLMAgent{
Name: "smarty",
Model: "model",
- Temperature: 1,
+ Temperature: 0.5,
Reply: "Reply",
Outputs: LLMOutputs[struct {
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 05e875f28..6cb84b5a3 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -35,7 +35,7 @@ type LLMAgent struct {
// Value that controls the degree of randomness in token selection.
// Lower temperatures are good for prompts that require a less open-ended or creative response,
// while higher temperatures can lead to more diverse or creative results.
- // Must be assigned a float32 value in the range [0, 2].
+ // Must be assigned a number in the range [0, 2].
Temperature any
// If set, the agent will generate that many candidates and the outputs will be arrays
// instead of scalars.
@@ -245,7 +245,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
}
return &genai.GenerateContentConfig{
ResponseModalities: []string{"TEXT"},
- Temperature: genai.Ptr(a.Temperature.(float32)),
+ Temperature: genai.Ptr(float32(a.Temperature.(float64))),
SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
Tools: tools,
}, instruction, toolMap
@@ -399,10 +399,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) {
ctx.requireNotEmpty(a.Name, "Model", a.Model)
ctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
- a.Temperature = float32(temp)
+ a.Temperature = float64(temp)
}
- if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
- ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
+ if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 {
+ ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]")
}
if a.Candidates < 0 || a.Candidates > 100 {
ctx.errorf(a.Name, "Candidates must be in the range [0, 100]")
diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
index 370954d98..6c6478020 100644
--- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json
+++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
@@ -10,7 +10,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -132,7 +132,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -367,7 +367,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -618,7 +618,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -915,7 +915,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [