diff options
Diffstat (limited to 'pkg/aflow/llm_agent.go')
| -rw-r--r-- | pkg/aflow/llm_agent.go | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go index 3c416b37c..d5e4d6d4d 100644 --- a/pkg/aflow/llm_agent.go +++ b/pkg/aflow/llm_agent.go @@ -9,6 +9,7 @@ import ( "maps" "net/http" "reflect" + "strings" "time" "github.com/google/syzkaller/pkg/aflow/trajectory" @@ -289,8 +290,9 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { backoff := time.Second + model := ctx.modelName(a.Model) for try := 0; ; try++ { - resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req) + resp, err := ctx.generateContent(model, cfg, req) var apiErr genai.APIError if err != nil && try < 100 && errors.As(err, &apiErr) && apiErr.Code == http.StatusServiceUnavailable { @@ -298,6 +300,11 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi backoff = min(backoff+time.Second, 10*time.Second) continue } + if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests && + strings.Contains(apiErr.Message, "Quota exceeded for metric") && + strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") { + return resp, &modelQuotaError{model} + } return resp, err } } |
