aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/csource/syscall_generation_test.go
blob: c84cb703bcb93cfd8c7c2e42619429a71e4679d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// Copyright 2025 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package csource

import (
	"bufio"
	"flag"
	"fmt"
	"os"
	"path"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/google/syzkaller/prog"
	"github.com/google/syzkaller/sys/targets"
	"github.com/stretchr/testify/assert"
)

var flagUpdate = flag.Bool("update", false, "update test files accordingly to current results")

type testData struct {
	filepath string
	// The input syscall description, e.g. bind$netlink(r0, &(0x7f0000514ff4)={0x10, 0x0, 0x0, 0x2ffffffff}, 0xc).
	input string
	calls []annotatedCall
}

type annotatedCall struct {
	comment string
	syscall string
}

func TestGenerateSyscalls(t *testing.T) {
	flag.Parse()

	testCases, err := readTestCases("./testdata")
	assert.NoError(t, err)

	target, err := prog.GetTarget(targets.Linux, targets.AMD64)
	if err != nil {
		t.Fatal(err)
	}

	for _, tc := range testCases {
		newData, equal := testGenerationImpl(t, tc, target)
		if *flagUpdate && !equal {
			t.Logf("writing updated contents to %s", tc.filepath)
			err = os.WriteFile(tc.filepath, []byte(newData), 0640)
			assert.NoError(t, err)
		}
	}
}

func readTestCases(dir string) ([]testData, error) {
	var testCases []testData

	testFiles, err := os.ReadDir(dir)
	if err != nil {
		return nil, err
	}

	for _, testFile := range testFiles {
		if testFile.IsDir() {
			continue
		}

		testCase, err := readTestData(path.Join(dir, testFile.Name()))
		if err != nil {
			return nil, err
		}
		testCases = append(testCases, testCase)
	}

	return testCases, nil
}

func readTestData(filepath string) (testData, error) {
	var td testData
	td.filepath = filepath

	file, err := os.Open(filepath)
	if err != nil {
		return testData{}, err
	}

	scanner := bufio.NewScanner(file)

	var inputBuilder strings.Builder
	for scanner.Scan() {
		line := scanner.Text()
		if line == "" {
			break
		}
		inputBuilder.WriteString(line + "\n")
	}
	td.input = inputBuilder.String()

	var commentBuilder strings.Builder
	for scanner.Scan() {
		line := scanner.Text()
		if strings.HasPrefix(line, commentPrefix) {
			if commentBuilder.Len() > 0 {
				commentBuilder.WriteString("\n")
			}
			commentBuilder.WriteString(line)
		} else {
			td.calls = append(td.calls, annotatedCall{
				comment: commentBuilder.String(),
				syscall: line,
			})
			commentBuilder.Reset()
		}
	}

	if err := scanner.Err(); err != nil {
		return testData{}, err
	}

	if commentBuilder.Len() != 0 {
		return testData{}, fmt.Errorf("expected a syscall expression but got EOF")
	}
	return td, nil
}

// Returns the generated content, and whether or not they were equal.
func testGenerationImpl(t *testing.T, test testData, target *prog.Target) (string, bool) {
	p, err := target.Deserialize([]byte(test.input), prog.Strict)
	if err != nil {
		t.Fatal(err)
	}

	// Generate the actual comments.
	var actualComments []string
	for _, call := range p.Calls {
		comment := generateComment(call)
		// Formatted comments make comparison easier.
		formatted, err := Format([]byte(comment))
		if err != nil {
			t.Fatal(err)
		}
		actualComments = append(actualComments, string(formatted))
	}

	// Minimal options as we are just testing syscall output.
	opts := Options{
		Slowdown: 1,
	}
	ctx := &context{
		p:         p,
		opts:      opts,
		target:    p.Target,
		sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
		calls:     make(map[string]uint64),
	}

	// Partially replicate the flow from csource.go.
	exec, err := p.SerializeForExec()
	if err != nil {
		t.Fatal(err)
	}
	decoded, err := ctx.target.DeserializeExec(exec, nil)
	if err != nil {
		t.Fatal(err)
	}
	var actualSyscalls []string
	for _, execCall := range decoded.Calls {
		actualSyscalls = append(actualSyscalls, ctx.fmtCallBody(execCall))
	}

	if len(actualSyscalls) != len(test.calls) || len(actualSyscalls) != len(actualComments) {
		t.Fatal("Generated inconsistent syscalls or comments.")
	}

	areEqual := true
	for i := range actualSyscalls {
		if diffSyscalls := cmp.Diff(actualSyscalls[i], test.calls[i].syscall); diffSyscalls != "" {
			fmt.Print(diffSyscalls)
			t.Fail()
			areEqual = false
		}
		if diffComments := cmp.Diff(actualComments[i], test.calls[i].comment); diffComments != "" {
			fmt.Print(diffComments)
			t.Fail()
			areEqual = false
		}
	}

	var outputBuilder strings.Builder
	outputBuilder.WriteString(test.input + "\n")
	for i := range actualSyscalls {
		outputBuilder.WriteString(actualComments[i] + "\n")
		outputBuilder.WriteString(actualSyscalls[i])
		// Avoid trailing newline.
		if i != len(test.calls)-1 {
			outputBuilder.WriteString("\n")
		}
	}

	return outputBuilder.String(), areEqual
}