// Copyright 2015 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. // Package csource generates [almost] equivalent C programs from syzkaller programs. package csource import ( "bytes" "fmt" "regexp" "strings" "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" ) func Write(p *prog.Prog, opts Options) ([]byte, error) { if err := opts.Check(p.Target.OS); err != nil { return nil, fmt.Errorf("csource: invalid opts: %v", err) } ctx := &context{ p: p, opts: opts, target: p.Target, sysTarget: targets.Get(p.Target.OS, p.Target.Arch), w: new(bytes.Buffer), calls: make(map[string]uint64), } calls, vars, err := ctx.generateProgCalls(ctx.p) if err != nil { return nil, err } mmapProg := p.Target.GenerateUberMmapProg() mmapCalls, _, err := ctx.generateProgCalls(mmapProg) if err != nil { return nil, err } for _, c := range append(mmapProg.Calls, p.Calls...) { ctx.calls[c.Meta.CallName] = c.Meta.NR } ctx.print("// autogenerated by syzkaller (http://github.com/google/syzkaller)\n\n") hdr, err := createCommonHeader(p, mmapProg, opts) if err != nil { return nil, err } ctx.w.Write(hdr) ctx.print("\n") ctx.generateSyscallDefines() if len(vars) != 0 { ctx.printf("uint64_t r[%v] = {", len(vars)) for i, v := range vars { if i != 0 { ctx.printf(", ") } ctx.printf("0x%x", v) } ctx.printf("};\n") } needProcID := opts.Procs > 1 || opts.EnableCgroups for _, c := range p.Calls { if c.Meta.CallName == "syz_mount_image" || c.Meta.CallName == "syz_read_part_table" { needProcID = true } } if needProcID { ctx.printf("unsigned long long procid;\n") } if !opts.Repeat { ctx.generateTestFunc(calls, len(vars) != 0, "loop") ctx.print("int main()\n{\n") for _, c := range mmapCalls { ctx.printf("%s", c) } if opts.HandleSegv { ctx.printf("\tinstall_segv_handler();\n") } if opts.UseTmpDir { ctx.printf("\tuse_temporary_dir();\n") } ctx.writeLoopCall() ctx.print("\treturn 0;\n}\n") } else { ctx.generateTestFunc(calls, len(vars) != 0, "execute_one") if opts.Procs <= 1 { ctx.print("int main()\n{\n") for _, c := range mmapCalls { ctx.printf("%s", c) } if opts.HandleSegv { ctx.print("\tinstall_segv_handler();\n") } if opts.UseTmpDir { ctx.print("\tchar *cwd = get_current_dir_name();\n") } ctx.print("\tfor (;;) {\n") if opts.UseTmpDir { ctx.print("\t\tif (chdir(cwd))\n") ctx.print("\t\t\tfail(\"failed to chdir\");\n") ctx.print("\t\tuse_temporary_dir();\n") } ctx.writeLoopCall() ctx.print("\t}\n}\n") } else { ctx.print("int main()\n{\n") for _, c := range mmapCalls { ctx.printf("%s", c) } if opts.UseTmpDir { ctx.print("\tchar *cwd = get_current_dir_name();\n") } ctx.printf("\tfor (procid = 0; procid < %v; procid++) {\n", opts.Procs) ctx.print("\t\tif (fork() == 0) {\n") if opts.HandleSegv { ctx.print("\t\t\tinstall_segv_handler();\n") } ctx.print("\t\t\tfor (;;) {\n") if opts.UseTmpDir { ctx.print("\t\t\t\tif (chdir(cwd))\n") ctx.print("\t\t\t\t\tfail(\"failed to chdir\");\n") ctx.print("\t\t\t\tuse_temporary_dir();\n") } ctx.writeLoopCall() ctx.print("\t\t\t}\n") ctx.print("\t\t}\n") ctx.print("\t}\n") ctx.print("\tsleep(1000000);\n") ctx.print("\treturn 0;\n}\n") } } // Remove NONFAILING and debug calls. result := ctx.w.Bytes() if !opts.HandleSegv { re := regexp.MustCompile(`\t*NONFAILING\((.*)\);\n`) result = re.ReplaceAll(result, []byte("$1;\n")) } if !opts.Debug { re := regexp.MustCompile(`\t*debug\((.*\n)*?.*\);\n`) result = re.ReplaceAll(result, nil) re = regexp.MustCompile(`\t*debug_dump_data\((.*\n)*?.*\);\n`) result = re.ReplaceAll(result, nil) } result = bytes.Replace(result, []byte("NORETURN"), nil, -1) result = bytes.Replace(result, []byte("PRINTF"), nil, -1) // Remove duplicate new lines. for { result1 := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1) result1 = bytes.Replace(result1, []byte("\n\n#include"), []byte("\n#include"), -1) if len(result1) == len(result) { break } result = result1 } return result, nil } type context struct { p *prog.Prog opts Options target *prog.Target sysTarget *targets.Target w *bytes.Buffer calls map[string]uint64 // CallName -> NR } func (ctx *context) print(str string) { ctx.w.WriteString(str) } func (ctx *context) printf(str string, args ...interface{}) { ctx.print(fmt.Sprintf(str, args...)) } func (ctx *context) writeLoopCall() { if ctx.opts.Sandbox != "" { ctx.printf("\tdo_sandbox_%v();\n", ctx.opts.Sandbox) return } if ctx.opts.EnableTun { ctx.printf("\tinitialize_tun();\n") } if ctx.opts.EnableNetdev { ctx.printf("\tinitialize_netdevices();\n") } ctx.print("\tloop();\n") } func (ctx *context) generateTestFunc(calls []string, hasVars bool, name string) { opts := ctx.opts if !opts.Threaded && !opts.Collide { ctx.printf("void %v()\n{\n", name) if hasVars { ctx.printf("\tlong res = 0;\n") } if opts.Debug { // Use debug to avoid: error: ‘debug’ defined but not used. ctx.printf("\tdebug(\"%v\\n\");\n", name) } if opts.Repro { ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n") } for _, c := range calls { ctx.printf("%s", c) } ctx.printf("}\n\n") } else { ctx.printf("void execute_call(int call)\n{\n") if hasVars { ctx.printf("\tlong res;") } ctx.printf("\tswitch (call) {\n") for i, c := range calls { ctx.printf("\tcase %v:\n", i) ctx.printf("%s", strings.Replace(c, "\t", "\t\t", -1)) ctx.printf("\t\tbreak;\n") } ctx.printf("\t}\n") ctx.printf("}\n\n") ctx.printf("void %v()\n{\n", name) if opts.Debug { // Use debug to avoid: error: ‘debug’ defined but not used. ctx.printf("\tdebug(\"%v\\n\");\n", name) } if opts.Repro { ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n") } ctx.printf("\texecute(%v);\n", len(calls)) if opts.Collide { ctx.printf("\tcollide = 1;\n") ctx.printf("\texecute(%v);\n", len(calls)) } ctx.printf("}\n\n") } } func (ctx *context) generateSyscallDefines() { prefix := ctx.sysTarget.SyscallPrefix for name, nr := range ctx.calls { if !ctx.sysTarget.SyscallNumbers || strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) { continue } ctx.printf("#ifndef %v%v\n", prefix, name) ctx.printf("#define %v%v %v\n", prefix, name, nr) ctx.printf("#endif\n") } if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 { // This is a dirty hack. // On 32-bit linux mmap translated to old_mmap syscall which has a different signature. // mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here. ctx.printf("#undef __NR_mmap\n") ctx.printf("#define __NR_mmap __NR_mmap2\n") } ctx.printf("\n") } func (ctx *context) generateProgCalls(p *prog.Prog) ([]string, []uint64, error) { exec := make([]byte, prog.ExecBufferSize) progSize, err := p.SerializeForExec(exec) if err != nil { return nil, nil, fmt.Errorf("failed to serialize program: %v", err) } decoded, err := ctx.target.DeserializeExec(exec[:progSize]) if err != nil { return nil, nil, err } calls, vars := ctx.generateCalls(decoded) return calls, vars, nil } func (ctx *context) generateCalls(p prog.ExecProg) ([]string, []uint64) { var calls []string csumSeq := 0 for ci, call := range p.Calls { w := new(bytes.Buffer) // Copyin. for _, copyin := range call.Copyin { ctx.copyin(w, &csumSeq, copyin) } // Call itself. if ctx.opts.Fault && ctx.opts.FaultCall == ci { fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/failslab/ignore-gfp-wait\", \"N\");\n") fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/fail_futex/ignore-private\", \"N\");\n") fmt.Fprintf(w, "\tinject_fault(%v);\n", ctx.opts.FaultNth) } callName := call.Meta.CallName resCopyout := call.Index != prog.ExecNoCopyout argCopyout := len(call.Copyout) != 0 emitCall := ctx.opts.EnableTun || callName != "syz_emit_ethernet" && callName != "syz_extract_tcp_res" // TODO: if we don't emit the call we must also not emit copyin, copyout and fault injection. // However, simply skipping whole iteration breaks tests due to unused static functions. if emitCall { native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_") fmt.Fprintf(w, "\t") if resCopyout || argCopyout { fmt.Fprintf(w, "res = ") } if native { fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName) } else if strings.HasPrefix(callName, "syz_") { fmt.Fprintf(w, "%v(", callName) } else { args := strings.Repeat(",long", len(call.Args)) if args != "" { args = args[1:] } fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName) } for ai, arg := range call.Args { if native || ai > 0 { fmt.Fprintf(w, ", ") } switch arg := arg.(type) { case prog.ExecArgConst: if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { panic("sring format in syscall argument") } fmt.Fprintf(w, "%v", ctx.constArgToStr(arg)) case prog.ExecArgResult: if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { panic("sring format in syscall argument") } fmt.Fprintf(w, "%v", ctx.resultArgToStr(arg)) default: panic(fmt.Sprintf("unknown arg type: %+v", arg)) } } fmt.Fprintf(w, ");\n") } // Copyout. if resCopyout || argCopyout { if ctx.sysTarget.OS == "fuchsia" { // On fuchsia we have real system calls that return ZX_OK on success, // and libc calls that are casted to function returning long, // as the result int -1 is returned as 0x00000000ffffffff rather than full -1. if strings.HasPrefix(callName, "zx_") { fmt.Fprintf(w, "\tif (res == ZX_OK)") } else { fmt.Fprintf(w, "\tif ((int)res != -1)") } } else { fmt.Fprintf(w, "\tif (res != -1)") } copyoutMultiple := len(call.Copyout) > 1 || resCopyout && len(call.Copyout) > 0 if copyoutMultiple { fmt.Fprintf(w, " {") } fmt.Fprintf(w, "\n") if resCopyout { fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index) } for _, copyout := range call.Copyout { fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v_t*)0x%x);\n", copyout.Index, copyout.Size*8, copyout.Addr) } if copyoutMultiple { fmt.Fprintf(w, "\t}\n") } } calls = append(calls, w.String()) } return calls, p.Vars } func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.ExecArgCsum, csumSeq int) { fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", csumSeq) fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", csumSeq) for i, chunk := range arg.Chunks { switch chunk.Kind { case prog.ExecArgCsumChunkData: fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n", csumSeq, chunk.Value, chunk.Size) case prog.ExecArgCsumChunkConst: fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n", chunk.Size*8, csumSeq, i, chunk.Value) fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n", csumSeq, csumSeq, i, chunk.Size) default: panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind)) } } fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n", addr, csumSeq) } func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin) { switch arg := copyin.Arg.(type) { case prog.ExecArgConst: if arg.BitfieldOffset == 0 && arg.BitfieldLength == 0 { ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.constArgToStr(arg), arg.Format) } else { if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { panic("bitfield+string format") } fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v_t, 0x%x, %v, %v, %v));\n", arg.Size*8, copyin.Addr, ctx.constArgToStr(arg), arg.BitfieldOffset, arg.BitfieldLength) } case prog.ExecArgResult: ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.resultArgToStr(arg), arg.Format) case prog.ExecArgData: fmt.Fprintf(w, "\tNONFAILING(memcpy((void*)0x%x, \"%s\", %v));\n", copyin.Addr, toCString(arg.Data), len(arg.Data)) case prog.ExecArgCsum: switch arg.Kind { case prog.ExecArgCsumInet: *csumSeq++ ctx.generateCsumInet(w, copyin.Addr, arg, *csumSeq) default: panic(fmt.Sprintf("unknown csum kind %v", arg.Kind)) } default: panic(fmt.Sprintf("bad argument type: %+v", arg)) } } func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) { switch bf { case prog.FormatNative, prog.FormatBigEndian: fmt.Fprintf(w, "\tNONFAILING(*(uint%v_t*)0x%x = %v);\n", size*8, addr, val) case prog.FormatStrDec: if size != 20 { panic("bad strdec size") } fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%020llu\", (long long)%v));\n", addr, val) case prog.FormatStrHex: if size != 18 { panic("bad strdec size") } fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"0x%%016llx\", (long long)%v));\n", addr, val) case prog.FormatStrOct: if size != 23 { panic("bad strdec size") } fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%023llo\", (long long)%v));\n", addr, val) default: panic("unknown binary format") } } func (ctx *context) constArgToStr(arg prog.ExecArgConst) string { mask := (uint64(1) << (arg.Size * 8)) - 1 v := arg.Value & mask val := fmt.Sprintf("%v", v) if v == ^uint64(0)&mask { val = "-1" } else if v >= 10 { val = fmt.Sprintf("0x%x", v) } if ctx.opts.Procs > 1 && arg.PidStride != 0 { val += fmt.Sprintf(" + procid*%v", arg.PidStride) } if arg.Format == prog.FormatBigEndian { val = fmt.Sprintf("htobe%v(%v)", arg.Size*8, val) } return val } func (ctx *context) resultArgToStr(arg prog.ExecArgResult) string { res := fmt.Sprintf("r[%v]", arg.Index) if arg.DivOp != 0 { res = fmt.Sprintf("%v/%v", res, arg.DivOp) } if arg.AddOp != 0 { res = fmt.Sprintf("%v+%v", res, arg.AddOp) } if arg.Format == prog.FormatBigEndian { res = fmt.Sprintf("htobe%v(%v)", arg.Size*8, res) } return res } func toCString(data []byte) []byte { if len(data) == 0 { return nil } readable := true for i, v := range data { // Allow 0 only as last byte. if !isReadable(v) && (i != len(data)-1 || v != 0) { readable = false break } } if !readable { buf := new(bytes.Buffer) for _, v := range data { buf.Write([]byte{'\\', 'x', toHex(v >> 4), toHex(v << 4 >> 4)}) } return buf.Bytes() } if data[len(data)-1] == 0 { // Don't serialize last 0, C strings are 0-terminated anyway. data = data[:len(data)-1] } buf := new(bytes.Buffer) for _, v := range data { switch v { case '\t': buf.Write([]byte{'\\', 't'}) case '\r': buf.Write([]byte{'\\', 'r'}) case '\n': buf.Write([]byte{'\\', 'n'}) case '\\': buf.Write([]byte{'\\', '\\'}) case '"': buf.Write([]byte{'\\', '"'}) default: if v < 0x20 || v >= 0x7f { panic("unexpected char during data serialization") } buf.WriteByte(v) } } return buf.Bytes() } func isReadable(v byte) bool { return v >= 0x20 && v < 0x7f || v == '\t' || v == '\r' || v == '\n' } func toHex(v byte) byte { if v < 10 { return '0' + v } return 'a' + v - 10 }