diff options
Diffstat (limited to 'pkg/csource/common.go')
| -rw-r--r-- | pkg/csource/common.go | 36 |
1 files changed, 27 insertions, 9 deletions
diff --git a/pkg/csource/common.go b/pkg/csource/common.go index 5f0cc4221..ebc085c7a 100644 --- a/pkg/csource/common.go +++ b/pkg/csource/common.go @@ -25,7 +25,7 @@ const ( sandboxNamespace = "namespace" ) -func createCommonHeader(p, mmapProg *prog.Prog, opts Options) ([]byte, error) { +func createCommonHeader(p, mmapProg *prog.Prog, replacements map[string]string, opts Options) ([]byte, error) { defines, err := defineList(p, mmapProg, opts) if err != nil { return nil, err @@ -49,6 +49,10 @@ func createCommonHeader(p, mmapProg *prog.Prog, opts Options) ([]byte, error) { return nil, err } + for from, to := range replacements { + src = bytes.Replace(src, []byte("[["+from+"]]"), []byte(to), -1) + } + for from, to := range map[string]string{ "uint64": "uint64_t", "uint32": "uint32_t", @@ -91,6 +95,9 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) { if opts.Repeat { defines = append(defines, "SYZ_REPEAT") } + if opts.Procs > 1 { + defines = append(defines, "SYZ_PROCS") + } if opts.Fault { defines = append(defines, "SYZ_FAULT_INJECTION") } @@ -112,6 +119,9 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) { if opts.HandleSegv { defines = append(defines, "SYZ_HANDLE_SEGV") } + if opts.Repro { + defines = append(defines, "SYZ_REPRO") + } for _, c := range p.Calls { defines = append(defines, "__NR_"+c.Meta.CallName) } @@ -131,14 +141,22 @@ func defineList(p, mmapProg *prog.Prog, opts Options) ([]string, error) { } func removeSystemDefines(src []byte, defines []string) ([]byte, error) { - remove := append(defines, []string{ - "__STDC__", - "__STDC_HOSTED__", - "__STDC_UTF_16__", - "__STDC_UTF_32__", - }...) - for _, def := range remove { - src = bytes.Replace(src, []byte("#define "+def+" 1\n"), nil, -1) + remove := map[string]string{ + "__STDC__": "1", + "__STDC_HOSTED__": "1", + "__STDC_UTF_16__": "1", + "__STDC_UTF_32__": "1", + } + for _, def := range defines { + eq := strings.IndexByte(def, '=') + if eq == -1 { + remove[def] = "1" + } else { + remove[def[:eq]] = def[eq+1:] + } + } + for def, val := range remove { + src = bytes.Replace(src, []byte("#define "+def+" "+val+"\n"), nil, -1) } // strip: #define __STDC_VERSION__ 201112L for _, def := range []string{"__STDC_VERSION__"} { |
