aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/csource/csource.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-07-26 17:47:27 +0200
committerDmitry Vyukov <dvyukov@google.com>2018-07-27 10:22:23 +0200
commit9d92841b4e4d0ac0f97f983cd90087323f27c26c (patch)
tree562c5d32f96e010c34b3f122616213110d1b979b /pkg/csource/csource.go
parentc3da5dc5e0d0c6614f48c2d1178d58ff1e47809c (diff)
pkg/csource: tidy generated code
1. Remove unnecessary includes. 2. Remove thunk function in threaded mode. 3. Inline syscalls into main for the simplest case. 4. Define main in common.h rather than form with printfs. 5. Fix generation for repeat mode (we had 2 infinite loops: in main and in loop). 6. Remove unused functions (setup/reset_loop, setup/reset_test, sandbox_namespace, etc).
Diffstat (limited to 'pkg/csource/csource.go')
-rw-r--r--pkg/csource/csource.go208
1 files changed, 61 insertions, 147 deletions
diff --git a/pkg/csource/csource.go b/pkg/csource/csource.go
index 512d9985a..66958f22b 100644
--- a/pkg/csource/csource.go
+++ b/pkg/csource/csource.go
@@ -24,7 +24,6 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) {
opts: opts,
target: p.Target,
sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
- w: new(bytes.Buffer),
calls: make(map[string]uint64),
}
@@ -43,100 +42,43 @@ func Write(p *prog.Prog, opts Options) ([]byte, error) {
ctx.calls[c.Meta.CallName] = c.Meta.NR
}
- ctx.print("// autogenerated by syzkaller (http://github.com/google/syzkaller)\n\n")
-
- hdr, err := createCommonHeader(p, mmapProg, opts)
- if err != nil {
- return nil, err
- }
- ctx.w.Write(hdr)
- ctx.print("\n")
-
- ctx.generateSyscallDefines()
-
+ varsBuf := new(bytes.Buffer)
if len(vars) != 0 {
- ctx.printf("uint64_t r[%v] = {", len(vars))
+ fmt.Fprintf(varsBuf, "uint64 r[%v] = {", len(vars))
for i, v := range vars {
if i != 0 {
- ctx.printf(", ")
+ fmt.Fprintf(varsBuf, ", ")
}
- ctx.printf("0x%x", v)
+ fmt.Fprintf(varsBuf, "0x%x", v)
}
- ctx.printf("};\n")
- }
- needProcID := opts.Procs > 1 || opts.EnableCgroups
- for _, c := range p.Calls {
- if c.Meta.CallName == "syz_mount_image" ||
- c.Meta.CallName == "syz_read_part_table" {
- needProcID = true
- }
- }
- if needProcID {
- ctx.printf("unsigned long long procid;\n")
+ fmt.Fprintf(varsBuf, "};\n")
}
- if opts.Repeat {
- ctx.generateTestFunc(calls, len(vars) != 0, "execute_one")
- } else {
- ctx.generateTestFunc(calls, len(vars) != 0, "loop")
+ sandboxFunc := "loop();"
+ if opts.Sandbox != "" {
+ sandboxFunc = "do_sandbox_" + opts.Sandbox + "();"
}
-
- if ctx.target.OS == "akaros" && opts.Repeat {
- ctx.printf("const char* program_name;\n")
- ctx.print("int main(int argc, char** argv)\n{\n")
- } else {
- ctx.print("int main()\n{\n")
+ replacements := map[string]string{
+ "PROCS": fmt.Sprint(opts.Procs),
+ "NUM_CALLS": fmt.Sprint(len(p.Calls)),
+ "MMAP_DATA": strings.Join(mmapCalls, ""),
+ "SYSCALL_DEFINES": ctx.generateSyscallDefines(),
+ "SANDBOX_FUNC": sandboxFunc,
+ "RESULTS": varsBuf.String(),
+ "SYSCALLS": ctx.generateSyscalls(calls, len(vars) != 0),
}
- for _, c := range mmapCalls {
- ctx.printf("%s", c)
+ if !opts.Threaded && !opts.Repeat && opts.Sandbox == "" {
+ // This inlines syscalls right into main for the simplest case.
+ replacements["SANDBOX_FUNC"] = replacements["SYSCALLS"]
+ replacements["SYSCALLS"] = "unused"
}
- if ctx.target.OS == "akaros" && opts.Repeat {
- ctx.printf("\tprogram_name = argv[0];\n")
- ctx.printf("\tif (argc == 2 && strcmp(argv[1], \"child\") == 0)\n")
- ctx.printf("\t\tchild();\n")
- }
- if opts.HandleSegv {
- ctx.printf("\tinstall_segv_handler();\n")
- }
-
- if !opts.Repeat {
- if opts.UseTmpDir {
- ctx.printf("\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- } else {
- if opts.UseTmpDir {
- ctx.print("\tchar *cwd = get_current_dir_name();\n")
- }
- if opts.Procs <= 1 {
- ctx.print("\tfor (;;) {\n")
- if opts.UseTmpDir {
- ctx.print("\t\tif (chdir(cwd))\n")
- ctx.print("\t\t\tfail(\"failed to chdir\");\n")
- ctx.print("\t\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- ctx.print("\t}\n")
- } else {
- ctx.printf("\tfor (procid = 0; procid < %v; procid++) {\n", opts.Procs)
- ctx.print("\t\tif (fork() == 0) {\n")
- ctx.print("\t\t\tfor (;;) {\n")
- if opts.UseTmpDir {
- ctx.print("\t\t\t\tif (chdir(cwd))\n")
- ctx.print("\t\t\t\t\tfail(\"failed to chdir\");\n")
- ctx.print("\t\t\t\tuse_temporary_dir();\n")
- }
- ctx.writeLoopCall()
- ctx.print("\t\t\t}\n")
- ctx.print("\t\t}\n")
- ctx.print("\t}\n")
- ctx.print("\tsleep(1000000);\n")
- }
+ result, err := createCommonHeader(p, mmapProg, replacements, opts)
+ if err != nil {
+ return nil, err
}
-
- ctx.print("\treturn 0;\n}\n")
-
- result := ctx.postProcess(ctx.w.Bytes())
+ const header = "// autogenerated by syzkaller (https://github.com/google/syzkaller)\n\n"
+ result = append([]byte(header), result...)
+ result = ctx.postProcess(result)
return result, nil
}
@@ -145,92 +87,62 @@ type context struct {
opts Options
target *prog.Target
sysTarget *targets.Target
- w *bytes.Buffer
calls map[string]uint64 // CallName -> NR
}
-func (ctx *context) print(str string) {
- ctx.w.WriteString(str)
-}
-
-func (ctx *context) printf(str string, args ...interface{}) {
- ctx.print(fmt.Sprintf(str, args...))
-}
-
-func (ctx *context) writeLoopCall() {
- if ctx.opts.Sandbox != "" {
- ctx.printf("\tdo_sandbox_%v();\n", ctx.opts.Sandbox)
- return
- }
- if ctx.opts.EnableTun {
- ctx.printf("\tinitialize_tun();\n")
- }
- if ctx.opts.EnableNetdev {
- ctx.printf("\tinitialize_netdevices();\n")
- }
- ctx.print("\tloop();\n")
-}
-
-func (ctx *context) generateTestFunc(calls []string, hasVars bool, name string) {
+func (ctx *context) generateSyscalls(calls []string, hasVars bool) string {
opts := ctx.opts
+ buf := new(bytes.Buffer)
if !opts.Threaded && !opts.Collide {
- ctx.printf("void %v()\n{\n", name)
if hasVars {
- ctx.printf("\tlong res = 0;\n")
+ fmt.Fprintf(buf, "\tlong res = 0;\n")
}
if opts.Repro {
- ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
+ fmt.Fprintf(buf, "\tif (write(1, \"executing program\\n\", sizeof(\"executing program\\n\") - 1)) {}\n")
}
for _, c := range calls {
- ctx.printf("%s", c)
+ fmt.Fprintf(buf, "%s", c)
}
- ctx.printf("}\n\n")
} else {
- ctx.printf("void execute_call(int call)\n{\n")
if hasVars {
- ctx.printf("\tlong res;")
+ fmt.Fprintf(buf, "\tlong res;")
}
- ctx.printf("\tswitch (call) {\n")
+ fmt.Fprintf(buf, "\tswitch (call) {\n")
for i, c := range calls {
- ctx.printf("\tcase %v:\n", i)
- ctx.printf("%s", strings.Replace(c, "\t", "\t\t", -1))
- ctx.printf("\t\tbreak;\n")
- }
- ctx.printf("\t}\n")
- ctx.printf("}\n\n")
-
- ctx.printf("void %v()\n{\n", name)
- if opts.Repro {
- ctx.printf("\tif (write(1, \"executing program\\n\", strlen(\"executing program\\n\"))) {}\n")
- }
- ctx.printf("\texecute(%v);\n", len(calls))
- if opts.Collide {
- ctx.printf("\tcollide = 1;\n")
- ctx.printf("\texecute(%v);\n", len(calls))
+ fmt.Fprintf(buf, "\tcase %v:\n", i)
+ fmt.Fprintf(buf, "%s", strings.Replace(c, "\t", "\t\t", -1))
+ fmt.Fprintf(buf, "\t\tbreak;\n")
}
- ctx.printf("}\n\n")
+ fmt.Fprintf(buf, "\t}\n")
}
+ return buf.String()
}
-func (ctx *context) generateSyscallDefines() {
- prefix := ctx.sysTarget.SyscallPrefix
+func (ctx *context) generateSyscallDefines() string {
+ var calls []string
for name, nr := range ctx.calls {
if !ctx.sysTarget.SyscallNumbers ||
strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) {
continue
}
- ctx.printf("#ifndef %v%v\n", prefix, name)
- ctx.printf("#define %v%v %v\n", prefix, name, nr)
- ctx.printf("#endif\n")
+ calls = append(calls, name)
+ }
+ sort.Strings(calls)
+ buf := new(bytes.Buffer)
+ prefix := ctx.sysTarget.SyscallPrefix
+ for _, name := range calls {
+ fmt.Fprintf(buf, "#ifndef %v%v\n", prefix, name)
+ fmt.Fprintf(buf, "#define %v%v %v\n", prefix, name, ctx.calls[name])
+ fmt.Fprintf(buf, "#endif\n")
}
if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 {
// This is a dirty hack.
// On 32-bit linux mmap translated to old_mmap syscall which has a different signature.
// mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here.
- ctx.printf("#undef __NR_mmap\n")
- ctx.printf("#define __NR_mmap __NR_mmap2\n")
+ fmt.Fprintf(buf, "#undef __NR_mmap\n")
+ fmt.Fprintf(buf, "#define __NR_mmap __NR_mmap2\n")
}
- ctx.printf("\n")
+ return buf.String()
}
func (ctx *context) generateProgCalls(p *prog.Prog) ([]string, []uint64, error) {
@@ -332,7 +244,7 @@ func (ctx *context) generateCalls(p prog.ExecProg) ([]string, []uint64) {
fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index)
}
for _, copyout := range call.Copyout {
- fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v_t*)0x%x);\n",
+ fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v*)0x%x);\n",
copyout.Index, copyout.Size*8, copyout.Addr)
}
if copyoutMultiple {
@@ -350,18 +262,18 @@ func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.Exec
for i, chunk := range arg.Chunks {
switch chunk.Kind {
case prog.ExecArgCsumChunkData:
- fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n",
+ fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8*)0x%x, %d));\n",
csumSeq, chunk.Value, chunk.Size)
case prog.ExecArgCsumChunkConst:
- fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n",
+ fmt.Fprintf(w, "\tuint%d csum_%d_chunk_%d = 0x%x;\n",
chunk.Size*8, csumSeq, i, chunk.Value)
- fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n",
+ fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8*)&csum_%d_chunk_%d, %d);\n",
csumSeq, csumSeq, i, chunk.Size)
default:
panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind))
}
}
- fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n",
+ fmt.Fprintf(w, "\tNONFAILING(*(uint16*)0x%x = csum_inet_digest(&csum_%d));\n",
addr, csumSeq)
}
@@ -374,7 +286,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin
if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian {
panic("bitfield+string format")
}
- fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v_t, 0x%x, %v, %v, %v));\n",
+ fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v, 0x%x, %v, %v, %v));\n",
arg.Size*8, copyin.Addr, ctx.constArgToStr(arg),
arg.BitfieldOffset, arg.BitfieldLength)
}
@@ -399,7 +311,7 @@ func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin
func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) {
switch bf {
case prog.FormatNative, prog.FormatBigEndian:
- fmt.Fprintf(w, "\tNONFAILING(*(uint%v_t*)0x%x = %v);\n", size*8, addr, val)
+ fmt.Fprintf(w, "\tNONFAILING(*(uint%v*)0x%x = %v);\n", size*8, addr, val)
case prog.FormatStrDec:
if size != 20 {
panic("bad strdec size")
@@ -506,6 +418,8 @@ func (ctx *context) hoistIncludes(result []byte) []byte {
func (ctx *context) removeEmptyLines(result []byte) []byte {
for {
newResult := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1)
+ newResult = bytes.Replace(newResult, []byte{'\n', '\n', '\t'}, []byte{'\n', '\t'}, -1)
+ newResult = bytes.Replace(newResult, []byte{'\n', '\n', ' '}, []byte{'\n', ' '}, -1)
if len(newResult) == len(result) {
return result
}