aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/flow_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/flow_test.go')
-rw-r--r--pkg/aflow/flow_test.go238
1 files changed, 237 insertions, 1 deletions
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
index 3ee3f545b..795b52d6e 100644
--- a/pkg/aflow/flow_test.go
+++ b/pkg/aflow/flow_test.go
@@ -916,7 +916,7 @@ func TestNoInputs(t *testing.T) {
onEvent := func(span *trajectory.Span) error { return nil }
_, err = flows["test"].Execute(ctx, "", workdir, inputs, cache, onEvent)
require.Equal(t, err.Error(), "flow inputs are missing:"+
- " field InBar is not present when converting map to aflow.flowInputs")
+ " field \"InBar\" is not present when converting map to aflow.flowInputs")
}
func TestQuotaResetTime(t *testing.T) {
@@ -940,3 +940,239 @@ func TestQuotaResetTime(t *testing.T) {
assert.Equal(t, test.reset, got, "when: %v", test.when)
}
}
+
+func TestToolMisbehavior(t *testing.T) {
+ type flowOutputs struct {
+ Reply string
+ AdditionalOutput int
+ }
+ type tool1Args struct {
+ Tool1Arg string `jsonschema:"arg"`
+ }
+ type tool2Args struct {
+ Tool2Arg int `jsonschema:"arg"`
+ }
+ type tool2Results struct {
+ Result int `jsonschema:"arg"`
+ }
+ flows := make(map[string]*Flow)
+ err := register[struct{}, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Name: "flow",
+ Root: NewPipeline(
+ &LLMAgent{
+ Name: "smarty",
+ Model: "model1",
+ Temperature: 1,
+ Reply: "Reply",
+
+ Outputs: LLMOutputs[struct {
+ AdditionalOutput int `jsonschema:"arg"`
+ }](),
+ Instruction: "Do something!",
+ Prompt: "Data",
+ Tools: []Tool{
+ NewFuncTool("tool1", func(ctx *Context, state struct{}, args tool1Args) (struct{}, error) {
+ return struct{}{}, nil
+ }, "tool description"),
+ NewFuncTool("tool2", func(ctx *Context, state struct{}, args tool2Args) (tool2Results, error) {
+ return tool2Results{42}, nil
+ }, "tool description"),
+ },
+ },
+ ),
+ },
+ })
+ require.NoError(t, err)
+ replySeq := 0
+ stub := &stubContext{
+ generateContent: func(model string, cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ replySeq++
+ switch replySeq {
+ case 1:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ // This tool call is OK, and the tool must be called.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id1",
+ Name: "tool1",
+ Args: map[string]any{
+ "Tool1Arg": "string",
+ },
+ },
+ },
+ // Incorrect argument type.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id2",
+ Name: "tool2",
+ Args: map[string]any{
+ "Tool2Arg": "string-instead-of-int",
+ },
+ },
+ },
+ // Missing argument.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id3",
+ Name: "tool2",
+ Args: map[string]any{},
+ },
+ },
+ // Excessive argument.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id4",
+ Name: "tool2",
+ Args: map[string]any{
+ "Tool2Arg": 0,
+ "Tool2Arg2": 100,
+ },
+ },
+ },
+ // Tool that does not exist.
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id5",
+ Name: "tool3",
+ Args: map[string]any{
+ "Arg": 0,
+ },
+ },
+ },
+ // Wrong arg for set-results (should not count as it was called).
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id6",
+ Name: "set-results",
+ Args: map[string]any{
+ "WrongArg": 0,
+ },
+ },
+ },
+ }}}}}, nil
+ case 2:
+ require.Equal(t, len(req), 3)
+ assert.Equal(t, req[2], &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id1",
+ Name: "tool1",
+ Response: map[string]any{},
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id2",
+ Name: "tool2",
+ Response: map[string]any{
+ "error": "argument \"Tool2Arg\" has wrong type: got string, want int",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id3",
+ Name: "tool2",
+ Response: map[string]any{
+ "error": "missing argument \"Tool2Arg\"",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id4",
+ Name: "tool2",
+ Response: map[string]any{
+ "Result": 42,
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id5",
+ Name: "tool3",
+ Response: map[string]any{
+ "error": "tool \"tool3\" does not exist, please correct the name",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id6",
+ Name: "set-results",
+ Response: map[string]any{
+ "error": "missing argument \"AdditionalOutput\"",
+ },
+ },
+ },
+ }})
+ // Now it tries to provide the final result w/o calling set-results (successfully).
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("I am done")},
+ }}}}, nil
+ case 3:
+ // Reply that set-results wasn't called.
+ require.Equal(t, len(req), 5)
+ assert.Equal(t, req[4], genai.NewContentFromText(llmMissingOutputs, genai.RoleUser))
+ // Now call it twice.
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id1",
+ Name: "set-results",
+ Args: map[string]any{
+ "AdditionalOutput": 1,
+ },
+ },
+ },
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id2",
+ Name: "set-results",
+ Args: map[string]any{
+ "AdditionalOutput": 2,
+ },
+ },
+ },
+ }}}}}, nil
+ case 4:
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("Finally done")},
+ }}}}, nil
+ default:
+ t.Fatal("unexpected LLM calls")
+ return nil, nil
+ }
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ onEvent := func(span *trajectory.Span) error { return nil }
+ res, err := flows["test-flow"].Execute(ctx, "", workdir, map[string]any{}, cache, onEvent)
+ require.NoError(t, err)
+ require.Equal(t, replySeq, 4)
+ require.Equal(t, res, map[string]any{
+ "Reply": "Finally done",
+ "AdditionalOutput": 2,
+ })
+}