aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/func_tool.go
blob: dde359485fe928a068e870cc88b099644899f211 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Copyright 2025 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package aflow

import (
	"errors"

	"github.com/google/syzkaller/pkg/aflow/trajectory"
	"google.golang.org/genai"
)

// NewFuncTool creates a new tool based on a custom function that an LLM agent can use.
// Name and description are important since they are passed to an LLM agent.
// Args and Results must be structs with fields commented with aflow tag,
// comments are also important since they are passed to the LLM agent.
// Args are accepted from the LLM agent on the tool invocation, Results are returned
// to the LLM agent. State fields are taken from the current execution state
// (they are not exposed to the LLM agent).
func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, Args) (Results, error),
	description string) Tool {
	return &funcTool[State, Args, Results]{
		Name:        name,
		Description: description,
		Func:        fn,
	}
}

// BadCallError creates an error that means that LLM made a bad tool call,
// the provided message will be returned to the LLM as an error,
// instead of failing the whole workflow.
func BadCallError(message string) error {
	return &badCallError{errors.New(message)}
}

type badCallError struct {
	error
}

type funcTool[State, Args, Results any] struct {
	Name        string
	Description string
	Func        func(*Context, State, Args) (Results, error)
}

func (t *funcTool[State, Args, Results]) declaration() *genai.FunctionDeclaration {
	return &genai.FunctionDeclaration{
		Name:                 t.Name,
		Description:          t.Description,
		ParametersJsonSchema: mustSchemaFor[Args](),
		ResponseJsonSchema:   mustSchemaFor[Results](),
	}
}

func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]any) (map[string]any, error) {
	state, err := convertFromMap[State](ctx.state, false, false)
	if err != nil {
		return nil, err
	}
	// We parse args in non-strict mode too.
	// LLM shouldn't provide excessive args, but they are known to mess up things
	// in all possible ways occasionally. Generally we want to handle such cases
	// in some way, rather than fail the whole workflow. We could reply to it
	// with an error about this, but it's unclear if the additional round-trip
	// worth it, it already provided all the actual arguments.
	a, err := convertFromMap[Args](args, false, true)
	if err != nil {
		return nil, err
	}
	span := &trajectory.Span{
		Type: trajectory.SpanTool,
		Name: t.Name,
		Args: args,
	}
	if err := ctx.startSpan(span); err != nil {
		return nil, err
	}
	res, err := t.Func(ctx, state, a)
	span.Results = convertToMap(res)
	err = ctx.finishSpan(span, err)
	return span.Results, err
}

func (t *funcTool[State, Args, Results]) verify(ctx *verifyContext) {
	ctx.requireNotEmpty(t.Name, "Name", t.Name)
	ctx.requireNotEmpty(t.Name, "Description", t.Description)
	requireSchema[Args](ctx, t.Name, "Args")
	requireSchema[Results](ctx, t.Name, "Results")
	requireInputs[State](ctx, t.Name)
}