diff options
Diffstat (limited to 'pkg/aflow/flow_test.go')
| -rw-r--r-- | pkg/aflow/flow_test.go | 238 |
1 files changed, 237 insertions, 1 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go index 3ee3f545b..795b52d6e 100644 --- a/pkg/aflow/flow_test.go +++ b/pkg/aflow/flow_test.go @@ -916,7 +916,7 @@ func TestNoInputs(t *testing.T) { onEvent := func(span *trajectory.Span) error { return nil } _, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent) require.Equal(t, err.Error(), "flow inputs are missing:"+ - " field InBar is not present when converting map to aflow.flowInputs") + " field \"InBar\" is not present when converting map to aflow.flowInputs") } func TestQuotaResetTime(t *testing.T) { @@ -940,3 +940,239 @@ func TestQuotaResetTime(t *testing.T) { assert.Equal(t, test.reset, got, "when: %v", test.when) } } + +func TestToolMisbehavior(t *testing.T) { + type flowOutputs struct { + Reply string + AdditionalOutput int + } + type tool1Args struct { + Tool1Arg string `jsonschema:"arg"` + } + type tool2Args struct { + Tool2Arg int `jsonschema:"arg"` + } + type tool2Results struct { + Result int `jsonschema:"arg"` + } + flows := make(map[string]*Flow) + err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{ + { + Name: "flow", + Root: NewPipeline( + &LLMAgent{ + Name: "smarty", + Model: "model1", + Temperature: 1, + Reply: "Reply", + + Outputs: LLMOutputs[struct { + AdditionalOutput int `jsonschema:"arg"` + }](), + Instruction: "Do something!", + Prompt: "Data", + Tools: []Tool{ + NewFuncTool("tool1", func(ctx *Context, state struct{}, args tool1Args) (struct{}, error) { + return struct{}{}, nil + }, "tool description"), + NewFuncTool("tool2", func(ctx *Context, state struct{}, args tool2Args) (tool2Results, error) { + return tool2Results{42}, nil + }, "tool description"), + }, + }, + ), + }, + }) + require.NoError(t, err) + replySeq := 0 + stub := &stubContext{ + generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + replySeq++ + switch replySeq { + case 1: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + // This tool call is OK, and the tool must be called. + { + FunctionCall: &genai.FunctionCall{ + ID: "id1", + Name: "tool1", + Args: map[string]any{ + "Tool1Arg": "string", + }, + }, + }, + // Incorrect argument type. + { + FunctionCall: &genai.FunctionCall{ + ID: "id2", + Name: "tool2", + Args: map[string]any{ + "Tool2Arg": "string-instead-of-int", + }, + }, + }, + // Missing argument. + { + FunctionCall: &genai.FunctionCall{ + ID: "id3", + Name: "tool2", + Args: map[string]any{}, + }, + }, + // Excessive argument. + { + FunctionCall: &genai.FunctionCall{ + ID: "id4", + Name: "tool2", + Args: map[string]any{ + "Tool2Arg": 0, + "Tool2Arg2": 100, + }, + }, + }, + // Tool that does not exist. + { + FunctionCall: &genai.FunctionCall{ + ID: "id5", + Name: "tool3", + Args: map[string]any{ + "Arg": 0, + }, + }, + }, + // Wrong arg for set-results (should not count as it was called). + { + FunctionCall: &genai.FunctionCall{ + ID: "id6", + Name: "set-results", + Args: map[string]any{ + "WrongArg": 0, + }, + }, + }, + }}}}}, nil + case 2: + require.Equal(t, len(req), 3) + assert.Equal(t, req[2], &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id1", + Name: "tool1", + Response: map[string]any{}, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id2", + Name: "tool2", + Response: map[string]any{ + "error": "argument \"Tool2Arg\" has wrong type: got string, want int", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id3", + Name: "tool2", + Response: map[string]any{ + "error": "missing argument \"Tool2Arg\"", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id4", + Name: "tool2", + Response: map[string]any{ + "Result": 42, + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id5", + Name: "tool3", + Response: map[string]any{ + "error": "tool \"tool3\" does not exist, please correct the name", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id6", + Name: "set-results", + Response: map[string]any{ + "error": "missing argument \"AdditionalOutput\"", + }, + }, + }, + }}) + // Now it tries to provide the final result w/o calling set-results (successfully). + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("I am done")}, + }}}}, nil + case 3: + // Reply that set-results wasn't called. + require.Equal(t, len(req), 5) + assert.Equal(t, req[4], genai.NewContentFromText(llmMissingOutputs, genai.RoleUser)) + // Now call it twice. + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id1", + Name: "set-results", + Args: map[string]any{ + "AdditionalOutput": 1, + }, + }, + }, + { + FunctionCall: &genai.FunctionCall{ + ID: "id2", + Name: "set-results", + Args: map[string]any{ + "AdditionalOutput": 2, + }, + }, + }, + }}}}}, nil + case 4: + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("Finally done")}, + }}}}, nil + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + res, err := flows["test-flow"].Execute(ctx, "", workdir, map[string]any{}, cache, onEvent) + require.NoError(t, err) + require.Equal(t, replySeq, 4) + require.Equal(t, res, map[string]any{ + "Reply": "Finally done", + "AdditionalOutput": 2, + }) +} |
