aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/csource/common.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/csource/common.go')
-rw-r--r--pkg/csource/common.go36
1 files changed, 27 insertions, 9 deletions
diff --git a/pkg/csource/common.go b/pkg/csource/common.go
index 5f0cc4221..ebc085c7a 100644
--- a/pkg/csource/common.go
+++ b/pkg/csource/common.go
@@ -25,7 +25,7 @@ const (
sandboxNamespace = "namespace"
)
-func createCommonHeader(p, mmapProg *prog.Prog, opts Options) ([]byte, error) {
+func createCommonHeader(p, mmapProg *prog.Prog, replacements map[string]string, opts Options) ([]byte, error) {
defines, err := defineList(p, mmapProg, opts)
if err != nil {
return nil, err
@@ -49,6 +49,10 @@ func createCommonHeader(p, mmapProg *prog.Prog, opts Options) ([]byte, error) {
return nil, err
}
+ for from, to := range replacements {
+ src = bytes.Replace(src, []byte("[["+from+"]]"), []byte(to), -1)
+ }
+
for from, to := range map[string]string{
"uint64": "uint64_t",
"uint32": "uint32_t",
@@ -91,6 +95,9 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) {
if opts.Repeat {
defines = append(defines, "SYZ_REPEAT")
}
+ if opts.Procs > 1 {
+ defines = append(defines, "SYZ_PROCS")
+ }
if opts.Fault {
defines = append(defines, "SYZ_FAULT_INJECTION")
}
@@ -112,6 +119,9 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) {
if opts.HandleSegv {
defines = append(defines, "SYZ_HANDLE_SEGV")
}
+ if opts.Repro {
+ defines = append(defines, "SYZ_REPRO")
+ }
for _, c := range p.Calls {
defines = append(defines, "__NR_"+c.Meta.CallName)
}
@@ -131,14 +141,22 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) {
}
func removeSystemDefines(src []byte, defines []string) ([]byte, error) {
- remove := append(defines, []string{
- "__STDC__",
- "__STDC_HOSTED__",
- "__STDC_UTF_16__",
- "__STDC_UTF_32__",
- }...)
- for _, def := range remove {
- src = bytes.Replace(src, []byte("#define "+def+" 1\n"), nil, -1)
+ remove := map[string]string{
+ "__STDC__": "1",
+ "__STDC_HOSTED__": "1",
+ "__STDC_UTF_16__": "1",
+ "__STDC_UTF_32__": "1",
+ }
+ for _, def := range defines {
+ eq := strings.IndexByte(def, '=')
+ if eq == -1 {
+ remove[def] = "1"
+ } else {
+ remove[def[:eq]] = def[eq+1:]
+ }
+ }
+ for def, val := range remove {
+ src = bytes.Replace(src, []byte("#define "+def+" "+val+"\n"), nil, -1)
}
// strip: #define __STDC_VERSION__ 201112L
for _, def := range []string{"__STDC_VERSION__"} {