aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-27 12:36:21 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-27 13:08:50 +0000
commitbc3f8e2821237a4e2f6ef31b6e8ad796159ef44f (patch)
tree29483afa70a95c19a64a09141538cedc1f91b918 /pkg
parentcdeabb956e8764d60b78ba11c55b0146365fca9d (diff)
pkg/aflow: unit-test LLM error handling
It's getting more and more complex.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/llm_agent.go63
-rw-r--r--pkg/aflow/llm_agent_test.go84
2 files changed, 122 insertions, 25 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index cb5267f0c..0997b9f09 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -383,39 +383,18 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *traj
func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig,
req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
- backoff := time.Second
+ var backoff time.Duration
for try := 0; ; try++ {
resp, err := a.generateContentCached(ctx, cfg, req, candidate)
- var apiErr genai.APIError
- if err == nil || !errors.As(err, &apiErr) {
- return resp, err
- }
- if try < 100 && apiErr.Code == http.StatusServiceUnavailable {
- time.Sleep(backoff)
- backoff = min(backoff+time.Second, 10*time.Second)
+ retry, err := parseLLMError(resp, err, ctx.modelName(a.Model), try, &backoff)
+ if retry != 0 {
+ time.Sleep(retry)
continue
}
- if apiErr.Code == http.StatusTooManyRequests &&
- strings.Contains(apiErr.Message, "Quota exceeded for metric") {
- if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
- sec, _ := strconv.Atoi(match[1])
- time.Sleep(time.Duration(sec+1) * time.Second)
- continue
- }
- if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
- return resp, &modelQuotaError{ctx.modelName(a.Model)}
- }
- }
- if apiErr.Code == http.StatusBadRequest &&
- strings.Contains(apiErr.Message, "The input token count exceeds the maximum") {
- return resp, &tokenOverflowError{err}
- }
return resp, err
}
}
-var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]")
-
func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig,
req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) {
type Cached struct {
@@ -437,6 +416,40 @@ func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateConten
return cached.Reply, err
}
+func parseLLMError(resp *genai.GenerateContentResponse, err error, model string,
+ try int, backoff *time.Duration) (time.Duration, error) {
+ var apiErr genai.APIError
+ if err == nil || !errors.As(err, &apiErr) {
+ return 0, err
+ }
+ if try < maxLLMRetryIters && apiErr.Code == http.StatusServiceUnavailable {
+ *backoff = min(*backoff+time.Second, maxLLMBackoff)
+ return *backoff, nil
+ }
+ if apiErr.Code == http.StatusTooManyRequests &&
+ strings.Contains(apiErr.Message, "Quota exceeded for metric") {
+ if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil {
+ sec, _ := strconv.Atoi(match[1])
+ return time.Duration(sec+1) * time.Second, nil
+ }
+ if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
+ return 0, &modelQuotaError{model}
+ }
+ }
+ if apiErr.Code == http.StatusBadRequest &&
+ strings.Contains(apiErr.Message, "The input token count exceeds the maximum") {
+ return 0, &tokenOverflowError{err}
+ }
+ return 0, err
+}
+
+const (
+ maxLLMRetryIters = 100
+ maxLLMBackoff = 10 * time.Second
+)
+
+var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]")
+
func (a *LLMAgent) verify(ctx *verifyContext) {
ctx.requireNotEmpty(a.Name, "Name", a.Name)
ctx.requireNotEmpty(a.Name, "Model", a.Model)
diff --git a/pkg/aflow/llm_agent_test.go b/pkg/aflow/llm_agent_test.go
new file mode 100644
index 000000000..adf2a1278
--- /dev/null
+++ b/pkg/aflow/llm_agent_test.go
@@ -0,0 +1,84 @@
+// Copyright 2026 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/genai"
+)
+
+func TestParseLLMError(t *testing.T) {
+ type Test struct {
+ resp *genai.GenerateContentResponse
+ inputErr error
+ retry time.Duration
+ outputErr error
+ }
+ // nolint:lll
+ tests := []Test{
+ {
+ resp: nil,
+ inputErr: nil,
+ },
+ {
+ resp: nil,
+ inputErr: fmt.Errorf("non API error"),
+ outputErr: fmt.Errorf("non API error"),
+ },
+ {
+ resp: nil,
+ inputErr: genai.APIError{
+ Code: 429,
+ Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_paid_tier_input_token_count, limit: 1000000, model: gemini-3-flash Please retry in 24.180878813s.`,
+ },
+ retry: 25 * time.Second,
+ },
+ {
+ resp: nil,
+ inputErr: genai.APIError{
+ Code: 429,
+ Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_requests_per_model_per_day, limit: 0`,
+ },
+ outputErr: &modelQuotaError{"smarty"},
+ },
+ {
+ resp: nil,
+ inputErr: genai.APIError{
+ Code: 400,
+ Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`,
+ },
+ outputErr: &tokenOverflowError{genai.APIError{
+ Code: 400,
+ Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`,
+ }},
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ var backoff time.Duration
+ retry, err := parseLLMError(test.resp, test.inputErr, "smarty", 1, &backoff)
+ assert.Equal(t, test.retry, retry)
+ assert.Equal(t, test.outputErr, err)
+ })
+ }
+}
+
+func TestParseLLMErrorBackoff(t *testing.T) {
+ var backoff time.Duration
+ err0 := genai.APIError{Code: http.StatusServiceUnavailable}
+ for try := 0; try < maxLLMRetryIters; try++ {
+ retry, err := parseLLMError(nil, err0, "model", try, &backoff)
+ require.Equal(t, retry, min(maxLLMBackoff, time.Duration(try+1)*time.Second))
+ require.NoError(t, err)
+ }
+ retry, err := parseLLMError(nil, err0, "model", maxLLMRetryIters, &backoff)
+ require.Equal(t, retry, time.Duration(0))
+ require.Equal(t, err, err0)
+}