aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/flow_test.go6
-rw-r--r--pkg/aflow/llm_agent.go7
2 files changed, 10 insertions, 3 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 795b52d6e..00b2dd8b4 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -184,7 +184,7 @@ func TestWorkflow(t *testing.T) {
if replySeq < 4 {
assert.Equal(t, model, "model1")
assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText("You are smarty. baz"+
- llmOutputsInstruction, genai.RoleUser))
+ llmMultipleToolsInstruction+llmOutputsInstruction, genai.RoleUser))
assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
assert.Equal(t, len(cfg.Tools), 3)
assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1")
@@ -444,7 +444,7 @@ func TestWorkflow(t *testing.T) {
Name: "smarty",
Model: "model1",
Started: startTime.Add(4 * time.Second),
- Instruction: "You are smarty. baz" + llmOutputsInstruction,
+ Instruction: "You are smarty. baz" + llmMultipleToolsInstruction + llmOutputsInstruction,
Prompt: "Prompt: baz func-output",
},
{
@@ -586,7 +586,7 @@ func TestWorkflow(t *testing.T) {
Model: "model1",
Started: startTime.Add(4 * time.Second),
Finished: startTime.Add(17 * time.Second),
- Instruction: "You are smarty. baz" + llmOutputsInstruction,
+ Instruction: "You are smarty. baz" + llmMultipleToolsInstruction + llmOutputsInstruction,
Prompt: "Prompt: baz func-output",
Reply: "hello, world!",
Results: map[string]any{
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
index 5934bf9bd..e5b753607 100644
--- a/pkg/aflow/llm_agent.go
+++ b/pkg/aflow/llm_agent.go
@@ -93,6 +93,10 @@ It must be called exactly once before the final reply.
Ignore results of this tool.
`
+const llmMultipleToolsInstruction = `
+Prefer calling several tools at the same time to save round-trips.
+`
+
const llmMissingOutputs = `You did not call set-results tool.
Please call set-results tool to provide results of the analysis.
`
@@ -210,6 +214,9 @@ func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools ma
func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, map[string]Tool) {
instruction := formatTemplate(a.Instruction, ctx.state)
toolList := a.Tools
+ if len(toolList) != 0 {
+ instruction += llmMultipleToolsInstruction
+ }
if a.Outputs != nil {
instruction += llmOutputsInstruction
toolList = append(toolList, a.Outputs.tool)