diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-27 13:09:37 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-27 13:08:50 +0000 |
| commit | 8d2502e0be2fbb25cb9be08f92afdb564b64c817 (patch) | |
| tree | 37750b681a7bf8f415cb547f738f062dac78cf2e /pkg | |
| parent | 576d15fdab7389b1f07162faf530abfe4e8bf7f9 (diff) | |
pkg/aflow: handle MALFORMED_FUNCTION_CALL LLM replies
In this case we don't get an error from LLM API,
but the response effectively means an error.
Handle it with a retry.
In such case we also don't want to cache the response,
handling this required a bit of refactoring.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/llm_agent.go | 115 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent_test.go | 87 |
2 files changed, 123 insertions, 79 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index d06e73642..bb92f1ecf 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -329,24 +329,7 @@ func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *trajectory.Span) ( reply string, calls []*genai.FunctionCall, err error) { - if len(resp.Candidates) == 0 || resp.Candidates[0] == nil { - err = fmt.Errorf("empty model response") - if resp.PromptFeedback != nil { - err = fmt.Errorf("request blocked: %v", resp.PromptFeedback.BlockReasonMessage) - } - return - } candidate := resp.Candidates[0] - if candidate.Content == nil || len(candidate.Content.Parts) == 0 { - err = fmt.Errorf("%v (%v)", candidate.FinishMessage, candidate.FinishReason) - return - } - // We don't expect to receive these fields now. - // Note: CitationMetadata may be present sometimes, but we don't have uses for it. - if candidate.GroundingMetadata != nil || candidate.LogprobsResult != nil { - err = fmt.Errorf("unexpected reply fields (%+v)", *candidate) - return - } if resp.UsageMetadata != nil { // We add ToolUsePromptTokenCount just in case, but Gemini does not use/set it. span.InputTokens = int(resp.UsageMetadata.PromptTokenCount) + @@ -355,13 +338,6 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *traj span.OutputThoughtsTokens = int(resp.UsageMetadata.ThoughtsTokenCount) } for _, part := range candidate.Content.Parts { - // We don't expect to receive these now. - if part.VideoMetadata != nil || part.InlineData != nil || - part.FileData != nil || part.FunctionResponse != nil || - part.CodeExecutionResult != nil || part.ExecutableCode != nil { - err = fmt.Errorf("unexpected reply part (%+v)", *part) - return - } if part.FunctionCall != nil { calls = append(calls, part.FunctionCall) } else if part.Thought { @@ -383,12 +359,10 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse, span *traj func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { - var backoff time.Duration for try := 0; ; try++ { - resp, err := a.generateContentCached(ctx, cfg, req, candidate) - retry, err := parseLLMError(resp, err, ctx.modelName(a.Model), try, &backoff) - if retry != 0 { - time.Sleep(retry) + resp, err := a.generateContentCached(ctx, cfg, req, candidate, try) + if retryErr := new(retryError); errors.As(err, &retryErr) { + time.Sleep(retryErr.delay) continue } return resp, err @@ -396,7 +370,7 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi } func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateContentConfig, - req []*genai.Content, candidate int) (*genai.GenerateContentResponse, error) { + req []*genai.Content, candidate, try int) (*genai.GenerateContentResponse, error) { type Cached struct { Config *genai.GenerateContentConfig Request []*genai.Content @@ -407,6 +381,7 @@ func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateConten model, hash.String(cfg), hash.String(req), candidate) cached, err := CacheObject(ctx, "llm", desc, func() (Cached, error) { resp, err := ctx.generateContent(model, cfg, req) + err = parseLLMError(resp, err, model, try) return Cached{ Config: cfg, Request: req, @@ -416,35 +391,78 @@ func (a *LLMAgent) generateContentCached(ctx *Context, cfg *genai.GenerateConten return cached.Reply, err } -func parseLLMError(resp *genai.GenerateContentResponse, err error, model string, - try int, backoff *time.Duration) (time.Duration, error) { +func parseLLMError(resp *genai.GenerateContentResponse, err error, model string, try int) error { + err = parseLLMErrorImpl(resp, err, model, try) + if retryErr := new(retryError); errors.As(err, &retryErr) && try >= maxLLMRetryIters { + // We can't retry infinity, so revert back to the original error + // when we reach maxLLMRetryIters. + return retryErr.err + } + return err +} + +func parseLLMErrorImpl(resp *genai.GenerateContentResponse, err error, model string, try int) error { + if err == nil { + return parseLLMResp(resp) + } var apiErr genai.APIError - if err == nil || !errors.As(err, &apiErr) { - return 0, err + if !errors.As(err, &apiErr) { + return err } if try < maxLLMRetryIters && apiErr.Code == http.StatusServiceUnavailable { - *backoff = min(*backoff+time.Second, maxLLMBackoff) - return *backoff, nil + return &retryError{min(time.Duration(try+1)*time.Second, maxLLMBackoff), err} } if apiErr.Code == http.StatusTooManyRequests && strings.Contains(apiErr.Message, "Quota exceeded for metric") { if match := rePleaseRetry.FindStringSubmatch(apiErr.Message); match != nil { sec, _ := strconv.Atoi(match[1]) - return time.Duration(sec+1) * time.Second, nil + return &retryError{time.Duration(sec+1) * time.Second, err} } if strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { - return 0, &modelQuotaError{model} + return &modelQuotaError{model} } } if apiErr.Code == http.StatusBadRequest && strings.Contains(apiErr.Message, "The input token count exceeds the maximum") { - return 0, &tokenOverflowError{err} + return &tokenOverflowError{err} } - if apiErr.Code == http.StatusInternalServerError && try < maxLLMRetryIters { + if apiErr.Code == http.StatusInternalServerError { // Let's assume ISE is just something temporal on the server side. - return time.Second, nil + return &retryError{time.Second, err} + } + return err +} + +func parseLLMResp(resp *genai.GenerateContentResponse) error { + if len(resp.Candidates) == 0 || resp.Candidates[0] == nil { + if resp.PromptFeedback != nil { + return fmt.Errorf("request blocked: %v", resp.PromptFeedback.BlockReasonMessage) + } + return fmt.Errorf("empty model response") + } + candidate := resp.Candidates[0] + if candidate.Content == nil || len(candidate.Content.Parts) == 0 { + if candidate.FinishReason == genai.FinishReasonMalformedFunctionCall { + // Let's consider this as a temp error, and that the next time it won't + // generate the same buggy output. In either case we have maxLLMRetryIters. + return &retryError{0, errors.New(string(genai.FinishReasonMalformedFunctionCall))} + } + return fmt.Errorf("%v (%v)", candidate.FinishMessage, candidate.FinishReason) + } + // We don't expect to receive these fields now. + // Note: CitationMetadata may be present sometimes, but we don't have uses for it. + if candidate.GroundingMetadata != nil || candidate.LogprobsResult != nil { + return fmt.Errorf("unexpected reply fields (%+v)", *candidate) + } + for _, part := range candidate.Content.Parts { + // We don't expect to receive these now. + if part.VideoMetadata != nil || part.InlineData != nil || + part.FileData != nil || part.FunctionResponse != nil || + part.CodeExecutionResult != nil || part.ExecutableCode != nil { + return fmt.Errorf("unexpected reply part (%+v)", *part) + } } - return 0, err + return nil } const ( @@ -454,6 +472,19 @@ const ( var rePleaseRetry = regexp.MustCompile("Please retry in ([0-9]+)[.s]") +type retryError struct { + delay time.Duration + err error +} + +func (err *retryError) Error() string { + return fmt.Sprintf("%s (should be retried after %v)", err.err, err.delay) +} + +func (err *retryError) Unwrap() error { + return err.err +} + func (a *LLMAgent) verify(ctx *verifyContext) { ctx.requireNotEmpty(a.Name, "Name", a.Name) ctx.requireNotEmpty(a.Name, "Model", a.Model) diff --git a/pkg/aflow/llm_agent_test.go b/pkg/aflow/llm_agent_test.go index ac3187884..a8aba9056 100644 --- a/pkg/aflow/llm_agent_test.go +++ b/pkg/aflow/llm_agent_test.go @@ -4,6 +4,7 @@ package aflow import ( + "errors" "fmt" "net/http" "testing" @@ -18,13 +19,34 @@ func TestParseLLMError(t *testing.T) { type Test struct { resp *genai.GenerateContentResponse inputErr error - retry time.Duration outputErr error } - // nolint:lll + tpsError := genai.APIError{ + Code: 429, + // nolint:lll + Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_paid_tier_input_token_count, limit: 1000000, model: gemini-3-flash Please retry in 24.180878813s.`, + } + rpmError := genai.APIError{ + Code: 429, + // nolint:lll + Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_requests_per_model_per_day, limit: 0`, + } + tokenError := genai.APIError{ + Code: 400, + Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`, + } + iseError := genai.APIError{ + Code: 500, + Message: `Internal error encountered.`, + } + normalResp := &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{ + Content: genai.NewContentFromText("repy", genai.RoleModel), + }}, + } tests := []Test{ { - resp: nil, + resp: normalResp, inputErr: nil, }, { @@ -33,60 +55,51 @@ func TestParseLLMError(t *testing.T) { outputErr: fmt.Errorf("non API error"), }, { - resp: nil, - inputErr: genai.APIError{ - Code: 429, - Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_paid_tier_input_token_count, limit: 1000000, model: gemini-3-flash Please retry in 24.180878813s.`, - }, - retry: 25 * time.Second, + resp: nil, + inputErr: tpsError, + outputErr: &retryError{25 * time.Second, tpsError}, }, { - resp: nil, - inputErr: genai.APIError{ - Code: 429, - Message: `You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/rate-limit. * Quota exceeded for metric: generativelanguage.googleapis.com/generate_requests_per_model_per_day, limit: 0`, - }, + resp: nil, + inputErr: rpmError, outputErr: &modelQuotaError{"smarty"}, }, { - resp: nil, - inputErr: genai.APIError{ - Code: 400, - Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`, - }, - outputErr: &tokenOverflowError{genai.APIError{ - Code: 400, - Message: `The input token count exceeds the maximum number of tokens allowed 1048576.`, - }}, + resp: nil, + inputErr: tokenError, + outputErr: &tokenOverflowError{tokenError}, + }, + { + resp: nil, + inputErr: iseError, + outputErr: &retryError{time.Second, iseError}, }, { - resp: nil, - inputErr: genai.APIError{ - Code: 500, - Message: `Internal error encountered.`, + resp: &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + { + FinishReason: genai.FinishReasonMalformedFunctionCall, + }, + }, }, - retry: time.Second, + outputErr: &retryError{0, errors.New(string(genai.FinishReasonMalformedFunctionCall))}, }, } for i, test := range tests { t.Run(fmt.Sprint(i), func(t *testing.T) { - var backoff time.Duration - retry, err := parseLLMError(test.resp, test.inputErr, "smarty", 1, &backoff) - assert.Equal(t, test.retry, retry) + err := parseLLMError(test.resp, test.inputErr, "smarty", 0) assert.Equal(t, test.outputErr, err) }) } } func TestParseLLMErrorBackoff(t *testing.T) { - var backoff time.Duration err0 := genai.APIError{Code: http.StatusServiceUnavailable} for try := 0; try < maxLLMRetryIters; try++ { - retry, err := parseLLMError(nil, err0, "model", try, &backoff) - require.Equal(t, retry, min(maxLLMBackoff, time.Duration(try+1)*time.Second)) - require.NoError(t, err) + wantDelay := min(maxLLMBackoff, time.Duration(try+1)*time.Second) + err := parseLLMError(nil, err0, "model", try) + require.Equal(t, err, &retryError{wantDelay, err0}) } - retry, err := parseLLMError(nil, err0, "model", maxLLMRetryIters, &backoff) - require.Equal(t, retry, time.Duration(0)) + err := parseLLMError(nil, err0, "model", maxLLMRetryIters) require.Equal(t, err, err0) } |
