aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/csource/csource.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/csource/csource.go')
-rw-r--r--pkg/csource/csource.go208
1 files changed, 61 insertions, 147 deletions
diff --git a/pkg/csource/csource.go b/pkg/csource/csource.go
index 512d9985a..66958f22b 100644
--- a/pkg/csource/csource.go
+++ b/pkg/csource/csource.go
@@ -24,7 +24,6 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) {
opts: opts,
target: p.Target,
sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
- w: new(bytes.Buffer),
calls: make(map[string]uint64),
}
@@ -43,100 +42,43 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) {
ctx.calls[c.Meta.CallName] = c.Meta.NR
}
- ctx.print("// autogenerated by syzkaller (http://github.com/google/syzkaller)\n\n")
-
- hdr, err := createCommonHeader(p, mmapProg, opts)
- if err != nil {
- return nil, err
- }
- ctx.w.Write(hdr)
- ctx.print("\n")
-
- ctx.generateSyscallDefines()
-
+ varsBuf := new(bytes.Buffer)
if len(vars) != 0 {
- ctx.printf("uint64_t r[%v] = {", len(vars))
+ fmt.Fprintf(varsBuf, "uint64 r[%v] = {", len(vars))
for i, v := range vars {
if i != 0 {
- ctx.printf(", ")
+ fmt.Fprintf(varsBuf, ", ")
}
- ctx.printf("0x%x", v)
+ fmt.Fprintf(varsBuf, "0x%x", v)
}
- ctx.printf("};\n")
- }
- needProcID := opts.Procs > 1 || opts.EnableCgroups
- for _, c := range p.Calls {
- if c.Meta.CallName == "syz_mount_image" ||
- c.Meta.CallName == "syz_read_part_table" {
- needProcID = true
- }
- }
- if needProcID {
- ctx.printf("unsigned long long procid;\n")
+ fmt.Fprintf(varsBuf, "};\n")
}
- if opts.Repeat {
- ctx.generateTestFunc(calls, len(vars) != 0, "execute_one")
- } else {
- ctx.generateTestFunc(calls, len(vars) != 0, "loop")
+ sandboxFunc := "loop();"
+ if opts.Sandbox != "" {
+ sandboxFunc = "do_sandbox_" + opts.Sandbox + "();"
}
-
- if ctx.target.OS == "akaros" && opts.Repeat {
- ctx.printf("const char* program_name;\n")
- ctx.print("int main(int argc, char** argv)\n{\n")
- } else {
- ctx.print("int main()\n{\n")
+ replacements := map[string]string{
+ "PROCS": fmt.Sprint(opts.Procs),
+ "NUM_CALLS": fmt.Sprint(len(p.Calls)),
+ "MMAP_DATA": strings.Join(mmapCalls, ""),
+ "SYSCALL_DEFINES": ctx.generateSyscallDefines(),
+ "SANDBOX_FUNC": sandboxFunc,
+ "RESULTS": varsBuf.String(),
+ "SYSCALLS": ctx.generateSyscalls(calls, len(vars) != 0),
}
- for _, c := range mmapCalls {
- ctx.printf("%s", c)
+ if !opts.Threaded && !opts.Repeat && opts.Sandbox == "" {
+ // This inlines syscalls right into main for the simplest case.
+ replacements["SANDBOX_FUNC"] = replacements["SYSCALLS"]
+ replacements["SYSCALLS"] = "unused"
}
- if ctx.target.OS == "akaros" && opts.Repeat {
- ctx.printf("\tprogram_name = argv[0];\n")
- ctx.printf("\tif (argc == 2 && strcmp(argv[1], \"child\") == 0)\n")
- ctx.printf("\t\tchild();\n")
- }
- if opts.HandleSegv {
- ctx.printf("\tinstall_segv_handler();\n")
- }
-
- if !opts.Repeat {
- if opts.UseTmpDir {
- ctx.printf("\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- } else {
- if opts.UseTmpDir {
- ctx.print("\tchar *cwd = get_current_dir_name();\n")
- }
- if opts.Procs <= 1 {
- ctx.print("\tfor (;;) {\n")
- if opts.UseTmpDir {
- ctx.print("\t\tif (chdir(cwd))\n")
- ctx.print("\t\t\tfail(\"failed to chdir\");\n")
- ctx.print("\t\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- ctx.print("\t}\n")
- } else {
- ctx.printf("\tfor (procid = 0; procid < %v; procid++) {\n", opts.Procs)
- ctx.print("\t\tif (fork() == 0) {\n")
- ctx.print("\t\t\tfor (;;) {\n")
- if opts.UseTmpDir {
- ctx.print("\t\t\t\tif (chdir(cwd))\n")
- ctx.print("\t\t\t\t\tfail(\"failed to chdir\");\n")
- ctx.print("\t\t\t\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- ctx.print("\t\t\t}\n")
- ctx.print("\t\t}\n")
- ctx.print("\t}\n")
- ctx.print("\tsleep(1000000);\n")
- }
+ result, err := createCommonHeader(p, mmapProg, replacements, opts)
+ if err != nil {
+ return nil, err
}
-
- ctx.print("\treturn 0;\n}\n")
-
- result := ctx.postProcess(ctx.w.Bytes())
+ const header = "// autogenerated by syzkaller (https://github.com/google/syzkaller)\n\n"
+ result = append([]byte(header), result...)
+ result = ctx.postProcess(result)
return result, nil
}
@@ -145,92 +87,62 @@ type context struct {
opts Options
target *prog.Target
sysTarget *targets.Target
- w *bytes.Buffer
calls map[string]uint64 // CallName -> NR
}
-func (ctx *context) print(str string) {
- ctx.w.WriteString(str)
-}
-
-func (ctx *context) printf(str string, args ...interface{}) {
- ctx.print(fmt.Sprintf(str, args...))
-}
-
-func (ctx *context) writeLoopCall() {
- if ctx.opts.Sandbox != "" {
- ctx.printf("\tdo_sandbox_%v();\n", ctx.opts.Sandbox)
- return
- }
- if ctx.opts.EnableTun {
- ctx.printf("\tinitialize_tun();\n")
- }
- if ctx.opts.EnableNetdev {
- ctx.printf("\tinitialize_netdevices();\n")
- }
- ctx.print("\tloop();\n")
-}
-
-func (ctx *context) generateTestFunc(calls []string, hasVars bool, name string) {
+func (ctx *context) generateSyscalls(calls []string, hasVars bool) string {
opts := ctx.opts
+ buf := new(bytes.Buffer)
if !opts.Threaded && !opts.Collide {
- ctx.printf("void %v()\n{\n", name)
if hasVars {
- ctx.printf("\tlong res = 0;\n")
+ fmt.Fprintf(buf, "\tlong res = 0;\n")
}
if opts.Repro {
- ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
+ fmt.Fprintf(buf, "\tif (write(1, \"executing program\\n\", sizeof(\"executing program\\n\") - 1)) {}\n")
}
for _, c := range calls {
- ctx.printf("%s", c)
+ fmt.Fprintf(buf, "%s", c)
}
- ctx.printf("}\n\n")
} else {
- ctx.printf("void execute_call(int call)\n{\n")
if hasVars {
- ctx.printf("\tlong res;")
+ fmt.Fprintf(buf, "\tlong res;")
}
- ctx.printf("\tswitch (call) {\n")
+ fmt.Fprintf(buf, "\tswitch (call) {\n")
for i, c := range calls {
- ctx.printf("\tcase %v:\n", i)
- ctx.printf("%s", strings.Replace(c, "\t", "\t\t", -1))
- ctx.printf("\t\tbreak;\n")
- }
- ctx.printf("\t}\n")
- ctx.printf("}\n\n")
-
- ctx.printf("void %v()\n{\n", name)
- if opts.Repro {
- ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
- }
- ctx.printf("\texecute(%v);\n", len(calls))
- if opts.Collide {
- ctx.printf("\tcollide = 1;\n")
- ctx.printf("\texecute(%v);\n", len(calls))
+ fmt.Fprintf(buf, "\tcase %v:\n", i)
+ fmt.Fprintf(buf, "%s", strings.Replace(c, "\t", "\t\t", -1))
+ fmt.Fprintf(buf, "\t\tbreak;\n")
}
- ctx.printf("}\n\n")
+ fmt.Fprintf(buf, "\t}\n")
}
+ return buf.String()
}
-func (ctx *context) generateSyscallDefines() {
- prefix := ctx.sysTarget.SyscallPrefix
+func (ctx *context) generateSyscallDefines() string {
+ var calls []string
for name, nr := range ctx.calls {
if !ctx.sysTarget.SyscallNumbers ||
strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) {
continue
}
- ctx.printf("#ifndef %v%v\n", prefix, name)
- ctx.printf("#define %v%v %v\n", prefix, name, nr)
- ctx.printf("#endif\n")
+ calls = append(calls, name)
+ }
+ sort.Strings(calls)
+ buf := new(bytes.Buffer)
+ prefix := ctx.sysTarget.SyscallPrefix
+ for _, name := range calls {
+ fmt.Fprintf(buf, "#ifndef %v%v\n", prefix, name)
+ fmt.Fprintf(buf, "#define %v%v %v\n", prefix, name, ctx.calls[name])
+ fmt.Fprintf(buf, "#endif\n")
}
if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 {
// This is a dirty hack.
// On 32-bit linux mmap translated to old_mmap syscall which has a different signature.
// mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here.
- ctx.printf("#undef __NR_mmap\n")
- ctx.printf("#define __NR_mmap __NR_mmap2\n")
+ fmt.Fprintf(buf, "#undef __NR_mmap\n")
+ fmt.Fprintf(buf, "#define __NR_mmap __NR_mmap2\n")
}
- ctx.printf("\n")
+ return buf.String()
}
func (ctx *context) generateProgCalls(p *prog.Prog) ([]string, []uint64, error) {
@@ -332,7 +244,7 @@ func (ctx *context) generateCalls(p prog.ExecProg) ([]string, []uint64) {
fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index)
}
for _, copyout := range call.Copyout {
- fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v_t*)0x%x);\n",
+ fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v*)0x%x);\n",
copyout.Index, copyout.Size*8, copyout.Addr)
}
if copyoutMultiple {
@@ -350,18 +262,18 @@ func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.Exec
for i, chunk := range arg.Chunks {
switch chunk.Kind {
case prog.ExecArgCsumChunkData:
- fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n",
+ fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8*)0x%x, %d));\n",
csumSeq, chunk.Value, chunk.Size)
case prog.ExecArgCsumChunkConst:
- fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n",
+ fmt.Fprintf(w, "\tuint%d csum_%d_chunk_%d = 0x%x;\n",
chunk.Size*8, csumSeq, i, chunk.Value)
- fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n",
+ fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8*)&csum_%d_chunk_%d, %d);\n",
csumSeq, csumSeq, i, chunk.Size)
default:
panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind))
}
}
- fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n",
+ fmt.Fprintf(w, "\tNONFAILING(*(uint16*)0x%x = csum_inet_digest(&csum_%d));\n",
addr, csumSeq)
}
@@ -374,7 +286,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin
if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
panic("bitfield+string format")
}
- fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v_t, 0x%x, %v, %v, %v));\n",
+ fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v, 0x%x, %v, %v, %v));\n",
arg.Size*8, copyin.Addr, ctx.constArgToStr(arg),
arg.BitfieldOffset, arg.BitfieldLength)
}
@@ -399,7 +311,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin
func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) {
switch bf {
case prog.FormatNative, prog.FormatBigEndian:
- fmt.Fprintf(w, "\tNONFAILING(*(uint%v_t*)0x%x = %v);\n", size*8, addr, val)
+ fmt.Fprintf(w, "\tNONFAILING(*(uint%v*)0x%x = %v);\n", size*8, addr, val)
case prog.FormatStrDec:
if size != 20 {
panic("bad strdec size")
@@ -506,6 +418,8 @@ func (ctx *context) hoistIncludes(result []byte) []byte {
func (ctx *context) removeEmptyLines(result []byte) []byte {
for {
newResult := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1)
+ newResult = bytes.Replace(newResult, []byte{'\n', '\n', '\t'}, []byte{'\n', '\t'}, -1)
+ newResult = bytes.Replace(newResult, []byte{'\n', '\n', ' '}, []byte{'\n', ' '}, -1)
if len(newResult) == len(result) {
return result
}