diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2026-01-02 17:03:40 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2026-01-09 12:51:45 +0000 |
| commit | 45d8f079628d0d9c0214c07e1abe9e8cb26057d6 (patch) | |
| tree | c7b6e95f040cbbf1322de719360bfe573740272c /pkg | |
| parent | ce25ef79a77633ecbd0042eb35c9432dd582d448 (diff) | |
pkg/aflow: add package for agentic workflows
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/aflow/action.go | 37 | ||||
| -rw-r--r-- | pkg/aflow/cache.go | 201 | ||||
| -rw-r--r-- | pkg/aflow/cache_test.go | 144 | ||||
| -rw-r--r-- | pkg/aflow/execute.go | 186 | ||||
| -rw-r--r-- | pkg/aflow/flow.go | 99 | ||||
| -rw-r--r-- | pkg/aflow/flow_test.go | 554 | ||||
| -rw-r--r-- | pkg/aflow/func_action.go | 46 | ||||
| -rw-r--r-- | pkg/aflow/func_tool.go | 71 | ||||
| -rw-r--r-- | pkg/aflow/llm_agent.go | 254 | ||||
| -rw-r--r-- | pkg/aflow/schema.go | 102 | ||||
| -rw-r--r-- | pkg/aflow/schema_test.go | 50 | ||||
| -rw-r--r-- | pkg/aflow/template.go | 100 | ||||
| -rw-r--r-- | pkg/aflow/template_test.go | 68 | ||||
| -rw-r--r-- | pkg/aflow/verify.go | 98 |
14 files changed, 2010 insertions, 0 deletions
diff --git a/pkg/aflow/action.go b/pkg/aflow/action.go new file mode 100644 index 000000000..cd09466a4 --- /dev/null +++ b/pkg/aflow/action.go @@ -0,0 +1,37 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +type Action interface { + verify(*verifyContext) + execute(*Context) error +} + +type Pipeline struct { + // These actions are invoked sequentially, + // but dataflow across actions is specified by their use + // of variables in args/instructions/prompts. + Actions []Action +} + +func NewPipeline(actions ...Action) *Pipeline { + return &Pipeline{ + Actions: actions, + } +} + +func (p *Pipeline) execute(ctx *Context) error { + for _, sub := range p.Actions { + if err := sub.execute(ctx); err != nil { + return err + } + } + return nil +} + +func (p *Pipeline) verify(ctx *verifyContext) { + for _, a := range p.Actions { + a.verify(ctx) + } +} diff --git a/pkg/aflow/cache.go b/pkg/aflow/cache.go new file mode 100644 index 000000000..fe60e5358 --- /dev/null +++ b/pkg/aflow/cache.go @@ -0,0 +1,201 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "maps" + "os" + "path/filepath" + "slices" + "sync" + "testing" + "time" + + "github.com/google/syzkaller/pkg/hash" + "github.com/google/syzkaller/pkg/osutil" +) + +// Cache maintains on-disk cache with directories with arbitrary contents (kernel checkouts, builds, etc). +// Create method is used to either create a new directory, if it's not cached yet, or returns a previously +// cached directory. Old unused directories are incrementally removed if the total disk usage grows +// over the specified limit. +type Cache struct { + dir string + maxSize uint64 + timeNow func() time.Time + t *testing.T + mu sync.Mutex + currentSize uint64 + entries map[string]*cacheEntry +} + +type cacheEntry struct { + dir string + size uint64 + usageCount int + lastUsed time.Time +} + +func NewCache(dir string, maxSize uint64) (*Cache, error) { + return newTestCache(nil, dir, maxSize, time.Now) +} + +func newTestCache(t *testing.T, dir string, maxSize uint64, timeNow func() time.Time) (*Cache, error) { + if dir == "" { + return nil, fmt.Errorf("cache workdir is empty") + } + c := &Cache{ + dir: osutil.Abs(dir), + maxSize: maxSize, + timeNow: timeNow, + t: t, + entries: make(map[string]*cacheEntry), + } + if err := c.init(); err != nil { + return nil, err + } + return c, nil +} + +// Create creates/returns a cached directory with contents created by the populate callback. +// The populate callback receives a dir it needs to populate with cached files. +// The typ must be a short descriptive name of the contents (e.g. "build", "source", etc). +// The desc is used to identify cached entries and must fully describe the cached contents +// (the second invocation with the same typ+desc will return dir created by the first +// invocation with the same typ+desc). +func (c *Cache) Create(typ, desc string, populate func(string) error) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + // Note: we don't populate a temp dir and then atomically rename it to the final destination, + // because at least kernel builds encode the current path in debug info/compile commands, + // so moving the dir later would break all that. Instead we rely on the presence of the meta file + // to denote valid cache entries. Modification time of the file says when it was last used. + id := hash.String(desc) + dir := filepath.Join(c.dir, typ, id) + metaFile := filepath.Join(dir, cacheMetaFile) + if c.entries[dir] == nil { + os.RemoveAll(dir) + if err := osutil.MkdirAll(dir); err != nil { + return "", err + } + if err := populate(dir); err != nil { + os.RemoveAll(dir) + return "", err + } + size, err := osutil.DiskUsage(dir) + if err != nil { + return "", err + } + if err := osutil.WriteFile(metaFile, []byte(desc)); err != nil { + os.RemoveAll(dir) + return "", err + } + c.entries[dir] = &cacheEntry{ + dir: dir, + size: size, + } + c.currentSize += size + c.logf("created entry %v, size %v, current size %v", dir, size, c.currentSize) + } + // Note the entry was used now. + now := c.timeNow() + if err := os.Chtimes(metaFile, now, now); err != nil { + return "", err + } + entry := c.entries[dir] + entry.usageCount++ + entry.lastUsed = now + c.logf("using entry %v, usage count %v", dir, entry.usageCount) + if err := c.purge(); err != nil { + entry.usageCount-- + return "", err + } + return dir, nil +} + +// Release must be called for every directory returned by Create method when the directory is not used anymore. +func (c *Cache) Release(dir string) { + c.mu.Lock() + defer c.mu.Unlock() + entry := c.entries[dir] + entry.usageCount-- + c.logf("release entry %v, usage count %v", dir, entry.usageCount) + if entry.usageCount < 0 { + panic("negative usageCount") + } +} + +// init reads the cached dirs (disk usage, last use time) from disk when the cache is created. +func (c *Cache) init() error { + dirs, err := filepath.Glob(filepath.Join(c.dir, "*", "*")) + if err != nil { + return err + } + for _, dir := range dirs { + metaFile := filepath.Join(dir, cacheMetaFile) + if !osutil.IsExist(metaFile) { + if err := osutil.RemoveAll(dir); err != nil { + return err + } + continue + } + stat, err := os.Stat(metaFile) + if err != nil { + return err + } + size, err := osutil.DiskUsage(dir) + if err != nil { + return err + } + c.entries[dir] = &cacheEntry{ + dir: dir, + size: size, + lastUsed: stat.ModTime(), + } + c.currentSize += size + } + c.mu.Lock() + defer c.mu.Unlock() + return c.purge() +} + +// purge removes oldest unused directories if the cache is over maxSize. +func (c *Cache) purge() error { + if c.mu.TryLock() { + panic("c.mu is not locked") + } + if c.currentSize < c.maxSize { + return nil + } + list := slices.Collect(maps.Values(c.entries)) + slices.SortFunc(list, func(a, b *cacheEntry) int { + if a.usageCount != b.usageCount { + return a.usageCount - b.usageCount + } + return a.lastUsed.Compare(b.lastUsed) + }) + for _, entry := range list { + if entry.usageCount != 0 || c.currentSize < c.maxSize { + break + } + if err := os.RemoveAll(entry.dir); err != nil { + return err + } + delete(c.entries, entry.dir) + if c.currentSize < entry.size { + panic(fmt.Sprintf("negative currentSize: %v %v", c.currentSize, entry.size)) + } + c.currentSize -= entry.size + } + return nil +} + +func (c *Cache) logf(msg string, args ...any) { + if c.t != nil { + c.t.Logf("cache: "+msg, args...) + } +} + +const cacheMetaFile = "aflow-meta" diff --git a/pkg/aflow/cache_test.go b/pkg/aflow/cache_test.go new file mode 100644 index 000000000..244defdd3 --- /dev/null +++ b/pkg/aflow/cache_test.go @@ -0,0 +1,144 @@ +// Copyright 2026 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/google/syzkaller/pkg/osutil" + "github.com/stretchr/testify/require" +) + +func TestCache(t *testing.T) { + var mockedTime time.Time + timeNow := func() time.Time { + return mockedTime + } + tempDir := t.TempDir() + c, err := newTestCache(t, tempDir, 1<<40, timeNow) + require.NoError(t, err) + dir1, err := c.Create("foo", "1", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "bar"), []byte("abc")) + }) + require.NoError(t, err) + data, err := os.ReadFile(filepath.Join(dir1, "bar")) + require.NoError(t, err) + require.Equal(t, data, []byte("abc")) + c.Release(dir1) + + dir2, err := c.Create("foo", "1", func(dir string) error { + t.Fatal("must not be called") + return nil + }) + require.NoError(t, err) + require.Equal(t, dir2, dir1) + data, err = os.ReadFile(filepath.Join(dir2, "bar")) + require.NoError(t, err) + require.Equal(t, data, []byte("abc")) + c.Release(dir2) + + dir3, err := c.Create("foo", "2", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "baz"), []byte("def")) + }) + require.NoError(t, err) + require.NotEqual(t, dir3, dir1) + data, err = os.ReadFile(filepath.Join(dir3, "baz")) + require.NoError(t, err) + require.Equal(t, data, []byte("def")) + c.Release(dir3) + + failedDir := "" + dir4, err := c.Create("foo", "3", func(dir string) error { + failedDir = dir + return fmt.Errorf("failed") + }) + require.Error(t, err) + require.Empty(t, dir4) + require.False(t, osutil.IsExist(failedDir)) + + // Create a new cache, it should pick up the state from disk. + c, err = newTestCache(t, tempDir, 1<<40, timeNow) + require.NoError(t, err) + + dir5, err := c.Create("foo", "1", func(dir string) error { + t.Fatal("must not be called") + return nil + }) + require.NoError(t, err) + require.Equal(t, dir5, dir1) + data, err = os.ReadFile(filepath.Join(dir5, "bar")) + require.NoError(t, err) + require.Equal(t, data, []byte("abc")) + c.Release(dir5) + + // Model an incomplete dir without metadata, it should be removed. + strayDir := filepath.Join(tempDir, "a", "b") + require.NoError(t, osutil.MkdirAll(strayDir)) + require.NoError(t, osutil.WriteFile(filepath.Join(strayDir, "foo"), []byte("foo"))) + + // With 0 max size everything unused should be purged. + _, err = newTestCache(t, tempDir, 0, timeNow) + require.NoError(t, err) + require.False(t, osutil.IsExist(dir1)) + require.False(t, osutil.IsExist(dir3)) + require.False(t, osutil.IsExist(strayDir)) + + // Test incremental purging of files. + c, err = newTestCache(t, tempDir, 100<<10, timeNow) + require.NoError(t, err) + + mockedTime = mockedTime.Add(time.Minute) + dir6, err := c.Create("foo", "1", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 5<<10)) + }) + require.NoError(t, err) + c.Release(dir6) + + mockedTime = mockedTime.Add(time.Minute) + dir7, err := c.Create("foo", "2", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 5<<10)) + }) + require.NoError(t, err) + c.Release(dir7) + + mockedTime = mockedTime.Add(time.Minute) + dir8, err := c.Create("foo", "3", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 60<<10)) + }) + require.NoError(t, err) + c.Release(dir8) + + // Force update of the last access time for the first dir. + mockedTime = mockedTime.Add(time.Minute) + dir9, err := c.Create("foo", "1", func(dir string) error { + t.Fatal("must not be called") + return nil + }) + require.NoError(t, err) + require.Equal(t, dir6, dir9) + c.Release(dir9) + + // Both dirs should exist since they should fit into cache size. + require.True(t, osutil.IsExist(dir6)) + require.True(t, osutil.IsExist(dir7)) + require.True(t, osutil.IsExist(dir8)) + + mockedTime = mockedTime.Add(time.Minute) + dir10, err := c.Create("foo", "4", func(dir string) error { + return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 60<<10)) + }) + require.NoError(t, err) + c.Release(dir10) + + // Two oldest dirs should be purged. + require.True(t, osutil.IsExist(dir6)) + require.False(t, osutil.IsExist(dir7)) + require.False(t, osutil.IsExist(dir8)) + require.True(t, osutil.IsExist(dir10)) +} diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go new file mode 100644 index 000000000..6e724988e --- /dev/null +++ b/pkg/aflow/execute.go @@ -0,0 +1,186 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "context" + "fmt" + "maps" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/google/syzkaller/pkg/osutil" + "google.golang.org/genai" +) + +// https://ai.google.dev/gemini-api/docs/models +const DefaultModel = "gemini-3-pro-preview" + +func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any, + cache *Cache, onEvent onEvent) (map[string]any, error) { + if err := flow.checkInputs(inputs); err != nil { + return nil, fmt.Errorf("flow inputs are missing: %w", err) + } + ctx := &Context{ + Context: c, + Workdir: osutil.Abs(workdir), + cache: cache, + state: maps.Clone(inputs), + onEvent: onEvent, + } + defer ctx.close() + if s := c.Value(stubContextKey); s != nil { + ctx.stubContext = *s.(*stubContext) + } + if ctx.timeNow == nil { + ctx.timeNow = time.Now + } + if ctx.generateContent == nil { + var err error + ctx.generateContent, err = contentGenerator(c, model) + if err != nil { + return nil, err + } + } + span := &trajectory.Span{ + Type: trajectory.SpanFlow, + Name: flow.Name, + } + if err := ctx.startSpan(span); err != nil { + return nil, err + } + flowErr := flow.Root.execute(ctx) + if flowErr == nil { + span.Results = flow.extractOutputs(ctx.state) + } + if err := ctx.finishSpan(span, flowErr); err != nil { + return nil, err + } + if ctx.spanNesting != 0 { + // Since we finish all spans, even on errors, we should end up at 0. + panic(fmt.Sprintf("unbalanced spans (%v)", ctx.spanNesting)) + } + return span.Results, nil +} + +type ( + onEvent func(*trajectory.Span) error + generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) ( + *genai.GenerateContentResponse, error) + contextKeyType int +) + +var ( + createClientOnce sync.Once + createClientErr error + client *genai.Client + modelList = make(map[string]bool) + stubContextKey = contextKeyType(1) +) + +func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) { + const modelPrefix = "models/" + createClientOnce.Do(func() { + if os.Getenv("GOOGLE_API_KEY") == "" { + createClientErr = fmt.Errorf("set GOOGLE_API_KEY env var to use with Gemini" + + " (see https://ai.google.dev/gemini-api/docs/api-key)") + return + } + client, createClientErr = genai.NewClient(ctx, nil) + if createClientErr != nil { + return + } + for m, err := range client.Models.All(ctx) { + if err != nil { + createClientErr = err + return + } + modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking + } + }) + if createClientErr != nil { + return nil, createClientErr + } + thinking, ok := modelList[model] + if !ok { + models := slices.Collect(maps.Keys(modelList)) + slices.Sort(models) + return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models) + } + return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) { + if thinking { + cfg.ThinkingConfig = &genai.ThinkingConfig{ + // We capture them in the trajectory for analysis. + IncludeThoughts: true, + // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request"). + // See https://ai.google.dev/gemini-api/docs/thinking#set-budget + // However, thoughts output also consumes total output token budget. + // We may consider adjusting ThinkingLevel parameter. + ThinkingBudget: genai.Ptr[int32](-1), + } + } + return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg) + }, nil +} + +type Context struct { + Context context.Context + Workdir string + cache *Cache + cachedDirs []string + state map[string]any + onEvent onEvent + spanSeq int + spanNesting int + stubContext +} + +type stubContext struct { + timeNow func() time.Time + generateContent generateContentFunc +} + +func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) { + dir, err := ctx.cache.Create(typ, desc, populate) + if err != nil { + return "", err + } + ctx.cachedDirs = append(ctx.cachedDirs, dir) + return dir, nil +} + +func (ctx *Context) close() { + for _, dir := range ctx.cachedDirs { + ctx.cache.Release(dir) + } +} + +func (ctx *Context) startSpan(span *trajectory.Span) error { + span.Seq = ctx.spanSeq + ctx.spanSeq++ + span.Nesting = ctx.spanNesting + ctx.spanNesting++ + span.Started = ctx.timeNow() + return ctx.onEvent(span) +} + +func (ctx *Context) finishSpan(span *trajectory.Span, spanErr error) error { + ctx.spanNesting-- + if ctx.spanNesting < 0 { + panic("unbalanced spans") + } + span.Finished = ctx.timeNow() + if spanErr != nil { + span.Error = spanErr.Error() + } + err := ctx.onEvent(span) + if spanErr != nil { + err = spanErr + } + return err +} diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go new file mode 100644 index 000000000..6325b2fd2 --- /dev/null +++ b/pkg/aflow/flow.go @@ -0,0 +1,99 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + + "github.com/google/syzkaller/pkg/aflow/ai" +) + +// Flow describes a single agentic workflow. +// A workflow takes some inputs, and produces some outputs in the end +// (specified as fields of the Inputs/Outputs struct types, correspondingly). +// A workflow consists of one or more actions that do the actual computation +// and produce the outputs. Actions can be based on an arbitrary Go function +// (FuncAction), or an LLM agent invocation (LLMAgent). Actions can produce +// final output fields, and/or intermediate inputs for subsequent actions. +// LLMAgent can also use tools that can accept workflow inputs, or outputs +// or preceding actions. +// A workflow is executed sequentially, but it can be thought of as a dataflow graph. +// Actions are nodes of the graph, and they consume/produce some named values +// (input/output fields, and intermediate values consumed by other actions). +type Flow struct { + Name string // Empty for the main workflow for the workflow type. + Root Action + + *FlowType +} + +type FlowType struct { + Type ai.WorkflowType + Description string + checkInputs func(map[string]any) error + extractOutputs func(map[string]any) map[string]any +} + +var Flows = make(map[string]*Flow) + +// Register a workflow type (characterized by Inputs and Outputs), +// and one or more implementations of the workflow type (actual workflows). +// All workflows for the same type consume the same inputs and produce the same outputs. +// There should be the "main" implementation for the workflow type with an empty name, +// and zero or more secondary implementations with non-empty names. +func Register[Inputs, Outputs any](typ ai.WorkflowType, description string, flows ...*Flow) { + if err := register[Inputs, Outputs](typ, description, Flows, flows); err != nil { + panic(err) + } +} + +func register[Inputs, Outputs any](typ ai.WorkflowType, description string, + all map[string]*Flow, flows []*Flow) error { + t := &FlowType{ + Type: typ, + Description: description, + checkInputs: func(inputs map[string]any) error { + _, err := convertFromMap[Inputs](inputs, false) + return err + }, + extractOutputs: func(state map[string]any) map[string]any { + // Ensure that we actually have all outputs. + tmp, err := convertFromMap[Outputs](state, false) + if err != nil { + panic(err) + } + return convertToMap(tmp) + }, + } + for _, flow := range flows { + if flow.Name == "" { + flow.Name = string(typ) + } else { + flow.Name = string(typ) + "-" + flow.Name + } + flow.FlowType = t + if err := registerOne[Inputs, Outputs](all, flow); err != nil { + return err + } + } + return nil +} + +func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error { + if all[flow.Name] != nil { + return fmt.Errorf("flow %v is already registered", flow.Name) + } + ctx := &verifyContext{ + actions: make(map[string]bool), + state: make(map[string]*varState), + } + provideOutputs[Inputs](ctx, "flow inputs") + flow.Root.verify(ctx) + requireInputs[Outputs](ctx, "flow outputs") + if err := ctx.finalize(); err != nil { + return fmt.Errorf("flow %v: %w", flow.Name, err) + } + all[flow.Name] = flow + return nil +} diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go new file mode 100644 index 000000000..8ab8016f3 --- /dev/null +++ b/pkg/aflow/flow_test.go @@ -0,0 +1,554 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +func TestWorkflow(t *testing.T) { + type flowInputs struct { + InFoo int + InBar string + InBaz string + } + type flowOutputs struct { + OutFoo string + OutBar int + OutBaz string + AgentFoo int + } + type firstFuncInputs struct { + InFoo int + InBar string + } + type firstFuncOutputs struct { + TmpFuncOutput string + OutBar int + } + type secondFuncInputs struct { + AgentBar string + TmpFuncOutput string + InFoo int + } + type secondFuncOutputs struct { + OutBaz string + } + type agentOutputs struct { + AgentFoo int `jsonschema:"foo"` + AgentBar string `jsonschema:"bar"` + } + type tool1State struct { + InFoo int + TmpFuncOutput string + } + type tool1Args struct { + ArgFoo string `jsonschema:"foo"` + ArgBar int `jsonschema:"bar"` + } + type tool1Results struct { + ResFoo int `jsonschema:"foo"` + ResString string `jsonschema:"string"` + } + type tool2State struct { + InFoo int + } + type tool2Args struct { + ArgBaz int `jsonschema:"baz"` + } + type tool2Results struct { + ResBaz int `jsonschema:"baz"` + } + inputs := map[string]any{ + "InFoo": 10, + "InBar": "bar", + "InBaz": "baz", + } + flows := make(map[string]*Flow) + err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ + { + Name: "flow", + Root: NewPipeline( + NewFuncAction("func-action", + func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) { + assert.Equal(t, args.InFoo, 10) + assert.Equal(t, args.InBar, "bar") + return firstFuncOutputs{ + TmpFuncOutput: "func-output", + OutBar: 142, + }, nil + }), + &LLMAgent{ + Name: "smarty", + Reply: "OutFoo", + Outputs: LLMOutputs[agentOutputs](), + Temperature: 0, + Instruction: "You are smarty. {{.InBaz}}", + Prompt: "Prompt: {{.InBaz}} {{.TmpFuncOutput}}", + Tools: []Tool{ + NewFuncTool("tool1", func(ctx *Context, state tool1State, args tool1Args) (tool1Results, error) { + assert.Equal(t, state.InFoo, 10) + assert.Equal(t, state.TmpFuncOutput, "func-output") + assert.Equal(t, args.ArgFoo, "arg-foo") + assert.Equal(t, args.ArgBar, 100) + return tool1Results{ + ResFoo: 200, + ResString: "res-string", + }, nil + }, "tool 1 description"), + NewFuncTool("tool2", func(ctx *Context, state tool2State, args tool2Args) (tool2Results, error) { + assert.Equal(t, state.InFoo, 10) + assert.Equal(t, args.ArgBaz, 101) + return tool2Results{ + ResBaz: 300, + }, nil + }, "tool 2 description"), + }, + }, + NewFuncAction("another-action", + func(ctx *Context, args secondFuncInputs) (secondFuncOutputs, error) { + assert.Equal(t, args.AgentBar, "agent-bar") + assert.Equal(t, args.TmpFuncOutput, "func-output") + assert.Equal(t, args.InFoo, 10) + return secondFuncOutputs{ + OutBaz: "baz", + }, nil + }), + ), + }, + }) + require.NoError(t, err) + var startTime time.Time + stubTime := startTime + replySeq := 0 + stub := &stubContext{ + timeNow: func() time.Time { + stubTime = stubTime.Add(time.Second) + return stubTime + }, + generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText(`You are smarty. baz + +Use set-results tool to provide results of the analysis. +It must be called exactly once before the final reply. +Ignore results of this tool. +`, genai.RoleUser)) + assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0)) + assert.Equal(t, len(cfg.Tools), 3) + assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1") + assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description") + assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2") + assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description") + assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results") + + reply1 := &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id0", + Name: "tool1", + Args: map[string]any{ + "ArgFoo": "arg-foo", + "ArgBar": 100, + }, + }, + }, + { + FunctionCall: &genai.FunctionCall{ + ID: "id1", + Name: "tool2", + Args: map[string]any{ + "ArgBaz": 101, + }, + }, + }, + { + Text: "I am thinking I need to call some tools", + Thought: true, + }, + }} + resp1 := &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id0", + Name: "tool1", + Response: map[string]any{ + "ResFoo": 200, + "ResString": "res-string", + }, + }, + }, + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id1", + Name: "tool2", + Response: map[string]any{ + "ResBaz": 300, + }, + }, + }, + }} + reply2 := &genai.Content{ + Role: string(genai.RoleModel), + Parts: []*genai.Part{ + { + FunctionCall: &genai.FunctionCall{ + ID: "id2", + Name: "set-results", + Args: map[string]any{ + "AgentFoo": 42, + "AgentBar": "agent-bar", + }, + }, + }, + { + Text: "Completly blank.", + Thought: true, + }, + { + Text: "Whatever.", + Thought: true, + }, + }} + resp2 := &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + { + FunctionResponse: &genai.FunctionResponse{ + ID: "id2", + Name: "set-results", + Response: map[string]any{ + "AgentFoo": 42, + "AgentBar": "agent-bar", + }, + }, + }, + }} + + replySeq++ + switch replySeq { + case 1: + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser), + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: reply1}}}, nil + case 2: + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser), + reply1, + resp1, + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{{Content: reply2}}}, nil + case 3: + assert.Equal(t, req, []*genai.Content{ + genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser), + reply1, + resp1, + reply2, + resp2, + }) + return &genai.GenerateContentResponse{ + Candidates: []*genai.Candidate{ + {Content: &genai.Content{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + genai.NewPartFromText("hello, world!")}, + }}}}, nil + + default: + t.Fatal("unexpected LLM calls") + return nil, nil + } + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + // nolint: dupl + expected := []*trajectory.Span{ + { + Seq: 0, + Nesting: 0, + Type: trajectory.SpanFlow, + Name: "test-flow", + Started: startTime.Add(1 * time.Second), + }, + { + Seq: 1, + Nesting: 1, + Type: trajectory.SpanAction, + Name: "func-action", + Started: startTime.Add(2 * time.Second), + }, + { + Seq: 1, + Nesting: 1, + Type: trajectory.SpanAction, + Name: "func-action", + Started: startTime.Add(2 * time.Second), + Finished: startTime.Add(3 * time.Second), + Results: map[string]any{ + "TmpFuncOutput": "func-output", + "OutBar": 142, + }, + }, + { + Seq: 2, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "smarty", + Started: startTime.Add(4 * time.Second), + Instruction: `You are smarty. baz + +Use set-results tool to provide results of the analysis. +It must be called exactly once before the final reply. +Ignore results of this tool. +`, + Prompt: "Prompt: baz func-output", + }, + { + Seq: 3, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(5 * time.Second), + }, + { + Seq: 3, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(5 * time.Second), + Finished: startTime.Add(6 * time.Second), + Thoughts: "I am thinking I need to call some tools", + }, + { + Seq: 4, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool1", + Started: startTime.Add(7 * time.Second), + Args: map[string]any{ + "ArgBar": 100, + "ArgFoo": "arg-foo", + }, + }, + { + Seq: 4, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool1", + Started: startTime.Add(7 * time.Second), + Finished: startTime.Add(8 * time.Second), + Args: map[string]any{ + "ArgBar": 100, + "ArgFoo": "arg-foo", + }, + Results: map[string]any{ + "ResFoo": 200, + "ResString": "res-string", + }, + }, + { + Seq: 5, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Started: startTime.Add(9 * time.Second), + Args: map[string]any{ + "ArgBaz": 101, + }, + }, + { + Seq: 5, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "tool2", + Started: startTime.Add(9 * time.Second), + Finished: startTime.Add(10 * time.Second), + Args: map[string]any{ + "ArgBaz": 101, + }, + Results: map[string]any{ + "ResBaz": 300, + }, + }, + { + Seq: 6, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(11 * time.Second), + }, + { + Seq: 6, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(11 * time.Second), + Finished: startTime.Add(12 * time.Second), + Thoughts: "Completly blank.Whatever.", + }, + { + Seq: 7, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(13 * time.Second), + Args: map[string]any{ + "AgentBar": "agent-bar", + "AgentFoo": 42, + }, + }, + { + Seq: 7, + Nesting: 2, + Type: trajectory.SpanTool, + Name: "set-results", + Started: startTime.Add(13 * time.Second), + Finished: startTime.Add(14 * time.Second), + Args: map[string]any{ + "AgentBar": "agent-bar", + "AgentFoo": 42, + }, + Results: map[string]any{ + "AgentBar": "agent-bar", + "AgentFoo": 42, + }, + }, + { + Seq: 8, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(15 * time.Second), + }, + { + Seq: 8, + Nesting: 2, + Type: trajectory.SpanLLM, + Name: "smarty", + Started: startTime.Add(15 * time.Second), + Finished: startTime.Add(16 * time.Second), + }, + { + Seq: 2, + Nesting: 1, + Type: trajectory.SpanAgent, + Name: "smarty", + Started: startTime.Add(4 * time.Second), + Finished: startTime.Add(17 * time.Second), + Instruction: `You are smarty. baz + +Use set-results tool to provide results of the analysis. +It must be called exactly once before the final reply. +Ignore results of this tool. +`, + Prompt: "Prompt: baz func-output", + Reply: "hello, world!", + Results: map[string]any{ + "AgentBar": "agent-bar", + "AgentFoo": 42, + }, + }, + { + Seq: 9, + Nesting: 1, + Type: trajectory.SpanAction, + Name: "another-action", + Started: startTime.Add(18 * time.Second), + }, + { + Seq: 9, + Nesting: 1, + Type: trajectory.SpanAction, + Name: "another-action", + Started: startTime.Add(18 * time.Second), + Finished: startTime.Add(19 * time.Second), + Results: map[string]any{ + "OutBaz": "baz", + }, + }, + { + Seq: 0, + Nesting: 0, + Type: trajectory.SpanFlow, + Name: "test-flow", + Started: startTime.Add(1 * time.Second), + Finished: startTime.Add(20 * time.Second), + Results: map[string]any{ + "AgentFoo": 42, + "OutBar": 142, + "OutBaz": "baz", + "OutFoo": "hello, world!", + }, + }, + } + onEvent := func(span *trajectory.Span) error { + require.NotEmpty(t, expected) + require.Equal(t, span, expected[0]) + expected = expected[1:] + return nil + } + res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + require.NoError(t, err) + require.Equal(t, res, map[string]any{ + "OutFoo": "hello, world!", + "OutBar": 142, + "OutBaz": "baz", + "AgentFoo": 42, + }) + require.Empty(t, expected) +} + +func TestNoInputs(t *testing.T) { + type flowInputs struct { + InFoo int + InBar string + } + type flowOutputs struct { + } + inputs := map[string]any{ + "InFoo": 10, + } + flows := make(map[string]*Flow) + err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{ + { + Root: NewFuncAction("func-action", + func(ctx *Context, args flowInputs) (flowOutputs, error) { + return flowOutputs{}, nil + }), + }, + }) + require.NoError(t, err) + stub := &stubContext{ + generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) ( + *genai.GenerateContentResponse, error) { + return nil, nil + }, + } + ctx := context.WithValue(context.Background(), stubContextKey, stub) + workdir := t.TempDir() + cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow) + require.NoError(t, err) + onEvent := func(span *trajectory.Span) error { return nil } + _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent) + require.Equal(t, err.Error(), "flow inputs are missing:"+ + " field InBar is not present when converting map to aflow.flowInputs") +} diff --git a/pkg/aflow/func_action.go b/pkg/aflow/func_action.go new file mode 100644 index 000000000..a54579320 --- /dev/null +++ b/pkg/aflow/func_action.go @@ -0,0 +1,46 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "maps" + + "github.com/google/syzkaller/pkg/aflow/trajectory" +) + +func NewFuncAction[Args, Results any](name string, fn func(*Context, Args) (Results, error)) Action { + return &funcAction[Args, Results]{ + name: name, + fn: fn, + } +} + +type funcAction[Args, Results any] struct { + name string + fn func(*Context, Args) (Results, error) +} + +func (a *funcAction[Args, Results]) execute(ctx *Context) error { + args, err := convertFromMap[Args](ctx.state, false) + if err != nil { + return err + } + span := &trajectory.Span{ + Type: trajectory.SpanAction, + Name: a.name, + } + if err := ctx.startSpan(span); err != nil { + return err + } + res, fnErr := a.fn(ctx, args) + span.Results = convertToMap(res) + maps.Insert(ctx.state, maps.All(span.Results)) + return ctx.finishSpan(span, fnErr) +} + +func (a *funcAction[Args, Results]) verify(vctx *verifyContext) { + vctx.requireNotEmpty(a.name, "Name", a.name) + requireInputs[Args](vctx, a.name) + provideOutputs[Results](vctx, a.name) +} diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go new file mode 100644 index 000000000..cd069db84 --- /dev/null +++ b/pkg/aflow/func_tool.go @@ -0,0 +1,71 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "github.com/google/syzkaller/pkg/aflow/trajectory" + "google.golang.org/genai" +) + +// NewFuncTool creates a new tool based on a custom function that an LLM agent can use. +// Name and description are important since they are passed to an LLM agent. +// Args and Results must be structs with fields commented with aflow tag, +// comments are also important since they are passed to the LLM agent. +// Args are accepted from the LLM agent on the tool invocation, Results are returned +// to the LLM agent. State fields are taken from the current execution state +// (they are not exposed to the LLM agent). +func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, Args) (Results, error), + description string) Tool { + return &funcTool[State, Args, Results]{ + Name: name, + Description: description, + Func: fn, + } +} + +type funcTool[State, Args, Results any] struct { + Name string + Description string + Func func(*Context, State, Args) (Results, error) +} + +func (t *funcTool[State, Args, Results]) declaration() *genai.FunctionDeclaration { + return &genai.FunctionDeclaration{ + Name: t.Name, + Description: t.Description, + ParametersJsonSchema: mustSchemaFor[Args](), + ResponseJsonSchema: mustSchemaFor[Results](), + } +} + +func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]any) (map[string]any, error) { + state, err := convertFromMap[State](ctx.state, false) + if err != nil { + return nil, err + } + a, err := convertFromMap[Args](args, true) + if err != nil { + return nil, err + } + span := &trajectory.Span{ + Type: trajectory.SpanTool, + Name: t.Name, + Args: args, + } + if err := ctx.startSpan(span); err != nil { + return nil, err + } + res, err := t.Func(ctx, state, a) + span.Results = convertToMap(res) + err = ctx.finishSpan(span, err) + return span.Results, err +} + +func (t *funcTool[State, Args, Results]) verify(ctx *verifyContext) { + ctx.requireNotEmpty(t.Name, "Name", t.Name) + ctx.requireNotEmpty(t.Name, "Description", t.Description) + requireSchema[Args](ctx, t.Name, "Args") + requireSchema[Results](ctx, t.Name, "Results") + requireInputs[State](ctx, t.Name) +} diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go new file mode 100644 index 000000000..76661add6 --- /dev/null +++ b/pkg/aflow/llm_agent.go @@ -0,0 +1,254 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "maps" + "reflect" + + "github.com/google/syzkaller/pkg/aflow/trajectory" + "google.golang.org/genai" +) + +type LLMAgent struct { + // For logging/debugging. + Name string + // Name of the state variable to store the final reply of the agent. + // These names can be used in subsequent action instructions/prompts, + // and as final workflow outputs. + Reply string + // Optional additional structured outputs besides the final text reply. + // Use LLMOutputs function to create it. + Outputs *llmOutputs + // Value that controls the degree of randomness in token selection. + // Lower temperatures are good for prompts that require a less open-ended or creative response, + // while higher temperatures can lead to more diverse or creative results. + // Must be assigned a float32 value in the range [0, 2]. + Temperature any + // Instructions for the agent. + // Formatted as text/template, can use "{{.Variable}}" as placeholders for dynamic content. + // Variables can come from the workflow inputs, or from preceding actions outputs. + Instruction string + // Prompt for the agent. The same format as Instruction. + Prompt string + // Set of tools for the agent to use. + Tools []Tool +} + +// Tool represents a custom tool an LLMAgent can invoke. +// Use NewFuncTool to create function-based tools. +type Tool interface { + verify(*verifyContext) + declaration() *genai.FunctionDeclaration + execute(*Context, map[string]any) (map[string]any, error) +} + +// LLMOutputs creates a special tool that can be used by LLM to provide structured outputs. +func LLMOutputs[Args any]() *llmOutputs { + return &llmOutputs{ + tool: NewFuncTool("set-results", func(ctx *Context, state struct{}, args Args) (Args, error) { + return args, nil + }, "Use this tool to provide results of the analysis."), + provideOutputs: func(ctx *verifyContext, who string) { + provideOutputs[Args](ctx, who) + }, + instruction: ` + +Use set-results tool to provide results of the analysis. +It must be called exactly once before the final reply. +Ignore results of this tool. +`, + } +} + +type llmOutputs struct { + tool Tool + provideOutputs func(*verifyContext, string) + instruction string +} + +func (a *LLMAgent) execute(ctx *Context) error { + cfg, instruction, tools := a.config(ctx) + span := &trajectory.Span{ + Type: trajectory.SpanAgent, + Name: a.Name, + Instruction: instruction, + Prompt: formatTemplate(a.Prompt, ctx.state), + } + if err := ctx.startSpan(span); err != nil { + return err + } + reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt) + span.Reply = reply + span.Results = outputs + return ctx.finishSpan(span, err) +} + +func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) ( + string, map[string]any, error) { + var outputs map[string]any + req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)} + for { + reqSpan := &trajectory.Span{ + Type: trajectory.SpanLLM, + Name: a.Name, + } + if err := ctx.startSpan(reqSpan); err != nil { + return "", nil, err + } + resp, err := ctx.generateContent(cfg, req) + if err != nil { + return "", nil, ctx.finishSpan(reqSpan, err) + } + reply, thoughts, calls, respErr := a.parseResponse(resp) + reqSpan.Thoughts = thoughts + if err := ctx.finishSpan(reqSpan, respErr); err != nil { + return "", nil, err + } + if len(calls) == 0 { + // This is the final reply. + if a.Outputs != nil && outputs == nil { + return "", nil, fmt.Errorf("LLM did not call tool to set outputs") + } + ctx.state[a.Reply] = reply + maps.Insert(ctx.state, maps.All(outputs)) + return reply, outputs, nil + } + // This is not the final reply, LLM asked to execute some tools. + // Append the current reply, and tool responses to the next request. + responses, outputs1, err := a.callTools(ctx, tools, calls) + if err != nil { + return "", nil, err + } + if outputs != nil && outputs1 != nil { + return "", nil, fmt.Errorf("LLM called outputs tool twice") + } + outputs = outputs1 + req = append(req, resp.Candidates[0].Content, responses) + } +} + +func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, map[string]Tool) { + instruction := formatTemplate(a.Instruction, ctx.state) + toolList := a.Tools + if a.Outputs != nil { + instruction += a.Outputs.instruction + toolList = append(toolList, a.Outputs.tool) + } + toolMap := make(map[string]Tool) + var tools []*genai.Tool + for _, tool := range toolList { + decl := tool.declaration() + toolMap[decl.Name] = tool + tools = append(tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{decl}}) + } + return &genai.GenerateContentConfig{ + ResponseModalities: []string{"TEXT"}, + Temperature: genai.Ptr(a.Temperature.(float32)), + SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser), + Tools: tools, + }, instruction, toolMap +} + +func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai.FunctionCall) ( + *genai.Content, map[string]any, error) { + responses := &genai.Content{ + Role: string(genai.RoleUser), + } + var outputs map[string]any + for _, call := range calls { + tool := tools[call.Name] + if tool == nil { + return nil, nil, fmt.Errorf("no tool %q", call.Name) + } + results, err := tool.execute(ctx, call.Args) + if err != nil { + return nil, nil, err + } + responses.Parts = append(responses.Parts, genai.NewPartFromFunctionResponse(call.Name, results)) + responses.Parts[len(responses.Parts)-1].FunctionResponse.ID = call.ID + if a.Outputs != nil && tool == a.Outputs.tool { + outputs = results + } + } + return responses, outputs, nil +} + +func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) ( + reply, thoughts string, calls []*genai.FunctionCall, err error) { + if len(resp.Candidates) == 0 || resp.Candidates[0] == nil { + err = fmt.Errorf("empty model response") + if resp.PromptFeedback != nil { + err = fmt.Errorf("request blocked: %v", resp.PromptFeedback.BlockReasonMessage) + } + return + } + candidate := resp.Candidates[0] + if candidate.Content == nil || len(candidate.Content.Parts) == 0 { + err = fmt.Errorf("%v (%v)", candidate.FinishMessage, candidate.FinishReason) + return + } + // We don't expect to receive these now. + if candidate.GroundingMetadata != nil || candidate.CitationMetadata != nil || + candidate.LogprobsResult != nil { + err = fmt.Errorf("unexpected reply fields (%+v)", *candidate) + return + } + for _, part := range candidate.Content.Parts { + // We don't expect to receive these now. + if part.VideoMetadata != nil || part.InlineData != nil || + part.FileData != nil || part.FunctionResponse != nil || + part.CodeExecutionResult != nil || part.ExecutableCode != nil { + err = fmt.Errorf("unexpected reply part (%+v)", *part) + return + } + if part.FunctionCall != nil { + calls = append(calls, part.FunctionCall) + } else if part.Thought { + thoughts += part.Text + } else { + reply += part.Text + } + } + return +} + +func (a *LLMAgent) verify(vctx *verifyContext) { + vctx.requireNotEmpty(a.Name, "Name", a.Name) + vctx.requireNotEmpty(a.Name, "Reply", a.Reply) + if temp, ok := a.Temperature.(int); ok { + a.Temperature = float32(temp) + } + if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 { + vctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]") + } + // Verify dataflow. All dynamic variables must be provided by inputs, + // or preceding actions. + a.verifyTemplate(vctx, "Instruction", a.Instruction) + a.verifyTemplate(vctx, "Prompt", a.Prompt) + for _, tool := range a.Tools { + tool.verify(vctx) + } + vctx.provideOutput(a.Name, a.Reply, reflect.TypeFor[string](), true) + if a.Outputs != nil { + a.Outputs.provideOutputs(vctx, a.Name) + } +} + +func (a *LLMAgent) verifyTemplate(vctx *verifyContext, what, text string) { + vctx.requireNotEmpty(a.Name, what, text) + vars := make(map[string]reflect.Type) + for name, state := range vctx.state { + vars[name] = state.typ + } + used, err := verifyTemplate(text, vars) + if err != nil { + vctx.errorf(a.Name, "%v: %v", what, err) + } + for name := range used { + vctx.state[name].used = true + } +} diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go new file mode 100644 index 000000000..e34d465ea --- /dev/null +++ b/pkg/aflow/schema.go @@ -0,0 +1,102 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "encoding/json" + "fmt" + "iter" + "maps" + "reflect" + + "github.com/google/jsonschema-go/jsonschema" +) + +func schemaFor[T any]() (*jsonschema.Schema, error) { + typ := reflect.TypeFor[T]() + if typ.Kind() != reflect.Struct { + return nil, fmt.Errorf("%v is not a struct", typ.Name()) + } + for _, field := range reflect.VisibleFields(typ) { + if field.Tag.Get("jsonschema") == "" { + return nil, fmt.Errorf("%v.%v does not have a jsonschema tag with description", + typ.Name(), field.Name) + } + } + schema, err := jsonschema.For[T](nil) + if err != nil { + return nil, err + } + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, err + } + return resolved.Schema(), nil +} + +func mustSchemaFor[T any]() *jsonschema.Schema { + schema, err := schemaFor[T]() + if err != nil { + panic(err) + } + return schema +} + +func convertToMap[T any](val T) map[string]any { + res := make(map[string]any) + for name, val := range foreachField(&val) { + res[name] = val.Interface() + } + return res +} + +// convertFromMap converts an untyped map to a struct. +// It always ensures that all struct fields are present in the map. +// In the strict mode it also checks that the map does not contain any other unused elements. +func convertFromMap[T any](m map[string]any, strict bool) (T, error) { + m = maps.Clone(m) + var val T + for name, field := range foreachField(&val) { + f, ok := m[name] + if !ok { + return val, fmt.Errorf("field %v is not present when converting map to %T", name, val) + } + delete(m, name) + if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() { + raw, err := json.Marshal(mm) + if err != nil { + return val, err + } + field.Set(reflect.ValueOf(json.RawMessage(raw))) + } else { + field.Set(reflect.ValueOf(f)) + } + } + if strict && len(m) != 0 { + return val, fmt.Errorf("unused fields when converting map to %T: %v", val, m) + } + return val, nil +} + +// foreachField iterates over all public fields of the struct provided in data. +func foreachField(data any) iter.Seq2[string, reflect.Value] { + return func(yield func(string, reflect.Value) bool) { + v := reflect.ValueOf(data).Elem() + for _, field := range reflect.VisibleFields(v.Type()) { + if !yield(field.Name, v.FieldByIndex(field.Index)) { + break + } + } + } +} + +func foreachFieldOf[T any]() iter.Seq2[string, reflect.Type] { + return func(yield func(string, reflect.Type) bool) { + for name, val := range foreachField(new(T)) { + if !yield(name, val.Type()) { + break + } + } + } +} diff --git a/pkg/aflow/schema_test.go b/pkg/aflow/schema_test.go new file mode 100644 index 000000000..ac441f7d6 --- /dev/null +++ b/pkg/aflow/schema_test.go @@ -0,0 +1,50 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "testing" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSchema(t *testing.T) { + type Test struct { + fn func() (*jsonschema.Schema, error) + err string + } + type structWithNoTags struct { + A int + } + type structWithTags struct { + A int `jsonschema:"aaa"` + B string `jsonschema:"bbb"` + } + tests := []Test{ + { + fn: schemaFor[int], + err: "int is not a struct", + }, + { + fn: schemaFor[structWithNoTags], + err: "structWithNoTags.A does not have a jsonschema tag with description", + }, + { + fn: schemaFor[structWithTags], + }, + } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + _, err := test.fn() + if err != nil { + assert.Equal(t, err.Error(), test.err) + return + } + require.Empty(t, test.err) + }) + } +} diff --git a/pkg/aflow/template.go b/pkg/aflow/template.go new file mode 100644 index 000000000..7b0efd194 --- /dev/null +++ b/pkg/aflow/template.go @@ -0,0 +1,100 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "text/template" + "text/template/parse" +) + +// formatTemplate formats template 'text' using the standard text/template logic. +// Panics on any errors, but these panics shouldn't happen if verifyTemplate +// was called for the template before. +func formatTemplate(text string, state map[string]any) string { + templ, err := parseTemplate(text) + if err != nil { + panic(err) + } + w := new(bytes.Buffer) + if err := templ.Execute(w, state); err != nil { + panic(err) + } + return w.String() +} + +// verifyTemplate checks that the template 'text' can be executed with the variables 'vars'. +// Returns the set of variables that are actually used in the template. +func verifyTemplate(text string, vars map[string]reflect.Type) (map[string]bool, error) { + templ, err := parseTemplate(text) + if err != nil { + return nil, err + } + used := make(map[string]bool) + walkTemplate(templ.Root, used, &err) + if err != nil { + return nil, err + } + vals := make(map[string]any) + for name := range used { + typ, ok := vars[name] + if !ok { + return nil, fmt.Errorf("input %v is not provided", name) + } + vals[name] = reflect.Zero(typ).Interface() + } + // Execute once just to make sure. + if err := templ.Execute(io.Discard, vals); err != nil { + return nil, err + } + return used, nil +} + +// walkTemplate recursively walks template nodes collecting used variables. +// It does not handle all node types, but enough to support reasonably simple templates. +func walkTemplate(n parse.Node, used map[string]bool, errp *error) { + if reflect.ValueOf(n).IsNil() { + return + } + switch n := n.(type) { + case *parse.ListNode: + for _, c := range n.Nodes { + walkTemplate(c, used, errp) + } + case *parse.IfNode: + walkTemplate(n.Pipe, used, errp) + walkTemplate(n.List, used, errp) + walkTemplate(n.ElseList, used, errp) + case *parse.ActionNode: + walkTemplate(n.Pipe, used, errp) + case *parse.PipeNode: + for _, c := range n.Decl { + walkTemplate(c, used, errp) + } + for _, c := range n.Cmds { + walkTemplate(c, used, errp) + } + case *parse.CommandNode: + for _, c := range n.Args { + walkTemplate(c, used, errp) + } + case *parse.FieldNode: + if len(n.Ident) != 1 { + noteError(errp, "compound values are not supported: .%v", strings.Join(n.Ident, ".")) + } + used[n.Ident[0]] = true + case *parse.VariableNode: + case *parse.TextNode: + default: + noteError(errp, "unhandled node type %T", n) + } +} + +func parseTemplate(prompt string) (*template.Template, error) { + return template.New("").Option("missingkey=error").Parse(prompt) +} diff --git a/pkg/aflow/template_test.go b/pkg/aflow/template_test.go new file mode 100644 index 000000000..e42ddd2c3 --- /dev/null +++ b/pkg/aflow/template_test.go @@ -0,0 +1,68 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "maps" + "reflect" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTemplate(t *testing.T) { + type Test struct { + template string + vars map[string]reflect.Type + used []string + err string + } + tests := []Test{ + { + template: `just text`, + }, + { + template: ` + {{if .bar}} + {{.foo}} + {{end}} + + {{if $local := .bar}} + {{$local}} + {{end}} + `, + vars: map[string]reflect.Type{ + "bar": reflect.TypeFor[bool](), + "foo": reflect.TypeFor[int](), + "baz": reflect.TypeFor[int](), + }, + used: []string{"bar", "foo"}, + }, + { + template: ` + {{if .bar}} + {{.foo}} + {{end}} + `, + vars: map[string]reflect.Type{ + "bar": reflect.TypeFor[bool](), + }, + err: "input foo is not provided", + }, + } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + used, err := verifyTemplate(test.template, test.vars) + if err != nil { + assert.Equal(t, err.Error(), test.err) + return + } + require.Empty(t, test.err) + assert.ElementsMatch(t, slices.Collect(maps.Keys(used)), test.used) + }) + } +} diff --git a/pkg/aflow/verify.go b/pkg/aflow/verify.go new file mode 100644 index 000000000..d7ccbd124 --- /dev/null +++ b/pkg/aflow/verify.go @@ -0,0 +1,98 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package aflow + +import ( + "fmt" + "maps" + "reflect" + "slices" +) + +type verifyContext struct { + actions map[string]bool + state map[string]*varState + err error +} + +type varState struct { + action string + typ reflect.Type + used bool +} + +func (ctx *verifyContext) errorf(who, msg string, args ...any) { + noteError(&ctx.err, fmt.Sprintf("action %v: %v", who, msg), args...) +} + +func (ctx *verifyContext) requireNotEmpty(who, name, value string) { + if value == "" { + ctx.errorf(who, "%v must not be empty", name) + } +} + +func (ctx *verifyContext) requireInput(who, name string, typ reflect.Type) { + state := ctx.state[name] + if state == nil { + ctx.errorf(who, "no input %v, available inputs: %v", + name, slices.Collect(maps.Keys(ctx.state))) + return + } + if typ != state.typ { + ctx.errorf(who, "input %v has wrong type: want %v, has %v", + name, typ, state.typ) + } + state.used = true +} + +func (ctx *verifyContext) provideOutput(who, name string, typ reflect.Type, unique bool) { + state := ctx.state[name] + if state != nil { + if unique { + ctx.errorf(who, "output %v is already set", name) + } else if typ != state.typ { + ctx.errorf(who, "output %v changes type: %v -> %v", + name, state.typ, typ) + } else if !state.used { + ctx.errorf(state.action, "output %v is unused", name) + } + } + ctx.state[name] = &varState{ + action: who, + typ: typ, + } +} + +func (ctx *verifyContext) finalize() error { + for name, state := range ctx.state { + if !state.used { + ctx.errorf(state.action, "output %v is unused", name) + } + } + return ctx.err +} + +func noteError(errp *error, msg string, args ...any) { + if *errp == nil { + *errp = fmt.Errorf(msg, args...) + } +} + +func requireInputs[T any](ctx *verifyContext, who string) { + for name, typ := range foreachFieldOf[T]() { + ctx.requireInput(who, name, typ) + } +} + +func provideOutputs[T any](ctx *verifyContext, who string) { + for name, typ := range foreachFieldOf[T]() { + ctx.provideOutput(who, name, typ, true) + } +} + +func requireSchema[T any](ctx *verifyContext, who, what string) { + if _, err := schemaFor[T](); err != nil { + ctx.errorf(who, "%v: %v", what, err) + } +} |
