1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
|
// Copyright 2025 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package aflow
import (
"encoding/json"
"fmt"
"iter"
"maps"
"reflect"
"github.com/google/jsonschema-go/jsonschema"
)
func schemaFor[T any]() (*jsonschema.Schema, error) {
typ := reflect.TypeFor[T]()
if typ.Kind() != reflect.Struct {
return nil, fmt.Errorf("%v is not a struct", typ.Name())
}
for _, field := range reflect.VisibleFields(typ) {
if field.Tag.Get("jsonschema") == "" {
return nil, fmt.Errorf("%v.%v does not have a jsonschema tag with description",
typ.Name(), field.Name)
}
}
schema, err := jsonschema.For[T](nil)
if err != nil {
return nil, err
}
resolved, err := schema.Resolve(nil)
if err != nil {
return nil, err
}
return resolved.Schema(), nil
}
func mustSchemaFor[T any]() *jsonschema.Schema {
schema, err := schemaFor[T]()
if err != nil {
panic(err)
}
return schema
}
func convertToMap[T any](val T) map[string]any {
res := make(map[string]any)
for name, val := range foreachField(&val) {
res[name] = val.Interface()
}
return res
}
// convertFromMap converts an untyped map to a struct.
// It always ensures that all struct fields are present in the map.
// In the strict mode it also checks that the map does not contain any other unused elements.
// If tool is set, return errors in the form suitable to return back to LLM
// during tool arguments conversion.
func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
m = maps.Clone(m)
var val T
for name, field := range foreachField(&val) {
f, ok := m[name]
if !ok {
if tool {
return val, BadCallError(fmt.Sprintf("missing argument %q", name))
} else {
return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
}
}
delete(m, name)
if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() {
raw, err := json.Marshal(mm)
if err != nil {
return val, err
}
field.Set(reflect.ValueOf(json.RawMessage(raw)))
} else if field.Type() == reflect.TypeOf(f) {
field.Set(reflect.ValueOf(f))
} else {
if tool {
return val, BadCallError(fmt.Sprintf("argument %q has wrong type: got %T, want %v",
name, f, field.Type().Name()))
} else {
return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
name, f, field.Type().Name())
}
}
}
if strict && len(m) != 0 {
return val, fmt.Errorf("unused fields when converting map to %T: %v", val, m)
}
return val, nil
}
// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
return func(yield func(string, reflect.Value) bool) {
v := reflect.ValueOf(data).Elem()
for _, field := range reflect.VisibleFields(v.Type()) {
if !yield(field.Name, v.FieldByIndex(field.Index)) {
break
}
}
}
}
func foreachFieldOf[T any]() iter.Seq2[string, reflect.Type] {
return func(yield func(string, reflect.Type) bool) {
for name, val := range foreachField(new(T)) {
if !yield(name, val.Type()) {
break
}
}
}
}
|