aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/llm_agent.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/llm_agent.go')
-rw-r--r--pkg/aflow/llm_agent.go21
1 files changed, 20 insertions, 1 deletions
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index c30143425..02d3bca85 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -4,9 +4,12 @@
package aflow
import (
+ "errors"
"fmt"
"maps"
+ "net/http"
"reflect"
+ "time"
"github.com/google/syzkaller/pkg/aflow/trajectory"
"google.golang.org/genai"
@@ -155,7 +158,7 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
if err := ctx.startSpan(reqSpan); err != nil {
return "", nil, err
}
- resp, err := ctx.generateContent(cfg, req)
+ resp, err := a.generateContent(ctx, cfg, req)
if err != nil {
return "", nil, ctx.finishSpan(reqSpan, err)
}
@@ -271,6 +274,22 @@ func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
return
}
+func (a *LLMAgent) generateContent(ctx *Context, cfg *genai.GenerateContentConfig,
+ req []*genai.Content) (*genai.GenerateContentResponse, error) {
+ backoff := time.Second
+ for try := 0; ; try++ {
+ resp, err := ctx.generateContent(cfg, req)
+ var apiErr genai.APIError
+ if err != nil && try < 100 && errors.As(err, &apiErr) &&
+ apiErr.Code == http.StatusServiceUnavailable {
+ time.Sleep(backoff)
+ backoff = min(backoff+time.Second, 10*time.Second)
+ continue
+ }
+ return resp, err
+ }
+}
+
func (a *LLMAgent) verify(vctx *verifyContext) {
vctx.requireNotEmpty(a.Name, "Name", a.Name)
vctx.requireNotEmpty(a.Name, "Reply", a.Reply)