diff options
Diffstat (limited to 'pkg/csource/syscall_generation_test.go')
| -rw-r--r-- | pkg/csource/syscall_generation_test.go | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/pkg/csource/syscall_generation_test.go b/pkg/csource/syscall_generation_test.go new file mode 100644 index 000000000..c84cb703b --- /dev/null +++ b/pkg/csource/syscall_generation_test.go @@ -0,0 +1,201 @@ +// Copyright 2025 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. +package csource + +import ( + "bufio" + "flag" + "fmt" + "os" + "path" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/syzkaller/prog" + "github.com/google/syzkaller/sys/targets" + "github.com/stretchr/testify/assert" +) + +var flagUpdate = flag.Bool("update", false, "update test files accordingly to current results") + +type testData struct { + filepath string + // The input syscall description, e.g. bind$netlink(r0, &(0x7f0000514ff4)={0x10, 0x0, 0x0, 0x2ffffffff}, 0xc). + input string + calls []annotatedCall +} + +type annotatedCall struct { + comment string + syscall string +} + +func TestGenerateSyscalls(t *testing.T) { + flag.Parse() + + testCases, err := readTestCases("./testdata") + assert.NoError(t, err) + + target, err := prog.GetTarget(targets.Linux, targets.AMD64) + if err != nil { + t.Fatal(err) + } + + for _, tc := range testCases { + newData, equal := testGenerationImpl(t, tc, target) + if *flagUpdate && !equal { + t.Logf("writing updated contents to %s", tc.filepath) + err = os.WriteFile(tc.filepath, []byte(newData), 0640) + assert.NoError(t, err) + } + } +} + +func readTestCases(dir string) ([]testData, error) { + var testCases []testData + + testFiles, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + for _, testFile := range testFiles { + if testFile.IsDir() { + continue + } + + testCase, err := readTestData(path.Join(dir, testFile.Name())) + if err != nil { + return nil, err + } + testCases = append(testCases, testCase) + } + + return testCases, nil +} + +func readTestData(filepath string) (testData, error) { + var td testData + td.filepath = filepath + + file, err := os.Open(filepath) + if err != nil { + return testData{}, err + } + + scanner := bufio.NewScanner(file) + + var inputBuilder strings.Builder + for scanner.Scan() { + line := scanner.Text() + if line == "" { + break + } + inputBuilder.WriteString(line + "\n") + } + td.input = inputBuilder.String() + + var commentBuilder strings.Builder + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, commentPrefix) { + if commentBuilder.Len() > 0 { + commentBuilder.WriteString("\n") + } + commentBuilder.WriteString(line) + } else { + td.calls = append(td.calls, annotatedCall{ + comment: commentBuilder.String(), + syscall: line, + }) + commentBuilder.Reset() + } + } + + if err := scanner.Err(); err != nil { + return testData{}, err + } + + if commentBuilder.Len() != 0 { + return testData{}, fmt.Errorf("expected a syscall expression but got EOF") + } + return td, nil +} + +// Returns the generated content, and whether or not they were equal. +func testGenerationImpl(t *testing.T, test testData, target *prog.Target) (string, bool) { + p, err := target.Deserialize([]byte(test.input), prog.Strict) + if err != nil { + t.Fatal(err) + } + + // Generate the actual comments. + var actualComments []string + for _, call := range p.Calls { + comment := generateComment(call) + // Formatted comments make comparison easier. + formatted, err := Format([]byte(comment)) + if err != nil { + t.Fatal(err) + } + actualComments = append(actualComments, string(formatted)) + } + + // Minimal options as we are just testing syscall output. + opts := Options{ + Slowdown: 1, + } + ctx := &context{ + p: p, + opts: opts, + target: p.Target, + sysTarget: targets.Get(p.Target.OS, p.Target.Arch), + calls: make(map[string]uint64), + } + + // Partially replicate the flow from csource.go. + exec, err := p.SerializeForExec() + if err != nil { + t.Fatal(err) + } + decoded, err := ctx.target.DeserializeExec(exec, nil) + if err != nil { + t.Fatal(err) + } + var actualSyscalls []string + for _, execCall := range decoded.Calls { + actualSyscalls = append(actualSyscalls, ctx.fmtCallBody(execCall)) + } + + if len(actualSyscalls) != len(test.calls) || len(actualSyscalls) != len(actualComments) { + t.Fatal("Generated inconsistent syscalls or comments.") + } + + areEqual := true + for i := range actualSyscalls { + if diffSyscalls := cmp.Diff(actualSyscalls[i], test.calls[i].syscall); diffSyscalls != "" { + fmt.Print(diffSyscalls) + t.Fail() + areEqual = false + } + if diffComments := cmp.Diff(actualComments[i], test.calls[i].comment); diffComments != "" { + fmt.Print(diffComments) + t.Fail() + areEqual = false + } + } + + var outputBuilder strings.Builder + outputBuilder.WriteString(test.input + "\n") + for i := range actualSyscalls { + outputBuilder.WriteString(actualComments[i] + "\n") + outputBuilder.WriteString(actualSyscalls[i]) + // Avoid trailing newline. + if i != len(test.calls)-1 { + outputBuilder.WriteString("\n") + } + } + + return outputBuilder.String(), areEqual +} |
