aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/aflow
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2026-01-02 17:03:40 +0100
committerDmitry Vyukov <dvyukov@google.com>2026-01-09 12:51:45 +0000
commit45d8f079628d0d9c0214c07e1abe9e8cb26057d6 (patch)
treec7b6e95f040cbbf1322de719360bfe573740272c /pkg/aflow
parentce25ef79a77633ecbd0042eb35c9432dd582d448 (diff)
pkg/aflow: add package for agentic workflows
Diffstat (limited to 'pkg/aflow')
-rw-r--r--pkg/aflow/action.go37
-rw-r--r--pkg/aflow/cache.go201
-rw-r--r--pkg/aflow/cache_test.go144
-rw-r--r--pkg/aflow/execute.go186
-rw-r--r--pkg/aflow/flow.go99
-rw-r--r--pkg/aflow/flow_test.go554
-rw-r--r--pkg/aflow/func_action.go46
-rw-r--r--pkg/aflow/func_tool.go71
-rw-r--r--pkg/aflow/llm_agent.go254
-rw-r--r--pkg/aflow/schema.go102
-rw-r--r--pkg/aflow/schema_test.go50
-rw-r--r--pkg/aflow/template.go100
-rw-r--r--pkg/aflow/template_test.go68
-rw-r--r--pkg/aflow/verify.go98
14 files changed, 2010 insertions, 0 deletions
diff --git a/pkg/aflow/action.go b/pkg/aflow/action.go
new file mode 100644
index 000000000..cd09466a4
--- /dev/null
+++ b/pkg/aflow/action.go
@@ -0,0 +1,37 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+type Action interface {
+ verify(*verifyContext)
+ execute(*Context) error
+}
+
+type Pipeline struct {
+ // These actions are invoked sequentially,
+ // but dataflow across actions is specified by their use
+ // of variables in args/instructions/prompts.
+ Actions []Action
+}
+
+func NewPipeline(actions ...Action) *Pipeline {
+ return &Pipeline{
+ Actions: actions,
+ }
+}
+
+func (p *Pipeline) execute(ctx *Context) error {
+ for _, sub := range p.Actions {
+ if err := sub.execute(ctx); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (p *Pipeline) verify(ctx *verifyContext) {
+ for _, a := range p.Actions {
+ a.verify(ctx)
+ }
+}
diff --git a/pkg/aflow/cache.go b/pkg/aflow/cache.go
new file mode 100644
index 000000000..fe60e5358
--- /dev/null
+++ b/pkg/aflow/cache.go
@@ -0,0 +1,201 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "maps"
+ "os"
+ "path/filepath"
+ "slices"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/syzkaller/pkg/hash"
+ "github.com/google/syzkaller/pkg/osutil"
+)
+
+// Cache maintains on-disk cache with directories with arbitrary contents (kernel checkouts, builds, etc).
+// Create method is used to either create a new directory, if it's not cached yet, or returns a previously
+// cached directory. Old unused directories are incrementally removed if the total disk usage grows
+// over the specified limit.
+type Cache struct {
+ dir string
+ maxSize uint64
+ timeNow func() time.Time
+ t *testing.T
+ mu sync.Mutex
+ currentSize uint64
+ entries map[string]*cacheEntry
+}
+
+type cacheEntry struct {
+ dir string
+ size uint64
+ usageCount int
+ lastUsed time.Time
+}
+
+func NewCache(dir string, maxSize uint64) (*Cache, error) {
+ return newTestCache(nil, dir, maxSize, time.Now)
+}
+
+func newTestCache(t *testing.T, dir string, maxSize uint64, timeNow func() time.Time) (*Cache, error) {
+ if dir == "" {
+ return nil, fmt.Errorf("cache workdir is empty")
+ }
+ c := &Cache{
+ dir: osutil.Abs(dir),
+ maxSize: maxSize,
+ timeNow: timeNow,
+ t: t,
+ entries: make(map[string]*cacheEntry),
+ }
+ if err := c.init(); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+// Create creates/returns a cached directory with contents created by the populate callback.
+// The populate callback receives a dir it needs to populate with cached files.
+// The typ must be a short descriptive name of the contents (e.g. "build", "source", etc).
+// The desc is used to identify cached entries and must fully describe the cached contents
+// (the second invocation with the same typ+desc will return dir created by the first
+// invocation with the same typ+desc).
+func (c *Cache) Create(typ, desc string, populate func(string) error) (string, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ // Note: we don't populate a temp dir and then atomically rename it to the final destination,
+ // because at least kernel builds encode the current path in debug info/compile commands,
+ // so moving the dir later would break all that. Instead we rely on the presence of the meta file
+ // to denote valid cache entries. Modification time of the file says when it was last used.
+ id := hash.String(desc)
+ dir := filepath.Join(c.dir, typ, id)
+ metaFile := filepath.Join(dir, cacheMetaFile)
+ if c.entries[dir] == nil {
+ os.RemoveAll(dir)
+ if err := osutil.MkdirAll(dir); err != nil {
+ return "", err
+ }
+ if err := populate(dir); err != nil {
+ os.RemoveAll(dir)
+ return "", err
+ }
+ size, err := osutil.DiskUsage(dir)
+ if err != nil {
+ return "", err
+ }
+ if err := osutil.WriteFile(metaFile, []byte(desc)); err != nil {
+ os.RemoveAll(dir)
+ return "", err
+ }
+ c.entries[dir] = &cacheEntry{
+ dir: dir,
+ size: size,
+ }
+ c.currentSize += size
+ c.logf("created entry %v, size %v, current size %v", dir, size, c.currentSize)
+ }
+ // Note the entry was used now.
+ now := c.timeNow()
+ if err := os.Chtimes(metaFile, now, now); err != nil {
+ return "", err
+ }
+ entry := c.entries[dir]
+ entry.usageCount++
+ entry.lastUsed = now
+ c.logf("using entry %v, usage count %v", dir, entry.usageCount)
+ if err := c.purge(); err != nil {
+ entry.usageCount--
+ return "", err
+ }
+ return dir, nil
+}
+
+// Release must be called for every directory returned by Create method when the directory is not used anymore.
+func (c *Cache) Release(dir string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ entry := c.entries[dir]
+ entry.usageCount--
+ c.logf("release entry %v, usage count %v", dir, entry.usageCount)
+ if entry.usageCount < 0 {
+ panic("negative usageCount")
+ }
+}
+
+// init reads the cached dirs (disk usage, last use time) from disk when the cache is created.
+func (c *Cache) init() error {
+ dirs, err := filepath.Glob(filepath.Join(c.dir, "*", "*"))
+ if err != nil {
+ return err
+ }
+ for _, dir := range dirs {
+ metaFile := filepath.Join(dir, cacheMetaFile)
+ if !osutil.IsExist(metaFile) {
+ if err := osutil.RemoveAll(dir); err != nil {
+ return err
+ }
+ continue
+ }
+ stat, err := os.Stat(metaFile)
+ if err != nil {
+ return err
+ }
+ size, err := osutil.DiskUsage(dir)
+ if err != nil {
+ return err
+ }
+ c.entries[dir] = &cacheEntry{
+ dir: dir,
+ size: size,
+ lastUsed: stat.ModTime(),
+ }
+ c.currentSize += size
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.purge()
+}
+
+// purge removes oldest unused directories if the cache is over maxSize.
+func (c *Cache) purge() error {
+ if c.mu.TryLock() {
+ panic("c.mu is not locked")
+ }
+ if c.currentSize < c.maxSize {
+ return nil
+ }
+ list := slices.Collect(maps.Values(c.entries))
+ slices.SortFunc(list, func(a, b *cacheEntry) int {
+ if a.usageCount != b.usageCount {
+ return a.usageCount - b.usageCount
+ }
+ return a.lastUsed.Compare(b.lastUsed)
+ })
+ for _, entry := range list {
+ if entry.usageCount != 0 || c.currentSize < c.maxSize {
+ break
+ }
+ if err := os.RemoveAll(entry.dir); err != nil {
+ return err
+ }
+ delete(c.entries, entry.dir)
+ if c.currentSize < entry.size {
+ panic(fmt.Sprintf("negative currentSize: %v %v", c.currentSize, entry.size))
+ }
+ c.currentSize -= entry.size
+ }
+ return nil
+}
+
+func (c *Cache) logf(msg string, args ...any) {
+ if c.t != nil {
+ c.t.Logf("cache: "+msg, args...)
+ }
+}
+
+const cacheMetaFile = "aflow-meta"
diff --git a/pkg/aflow/cache_test.go b/pkg/aflow/cache_test.go
new file mode 100644
index 000000000..244defdd3
--- /dev/null
+++ b/pkg/aflow/cache_test.go
@@ -0,0 +1,144 @@
+// Copyright 2026 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/google/syzkaller/pkg/osutil"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCache(t *testing.T) {
+ var mockedTime time.Time
+ timeNow := func() time.Time {
+ return mockedTime
+ }
+ tempDir := t.TempDir()
+ c, err := newTestCache(t, tempDir, 1<<40, timeNow)
+ require.NoError(t, err)
+ dir1, err := c.Create("foo", "1", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "bar"), []byte("abc"))
+ })
+ require.NoError(t, err)
+ data, err := os.ReadFile(filepath.Join(dir1, "bar"))
+ require.NoError(t, err)
+ require.Equal(t, data, []byte("abc"))
+ c.Release(dir1)
+
+ dir2, err := c.Create("foo", "1", func(dir string) error {
+ t.Fatal("must not be called")
+ return nil
+ })
+ require.NoError(t, err)
+ require.Equal(t, dir2, dir1)
+ data, err = os.ReadFile(filepath.Join(dir2, "bar"))
+ require.NoError(t, err)
+ require.Equal(t, data, []byte("abc"))
+ c.Release(dir2)
+
+ dir3, err := c.Create("foo", "2", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "baz"), []byte("def"))
+ })
+ require.NoError(t, err)
+ require.NotEqual(t, dir3, dir1)
+ data, err = os.ReadFile(filepath.Join(dir3, "baz"))
+ require.NoError(t, err)
+ require.Equal(t, data, []byte("def"))
+ c.Release(dir3)
+
+ failedDir := ""
+ dir4, err := c.Create("foo", "3", func(dir string) error {
+ failedDir = dir
+ return fmt.Errorf("failed")
+ })
+ require.Error(t, err)
+ require.Empty(t, dir4)
+ require.False(t, osutil.IsExist(failedDir))
+
+ // Create a new cache, it should pick up the state from disk.
+ c, err = newTestCache(t, tempDir, 1<<40, timeNow)
+ require.NoError(t, err)
+
+ dir5, err := c.Create("foo", "1", func(dir string) error {
+ t.Fatal("must not be called")
+ return nil
+ })
+ require.NoError(t, err)
+ require.Equal(t, dir5, dir1)
+ data, err = os.ReadFile(filepath.Join(dir5, "bar"))
+ require.NoError(t, err)
+ require.Equal(t, data, []byte("abc"))
+ c.Release(dir5)
+
+ // Model an incomplete dir without metadata, it should be removed.
+ strayDir := filepath.Join(tempDir, "a", "b")
+ require.NoError(t, osutil.MkdirAll(strayDir))
+ require.NoError(t, osutil.WriteFile(filepath.Join(strayDir, "foo"), []byte("foo")))
+
+ // With 0 max size everything unused should be purged.
+ _, err = newTestCache(t, tempDir, 0, timeNow)
+ require.NoError(t, err)
+ require.False(t, osutil.IsExist(dir1))
+ require.False(t, osutil.IsExist(dir3))
+ require.False(t, osutil.IsExist(strayDir))
+
+ // Test incremental purging of files.
+ c, err = newTestCache(t, tempDir, 100<<10, timeNow)
+ require.NoError(t, err)
+
+ mockedTime = mockedTime.Add(time.Minute)
+ dir6, err := c.Create("foo", "1", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 5<<10))
+ })
+ require.NoError(t, err)
+ c.Release(dir6)
+
+ mockedTime = mockedTime.Add(time.Minute)
+ dir7, err := c.Create("foo", "2", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 5<<10))
+ })
+ require.NoError(t, err)
+ c.Release(dir7)
+
+ mockedTime = mockedTime.Add(time.Minute)
+ dir8, err := c.Create("foo", "3", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 60<<10))
+ })
+ require.NoError(t, err)
+ c.Release(dir8)
+
+ // Force update of the last access time for the first dir.
+ mockedTime = mockedTime.Add(time.Minute)
+ dir9, err := c.Create("foo", "1", func(dir string) error {
+ t.Fatal("must not be called")
+ return nil
+ })
+ require.NoError(t, err)
+ require.Equal(t, dir6, dir9)
+ c.Release(dir9)
+
+ // Both dirs should exist since they should fit into cache size.
+ require.True(t, osutil.IsExist(dir6))
+ require.True(t, osutil.IsExist(dir7))
+ require.True(t, osutil.IsExist(dir8))
+
+ mockedTime = mockedTime.Add(time.Minute)
+ dir10, err := c.Create("foo", "4", func(dir string) error {
+ return osutil.WriteFile(filepath.Join(dir, "bar"), bytes.Repeat([]byte{'a'}, 60<<10))
+ })
+ require.NoError(t, err)
+ c.Release(dir10)
+
+ // Two oldest dirs should be purged.
+ require.True(t, osutil.IsExist(dir6))
+ require.False(t, osutil.IsExist(dir7))
+ require.False(t, osutil.IsExist(dir8))
+ require.True(t, osutil.IsExist(dir10))
+}
diff --git a/pkg/aflow/execute.go b/pkg/aflow/execute.go
new file mode 100644
index 000000000..6e724988e
--- /dev/null
+++ b/pkg/aflow/execute.go
@@ -0,0 +1,186 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "context"
+ "fmt"
+ "maps"
+ "os"
+ "slices"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "github.com/google/syzkaller/pkg/osutil"
+ "google.golang.org/genai"
+)
+
+// https://ai.google.dev/gemini-api/docs/models
+const DefaultModel = "gemini-3-pro-preview"
+
+func (flow *Flow) Execute(c context.Context, model, workdir string, inputs map[string]any,
+ cache *Cache, onEvent onEvent) (map[string]any, error) {
+ if err := flow.checkInputs(inputs); err != nil {
+ return nil, fmt.Errorf("flow inputs are missing: %w", err)
+ }
+ ctx := &Context{
+ Context: c,
+ Workdir: osutil.Abs(workdir),
+ cache: cache,
+ state: maps.Clone(inputs),
+ onEvent: onEvent,
+ }
+ defer ctx.close()
+ if s := c.Value(stubContextKey); s != nil {
+ ctx.stubContext = *s.(*stubContext)
+ }
+ if ctx.timeNow == nil {
+ ctx.timeNow = time.Now
+ }
+ if ctx.generateContent == nil {
+ var err error
+ ctx.generateContent, err = contentGenerator(c, model)
+ if err != nil {
+ return nil, err
+ }
+ }
+ span := &trajectory.Span{
+ Type: trajectory.SpanFlow,
+ Name: flow.Name,
+ }
+ if err := ctx.startSpan(span); err != nil {
+ return nil, err
+ }
+ flowErr := flow.Root.execute(ctx)
+ if flowErr == nil {
+ span.Results = flow.extractOutputs(ctx.state)
+ }
+ if err := ctx.finishSpan(span, flowErr); err != nil {
+ return nil, err
+ }
+ if ctx.spanNesting != 0 {
+ // Since we finish all spans, even on errors, we should end up at 0.
+ panic(fmt.Sprintf("unbalanced spans (%v)", ctx.spanNesting))
+ }
+ return span.Results, nil
+}
+
+type (
+ onEvent func(*trajectory.Span) error
+ generateContentFunc func(*genai.GenerateContentConfig, []*genai.Content) (
+ *genai.GenerateContentResponse, error)
+ contextKeyType int
+)
+
+var (
+ createClientOnce sync.Once
+ createClientErr error
+ client *genai.Client
+ modelList = make(map[string]bool)
+ stubContextKey = contextKeyType(1)
+)
+
+func contentGenerator(ctx context.Context, model string) (generateContentFunc, error) {
+ const modelPrefix = "models/"
+ createClientOnce.Do(func() {
+ if os.Getenv("GOOGLE_API_KEY") == "" {
+ createClientErr = fmt.Errorf("set GOOGLE_API_KEY env var to use with Gemini" +
+ " (see https://ai.google.dev/gemini-api/docs/api-key)")
+ return
+ }
+ client, createClientErr = genai.NewClient(ctx, nil)
+ if createClientErr != nil {
+ return
+ }
+ for m, err := range client.Models.All(ctx) {
+ if err != nil {
+ createClientErr = err
+ return
+ }
+ modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
+ }
+ })
+ if createClientErr != nil {
+ return nil, createClientErr
+ }
+ thinking, ok := modelList[model]
+ if !ok {
+ models := slices.Collect(maps.Keys(modelList))
+ slices.Sort(models)
+ return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
+ }
+ return func(cfg *genai.GenerateContentConfig, req []*genai.Content) (*genai.GenerateContentResponse, error) {
+ if thinking {
+ cfg.ThinkingConfig = &genai.ThinkingConfig{
+ // We capture them in the trajectory for analysis.
+ IncludeThoughts: true,
+ // Enable "dynamic thinking" ("the model will adjust the budget based on the complexity of the request").
+ // See https://ai.google.dev/gemini-api/docs/thinking#set-budget
+ // However, thoughts output also consumes total output token budget.
+ // We may consider adjusting ThinkingLevel parameter.
+ ThinkingBudget: genai.Ptr[int32](-1),
+ }
+ }
+ return client.Models.GenerateContent(ctx, modelPrefix+model, req, cfg)
+ }, nil
+}
+
+type Context struct {
+ Context context.Context
+ Workdir string
+ cache *Cache
+ cachedDirs []string
+ state map[string]any
+ onEvent onEvent
+ spanSeq int
+ spanNesting int
+ stubContext
+}
+
+type stubContext struct {
+ timeNow func() time.Time
+ generateContent generateContentFunc
+}
+
+func (ctx *Context) Cache(typ, desc string, populate func(string) error) (string, error) {
+ dir, err := ctx.cache.Create(typ, desc, populate)
+ if err != nil {
+ return "", err
+ }
+ ctx.cachedDirs = append(ctx.cachedDirs, dir)
+ return dir, nil
+}
+
+func (ctx *Context) close() {
+ for _, dir := range ctx.cachedDirs {
+ ctx.cache.Release(dir)
+ }
+}
+
+func (ctx *Context) startSpan(span *trajectory.Span) error {
+ span.Seq = ctx.spanSeq
+ ctx.spanSeq++
+ span.Nesting = ctx.spanNesting
+ ctx.spanNesting++
+ span.Started = ctx.timeNow()
+ return ctx.onEvent(span)
+}
+
+func (ctx *Context) finishSpan(span *trajectory.Span, spanErr error) error {
+ ctx.spanNesting--
+ if ctx.spanNesting < 0 {
+ panic("unbalanced spans")
+ }
+ span.Finished = ctx.timeNow()
+ if spanErr != nil {
+ span.Error = spanErr.Error()
+ }
+ err := ctx.onEvent(span)
+ if spanErr != nil {
+ err = spanErr
+ }
+ return err
+}
diff --git a/pkg/aflow/flow.go b/pkg/aflow/flow.go
new file mode 100644
index 000000000..6325b2fd2
--- /dev/null
+++ b/pkg/aflow/flow.go
@@ -0,0 +1,99 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+
+ "github.com/google/syzkaller/pkg/aflow/ai"
+)
+
+// Flow describes a single agentic workflow.
+// A workflow takes some inputs, and produces some outputs in the end
+// (specified as fields of the Inputs/Outputs struct types, correspondingly).
+// A workflow consists of one or more actions that do the actual computation
+// and produce the outputs. Actions can be based on an arbitrary Go function
+// (FuncAction), or an LLM agent invocation (LLMAgent). Actions can produce
+// final output fields, and/or intermediate inputs for subsequent actions.
+// LLMAgent can also use tools that can accept workflow inputs, or outputs
+// or preceding actions.
+// A workflow is executed sequentially, but it can be thought of as a dataflow graph.
+// Actions are nodes of the graph, and they consume/produce some named values
+// (input/output fields, and intermediate values consumed by other actions).
+type Flow struct {
+ Name string // Empty for the main workflow for the workflow type.
+ Root Action
+
+ *FlowType
+}
+
+type FlowType struct {
+ Type ai.WorkflowType
+ Description string
+ checkInputs func(map[string]any) error
+ extractOutputs func(map[string]any) map[string]any
+}
+
+var Flows = make(map[string]*Flow)
+
+// Register a workflow type (characterized by Inputs and Outputs),
+// and one or more implementations of the workflow type (actual workflows).
+// All workflows for the same type consume the same inputs and produce the same outputs.
+// There should be the "main" implementation for the workflow type with an empty name,
+// and zero or more secondary implementations with non-empty names.
+func Register[Inputs, Outputs any](typ ai.WorkflowType, description string, flows ...*Flow) {
+ if err := register[Inputs, Outputs](typ, description, Flows, flows); err != nil {
+ panic(err)
+ }
+}
+
+func register[Inputs, Outputs any](typ ai.WorkflowType, description string,
+ all map[string]*Flow, flows []*Flow) error {
+ t := &FlowType{
+ Type: typ,
+ Description: description,
+ checkInputs: func(inputs map[string]any) error {
+ _, err := convertFromMap[Inputs](inputs, false)
+ return err
+ },
+ extractOutputs: func(state map[string]any) map[string]any {
+ // Ensure that we actually have all outputs.
+ tmp, err := convertFromMap[Outputs](state, false)
+ if err != nil {
+ panic(err)
+ }
+ return convertToMap(tmp)
+ },
+ }
+ for _, flow := range flows {
+ if flow.Name == "" {
+ flow.Name = string(typ)
+ } else {
+ flow.Name = string(typ) + "-" + flow.Name
+ }
+ flow.FlowType = t
+ if err := registerOne[Inputs, Outputs](all, flow); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func registerOne[Inputs, Outputs any](all map[string]*Flow, flow *Flow) error {
+ if all[flow.Name] != nil {
+ return fmt.Errorf("flow %v is already registered", flow.Name)
+ }
+ ctx := &verifyContext{
+ actions: make(map[string]bool),
+ state: make(map[string]*varState),
+ }
+ provideOutputs[Inputs](ctx, "flow inputs")
+ flow.Root.verify(ctx)
+ requireInputs[Outputs](ctx, "flow outputs")
+ if err := ctx.finalize(); err != nil {
+ return fmt.Errorf("flow %v: %w", flow.Name, err)
+ }
+ all[flow.Name] = flow
+ return nil
+}
diff --git a/pkg/aflow/flow_test.go b/pkg/aflow/flow_test.go
new file mode 100644
index 000000000..8ab8016f3
--- /dev/null
+++ b/pkg/aflow/flow_test.go
@@ -0,0 +1,554 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/genai"
+)
+
+func TestWorkflow(t *testing.T) {
+ type flowInputs struct {
+ InFoo int
+ InBar string
+ InBaz string
+ }
+ type flowOutputs struct {
+ OutFoo string
+ OutBar int
+ OutBaz string
+ AgentFoo int
+ }
+ type firstFuncInputs struct {
+ InFoo int
+ InBar string
+ }
+ type firstFuncOutputs struct {
+ TmpFuncOutput string
+ OutBar int
+ }
+ type secondFuncInputs struct {
+ AgentBar string
+ TmpFuncOutput string
+ InFoo int
+ }
+ type secondFuncOutputs struct {
+ OutBaz string
+ }
+ type agentOutputs struct {
+ AgentFoo int `jsonschema:"foo"`
+ AgentBar string `jsonschema:"bar"`
+ }
+ type tool1State struct {
+ InFoo int
+ TmpFuncOutput string
+ }
+ type tool1Args struct {
+ ArgFoo string `jsonschema:"foo"`
+ ArgBar int `jsonschema:"bar"`
+ }
+ type tool1Results struct {
+ ResFoo int `jsonschema:"foo"`
+ ResString string `jsonschema:"string"`
+ }
+ type tool2State struct {
+ InFoo int
+ }
+ type tool2Args struct {
+ ArgBaz int `jsonschema:"baz"`
+ }
+ type tool2Results struct {
+ ResBaz int `jsonschema:"baz"`
+ }
+ inputs := map[string]any{
+ "InFoo": 10,
+ "InBar": "bar",
+ "InBaz": "baz",
+ }
+ flows := make(map[string]*Flow)
+ err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Name: "flow",
+ Root: NewPipeline(
+ NewFuncAction("func-action",
+ func(ctx *Context, args firstFuncInputs) (firstFuncOutputs, error) {
+ assert.Equal(t, args.InFoo, 10)
+ assert.Equal(t, args.InBar, "bar")
+ return firstFuncOutputs{
+ TmpFuncOutput: "func-output",
+ OutBar: 142,
+ }, nil
+ }),
+ &LLMAgent{
+ Name: "smarty",
+ Reply: "OutFoo",
+ Outputs: LLMOutputs[agentOutputs](),
+ Temperature: 0,
+ Instruction: "You are smarty. {{.InBaz}}",
+ Prompt: "Prompt: {{.InBaz}} {{.TmpFuncOutput}}",
+ Tools: []Tool{
+ NewFuncTool("tool1", func(ctx *Context, state tool1State, args tool1Args) (tool1Results, error) {
+ assert.Equal(t, state.InFoo, 10)
+ assert.Equal(t, state.TmpFuncOutput, "func-output")
+ assert.Equal(t, args.ArgFoo, "arg-foo")
+ assert.Equal(t, args.ArgBar, 100)
+ return tool1Results{
+ ResFoo: 200,
+ ResString: "res-string",
+ }, nil
+ }, "tool 1 description"),
+ NewFuncTool("tool2", func(ctx *Context, state tool2State, args tool2Args) (tool2Results, error) {
+ assert.Equal(t, state.InFoo, 10)
+ assert.Equal(t, args.ArgBaz, 101)
+ return tool2Results{
+ ResBaz: 300,
+ }, nil
+ }, "tool 2 description"),
+ },
+ },
+ NewFuncAction("another-action",
+ func(ctx *Context, args secondFuncInputs) (secondFuncOutputs, error) {
+ assert.Equal(t, args.AgentBar, "agent-bar")
+ assert.Equal(t, args.TmpFuncOutput, "func-output")
+ assert.Equal(t, args.InFoo, 10)
+ return secondFuncOutputs{
+ OutBaz: "baz",
+ }, nil
+ }),
+ ),
+ },
+ })
+ require.NoError(t, err)
+ var startTime time.Time
+ stubTime := startTime
+ replySeq := 0
+ stub := &stubContext{
+ timeNow: func() time.Time {
+ stubTime = stubTime.Add(time.Second)
+ return stubTime
+ },
+ generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ assert.Equal(t, cfg.SystemInstruction, genai.NewContentFromText(`You are smarty. baz
+
+Use set-results tool to provide results of the analysis.
+It must be called exactly once before the final reply.
+Ignore results of this tool.
+`, genai.RoleUser))
+ assert.Equal(t, cfg.Temperature, genai.Ptr[float32](0))
+ assert.Equal(t, len(cfg.Tools), 3)
+ assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Name, "tool1")
+ assert.Equal(t, cfg.Tools[0].FunctionDeclarations[0].Description, "tool 1 description")
+ assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Name, "tool2")
+ assert.Equal(t, cfg.Tools[1].FunctionDeclarations[0].Description, "tool 2 description")
+ assert.Equal(t, cfg.Tools[2].FunctionDeclarations[0].Name, "set-results")
+
+ reply1 := &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id0",
+ Name: "tool1",
+ Args: map[string]any{
+ "ArgFoo": "arg-foo",
+ "ArgBar": 100,
+ },
+ },
+ },
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id1",
+ Name: "tool2",
+ Args: map[string]any{
+ "ArgBaz": 101,
+ },
+ },
+ },
+ {
+ Text: "I am thinking I need to call some tools",
+ Thought: true,
+ },
+ }}
+ resp1 := &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id0",
+ Name: "tool1",
+ Response: map[string]any{
+ "ResFoo": 200,
+ "ResString": "res-string",
+ },
+ },
+ },
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id1",
+ Name: "tool2",
+ Response: map[string]any{
+ "ResBaz": 300,
+ },
+ },
+ },
+ }}
+ reply2 := &genai.Content{
+ Role: string(genai.RoleModel),
+ Parts: []*genai.Part{
+ {
+ FunctionCall: &genai.FunctionCall{
+ ID: "id2",
+ Name: "set-results",
+ Args: map[string]any{
+ "AgentFoo": 42,
+ "AgentBar": "agent-bar",
+ },
+ },
+ },
+ {
+ Text: "Completly blank.",
+ Thought: true,
+ },
+ {
+ Text: "Whatever.",
+ Thought: true,
+ },
+ }}
+ resp2 := &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ {
+ FunctionResponse: &genai.FunctionResponse{
+ ID: "id2",
+ Name: "set-results",
+ Response: map[string]any{
+ "AgentFoo": 42,
+ "AgentBar": "agent-bar",
+ },
+ },
+ },
+ }}
+
+ replySeq++
+ switch replySeq {
+ case 1:
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser),
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: reply1}}}, nil
+ case 2:
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser),
+ reply1,
+ resp1,
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{{Content: reply2}}}, nil
+ case 3:
+ assert.Equal(t, req, []*genai.Content{
+ genai.NewContentFromText("Prompt: baz func-output", genai.RoleUser),
+ reply1,
+ resp1,
+ reply2,
+ resp2,
+ })
+ return &genai.GenerateContentResponse{
+ Candidates: []*genai.Candidate{
+ {Content: &genai.Content{
+ Role: string(genai.RoleUser),
+ Parts: []*genai.Part{
+ genai.NewPartFromText("hello, world!")},
+ }}}}, nil
+
+ default:
+ t.Fatal("unexpected LLM calls")
+ return nil, nil
+ }
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ // nolint: dupl
+ expected := []*trajectory.Span{
+ {
+ Seq: 0,
+ Nesting: 0,
+ Type: trajectory.SpanFlow,
+ Name: "test-flow",
+ Started: startTime.Add(1 * time.Second),
+ },
+ {
+ Seq: 1,
+ Nesting: 1,
+ Type: trajectory.SpanAction,
+ Name: "func-action",
+ Started: startTime.Add(2 * time.Second),
+ },
+ {
+ Seq: 1,
+ Nesting: 1,
+ Type: trajectory.SpanAction,
+ Name: "func-action",
+ Started: startTime.Add(2 * time.Second),
+ Finished: startTime.Add(3 * time.Second),
+ Results: map[string]any{
+ "TmpFuncOutput": "func-output",
+ "OutBar": 142,
+ },
+ },
+ {
+ Seq: 2,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "smarty",
+ Started: startTime.Add(4 * time.Second),
+ Instruction: `You are smarty. baz
+
+Use set-results tool to provide results of the analysis.
+It must be called exactly once before the final reply.
+Ignore results of this tool.
+`,
+ Prompt: "Prompt: baz func-output",
+ },
+ {
+ Seq: 3,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(5 * time.Second),
+ },
+ {
+ Seq: 3,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(5 * time.Second),
+ Finished: startTime.Add(6 * time.Second),
+ Thoughts: "I am thinking I need to call some tools",
+ },
+ {
+ Seq: 4,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "tool1",
+ Started: startTime.Add(7 * time.Second),
+ Args: map[string]any{
+ "ArgBar": 100,
+ "ArgFoo": "arg-foo",
+ },
+ },
+ {
+ Seq: 4,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "tool1",
+ Started: startTime.Add(7 * time.Second),
+ Finished: startTime.Add(8 * time.Second),
+ Args: map[string]any{
+ "ArgBar": 100,
+ "ArgFoo": "arg-foo",
+ },
+ Results: map[string]any{
+ "ResFoo": 200,
+ "ResString": "res-string",
+ },
+ },
+ {
+ Seq: 5,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "tool2",
+ Started: startTime.Add(9 * time.Second),
+ Args: map[string]any{
+ "ArgBaz": 101,
+ },
+ },
+ {
+ Seq: 5,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "tool2",
+ Started: startTime.Add(9 * time.Second),
+ Finished: startTime.Add(10 * time.Second),
+ Args: map[string]any{
+ "ArgBaz": 101,
+ },
+ Results: map[string]any{
+ "ResBaz": 300,
+ },
+ },
+ {
+ Seq: 6,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(11 * time.Second),
+ },
+ {
+ Seq: 6,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(11 * time.Second),
+ Finished: startTime.Add(12 * time.Second),
+ Thoughts: "Completly blank.Whatever.",
+ },
+ {
+ Seq: 7,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(13 * time.Second),
+ Args: map[string]any{
+ "AgentBar": "agent-bar",
+ "AgentFoo": 42,
+ },
+ },
+ {
+ Seq: 7,
+ Nesting: 2,
+ Type: trajectory.SpanTool,
+ Name: "set-results",
+ Started: startTime.Add(13 * time.Second),
+ Finished: startTime.Add(14 * time.Second),
+ Args: map[string]any{
+ "AgentBar": "agent-bar",
+ "AgentFoo": 42,
+ },
+ Results: map[string]any{
+ "AgentBar": "agent-bar",
+ "AgentFoo": 42,
+ },
+ },
+ {
+ Seq: 8,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(15 * time.Second),
+ },
+ {
+ Seq: 8,
+ Nesting: 2,
+ Type: trajectory.SpanLLM,
+ Name: "smarty",
+ Started: startTime.Add(15 * time.Second),
+ Finished: startTime.Add(16 * time.Second),
+ },
+ {
+ Seq: 2,
+ Nesting: 1,
+ Type: trajectory.SpanAgent,
+ Name: "smarty",
+ Started: startTime.Add(4 * time.Second),
+ Finished: startTime.Add(17 * time.Second),
+ Instruction: `You are smarty. baz
+
+Use set-results tool to provide results of the analysis.
+It must be called exactly once before the final reply.
+Ignore results of this tool.
+`,
+ Prompt: "Prompt: baz func-output",
+ Reply: "hello, world!",
+ Results: map[string]any{
+ "AgentBar": "agent-bar",
+ "AgentFoo": 42,
+ },
+ },
+ {
+ Seq: 9,
+ Nesting: 1,
+ Type: trajectory.SpanAction,
+ Name: "another-action",
+ Started: startTime.Add(18 * time.Second),
+ },
+ {
+ Seq: 9,
+ Nesting: 1,
+ Type: trajectory.SpanAction,
+ Name: "another-action",
+ Started: startTime.Add(18 * time.Second),
+ Finished: startTime.Add(19 * time.Second),
+ Results: map[string]any{
+ "OutBaz": "baz",
+ },
+ },
+ {
+ Seq: 0,
+ Nesting: 0,
+ Type: trajectory.SpanFlow,
+ Name: "test-flow",
+ Started: startTime.Add(1 * time.Second),
+ Finished: startTime.Add(20 * time.Second),
+ Results: map[string]any{
+ "AgentFoo": 42,
+ "OutBar": 142,
+ "OutBaz": "baz",
+ "OutFoo": "hello, world!",
+ },
+ },
+ }
+ onEvent := func(span *trajectory.Span) error {
+ require.NotEmpty(t, expected)
+ require.Equal(t, span, expected[0])
+ expected = expected[1:]
+ return nil
+ }
+ res, err := flows["test-flow"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ require.NoError(t, err)
+ require.Equal(t, res, map[string]any{
+ "OutFoo": "hello, world!",
+ "OutBar": 142,
+ "OutBaz": "baz",
+ "AgentFoo": 42,
+ })
+ require.Empty(t, expected)
+}
+
+func TestNoInputs(t *testing.T) {
+ type flowInputs struct {
+ InFoo int
+ InBar string
+ }
+ type flowOutputs struct {
+ }
+ inputs := map[string]any{
+ "InFoo": 10,
+ }
+ flows := make(map[string]*Flow)
+ err := register[flowInputs, flowOutputs]("test", "description", flows, []*Flow{
+ {
+ Root: NewFuncAction("func-action",
+ func(ctx *Context, args flowInputs) (flowOutputs, error) {
+ return flowOutputs{}, nil
+ }),
+ },
+ })
+ require.NoError(t, err)
+ stub := &stubContext{
+ generateContent: func(cfg *genai.GenerateContentConfig, req []*genai.Content) (
+ *genai.GenerateContentResponse, error) {
+ return nil, nil
+ },
+ }
+ ctx := context.WithValue(context.Background(), stubContextKey, stub)
+ workdir := t.TempDir()
+ cache, err := newTestCache(t, filepath.Join(workdir, "cache"), 0, stub.timeNow)
+ require.NoError(t, err)
+ onEvent := func(span *trajectory.Span) error { return nil }
+ _, err = flows["test"].Execute(ctx, "model", workdir, inputs, cache, onEvent)
+ require.Equal(t, err.Error(), "flow inputs are missing:"+
+ " field InBar is not present when converting map to aflow.flowInputs")
+}
diff --git a/pkg/aflow/func_action.go b/pkg/aflow/func_action.go
new file mode 100644
index 000000000..a54579320
--- /dev/null
+++ b/pkg/aflow/func_action.go
@@ -0,0 +1,46 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "maps"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+)
+
+func NewFuncAction[Args, Results any](name string, fn func(*Context, Args) (Results, error)) Action {
+ return &funcAction[Args, Results]{
+ name: name,
+ fn: fn,
+ }
+}
+
+type funcAction[Args, Results any] struct {
+ name string
+ fn func(*Context, Args) (Results, error)
+}
+
+func (a *funcAction[Args, Results]) execute(ctx *Context) error {
+ args, err := convertFromMap[Args](ctx.state, false)
+ if err != nil {
+ return err
+ }
+ span := &trajectory.Span{
+ Type: trajectory.SpanAction,
+ Name: a.name,
+ }
+ if err := ctx.startSpan(span); err != nil {
+ return err
+ }
+ res, fnErr := a.fn(ctx, args)
+ span.Results = convertToMap(res)
+ maps.Insert(ctx.state, maps.All(span.Results))
+ return ctx.finishSpan(span, fnErr)
+}
+
+func (a *funcAction[Args, Results]) verify(vctx *verifyContext) {
+ vctx.requireNotEmpty(a.name, "Name", a.name)
+ requireInputs[Args](vctx, a.name)
+ provideOutputs[Results](vctx, a.name)
+}
diff --git a/pkg/aflow/func_tool.go b/pkg/aflow/func_tool.go
new file mode 100644
index 000000000..cd069db84
--- /dev/null
+++ b/pkg/aflow/func_tool.go
@@ -0,0 +1,71 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "google.golang.org/genai"
+)
+
+// NewFuncTool creates a new tool based on a custom function that an LLM agent can use.
+// Name and description are important since they are passed to an LLM agent.
+// Args and Results must be structs with fields commented with aflow tag,
+// comments are also important since they are passed to the LLM agent.
+// Args are accepted from the LLM agent on the tool invocation, Results are returned
+// to the LLM agent. State fields are taken from the current execution state
+// (they are not exposed to the LLM agent).
+func NewFuncTool[State, Args, Results any](name string, fn func(*Context, State, Args) (Results, error),
+ description string) Tool {
+ return &funcTool[State, Args, Results]{
+ Name: name,
+ Description: description,
+ Func: fn,
+ }
+}
+
+type funcTool[State, Args, Results any] struct {
+ Name string
+ Description string
+ Func func(*Context, State, Args) (Results, error)
+}
+
+func (t *funcTool[State, Args, Results]) declaration() *genai.FunctionDeclaration {
+ return &genai.FunctionDeclaration{
+ Name: t.Name,
+ Description: t.Description,
+ ParametersJsonSchema: mustSchemaFor[Args](),
+ ResponseJsonSchema: mustSchemaFor[Results](),
+ }
+}
+
+func (t *funcTool[State, Args, Results]) execute(ctx *Context, args map[string]any) (map[string]any, error) {
+ state, err := convertFromMap[State](ctx.state, false)
+ if err != nil {
+ return nil, err
+ }
+ a, err := convertFromMap[Args](args, true)
+ if err != nil {
+ return nil, err
+ }
+ span := &trajectory.Span{
+ Type: trajectory.SpanTool,
+ Name: t.Name,
+ Args: args,
+ }
+ if err := ctx.startSpan(span); err != nil {
+ return nil, err
+ }
+ res, err := t.Func(ctx, state, a)
+ span.Results = convertToMap(res)
+ err = ctx.finishSpan(span, err)
+ return span.Results, err
+}
+
+func (t *funcTool[State, Args, Results]) verify(ctx *verifyContext) {
+ ctx.requireNotEmpty(t.Name, "Name", t.Name)
+ ctx.requireNotEmpty(t.Name, "Description", t.Description)
+ requireSchema[Args](ctx, t.Name, "Args")
+ requireSchema[Results](ctx, t.Name, "Results")
+ requireInputs[State](ctx, t.Name)
+}
diff --git a/pkg/aflow/llm_agent.go b/pkg/aflow/llm_agent.go
new file mode 100644
index 000000000..76661add6
--- /dev/null
+++ b/pkg/aflow/llm_agent.go
@@ -0,0 +1,254 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "maps"
+ "reflect"
+
+ "github.com/google/syzkaller/pkg/aflow/trajectory"
+ "google.golang.org/genai"
+)
+
+type LLMAgent struct {
+ // For logging/debugging.
+ Name string
+ // Name of the state variable to store the final reply of the agent.
+ // These names can be used in subsequent action instructions/prompts,
+ // and as final workflow outputs.
+ Reply string
+ // Optional additional structured outputs besides the final text reply.
+ // Use LLMOutputs function to create it.
+ Outputs *llmOutputs
+ // Value that controls the degree of randomness in token selection.
+ // Lower temperatures are good for prompts that require a less open-ended or creative response,
+ // while higher temperatures can lead to more diverse or creative results.
+ // Must be assigned a float32 value in the range [0, 2].
+ Temperature any
+ // Instructions for the agent.
+ // Formatted as text/template, can use "{{.Variable}}" as placeholders for dynamic content.
+ // Variables can come from the workflow inputs, or from preceding actions outputs.
+ Instruction string
+ // Prompt for the agent. The same format as Instruction.
+ Prompt string
+ // Set of tools for the agent to use.
+ Tools []Tool
+}
+
+// Tool represents a custom tool an LLMAgent can invoke.
+// Use NewFuncTool to create function-based tools.
+type Tool interface {
+ verify(*verifyContext)
+ declaration() *genai.FunctionDeclaration
+ execute(*Context, map[string]any) (map[string]any, error)
+}
+
+// LLMOutputs creates a special tool that can be used by LLM to provide structured outputs.
+func LLMOutputs[Args any]() *llmOutputs {
+ return &llmOutputs{
+ tool: NewFuncTool("set-results", func(ctx *Context, state struct{}, args Args) (Args, error) {
+ return args, nil
+ }, "Use this tool to provide results of the analysis."),
+ provideOutputs: func(ctx *verifyContext, who string) {
+ provideOutputs[Args](ctx, who)
+ },
+ instruction: `
+
+Use set-results tool to provide results of the analysis.
+It must be called exactly once before the final reply.
+Ignore results of this tool.
+`,
+ }
+}
+
+type llmOutputs struct {
+ tool Tool
+ provideOutputs func(*verifyContext, string)
+ instruction string
+}
+
+func (a *LLMAgent) execute(ctx *Context) error {
+ cfg, instruction, tools := a.config(ctx)
+ span := &trajectory.Span{
+ Type: trajectory.SpanAgent,
+ Name: a.Name,
+ Instruction: instruction,
+ Prompt: formatTemplate(a.Prompt, ctx.state),
+ }
+ if err := ctx.startSpan(span); err != nil {
+ return err
+ }
+ reply, outputs, err := a.chat(ctx, cfg, tools, span.Prompt)
+ span.Reply = reply
+ span.Results = outputs
+ return ctx.finishSpan(span, err)
+}
+
+func (a *LLMAgent) chat(ctx *Context, cfg *genai.GenerateContentConfig, tools map[string]Tool, prompt string) (
+ string, map[string]any, error) {
+ var outputs map[string]any
+ req := []*genai.Content{genai.NewContentFromText(prompt, genai.RoleUser)}
+ for {
+ reqSpan := &trajectory.Span{
+ Type: trajectory.SpanLLM,
+ Name: a.Name,
+ }
+ if err := ctx.startSpan(reqSpan); err != nil {
+ return "", nil, err
+ }
+ resp, err := ctx.generateContent(cfg, req)
+ if err != nil {
+ return "", nil, ctx.finishSpan(reqSpan, err)
+ }
+ reply, thoughts, calls, respErr := a.parseResponse(resp)
+ reqSpan.Thoughts = thoughts
+ if err := ctx.finishSpan(reqSpan, respErr); err != nil {
+ return "", nil, err
+ }
+ if len(calls) == 0 {
+ // This is the final reply.
+ if a.Outputs != nil && outputs == nil {
+ return "", nil, fmt.Errorf("LLM did not call tool to set outputs")
+ }
+ ctx.state[a.Reply] = reply
+ maps.Insert(ctx.state, maps.All(outputs))
+ return reply, outputs, nil
+ }
+ // This is not the final reply, LLM asked to execute some tools.
+ // Append the current reply, and tool responses to the next request.
+ responses, outputs1, err := a.callTools(ctx, tools, calls)
+ if err != nil {
+ return "", nil, err
+ }
+ if outputs != nil && outputs1 != nil {
+ return "", nil, fmt.Errorf("LLM called outputs tool twice")
+ }
+ outputs = outputs1
+ req = append(req, resp.Candidates[0].Content, responses)
+ }
+}
+
+func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, map[string]Tool) {
+ instruction := formatTemplate(a.Instruction, ctx.state)
+ toolList := a.Tools
+ if a.Outputs != nil {
+ instruction += a.Outputs.instruction
+ toolList = append(toolList, a.Outputs.tool)
+ }
+ toolMap := make(map[string]Tool)
+ var tools []*genai.Tool
+ for _, tool := range toolList {
+ decl := tool.declaration()
+ toolMap[decl.Name] = tool
+ tools = append(tools, &genai.Tool{
+ FunctionDeclarations: []*genai.FunctionDeclaration{decl}})
+ }
+ return &genai.GenerateContentConfig{
+ ResponseModalities: []string{"TEXT"},
+ Temperature: genai.Ptr(a.Temperature.(float32)),
+ SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
+ Tools: tools,
+ }, instruction, toolMap
+}
+
+func (a *LLMAgent) callTools(ctx *Context, tools map[string]Tool, calls []*genai.FunctionCall) (
+ *genai.Content, map[string]any, error) {
+ responses := &genai.Content{
+ Role: string(genai.RoleUser),
+ }
+ var outputs map[string]any
+ for _, call := range calls {
+ tool := tools[call.Name]
+ if tool == nil {
+ return nil, nil, fmt.Errorf("no tool %q", call.Name)
+ }
+ results, err := tool.execute(ctx, call.Args)
+ if err != nil {
+ return nil, nil, err
+ }
+ responses.Parts = append(responses.Parts, genai.NewPartFromFunctionResponse(call.Name, results))
+ responses.Parts[len(responses.Parts)-1].FunctionResponse.ID = call.ID
+ if a.Outputs != nil && tool == a.Outputs.tool {
+ outputs = results
+ }
+ }
+ return responses, outputs, nil
+}
+
+func (a *LLMAgent) parseResponse(resp *genai.GenerateContentResponse) (
+ reply, thoughts string, calls []*genai.FunctionCall, err error) {
+ if len(resp.Candidates) == 0 || resp.Candidates[0] == nil {
+ err = fmt.Errorf("empty model response")
+ if resp.PromptFeedback != nil {
+ err = fmt.Errorf("request blocked: %v", resp.PromptFeedback.BlockReasonMessage)
+ }
+ return
+ }
+ candidate := resp.Candidates[0]
+ if candidate.Content == nil || len(candidate.Content.Parts) == 0 {
+ err = fmt.Errorf("%v (%v)", candidate.FinishMessage, candidate.FinishReason)
+ return
+ }
+ // We don't expect to receive these now.
+ if candidate.GroundingMetadata != nil || candidate.CitationMetadata != nil ||
+ candidate.LogprobsResult != nil {
+ err = fmt.Errorf("unexpected reply fields (%+v)", *candidate)
+ return
+ }
+ for _, part := range candidate.Content.Parts {
+ // We don't expect to receive these now.
+ if part.VideoMetadata != nil || part.InlineData != nil ||
+ part.FileData != nil || part.FunctionResponse != nil ||
+ part.CodeExecutionResult != nil || part.ExecutableCode != nil {
+ err = fmt.Errorf("unexpected reply part (%+v)", *part)
+ return
+ }
+ if part.FunctionCall != nil {
+ calls = append(calls, part.FunctionCall)
+ } else if part.Thought {
+ thoughts += part.Text
+ } else {
+ reply += part.Text
+ }
+ }
+ return
+}
+
+func (a *LLMAgent) verify(vctx *verifyContext) {
+ vctx.requireNotEmpty(a.Name, "Name", a.Name)
+ vctx.requireNotEmpty(a.Name, "Reply", a.Reply)
+ if temp, ok := a.Temperature.(int); ok {
+ a.Temperature = float32(temp)
+ }
+ if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
+ vctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
+ }
+ // Verify dataflow. All dynamic variables must be provided by inputs,
+ // or preceding actions.
+ a.verifyTemplate(vctx, "Instruction", a.Instruction)
+ a.verifyTemplate(vctx, "Prompt", a.Prompt)
+ for _, tool := range a.Tools {
+ tool.verify(vctx)
+ }
+ vctx.provideOutput(a.Name, a.Reply, reflect.TypeFor[string](), true)
+ if a.Outputs != nil {
+ a.Outputs.provideOutputs(vctx, a.Name)
+ }
+}
+
+func (a *LLMAgent) verifyTemplate(vctx *verifyContext, what, text string) {
+ vctx.requireNotEmpty(a.Name, what, text)
+ vars := make(map[string]reflect.Type)
+ for name, state := range vctx.state {
+ vars[name] = state.typ
+ }
+ used, err := verifyTemplate(text, vars)
+ if err != nil {
+ vctx.errorf(a.Name, "%v: %v", what, err)
+ }
+ for name := range used {
+ vctx.state[name].used = true
+ }
+}
diff --git a/pkg/aflow/schema.go b/pkg/aflow/schema.go
new file mode 100644
index 000000000..e34d465ea
--- /dev/null
+++ b/pkg/aflow/schema.go
@@ -0,0 +1,102 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "encoding/json"
+ "fmt"
+ "iter"
+ "maps"
+ "reflect"
+
+ "github.com/google/jsonschema-go/jsonschema"
+)
+
+func schemaFor[T any]() (*jsonschema.Schema, error) {
+ typ := reflect.TypeFor[T]()
+ if typ.Kind() != reflect.Struct {
+ return nil, fmt.Errorf("%v is not a struct", typ.Name())
+ }
+ for _, field := range reflect.VisibleFields(typ) {
+ if field.Tag.Get("jsonschema") == "" {
+ return nil, fmt.Errorf("%v.%v does not have a jsonschema tag with description",
+ typ.Name(), field.Name)
+ }
+ }
+ schema, err := jsonschema.For[T](nil)
+ if err != nil {
+ return nil, err
+ }
+ resolved, err := schema.Resolve(nil)
+ if err != nil {
+ return nil, err
+ }
+ return resolved.Schema(), nil
+}
+
+func mustSchemaFor[T any]() *jsonschema.Schema {
+ schema, err := schemaFor[T]()
+ if err != nil {
+ panic(err)
+ }
+ return schema
+}
+
+func convertToMap[T any](val T) map[string]any {
+ res := make(map[string]any)
+ for name, val := range foreachField(&val) {
+ res[name] = val.Interface()
+ }
+ return res
+}
+
+// convertFromMap converts an untyped map to a struct.
+// It always ensures that all struct fields are present in the map.
+// In the strict mode it also checks that the map does not contain any other unused elements.
+func convertFromMap[T any](m map[string]any, strict bool) (T, error) {
+ m = maps.Clone(m)
+ var val T
+ for name, field := range foreachField(&val) {
+ f, ok := m[name]
+ if !ok {
+ return val, fmt.Errorf("field %v is not present when converting map to %T", name, val)
+ }
+ delete(m, name)
+ if mm, ok := f.(map[string]any); ok && field.Type() == reflect.TypeFor[json.RawMessage]() {
+ raw, err := json.Marshal(mm)
+ if err != nil {
+ return val, err
+ }
+ field.Set(reflect.ValueOf(json.RawMessage(raw)))
+ } else {
+ field.Set(reflect.ValueOf(f))
+ }
+ }
+ if strict && len(m) != 0 {
+ return val, fmt.Errorf("unused fields when converting map to %T: %v", val, m)
+ }
+ return val, nil
+}
+
+// foreachField iterates over all public fields of the struct provided in data.
+func foreachField(data any) iter.Seq2[string, reflect.Value] {
+ return func(yield func(string, reflect.Value) bool) {
+ v := reflect.ValueOf(data).Elem()
+ for _, field := range reflect.VisibleFields(v.Type()) {
+ if !yield(field.Name, v.FieldByIndex(field.Index)) {
+ break
+ }
+ }
+ }
+}
+
+func foreachFieldOf[T any]() iter.Seq2[string, reflect.Type] {
+ return func(yield func(string, reflect.Type) bool) {
+ for name, val := range foreachField(new(T)) {
+ if !yield(name, val.Type()) {
+ break
+ }
+ }
+ }
+}
diff --git a/pkg/aflow/schema_test.go b/pkg/aflow/schema_test.go
new file mode 100644
index 000000000..ac441f7d6
--- /dev/null
+++ b/pkg/aflow/schema_test.go
@@ -0,0 +1,50 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/jsonschema-go/jsonschema"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSchema(t *testing.T) {
+ type Test struct {
+ fn func() (*jsonschema.Schema, error)
+ err string
+ }
+ type structWithNoTags struct {
+ A int
+ }
+ type structWithTags struct {
+ A int `jsonschema:"aaa"`
+ B string `jsonschema:"bbb"`
+ }
+ tests := []Test{
+ {
+ fn: schemaFor[int],
+ err: "int is not a struct",
+ },
+ {
+ fn: schemaFor[structWithNoTags],
+ err: "structWithNoTags.A does not have a jsonschema tag with description",
+ },
+ {
+ fn: schemaFor[structWithTags],
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ _, err := test.fn()
+ if err != nil {
+ assert.Equal(t, err.Error(), test.err)
+ return
+ }
+ require.Empty(t, test.err)
+ })
+ }
+}
diff --git a/pkg/aflow/template.go b/pkg/aflow/template.go
new file mode 100644
index 000000000..7b0efd194
--- /dev/null
+++ b/pkg/aflow/template.go
@@ -0,0 +1,100 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+ "text/template"
+ "text/template/parse"
+)
+
+// formatTemplate formats template 'text' using the standard text/template logic.
+// Panics on any errors, but these panics shouldn't happen if verifyTemplate
+// was called for the template before.
+func formatTemplate(text string, state map[string]any) string {
+ templ, err := parseTemplate(text)
+ if err != nil {
+ panic(err)
+ }
+ w := new(bytes.Buffer)
+ if err := templ.Execute(w, state); err != nil {
+ panic(err)
+ }
+ return w.String()
+}
+
+// verifyTemplate checks that the template 'text' can be executed with the variables 'vars'.
+// Returns the set of variables that are actually used in the template.
+func verifyTemplate(text string, vars map[string]reflect.Type) (map[string]bool, error) {
+ templ, err := parseTemplate(text)
+ if err != nil {
+ return nil, err
+ }
+ used := make(map[string]bool)
+ walkTemplate(templ.Root, used, &err)
+ if err != nil {
+ return nil, err
+ }
+ vals := make(map[string]any)
+ for name := range used {
+ typ, ok := vars[name]
+ if !ok {
+ return nil, fmt.Errorf("input %v is not provided", name)
+ }
+ vals[name] = reflect.Zero(typ).Interface()
+ }
+ // Execute once just to make sure.
+ if err := templ.Execute(io.Discard, vals); err != nil {
+ return nil, err
+ }
+ return used, nil
+}
+
+// walkTemplate recursively walks template nodes collecting used variables.
+// It does not handle all node types, but enough to support reasonably simple templates.
+func walkTemplate(n parse.Node, used map[string]bool, errp *error) {
+ if reflect.ValueOf(n).IsNil() {
+ return
+ }
+ switch n := n.(type) {
+ case *parse.ListNode:
+ for _, c := range n.Nodes {
+ walkTemplate(c, used, errp)
+ }
+ case *parse.IfNode:
+ walkTemplate(n.Pipe, used, errp)
+ walkTemplate(n.List, used, errp)
+ walkTemplate(n.ElseList, used, errp)
+ case *parse.ActionNode:
+ walkTemplate(n.Pipe, used, errp)
+ case *parse.PipeNode:
+ for _, c := range n.Decl {
+ walkTemplate(c, used, errp)
+ }
+ for _, c := range n.Cmds {
+ walkTemplate(c, used, errp)
+ }
+ case *parse.CommandNode:
+ for _, c := range n.Args {
+ walkTemplate(c, used, errp)
+ }
+ case *parse.FieldNode:
+ if len(n.Ident) != 1 {
+ noteError(errp, "compound values are not supported: .%v", strings.Join(n.Ident, "."))
+ }
+ used[n.Ident[0]] = true
+ case *parse.VariableNode:
+ case *parse.TextNode:
+ default:
+ noteError(errp, "unhandled node type %T", n)
+ }
+}
+
+func parseTemplate(prompt string) (*template.Template, error) {
+ return template.New("").Option("missingkey=error").Parse(prompt)
+}
diff --git a/pkg/aflow/template_test.go b/pkg/aflow/template_test.go
new file mode 100644
index 000000000..e42ddd2c3
--- /dev/null
+++ b/pkg/aflow/template_test.go
@@ -0,0 +1,68 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "maps"
+ "reflect"
+ "slices"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTemplate(t *testing.T) {
+ type Test struct {
+ template string
+ vars map[string]reflect.Type
+ used []string
+ err string
+ }
+ tests := []Test{
+ {
+ template: `just text`,
+ },
+ {
+ template: `
+ {{if .bar}}
+ {{.foo}}
+ {{end}}
+
+ {{if $local := .bar}}
+ {{$local}}
+ {{end}}
+ `,
+ vars: map[string]reflect.Type{
+ "bar": reflect.TypeFor[bool](),
+ "foo": reflect.TypeFor[int](),
+ "baz": reflect.TypeFor[int](),
+ },
+ used: []string{"bar", "foo"},
+ },
+ {
+ template: `
+ {{if .bar}}
+ {{.foo}}
+ {{end}}
+ `,
+ vars: map[string]reflect.Type{
+ "bar": reflect.TypeFor[bool](),
+ },
+ err: "input foo is not provided",
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ used, err := verifyTemplate(test.template, test.vars)
+ if err != nil {
+ assert.Equal(t, err.Error(), test.err)
+ return
+ }
+ require.Empty(t, test.err)
+ assert.ElementsMatch(t, slices.Collect(maps.Keys(used)), test.used)
+ })
+ }
+}
diff --git a/pkg/aflow/verify.go b/pkg/aflow/verify.go
new file mode 100644
index 000000000..d7ccbd124
--- /dev/null
+++ b/pkg/aflow/verify.go
@@ -0,0 +1,98 @@
+// Copyright 2025 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package aflow
+
+import (
+ "fmt"
+ "maps"
+ "reflect"
+ "slices"
+)
+
+type verifyContext struct {
+ actions map[string]bool
+ state map[string]*varState
+ err error
+}
+
+type varState struct {
+ action string
+ typ reflect.Type
+ used bool
+}
+
+func (ctx *verifyContext) errorf(who, msg string, args ...any) {
+ noteError(&ctx.err, fmt.Sprintf("action %v: %v", who, msg), args...)
+}
+
+func (ctx *verifyContext) requireNotEmpty(who, name, value string) {
+ if value == "" {
+ ctx.errorf(who, "%v must not be empty", name)
+ }
+}
+
+func (ctx *verifyContext) requireInput(who, name string, typ reflect.Type) {
+ state := ctx.state[name]
+ if state == nil {
+ ctx.errorf(who, "no input %v, available inputs: %v",
+ name, slices.Collect(maps.Keys(ctx.state)))
+ return
+ }
+ if typ != state.typ {
+ ctx.errorf(who, "input %v has wrong type: want %v, has %v",
+ name, typ, state.typ)
+ }
+ state.used = true
+}
+
+func (ctx *verifyContext) provideOutput(who, name string, typ reflect.Type, unique bool) {
+ state := ctx.state[name]
+ if state != nil {
+ if unique {
+ ctx.errorf(who, "output %v is already set", name)
+ } else if typ != state.typ {
+ ctx.errorf(who, "output %v changes type: %v -> %v",
+ name, state.typ, typ)
+ } else if !state.used {
+ ctx.errorf(state.action, "output %v is unused", name)
+ }
+ }
+ ctx.state[name] = &varState{
+ action: who,
+ typ: typ,
+ }
+}
+
+func (ctx *verifyContext) finalize() error {
+ for name, state := range ctx.state {
+ if !state.used {
+ ctx.errorf(state.action, "output %v is unused", name)
+ }
+ }
+ return ctx.err
+}
+
+func noteError(errp *error, msg string, args ...any) {
+ if *errp == nil {
+ *errp = fmt.Errorf(msg, args...)
+ }
+}
+
+func requireInputs[T any](ctx *verifyContext, who string) {
+ for name, typ := range foreachFieldOf[T]() {
+ ctx.requireInput(who, name, typ)
+ }
+}
+
+func provideOutputs[T any](ctx *verifyContext, who string) {
+ for name, typ := range foreachFieldOf[T]() {
+ ctx.provideOutput(who, name, typ, true)
+ }
+}
+
+func requireSchema[T any](ctx *verifyContext, who, what string) {
+ if _, err := schemaFor[T](); err != nil {
+ ctx.errorf(who, "%v: %v", what, err)
+ }
+}