From 1baf121c2fc6a3b92f01a48633d59290ff183476 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Thu, 2 Aug 2018 16:50:43 +0200 Subject: pkg/csource: refactor generateCalls Move call generation into a separate function. Update #538 --- pkg/csource/csource.go | 90 +++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 42 deletions(-) (limited to 'pkg/csource') diff --git a/pkg/csource/csource.go b/pkg/csource/csource.go index 51234fb2d..06bfd5d40 100644 --- a/pkg/csource/csource.go +++ b/pkg/csource/csource.go @@ -173,59 +173,22 @@ func (ctx *context) generateCalls(p prog.ExecProg, trace bool) ([]string, []uint ctx.copyin(w, &csumSeq, copyin) } - // Call itself. if ctx.opts.Fault && ctx.opts.FaultCall == ci { fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/failslab/ignore-gfp-wait\", \"N\");\n") fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/fail_futex/ignore-private\", \"N\");\n") fmt.Fprintf(w, "\tinject_fault(%v);\n", ctx.opts.FaultNth) } + // Call itself. callName := call.Meta.CallName resCopyout := call.Index != prog.ExecNoCopyout argCopyout := len(call.Copyout) != 0 - emitCall := ctx.opts.EnableTun || callName != "syz_emit_ethernet" && - callName != "syz_extract_tcp_res" + emitCall := ctx.opts.EnableTun || + callName != "syz_emit_ethernet" && + callName != "syz_extract_tcp_res" // TODO: if we don't emit the call we must also not emit copyin, copyout and fault injection. // However, simply skipping whole iteration breaks tests due to unused static functions. if emitCall { - native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_") - fmt.Fprintf(w, "\t") - if resCopyout || argCopyout || trace { - fmt.Fprintf(w, "res = ") - } - if native { - fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName) - } else if strings.HasPrefix(callName, "syz_") { - fmt.Fprintf(w, "%v(", callName) - } else { - args := strings.Repeat(",long", len(call.Args)) - if args != "" { - args = args[1:] - } - fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName) - } - for ai, arg := range call.Args { - if native || ai > 0 { - fmt.Fprintf(w, ", ") - } - switch arg := arg.(type) { - case prog.ExecArgConst: - if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { - panic("sring format in syscall argument") - } - fmt.Fprintf(w, "%v", ctx.constArgToStr(arg)) - case prog.ExecArgResult: - if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { - panic("sring format in syscall argument") - } - fmt.Fprintf(w, "%v", ctx.resultArgToStr(arg)) - default: - panic(fmt.Sprintf("unknown arg type: %+v", arg)) - } - } - fmt.Fprintf(w, ");\n") - if trace { - fmt.Fprintf(w, "\tprintf(\"### call=%v errno=%%d\\n\", res == -1 ? errno : 0);\n", ci) - } + ctx.emitCall(w, call, ci, resCopyout || argCopyout, trace) } // Copyout. @@ -237,6 +200,49 @@ func (ctx *context) generateCalls(p prog.ExecProg, trace bool) ([]string, []uint return calls, p.Vars } +func (ctx *context) emitCall(w *bytes.Buffer, call prog.ExecCall, ci int, haveCopyout, trace bool) { + callName := call.Meta.CallName + native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_") + fmt.Fprintf(w, "\t") + if haveCopyout || trace { + fmt.Fprintf(w, "res = ") + } + if native { + fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName) + } else if strings.HasPrefix(callName, "syz_") { + fmt.Fprintf(w, "%v(", callName) + } else { + args := strings.Repeat(",long", len(call.Args)) + if args != "" { + args = args[1:] + } + fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName) + } + for ai, arg := range call.Args { + if native || ai > 0 { + fmt.Fprintf(w, ", ") + } + switch arg := arg.(type) { + case prog.ExecArgConst: + if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { + panic("sring format in syscall argument") + } + fmt.Fprintf(w, "%v", ctx.constArgToStr(arg)) + case prog.ExecArgResult: + if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { + panic("sring format in syscall argument") + } + fmt.Fprintf(w, "%v", ctx.resultArgToStr(arg)) + default: + panic(fmt.Sprintf("unknown arg type: %+v", arg)) + } + } + fmt.Fprintf(w, ");\n") + if trace { + fmt.Fprintf(w, "\tprintf(\"### call=%v errno=%%d\\n\", res == -1 ? errno : 0);\n", ci) + } +} + func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.ExecArgCsum, csumSeq int) { fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", csumSeq) fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", csumSeq) -- cgit mrf-deployment