diff options
Diffstat (limited to 'pkg/csource/csource.go')
| -rw-r--r-- | pkg/csource/csource.go | 208 |
1 files changed, 61 insertions, 147 deletions
diff --git a/pkg/csource/csource.go b/pkg/csource/csource.go index 512d9985a..66958f22b 100644 --- a/pkg/csource/csource.go +++ b/pkg/csource/csource.go @@ -24,7 +24,6 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) { opts: opts, target: p.Target, sysTarget: targets.Get(p.Target.OS, p.Target.Arch), - w: new(bytes.Buffer), calls: make(map[string]uint64), } @@ -43,100 +42,43 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) { ctx.calls[c.Meta.CallName] = c.Meta.NR } - ctx.print("// autogenerated by syzkaller (http://github.com/google/syzkaller)\n\n") - - hdr, err := createCommonHeader(p, mmapProg, opts) - if err != nil { - return nil, err - } - ctx.w.Write(hdr) - ctx.print("\n") - - ctx.generateSyscallDefines() - + varsBuf := new(bytes.Buffer) if len(vars) != 0 { - ctx.printf("uint64_t r[%v] = {", len(vars)) + fmt.Fprintf(varsBuf, "uint64 r[%v] = {", len(vars)) for i, v := range vars { if i != 0 { - ctx.printf(", ") + fmt.Fprintf(varsBuf, ", ") } - ctx.printf("0x%x", v) + fmt.Fprintf(varsBuf, "0x%x", v) } - ctx.printf("};\n") - } - needProcID := opts.Procs > 1 || opts.EnableCgroups - for _, c := range p.Calls { - if c.Meta.CallName == "syz_mount_image" || - c.Meta.CallName == "syz_read_part_table" { - needProcID = true - } - } - if needProcID { - ctx.printf("unsigned long long procid;\n") + fmt.Fprintf(varsBuf, "};\n") } - if opts.Repeat { - ctx.generateTestFunc(calls, len(vars) != 0, "execute_one") - } else { - ctx.generateTestFunc(calls, len(vars) != 0, "loop") + sandboxFunc := "loop();" + if opts.Sandbox != "" { + sandboxFunc = "do_sandbox_" + opts.Sandbox + "();" } - - if ctx.target.OS == "akaros" && opts.Repeat { - ctx.printf("const char* program_name;\n") - ctx.print("int main(int argc, char** argv)\n{\n") - } else { - ctx.print("int main()\n{\n") + replacements := map[string]string{ + "PROCS": fmt.Sprint(opts.Procs), + "NUM_CALLS": fmt.Sprint(len(p.Calls)), + "MMAP_DATA": strings.Join(mmapCalls, ""), + "SYSCALL_DEFINES": ctx.generateSyscallDefines(), + "SANDBOX_FUNC": sandboxFunc, + "RESULTS": varsBuf.String(), + "SYSCALLS": ctx.generateSyscalls(calls, len(vars) != 0), } - for _, c := range mmapCalls { - ctx.printf("%s", c) + if !opts.Threaded && !opts.Repeat && opts.Sandbox == "" { + // This inlines syscalls right into main for the simplest case. + replacements["SANDBOX_FUNC"] = replacements["SYSCALLS"] + replacements["SYSCALLS"] = "unused" } - if ctx.target.OS == "akaros" && opts.Repeat { - ctx.printf("\tprogram_name = argv[0];\n") - ctx.printf("\tif (argc == 2 && strcmp(argv[1], \"child\") == 0)\n") - ctx.printf("\t\tchild();\n") - } - if opts.HandleSegv { - ctx.printf("\tinstall_segv_handler();\n") - } - - if !opts.Repeat { - if opts.UseTmpDir { - ctx.printf("\tuse_temporary_dir();\n") - } - ctx.writeLoopCall() - } else { - if opts.UseTmpDir { - ctx.print("\tchar *cwd = get_current_dir_name();\n") - } - if opts.Procs <= 1 { - ctx.print("\tfor (;;) {\n") - if opts.UseTmpDir { - ctx.print("\t\tif (chdir(cwd))\n") - ctx.print("\t\t\tfail(\"failed to chdir\");\n") - ctx.print("\t\tuse_temporary_dir();\n") - } - ctx.writeLoopCall() - ctx.print("\t}\n") - } else { - ctx.printf("\tfor (procid = 0; procid < %v; procid++) {\n", opts.Procs) - ctx.print("\t\tif (fork() == 0) {\n") - ctx.print("\t\t\tfor (;;) {\n") - if opts.UseTmpDir { - ctx.print("\t\t\t\tif (chdir(cwd))\n") - ctx.print("\t\t\t\t\tfail(\"failed to chdir\");\n") - ctx.print("\t\t\t\tuse_temporary_dir();\n") - } - ctx.writeLoopCall() - ctx.print("\t\t\t}\n") - ctx.print("\t\t}\n") - ctx.print("\t}\n") - ctx.print("\tsleep(1000000);\n") - } + result, err := createCommonHeader(p, mmapProg, replacements, opts) + if err != nil { + return nil, err } - - ctx.print("\treturn 0;\n}\n") - - result := ctx.postProcess(ctx.w.Bytes()) + const header = "// autogenerated by syzkaller (https://github.com/google/syzkaller)\n\n" + result = append([]byte(header), result...) + result = ctx.postProcess(result) return result, nil } @@ -145,92 +87,62 @@ type context struct { opts Options target *prog.Target sysTarget *targets.Target - w *bytes.Buffer calls map[string]uint64 // CallName -> NR } -func (ctx *context) print(str string) { - ctx.w.WriteString(str) -} - -func (ctx *context) printf(str string, args ...interface{}) { - ctx.print(fmt.Sprintf(str, args...)) -} - -func (ctx *context) writeLoopCall() { - if ctx.opts.Sandbox != "" { - ctx.printf("\tdo_sandbox_%v();\n", ctx.opts.Sandbox) - return - } - if ctx.opts.EnableTun { - ctx.printf("\tinitialize_tun();\n") - } - if ctx.opts.EnableNetdev { - ctx.printf("\tinitialize_netdevices();\n") - } - ctx.print("\tloop();\n") -} - -func (ctx *context) generateTestFunc(calls []string, hasVars bool, name string) { +func (ctx *context) generateSyscalls(calls []string, hasVars bool) string { opts := ctx.opts + buf := new(bytes.Buffer) if !opts.Threaded && !opts.Collide { - ctx.printf("void %v()\n{\n", name) if hasVars { - ctx.printf("\tlong res = 0;\n") + fmt.Fprintf(buf, "\tlong res = 0;\n") } if opts.Repro { - ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n") + fmt.Fprintf(buf, "\tif (write(1, \"executing program\\n\", sizeof(\"executing program\\n\") - 1)) {}\n") } for _, c := range calls { - ctx.printf("%s", c) + fmt.Fprintf(buf, "%s", c) } - ctx.printf("}\n\n") } else { - ctx.printf("void execute_call(int call)\n{\n") if hasVars { - ctx.printf("\tlong res;") + fmt.Fprintf(buf, "\tlong res;") } - ctx.printf("\tswitch (call) {\n") + fmt.Fprintf(buf, "\tswitch (call) {\n") for i, c := range calls { - ctx.printf("\tcase %v:\n", i) - ctx.printf("%s", strings.Replace(c, "\t", "\t\t", -1)) - ctx.printf("\t\tbreak;\n") - } - ctx.printf("\t}\n") - ctx.printf("}\n\n") - - ctx.printf("void %v()\n{\n", name) - if opts.Repro { - ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n") - } - ctx.printf("\texecute(%v);\n", len(calls)) - if opts.Collide { - ctx.printf("\tcollide = 1;\n") - ctx.printf("\texecute(%v);\n", len(calls)) + fmt.Fprintf(buf, "\tcase %v:\n", i) + fmt.Fprintf(buf, "%s", strings.Replace(c, "\t", "\t\t", -1)) + fmt.Fprintf(buf, "\t\tbreak;\n") } - ctx.printf("}\n\n") + fmt.Fprintf(buf, "\t}\n") } + return buf.String() } -func (ctx *context) generateSyscallDefines() { - prefix := ctx.sysTarget.SyscallPrefix +func (ctx *context) generateSyscallDefines() string { + var calls []string for name, nr := range ctx.calls { if !ctx.sysTarget.SyscallNumbers || strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) { continue } - ctx.printf("#ifndef %v%v\n", prefix, name) - ctx.printf("#define %v%v %v\n", prefix, name, nr) - ctx.printf("#endif\n") + calls = append(calls, name) + } + sort.Strings(calls) + buf := new(bytes.Buffer) + prefix := ctx.sysTarget.SyscallPrefix + for _, name := range calls { + fmt.Fprintf(buf, "#ifndef %v%v\n", prefix, name) + fmt.Fprintf(buf, "#define %v%v %v\n", prefix, name, ctx.calls[name]) + fmt.Fprintf(buf, "#endif\n") } if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 { // This is a dirty hack. // On 32-bit linux mmap translated to old_mmap syscall which has a different signature. // mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here. - ctx.printf("#undef __NR_mmap\n") - ctx.printf("#define __NR_mmap __NR_mmap2\n") + fmt.Fprintf(buf, "#undef __NR_mmap\n") + fmt.Fprintf(buf, "#define __NR_mmap __NR_mmap2\n") } - ctx.printf("\n") + return buf.String() } func (ctx *context) generateProgCalls(p *prog.Prog) ([]string, []uint64, error) { @@ -332,7 +244,7 @@ func (ctx *context) generateCalls(p prog.ExecProg) ([]string, []uint64) { fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index) } for _, copyout := range call.Copyout { - fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v_t*)0x%x);\n", + fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v*)0x%x);\n", copyout.Index, copyout.Size*8, copyout.Addr) } if copyoutMultiple { @@ -350,18 +262,18 @@ func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.Exec for i, chunk := range arg.Chunks { switch chunk.Kind { case prog.ExecArgCsumChunkData: - fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n", + fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8*)0x%x, %d));\n", csumSeq, chunk.Value, chunk.Size) case prog.ExecArgCsumChunkConst: - fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n", + fmt.Fprintf(w, "\tuint%d csum_%d_chunk_%d = 0x%x;\n", chunk.Size*8, csumSeq, i, chunk.Value) - fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n", + fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8*)&csum_%d_chunk_%d, %d);\n", csumSeq, csumSeq, i, chunk.Size) default: panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind)) } } - fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n", + fmt.Fprintf(w, "\tNONFAILING(*(uint16*)0x%x = csum_inet_digest(&csum_%d));\n", addr, csumSeq) } @@ -374,7 +286,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { panic("bitfield+string format") } - fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v_t, 0x%x, %v, %v, %v));\n", + fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v, 0x%x, %v, %v, %v));\n", arg.Size*8, copyin.Addr, ctx.constArgToStr(arg), arg.BitfieldOffset, arg.BitfieldLength) } @@ -399,7 +311,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) { switch bf { case prog.FormatNative, prog.FormatBigEndian: - fmt.Fprintf(w, "\tNONFAILING(*(uint%v_t*)0x%x = %v);\n", size*8, addr, val) + fmt.Fprintf(w, "\tNONFAILING(*(uint%v*)0x%x = %v);\n", size*8, addr, val) case prog.FormatStrDec: if size != 20 { panic("bad strdec size") @@ -506,6 +418,8 @@ func (ctx *context) hoistIncludes(result []byte) []byte { func (ctx *context) removeEmptyLines(result []byte) []byte { for { newResult := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1) + newResult = bytes.Replace(newResult, []byte{'\n', '\n', '\t'}, []byte{'\n', '\t'}, -1) + newResult = bytes.Replace(newResult, []byte{'\n', '\n', ' '}, []byte{'\n', ' '}, -1) if len(newResult) == len(result) { return result } |
