aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/schema.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/aflow/schema.go')
-rw-r--r--pkg/aflow/schema.go22
1 files changed, 19 insertions, 3 deletions
diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go
index e34d465ea..2b2d77f76 100644
--- a/pkg/aflow/schema.go
+++ b/pkg/aflow/schema.go
@@ -54,13 +54,19 @@ func convertToMap[T any](val T) map[string]any {
// convertFromMap converts an untyped map to a struct.
// It always ensures that all struct fields are present in the map.
// In the strict mode it also checks that the map does not contain any other unused elements.
-func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
+// If tool is set, return errors in the form suitable to return back to LLM
+// during tool arguments conversion.
+func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
m = maps.Clone(m)
var val T
for name, field := range foreachField(&val) {
f, ok := m[name]
if !ok {
- return val, fmt.Errorf("field %v is not present when converting map to %T", name, val)
+ if tool {
+ return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
+ } else {
+ return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
+ }
}
delete(m, name)
if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() {
@@ -69,8 +75,16 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
return val, err
}
field.Set(reflect.ValueOf(json.RawMessage(raw)))
- } else {
+ } else if field.Type() == reflect.TypeOf(f) {
field.Set(reflect.ValueOf(f))
+ } else {
+ if tool {
+ return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name())}
+ } else {
+ return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
+ name, f, field.Type().Name())
+ }
}
}
if strict && len(m) != 0 {
@@ -79,6 +93,8 @@ func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
return val, nil
}
+type toolArgsError struct{ error }
+
// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
return func(yield func(string, reflect.Value) bool) {