aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/execute.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-26 15:07:38 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-26 15:57:53 +0000
commitdd7d1ced5000d600f31d0be2ff901c5827a349c6 (patch)
treea2de2d8bc949f0167eaa7ef91357d6159cd007f4 /pkg/aflow/execute.go
parent0a7bbd79ddce993add61e99ffe0e9983dd56257d (diff)
pkg/aflow: fix Temperature handling
If LLMAgent.Temperature is assigned an untyped float const (0.5) it will be typed as float64 rather than float32. So recast them. Cap Temperature at model's supported MaxTemperature.
Diffstat (limited to 'pkg/aflow/execute.go')
-rw-r--r--pkg/aflow/execute.go29
1 files changed, 24 insertions, 5 deletions
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
index 405498800..ee83b541f 100644
--- a/pkg/aflow/execute.go
+++ b/pkg/aflow/execute.go
@@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
state: maps.Clone(inputs),
onEvent: onEvent,
}
+
defer c.close()
if s := ctx.Value(stubContextKey); s != nil {
c.stubContext = *s.(*stubContext)
@@ -143,10 +144,17 @@ var (
createClientOnce sync.Once
createClientErr error
client *genai.Client
- modelList = make(map[string]bool)
+ modelList = make(map[string]*modelInfo)
stubContextKey = contextKeyType(1)
)
+type modelInfo struct {
+ Thinking bool
+ MaxTemperature float32
+ InputTokenLimit int
+ OutputTokenLimit int
+}
+
func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
const modelPrefix = "models/"
@@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
createClientErr = err
return
}
- modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
+ if !slices.Contains(m.SupportedActions, "generateContent") ||
+ strings.Contains(m.Name, "-image") ||
+ strings.Contains(m.Name, "-audio") {
+ continue
+ }
+ modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
+ Thinking: m.Thinking,
+ MaxTemperature: m.MaxTemperature,
+ InputTokenLimit: int(m.InputTokenLimit),
+ OutputTokenLimit: int(m.OutputTokenLimit),
+ }
}
})
if createClientErr != nil {
return nil, createClientErr
}
- thinking, ok := modelList[model]
- if !ok {
+ info := modelList[model]
+ if info == nil {
models := slices.Collect(maps.Keys(modelList))
slices.Sort(models)
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
}
- if thinking {
+ *cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
+ if info.Thinking {
// Don't alter the original object (that may affect request caching).
cfgCopy := *cfg
cfg = &cfgCopy