aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-26 15:07:38 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-26 15:57:53 +0000
commitdd7d1ced5000d600f31d0be2ff901c5827a349c6 (patch)
treea2de2d8bc949f0167eaa7ef91357d6159cd007f4 /pkg
parent0a7bbd79ddce993add61e99ffe0e9983dd56257d (diff)
pkg/aflow: fix Temperature handling
If LLMAgent.Temperature is assigned an untyped float const (0.5) it will be typed as float64 rather than float32. So recast them. Cap Temperature at model's supported MaxTemperature.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/execute.go29
-rw-r--r--pkg/aflow/flow_test.go2
-rw-r--r--pkg/aflow/llm_agent.go10
-rw-r--r--pkg/aflow/testdata/TestToolMisbehavior.llm.json10
4 files changed, 35 insertions, 16 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 405498800..ee83b541f 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
state: maps.Clone(inputs),
onEvent: onEvent,
}
+
defer c.close()
if s := ctx.Value(stubContextKey); s != nil {
c.stubContext = *s.(*stubContext)
@@ -143,10 +144,17 @@ var (
createClientOnce sync.Once
createClientErr error
client *genai.Client
- modelList = make(map[string]bool)
+ modelList = make(map[string]*modelInfo)
stubContextKey = contextKeyType(1)
)
+type modelInfo struct {
+ Thinking bool
+ MaxTemperature float32
+ InputTokenLimit int
+ OutputTokenLimit int
+}
+
func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
@@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
createClientErr = err
return
}
- modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
+ if !slices.Contains(m.SupportedActions, "generateContent") ||
+ strings.Contains(m.Name, "-image") ||
+ strings.Contains(m.Name, "-audio") {
+ continue
+ }
+ modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
+ Thinking: m.Thinking,
+ MaxTemperature: m.MaxTemperature,
+ InputTokenLimit: int(m.InputTokenLimit),
+ OutputTokenLimit: int(m.OutputTokenLimit),
+ }
}
})
if createClientErr != nil {
return nil, createClientErr
}
- thinking, ok := modelList[model]
- if !ok {
+ info := modelList[model]
+ if info == nil {
models := slices.Collect(maps.Keys(modelList))
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- if thinking {
+ *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
+ if info.Thinking {
// Don't alter the original object (that may affect request caching).
cfgCopy := *cfg
cfg = &cfgCopy
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 67b51dd79..203ff798c 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -321,7 +321,7 @@ func TestToolMisbehavior(t *testing.T) {
&LLMAgent{
Name: "smarty",
Model: "model",
- Temperature: 1,
+ Temperature: 0.5,
Reply: "Reply",
Outputs: LLMOutputs[struct {
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 05e875f28..6cb84b5a3 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -35,7 +35,7 @@ type LLMAgent struct {
// Value that controls the degree of randomness in token selection.
// Lower temperatures are good for prompts that require a less open-ended or creative response,
// while higher temperatures can lead to more diverse or creative results.
- // Must be assigned a float32 value in the range [0, 2].
+ // Must be assigned a number in the range [0, 2].
Temperature any
// If set, the agent will generate that many candidates and the outputs will be arrays
// instead of scalars.
@@ -245,7 +245,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
}
return &genai.GenerateContentConfig{
ResponseModalities: []string{"TEXT"},
- Temperature: genai.Ptr(a.Temperature.(float32)),
+ Temperature: genai.Ptr(float32(a.Temperature.(float64))),
SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
Tools: tools,
}, instruction, toolMap
@@ -399,10 +399,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) {
ctx.requireNotEmpty(a.Name, "Model", a.Model)
ctx.requireNotEmpty(a.Name, "Reply", a.Reply)
if temp, ok := a.Temperature.(int); ok {
- a.Temperature = float32(temp)
+ a.Temperature = float64(temp)
}
- if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
- ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
+ if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 {
+ ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]")
}
if a.Candidates < 0 || a.Candidates > 100 {
ctx.errorf(a.Name, "Candidates must be in the range [0, 100]")
diff --git a/pkg/aflow/testdata/TestToolMisbehavior.llm.json b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
index 370954d98..6c6478020 100644
--- a/pkg/aflow/testdata/TestToolMisbehavior.llm.json
+++ b/pkg/aflow/testdata/TestToolMisbehavior.llm.json
@@ -10,7 +10,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -132,7 +132,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -367,7 +367,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -618,7 +618,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [
@@ -915,7 +915,7 @@
],
"role": "user"
},
- "temperature": 1,
+ "temperature": 0.5,
"tools": [
{
"functionDeclarations": [