diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-27 12:36:21 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-27 13:08:50 +0000 |
| commit | bc3f8e2821237a4e2f6ef31b6e8ad796159ef44f (patch) | |
| tree | 29483afa70a95c19a64a09141538cedc1f91b918 /pkg | |
| parent | cdeabb956e8764d60b78ba11c55b0146365fca9d (diff) | |
pkg/aflow: unit-test LLM error handling
It's getting more and more complex.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/llm_agent.go | 63 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent_test.go | 84 |
2 files changed, 122 insertions, 25 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index cb5267f0c..0997b9f09 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -383,39 +383,18 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *traj func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { - backoff := time.Second + var backoff time.Duration for try := 0; ; try++ { resp, err := a.generateContentCached(ctx, cfg, req, candidate) - var apiErr genai.APIError - if err == nil || !errors.As(err, &apiErr) { - return resp, err - } - if try < 100 && apiErr.Code == http.StatusServiceUnavailable { - time.Sleep(backoff) - backoff = min(backoff+time.Second, 10*time.Second) + retry, err := parseLLMError(resp, err, ctx.modelName(a.Model), try, &backoff) + if retry != 0 { + time.Sleep(retry) continue } - if apiErr.Code == http.StatusTooManyRequests && - strings.Contains(apiErr.Message, "Quota exceeded for metric") { - if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil { - sec, _ := strconv.Atoi(match[1]) - time.Sleep(time.Duration(sec+1) * time.Second) - continue - } - if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { - return resp, &modelQuotaError{ctx.modelName(a.Model)} - } - } - if apiErr.Code == http.StatusBadRequest && - strings.Contains(apiErr.Message, "The input token count exceeds the maximum") { - return resp, &tokenOverflowError{err} - } return resp, err } } -var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]") - func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig, req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { type Cached struct { @@ -437,6 +416,40 @@ func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateConten return cached.Reply, err } +func parseLLMError(resp *genai.GenerateContentResponse, err error, model string, + try int, backoff *time.Duration) (time.Duration, error) { + var apiErr genai.APIError + if err == nil || !errors.As(err, &apiErr) { + return 0, err + } + if try < maxLLMRetryIters && apiErr.Code == http.StatusServiceUnavailable { + *backoff = min(*backoff+time.Second, maxLLMBackoff) + return *backoff, nil + } + if apiErr.Code == http.StatusTooManyRequests && + strings.Contains(apiErr.Message, "Quota exceeded for metric") { + if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil { + sec, _ := strconv.Atoi(match[1]) + return time.Duration(sec+1) * time.Second, nil + } + if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { + return 0, &modelQuotaError{model} + } + } + if apiErr.Code == http.StatusBadRequest && + strings.Contains(apiErr.Message, "The input token count exceeds the maximum") { + return 0, &tokenOverflowError{err} + } + return 0, err +} + +const ( + maxLLMRetryIters = 100 + maxLLMBackoff = 10 * time.Second +) + +var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]") + func (a *LLMAgent) verify(ctx *verifyContext) { ctx.requireNotEmpty(a.Name, "Name", a.Name) ctx.requireNotEmpty(a.Name, "Model", a.Model) diff --git a/pkg/aflow/llm_agent_test.go b/pkg/aflow/llm_agent_test.go new file mode 100644 index 000000000..adf2a1278 --- /dev/null +++ b/pkg/aflow/llm_agent_test.go @@ -0,0 +1,84 @@ +// Copyright 2026 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestParseLLMError(t *testing.T) { + type Test struct { + resp *genai.GenerateContentResponse + inputErr error + retry time.Duration + outputErr error + } + // nolint:lll + tests := []Test{ + { + resp: nil, + inputErr: nil, + }, + { + resp: nil, + inputErr: fmt.Errorf("non API error"), + outputErr: fmt.Errorf("non API error"), + }, + { + resp: nil, + inputErr: genai.APIError{ + Code: 429, + Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_paid_tier_input_token_count, limit: 1000000, model: gemini-3-flash Please retry in 24.180878813s.`, + }, + retry: 25 * time.Second, + }, + { + resp: nil, + inputErr: genai.APIError{ + Code: 429, + Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_requests_per_model_per_day, limit: 0`, + }, + outputErr: &modelQuotaError{"smarty"}, + }, + { + resp: nil, + inputErr: genai.APIError{ + Code: 400, + Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`, + }, + outputErr: &tokenOverflowError{genai.APIError{ + Code: 400, + Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`, + }}, + }, + } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + var backoff time.Duration + retry, err := parseLLMError(test.resp, test.inputErr, "smarty", 1, &backoff) + assert.Equal(t, test.retry, retry) + assert.Equal(t, test.outputErr, err) + }) + } +} + +func TestParseLLMErrorBackoff(t *testing.T) { + var backoff time.Duration + err0 := genai.APIError{Code: http.StatusServiceUnavailable} + for try := 0; try < maxLLMRetryIters; try++ { + retry, err := parseLLMError(nil, err0, "model", try, &backoff) + require.Equal(t, retry, min(maxLLMBackoff, time.Duration(try+1)*time.Second)) + require.NoError(t, err) + } + retry, err := parseLLMError(nil, err0, "model", maxLLMRetryIters, &backoff) + require.Equal(t, retry, time.Duration(0)) + require.Equal(t, err, err0) +} |
