aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/llm_agent.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/llm_agent.go')
-rw-r--r--pkg/aflow/llm_agent.go9
1 files changed, 8 insertions, 1 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 3c416b37c..d5e4d6d4d 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -9,6 +9,7 @@ import (
"maps"
"net/http"
"reflect"
+ "strings"
"time"
"github.com/google/syzkaller/pkg/aflow/trajectory"
@@ -289,8 +290,9 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig,
req []*genai.Content) (*genai.GenerateContentResponse, error) {
backoff := time.Second
+ model := ctx.modelName(a.Model)
for try := 0; ; try++ {
- resp, err := ctx.generateContent(ctx.modelName(a.Model), cfg, req)
+ resp, err := ctx.generateContent(model, cfg, req)
var apiErr genai.APIError
if err != nil && try < 100 && errors.As(err, &apiErr) &&
apiErr.Code == http.StatusServiceUnavailable {
@@ -298,6 +300,11 @@ func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfi
backoff = min(backoff+time.Second, 10*time.Second)
continue
}
+ if err != nil && errors.As(err, &apiErr) && apiErr.Code == http.StatusTooManyRequests &&
+ strings.Contains(apiErr.Message, "Quota exceeded for metric") &&
+ strings.Contains(apiErr.Message, "generate_requests_per_model_per_day") {
+ return resp, &modelQuotaError{model}
+ }
return resp, err
}
}