aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow/schema.go
blob: 2b2d77f766ee29c5bb34eb83ee2ca9bdd1b4c2d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright 2025 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package aflow

import (
	"encoding/json"
	"fmt"
	"iter"
	"maps"
	"reflect"

	"github.com/google/jsonschema-go/jsonschema"
)

func schemaFor[T any]() (*jsonschema.Schema, error) {
	typ := reflect.TypeFor[T]()
	if typ.Kind() != reflect.Struct {
		return nil, fmt.Errorf("%v is not a struct", typ.Name())
	}
	for _, field := range reflect.VisibleFields(typ) {
		if field.Tag.Get("jsonschema") == "" {
			return nil, fmt.Errorf("%v.%v does not have a jsonschema tag with description",
				typ.Name(), field.Name)
		}
	}
	schema, err := jsonschema.For[T](nil)
	if err != nil {
		return nil, err
	}
	resolved, err := schema.Resolve(nil)
	if err != nil {
		return nil, err
	}
	return resolved.Schema(), nil
}

func mustSchemaFor[T any]() *jsonschema.Schema {
	schema, err := schemaFor[T]()
	if err != nil {
		panic(err)
	}
	return schema
}

func convertToMap[T any](val T) map[string]any {
	res := make(map[string]any)
	for name, val := range foreachField(&val) {
		res[name] = val.Interface()
	}
	return res
}

// convertFromMap converts an untyped map to a struct.
// It always ensures that all struct fields are present in the map.
// In the strict mode it also checks that the map does not contain any other unused elements.
// If tool is set, return errors in the form suitable to return back to LLM
// during tool arguments conversion.
func convertFromMap[T any](m map[string]any, strict, tool bool) (T, error) {
	m = maps.Clone(m)
	var val T
	for name, field := range foreachField(&val) {
		f, ok := m[name]
		if !ok {
			if tool {
				return val, &toolArgsError{fmt.Errorf("missing argument %q", name)}
			} else {
				return val, fmt.Errorf("field %q is not present when converting map to %T", name, val)
			}
		}
		delete(m, name)
		if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() {
			raw, err := json.Marshal(mm)
			if err != nil {
				return val, err
			}
			field.Set(reflect.ValueOf(json.RawMessage(raw)))
		} else if field.Type() == reflect.TypeOf(f) {
			field.Set(reflect.ValueOf(f))
		} else {
			if tool {
				return val, &toolArgsError{fmt.Errorf("argument %q has wrong type: got %T, want %v",
					name, f, field.Type().Name())}
			} else {
				return val, fmt.Errorf("field %q has wrong type: got %T, want %v",
					name, f, field.Type().Name())
			}
		}
	}
	if strict && len(m) != 0 {
		return val, fmt.Errorf("unused fields when converting map to %T: %v", val, m)
	}
	return val, nil
}

type toolArgsError struct{ error }

// foreachField iterates over all public fields of the struct provided in data.
func foreachField(data any) iter.Seq2[string, reflect.Value] {
	return func(yield func(string, reflect.Value) bool) {
		v := reflect.ValueOf(data).Elem()
		for _, field := range reflect.VisibleFields(v.Type()) {
			if !yield(field.Name, v.FieldByIndex(field.Index)) {
				break
			}
		}
	}
}

func foreachFieldOf[T any]() iter.Seq2[string, reflect.Type] {
	return func(yield func(string, reflect.Type) bool) {
		for name, val := range foreachField(new(T)) {
			if !yield(name, val.Type()) {
				break
			}
		}
	}
}