diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2024-05-29 16:51:01 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2024-06-04 14:58:31 +0000 |
| commit | e1e2c66edd2e6bbef9c942acf1f59719c482c0d9 (patch) | |
| tree | 614aa21c3f4bccdd7e84b1cba4b61a15f136d82b /tools/syz-gemini-seed | |
| parent | ff43a057f559daec032a1e8e18791e7e05c6676e (diff) | |
tools/syz-gemini-seed: add tool
syz-gemini-seed generates program seeds based on existing
programs in the corpus using Gemini API.
Diffstat (limited to 'tools/syz-gemini-seed')
| -rw-r--r-- | tools/syz-gemini-seed/gemini-seed.go | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/tools/syz-gemini-seed/gemini-seed.go b/tools/syz-gemini-seed/gemini-seed.go new file mode 100644 index 000000000..09d5e06ca --- /dev/null +++ b/tools/syz-gemini-seed/gemini-seed.go @@ -0,0 +1,110 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +// syz-gemini-seed generates program seeds based on existing programs in the corpus using Gemini API. +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "runtime" + + "github.com/google/generative-ai-go/genai" + "github.com/google/syzkaller/pkg/db" + "github.com/google/syzkaller/pkg/tool" + "github.com/google/syzkaller/prog" + _ "github.com/google/syzkaller/sys" + "google.golang.org/api/option" +) + +func main() { + var ( + flagOS = flag.String("os", runtime.GOOS, "target OS") + flagArch = flag.String("arch", runtime.GOARCH, "target arch") + flagCorpus = flag.String("corpus", "", "wxisting corpus.db file to use as examples") + flagCount = flag.Int("count", 1, "number of programs to generate") + flagAPIKey = flag.String("key", "", "gemini API key to use") + ) + tool.Init() + + target, err := prog.GetTarget(*flagOS, *flagArch) + if err != nil { + tool.Failf("failed to find target: %v", err) + } + + db, err := db.Open(*flagCorpus, false) + if err != nil { + tool.Failf("failed to open database: %v", err) + } + + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(*flagAPIKey)) + if err != nil { + tool.Fail(err) + } + defer client.Close() + + for i := 0; i < *flagCount; i++ { + model := client.GenerativeModel("gemini-1.5-pro") + model.SetTemperature(0.9) + // This does not work (fails with "Only one candidate can be specified"). + // model.SetCandidateCount(3) + // TODO: tune TopP/TopK. + // model.SetTopP(0.5) + // model.SetTopK(20) + // TODO: do we need any system instructions? + // model.SystemInstruction = &genai.Content{ + // Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")}, + // } + + // In some cases it thinks it generates unsafe content, so disable safety. + // TODO: this fails with some cryptic error. + if false { + for cat := genai.HarmCategoryDerogatory; cat <= genai.HarmCategoryDangerousContent; cat++ { + model.SafetySettings = append(model.SafetySettings, &genai.SafetySetting{ + Category: cat, + Threshold: genai.HarmBlockNone, + }) + } + } + + prompt := new(bytes.Buffer) + prompt.WriteString("Below are examples of test programs in a special notation.\n\n") + // TODO: select a subset of related programs (using the same syscall). + n := 0 + for _, rec := range db.Records { + prompt.WriteString("\n\nHere is an example:\n\n") + prompt.Write(rec.Val) + n++ + if len(prompt.Bytes()) > 50<<10 || n >= 20 { + break + } + } + prompt.WriteString("\n\nPlease generate a similar but different test program with 5 lines.\n") + prompt.WriteString("Output just the program.\n") + resp, err := model.GenerateContent(ctx, genai.Text(prompt.String())) + if err != nil { + tool.Fail(err) + } + + for _, cand := range resp.Candidates { + reply := new(bytes.Buffer) + if cand.Content != nil { + for _, part := range cand.Content.Parts { + if text, ok := part.(genai.Text); ok { + reply.WriteString(string(text)) + } + } + } + fmt.Printf("REPLY:\n%s\n\n", reply) + p, err := target.Deserialize(reply.Bytes(), prog.NonStrict) + if err != nil { + fmt.Printf("failed to parse: %v\n\n", err) + } else { + fmt.Printf("PARSED:\n%s\n\n", p.Serialize()) + } + } + } +} |
