aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-26 15:07:39 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-26 15:57:53 +0000
commit4964e02465f2e9b4f98fbbf49f2fe058892b715d (patch)
treebe5a499ae380e3e507963541f5e9f6d76cc5e3a2 /pkg
parentdd7d1ced5000d600f31d0be2ff901c5827a349c6 (diff)
pkg/aflow/trajectory: add token usage
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/llm_agent.go16
-rw-r--r--pkg/aflow/trajectory/trajectory.go8
2 files changed, 19 insertions, 5 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 6cb84b5a3..2c83be2eb 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -192,8 +192,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
}
- reply, thoughts, calls, respErr := a.parseResponse(resp)
- reqSpan.Thoughts = thoughts
+ reply, calls, respErr := a.parseResponse(resp, reqSpan)
if err := ctx.finishSpan(reqSpan, respErr); err != nil {
return "", nil, err
}
@@ -302,8 +301,8 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai
return responses, outputs, nil
}
-func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
- reply, thoughts string, calls []*genai.FunctionCall, err error) {
+func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *trajectory.Span) (
+ reply string, calls []*genai.FunctionCall, err error) {
if len(resp.Candidates) == 0 || resp.Candidates[0] == nil {
err = fmt.Errorf("empty model response")
if resp.PromptFeedback != nil {
@@ -322,6 +321,13 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
err = fmt.Errorf("unexpected reply fields (%+v)", *candidate)
return
}
+ if resp.UsageMetadata != nil {
+ // We add ToolUsePromptTokenCount just in case, but Gemini does not use/set it.
+ span.InputTokens = int(resp.UsageMetadata.PromptTokenCount) +
+ int(resp.UsageMetadata.ToolUsePromptTokenCount)
+ span.OutputTokens = int(resp.UsageMetadata.CandidatesTokenCount)
+ span.OutputThoughtsTokens = int(resp.UsageMetadata.ThoughtsTokenCount)
+ }
for _, part := range candidate.Content.Parts {
// We don't expect to receive these now.
if part.VideoMetadata != nil || part.InlineData != nil ||
@@ -333,7 +339,7 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
if part.FunctionCall != nil {
calls = append(calls, part.FunctionCall)
} else if part.Thought {
- thoughts += part.Text
+ span.Thoughts += part.Text
} else {
reply += part.Text
}
diff --git a/pkg/aflow/trajectory/trajectory.go b/pkg/aflow/trajectory/trajectory.go
index ad22b018e..fc63235cf 100644
--- a/pkg/aflow/trajectory/trajectory.go
+++ b/pkg/aflow/trajectory/trajectory.go
@@ -38,6 +38,12 @@ type Span struct {
// LLM invocation.
Thoughts string `json:",omitzero"`
+
+ // For details see:
+ // https://pkg.go.dev/google.golang.org/genai#GenerateContentResponseUsageMetadata
+ InputTokens int `json:",omitzero"`
+ OutputTokens int `json:",omitzero"`
+ OutputThoughtsTokens int `json:",omitzero"`
}
type SpanType string
@@ -89,6 +95,8 @@ func (span *Span) String() string {
}
fmt.Fprintf(sb, "reply:\n%v\n", span.Reply)
case SpanLLM:
+ fmt.Fprintf(sb, "tokens: input=%v output=%v thoughts=%v\n",
+ span.InputTokens, span.OutputTokens, span.OutputThoughtsTokens)
if span.Thoughts != "" {
fmt.Fprintf(sb, "thoughts:\n%v\n", span.Thoughts)
}