aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/aflow/tool/codeeditor/codeeditor.go94
-rw-r--r--pkg/aflow/tool/codeeditor/codeeditor_test.go176
2 files changed, 252 insertions, 18 deletions
diff --git a/pkg/aflow/tool/codeeditor/codeeditor.go b/pkg/aflow/tool/codeeditor/codeeditor.go
index ce2d7afb7..750dba5d8 100644
--- a/pkg/aflow/tool/codeeditor/codeeditor.go
+++ b/pkg/aflow/tool/codeeditor/codeeditor.go
@@ -4,15 +4,21 @@
package codeeditor
import (
+ "bytes"
+ "os"
"path/filepath"
+ "slices"
"strings"
"github.com/google/syzkaller/pkg/aflow"
+ "github.com/google/syzkaller/pkg/codesearch"
"github.com/google/syzkaller/pkg/osutil"
)
var Tool = aflow.NewFuncTool("codeeditor", codeeditor, `
-The tool does one code edit to form the final patch.
+The tool does one source code edit to form the final patch by replacing full lines
+with new provided lines. If new code is empty, current lines will be deleted.
+Provide full lines of code including new line characters.
The tool should be called mutiple times to do all required changes one-by-one,
but avoid changing the same lines multiple times.
Note: You will not see your edits via the codesearch tool.
@@ -24,9 +30,9 @@ type state struct {
}
type args struct {
- SourceFile string `jsonschema:"Full source file path."`
- CurrentCode string `jsonschema:"The current code to replace verbatim with new lines, but without line numbers."`
- NewCode string `jsonschema:"New code to replace the current code snippet."`
+ SourceFile string `jsonschema:"Full source file path to edit."`
+ CurrentCode string `jsonschema:"The current code lines to be replaced."`
+ NewCode string `jsonschema:"New code lines to replace the current code lines."`
}
func codeeditor(ctx *aflow.Context, state state, args args) (struct{}, error) {
@@ -34,17 +40,81 @@ func codeeditor(ctx *aflow.Context, state state, args args) (struct{}, error) {
return struct{}{}, aflow.BadCallError("SourceFile %q is outside of the source tree", args.SourceFile)
}
file := filepath.Join(state.KernelScratchSrc, args.SourceFile)
- if !osutil.IsExist(file) {
+ // Filter out not source files too (e.g. .git, etc),
+ // LLM have not seen them and should not be messing with them.
+ if !osutil.IsExist(file) || !codesearch.IsSourceFile(file) {
return struct{}{}, aflow.BadCallError("SourceFile %q does not exist", args.SourceFile)
}
if strings.TrimSpace(args.CurrentCode) == "" {
return struct{}{}, aflow.BadCallError("CurrentCode snippet is empty")
}
- // If SourceFile is incorrect, or CurrentCode is not matched, return aflow.BadCallError
- // with an explanation. Say that it needs to increase context if CurrentCode is not matched.
- // Try to do as fuzzy match for CurrentCode as possible (strip line numbers,
- // ignore white-spaces, etc).
- // Should we accept a reference line number, or function name to disambiguate in the case
- // of multiple matches?
- return struct{}{}, nil
+ fileData, err := os.ReadFile(file)
+ if err != nil {
+ return struct{}{}, err
+ }
+ if len(fileData) == 0 || fileData[len(fileData)-1] != '\n' {
+ // Generally shouldn't happen, but just in case.
+ fileData = append(fileData, '\n')
+ }
+ if args.CurrentCode[len(args.CurrentCode)-1] != '\n' {
+ args.CurrentCode += "\n"
+ }
+ if args.NewCode != "" && args.NewCode[len(args.NewCode)-1] != '\n' {
+ args.NewCode += "\n"
+ }
+ lines := slices.Collect(bytes.Lines(fileData))
+ src := slices.Collect(bytes.Lines([]byte(args.CurrentCode)))
+ dst := slices.Collect(bytes.Lines([]byte(args.NewCode)))
+ // First, try to match as is. If that fails, try a more permissive matching
+ // that ignores whitespaces, empty lines, etc.
+ newLines, matches := replace(lines, src, dst, false)
+ if matches == 0 {
+ newLines, matches = replace(lines, src, dst, true)
+ }
+ if matches == 0 {
+ return struct{}{}, aflow.BadCallError("CurrentCode snippet does not match anything in the source file," +
+ " provide more precise CurrentCode snippet")
+ }
+ if matches > 1 {
+ return struct{}{}, aflow.BadCallError("CurrentCode snippet matched %v places,"+
+ " increase context in CurrentCode to avoid ambiguity", matches)
+ }
+ err = osutil.WriteFile(file, slices.Concat(newLines...))
+ return struct{}{}, err
+}
+
+func replace(lines, src, dst [][]byte, fuzzy bool) (newLines [][]byte, matches int) {
+ for i := 0; i < len(lines); i++ {
+ li, si := i, 0
+ for li < len(lines) && si < len(src) {
+ l, s := lines[li], src[si]
+ if fuzzy {
+ // Ignore whitespaces and empty lines.
+ l, s = bytes.TrimSpace(l), bytes.TrimSpace(s)
+ // Potentially we can remove line numbers from s here if they are present,
+ // or use them to disambiguate in the case of multiple matches.
+ if len(s) == 0 {
+ si++
+ continue
+ }
+ if len(l) == 0 {
+ li++
+ continue
+ }
+ }
+ if !bytes.Equal(l, s) {
+ break
+ }
+ li++
+ si++
+ }
+ if si != len(src) {
+ newLines = append(newLines, lines[i])
+ continue
+ }
+ matches++
+ newLines = append(newLines, dst...)
+ i = li - 1
+ }
+ return
}
diff --git a/pkg/aflow/tool/codeeditor/codeeditor_test.go b/pkg/aflow/tool/codeeditor/codeeditor_test.go
index 4ba556f1b..06a97d7af 100644
--- a/pkg/aflow/tool/codeeditor/codeeditor_test.go
+++ b/pkg/aflow/tool/codeeditor/codeeditor_test.go
@@ -4,6 +4,8 @@
package codeeditor
import (
+ "fmt"
+ "os"
"path/filepath"
"testing"
@@ -15,7 +17,7 @@ import (
func TestCodeeditorEscapingPath(t *testing.T) {
aflow.TestTool(t, Tool,
state{
- KernelScratchSrc: "whatever",
+ KernelScratchSrc: t.TempDir(),
},
args{
SourceFile: "../../passwd",
@@ -38,25 +40,187 @@ func TestCodeeditorMissingPath(t *testing.T) {
)
}
+func TestCodeeditorNonSourceFile(t *testing.T) {
+ dir := writeTestFile(t, "src", "data")
+ aflow.TestTool(t, Tool,
+ state{
+ KernelScratchSrc: dir,
+ },
+ args{
+ SourceFile: "src",
+ },
+ struct{}{},
+ `SourceFile "src" does not exist`,
+ )
+}
+
func TestCodeeditorEmptyCurrentCode(t *testing.T) {
- dir := writeTestFile(t, "foo", "data")
+ dir := writeTestFile(t, "src.c", "data")
aflow.TestTool(t, Tool,
state{
KernelScratchSrc: dir,
},
args{
- SourceFile: "foo",
+ SourceFile: "src.c",
},
struct{}{},
`CurrentCode snippet is empty`,
)
}
+func TestCodeeditorNoMatches(t *testing.T) {
+ dir := writeTestFile(t, "src.c", "foo")
+ aflow.TestTool(t, Tool,
+ state{
+ KernelScratchSrc: dir,
+ },
+ args{
+ SourceFile: "src.c",
+ CurrentCode: "foobar",
+ },
+ struct{}{},
+ `CurrentCode snippet does not match anything in the source file, provide more precise CurrentCode snippet`,
+ )
+}
+
+func TestCodeeditorMultipleMatches(t *testing.T) {
+ dir := writeTestFile(t, "src.c", `
+linefoo
+bar
+foo
+bar
+foo
+fooline
+foo`)
+ aflow.TestTool(t, Tool,
+ state{
+ KernelScratchSrc: dir,
+ },
+ args{
+ SourceFile: "src.c",
+ CurrentCode: "foo",
+ },
+ struct{}{},
+ `CurrentCode snippet matched 3 places, increase context in CurrentCode to avoid ambiguity`,
+ )
+}
+
+func TestCodeeditorReplacement(t *testing.T) {
+ type Test struct {
+ curFile string
+ curCode string
+ newCode string
+ newFile string
+ }
+ tests := []Test{
+ {
+ curFile: `
+line0
+line1
+lineee2
+lin3
+last line
+`,
+ curCode: `line1
+lineee2
+lin3`,
+ newCode: `replaced line`,
+ newFile: `
+line0
+replaced line
+last line
+`,
+ },
+ {
+ curFile: `
+line0
+line1
+last line
+`,
+ curCode: `line1
+`,
+ newCode: `replaced line 1
+replaced line 2
+replaced line 3`,
+ newFile: `
+line0
+replaced line 1
+replaced line 2
+replaced line 3
+last line
+`,
+ },
+ {
+ curFile: `
+line0
+line1
+line2
+`,
+ curCode: `line2
+`,
+ newCode: ``,
+ newFile: `
+line0
+line1
+`,
+ },
+ {
+ curFile: `that's it`,
+ curCode: `that's it`,
+ newCode: `that's that`,
+ newFile: `that's that
+`,
+ },
+ {
+ curFile: `
+ line0
+ line1
+
+ line2
+line3
+
+line4
+`,
+ curCode: `
+line1
+ line2
+
+
+ line3 `,
+ newCode: ` replacement`,
+ newFile: `
+ line0
+ replacement
+
+line4
+`,
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ const filename = "src.c"
+ dir := writeTestFile(t, filename, test.curFile)
+ aflow.TestTool(t, Tool,
+ state{
+ KernelScratchSrc: dir,
+ },
+ args{
+ SourceFile: filename,
+ CurrentCode: test.curCode,
+ NewCode: test.newCode,
+ },
+ struct{}{},
+ "")
+ data, err := os.ReadFile(filepath.Join(dir, filename))
+ require.NoError(t, err)
+ require.Equal(t, test.newFile, string(data))
+ })
+ }
+}
+
func writeTestFile(t *testing.T, filename, data string) string {
dir := t.TempDir()
- if err := osutil.WriteFile(filepath.Join(dir, filename), []byte(data)); err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, osutil.WriteFile(filepath.Join(dir, filename), []byte(data)))
return dir
}