diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2018-12-07 14:23:58 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2018-12-07 14:23:58 +0100 |
| commit | 276faf74b2136cc280b921f4c6e9080a61563fec (patch) | |
| tree | f488e914f0ec251ed594fc404665330a85d465bb /tools/syz-trace2syz | |
| parent | c9f43ce69883dd0b9829b27e55eb7b3bb8f8603e (diff) | |
tools/syz-trace2syz/proggen: unexport and refactor Context
1. Unexport Context, it's not meant for callers.
2. Unexport all Context fields.
3. Make all function Context methods.
Diffstat (limited to 'tools/syz-trace2syz')
| -rw-r--r-- | tools/syz-trace2syz/proggen/call_selector.go | 28 | ||||
| -rw-r--r-- | tools/syz-trace2syz/proggen/context.go | 29 | ||||
| -rw-r--r-- | tools/syz-trace2syz/proggen/generate_unions.go | 32 | ||||
| -rw-r--r-- | tools/syz-trace2syz/proggen/proggen.go | 123 |
4 files changed, 103 insertions, 109 deletions
diff --git a/tools/syz-trace2syz/proggen/call_selector.go b/tools/syz-trace2syz/proggen/call_selector.go index 0735a0f1b..00f2e3ea9 100644 --- a/tools/syz-trace2syz/proggen/call_selector.go +++ b/tools/syz-trace2syz/proggen/call_selector.go @@ -25,23 +25,29 @@ var discriminatorArgs = map[string][]int{ } type callSelector struct { - callCache map[string][]*prog.Syscall + target *prog.Target + returnCache returnCache + callCache map[string][]*prog.Syscall } -func newCallSelector() *callSelector { - return &callSelector{callCache: make(map[string][]*prog.Syscall)} +func newCallSelector(target *prog.Target, returnCache returnCache) *callSelector { + return &callSelector{ + target: target, + returnCache: returnCache, + callCache: make(map[string][]*prog.Syscall), + } } // Select returns the best matching descrimination for this syscall. -func (cs *callSelector) Select(ctx *Context, call *parser.Syscall) *prog.Syscall { - match := ctx.Target.SyscallMap[call.CallName] +func (cs *callSelector) Select(call *parser.Syscall) *prog.Syscall { + match := cs.target.SyscallMap[call.CallName] discriminators := discriminatorArgs[call.CallName] if len(discriminators) == 0 { return match } score := 0 - for _, meta := range cs.callSet(ctx, call.CallName) { - if score1 := matchCall(ctx, meta, call, discriminators); score1 > score { + for _, meta := range cs.callSet(call.CallName) { + if score1 := cs.matchCall(meta, call, discriminators); score1 > score { match, score = meta, score1 } } @@ -49,12 +55,12 @@ func (cs *callSelector) Select(ctx *Context, call *parser.Syscall) *prog.Syscall } // callSet returns all syscalls with the given name. -func (cs *callSelector) callSet(ctx *Context, callName string) []*prog.Syscall { +func (cs *callSelector) callSet(callName string) []*prog.Syscall { calls, ok := cs.callCache[callName] if ok { return calls } - for _, call := range ctx.Target.Syscalls { + for _, call := range cs.target.Syscalls { if call.CallName == callName { calls = append(calls, call) } @@ -65,7 +71,7 @@ func (cs *callSelector) callSet(ctx *Context, callName string) []*prog.Syscall { // matchCall returns match score between meta and call. // Higher score means better match, -1 if they are not matching at all. -func matchCall(ctx *Context, meta *prog.Syscall, call *parser.Syscall, discriminators []int) int { +func (cs *callSelector) matchCall(meta *prog.Syscall, call *parser.Syscall, discriminators []int) int { score := 0 for _, i := range discriminators { if i >= len(meta.Args) || i >= len(call.Args) { @@ -97,7 +103,7 @@ func matchCall(ctx *Context, meta *prog.Syscall, call *parser.Syscall, discrimin case *prog.ResourceType: // Resources must match one of subtypes, // the more precise match, the higher the score. - retArg := ctx.ReturnCache.get(t, arg) + retArg := cs.returnCache.get(t, arg) if retArg == nil { return -1 } diff --git a/tools/syz-trace2syz/proggen/context.go b/tools/syz-trace2syz/proggen/context.go deleted file mode 100644 index 8283cfd8f..000000000 --- a/tools/syz-trace2syz/proggen/context.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 syzkaller project authors. All rights reserved. -// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. - -package proggen - -import ( - "github.com/google/syzkaller/prog" - "github.com/google/syzkaller/tools/syz-trace2syz/parser" -) - -// Context stores metadata related to a syzkaller program -type Context struct { - pg *prog.ProgGen - ReturnCache returnCache - CurrentStraceCall *parser.Syscall - CurrentSyzCall *prog.Call - CurrentStraceArg parser.IrType - Target *prog.Target - callSelector *callSelector -} - -func newContext(target *prog.Target) *Context { - return &Context{ - pg: prog.MakeProgGen(target), - ReturnCache: newRCache(), - Target: target, - callSelector: newCallSelector(), - } -} diff --git a/tools/syz-trace2syz/proggen/generate_unions.go b/tools/syz-trace2syz/proggen/generate_unions.go index 1465b19d1..da8b26cea 100644 --- a/tools/syz-trace2syz/proggen/generate_unions.go +++ b/tools/syz-trace2syz/proggen/generate_unions.go @@ -9,14 +9,14 @@ import ( "github.com/google/syzkaller/tools/syz-trace2syz/parser" ) -func genSockaddrStorage(syzType *prog.UnionType, straceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genSockaddrStorage(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { field2Opt := make(map[string]int) for i, field := range syzType.Fields { field2Opt[field.FieldName()] = i } // We currently look at the first argument of the system call // To determine which option of the union we select. - call := ctx.CurrentStraceCall + call := ctx.currentStraceCall var straceArg parser.IrType switch call.CallName { // May need to handle special cases. @@ -39,35 +39,35 @@ func genSockaddrStorage(syzType *prog.UnionType, straceType parser.IrType, ctx * "expected constant got: %#v", strType.Elems[0]) } switch socketFamily.Val() { - case ctx.Target.ConstMap["AF_INET6"]: + case ctx.target.ConstMap["AF_INET6"]: idx = field2Opt["in6"] - case ctx.Target.ConstMap["AF_INET"]: + case ctx.target.ConstMap["AF_INET"]: idx = field2Opt["in"] - case ctx.Target.ConstMap["AF_UNIX"]: + case ctx.target.ConstMap["AF_UNIX"]: idx = field2Opt["un"] - case ctx.Target.ConstMap["AF_UNSPEC"]: + case ctx.target.ConstMap["AF_UNSPEC"]: idx = field2Opt["nl"] - case ctx.Target.ConstMap["AF_NETLINK"]: + case ctx.target.ConstMap["AF_NETLINK"]: idx = field2Opt["nl"] - case ctx.Target.ConstMap["AF_NFC"]: + case ctx.target.ConstMap["AF_NFC"]: idx = field2Opt["nfc"] - case ctx.Target.ConstMap["AF_PACKET"]: + case ctx.target.ConstMap["AF_PACKET"]: idx = field2Opt["ll"] } default: log.Fatalf("unable to parse sockaddr_storage. Unsupported type: %#v", strType) } - return prog.MakeUnionArg(syzType, genArgs(syzType.Fields[idx], straceType, ctx)) + return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) } -func genSockaddrNetlink(syzType *prog.UnionType, straceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genSockaddrNetlink(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { var idx = 2 field2Opt := make(map[string]int) for i, field := range syzType.Fields { field2Opt[field.FieldName()] = i } - switch a := ctx.CurrentStraceArg.(type) { + switch a := ctx.currentStraceArg.(type) { case *parser.GroupType: if len(a.Elems) > 2 { switch b := a.Elems[1].(type) { @@ -88,14 +88,14 @@ func genSockaddrNetlink(syzType *prog.UnionType, straceType parser.IrType, ctx * } } } - return prog.MakeUnionArg(syzType, genArgs(syzType.Fields[idx], straceType, ctx)) + return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) } -func genIfrIfru(syzType *prog.UnionType, straceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genIfrIfru(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { idx := 0 - switch ctx.CurrentStraceArg.(type) { + switch ctx.currentStraceArg.(type) { case parser.Constant: idx = 2 } - return prog.MakeUnionArg(syzType, genArgs(syzType.Fields[idx], straceType, ctx)) + return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) } diff --git a/tools/syz-trace2syz/proggen/proggen.go b/tools/syz-trace2syz/proggen/proggen.go index 091068ebd..541c1b4eb 100644 --- a/tools/syz-trace2syz/proggen/proggen.go +++ b/tools/syz-trace2syz/proggen/proggen.go @@ -49,9 +49,26 @@ func parseTree(tree *parser.TraceTree, pid int64, target *prog.Target, progs *[] } } +// Context stores metadata related to a syzkaller program +type context struct { + pg *prog.ProgGen + target *prog.Target + callSelector *callSelector + returnCache returnCache + currentStraceCall *parser.Syscall + currentSyzCall *prog.Call + currentStraceArg parser.IrType +} + // genProg converts a trace to one of our programs. func genProg(trace *parser.Trace, target *prog.Target) *prog.Prog { - ctx := newContext(target) + returnCache := newRCache() + ctx := &context{ + pg: prog.MakeProgGen(target), + target: target, + callSelector: newCallSelector(target, returnCache), + returnCache: returnCache, + } for _, sCall := range trace.Calls { if sCall.Paused { // Probably a case where the call was killed by a signal like the following @@ -64,8 +81,8 @@ func genProg(trace *parser.Trace, target *prog.Target) *prog.Prog { log.Logf(2, "skipping call: %s", sCall.CallName) continue } - ctx.CurrentStraceCall = sCall - call := genCall(ctx) + ctx.currentStraceCall = sCall + call := ctx.genCall() if call == nil { continue } @@ -80,14 +97,14 @@ func genProg(trace *parser.Trace, target *prog.Target) *prog.Prog { return p } -func genCall(ctx *Context) *prog.Call { - log.Logf(3, "parsing call: %s", ctx.CurrentStraceCall.CallName) - straceCall := ctx.CurrentStraceCall - ctx.CurrentSyzCall = new(prog.Call) - ctx.CurrentSyzCall.Meta = ctx.callSelector.Select(ctx, straceCall) - syzCall := ctx.CurrentSyzCall - if ctx.CurrentSyzCall.Meta == nil { - log.Logf(2, "skipping call: %s which has no matching description", ctx.CurrentStraceCall.CallName) +func (ctx *context) genCall() *prog.Call { + log.Logf(3, "parsing call: %s", ctx.currentStraceCall.CallName) + straceCall := ctx.currentStraceCall + ctx.currentSyzCall = new(prog.Call) + ctx.currentSyzCall.Meta = ctx.callSelector.Select(straceCall) + syzCall := ctx.currentSyzCall + if ctx.currentSyzCall.Meta == nil { + log.Logf(2, "skipping call: %s which has no matching description", ctx.currentStraceCall.CallName) return nil } syzCall.Ret = prog.MakeReturnArg(syzCall.Meta.Ret) @@ -97,31 +114,31 @@ func genCall(ctx *Context) *prog.Call { if i < len(straceCall.Args) { strArg = straceCall.Args[i] } - res := genArgs(syzCall.Meta.Args[i], strArg, ctx) + res := ctx.genArgs(syzCall.Meta.Args[i], strArg) syzCall.Args = append(syzCall.Args, res) } - genResult(syzCall.Meta.Ret, straceCall.Ret, ctx) + ctx.genResult(syzCall.Meta.Ret, straceCall.Ret) return syzCall } -func genResult(syzType prog.Type, straceRet int64, ctx *Context) { +func (ctx *context) genResult(syzType prog.Type, straceRet int64) { if straceRet > 0 { straceExpr := parser.Constant(uint64(straceRet)) switch syzType.(type) { case *prog.ResourceType: log.Logf(2, "call: %s returned a resource type with val: %s", - ctx.CurrentStraceCall.CallName, straceExpr.String()) - ctx.ReturnCache.cache(syzType, straceExpr, ctx.CurrentSyzCall.Ret) + ctx.currentStraceCall.CallName, straceExpr.String()) + ctx.returnCache.cache(syzType, straceExpr, ctx.currentSyzCall.Ret) } } } -func genArgs(syzType prog.Type, traceArg parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genArgs(syzType prog.Type, traceArg parser.IrType) prog.Arg { if traceArg == nil { log.Logf(3, "parsing syzType: %s, traceArg is nil. generating default arg...", syzType.Name()) return syzType.DefaultArg() } - ctx.CurrentStraceArg = traceArg + ctx.currentStraceArg = traceArg log.Logf(3, "parsing arg of syz type: %s, ir type: %#v", syzType.Name(), traceArg) if syzType.Dir() == prog.DirOut { @@ -136,32 +153,32 @@ func genArgs(syzType prog.Type, traceArg parser.IrType, ctx *Context) prog.Arg { switch a := syzType.(type) { case *prog.IntType, *prog.ConstType, *prog.FlagsType, *prog.CsumType: - return genConst(a, traceArg, ctx) + return ctx.genConst(a, traceArg) case *prog.LenType: return syzType.DefaultArg() case *prog.ProcType: - return parseProc(a, traceArg, ctx) + return ctx.parseProc(a, traceArg) case *prog.ResourceType: - return genResource(a, traceArg, ctx) + return ctx.genResource(a, traceArg) case *prog.PtrType: - return genPtr(a, traceArg, ctx) + return ctx.genPtr(a, traceArg) case *prog.BufferType: - return genBuffer(a, traceArg, ctx) + return ctx.genBuffer(a, traceArg) case *prog.StructType: - return genStruct(a, traceArg, ctx) + return ctx.genStruct(a, traceArg) case *prog.ArrayType: - return genArray(a, traceArg, ctx) + return ctx.genArray(a, traceArg) case *prog.UnionType: - return genUnionArg(a, traceArg, ctx) + return ctx.genUnionArg(a, traceArg) case *prog.VmaType: - return genVma(a, traceArg, ctx) + return ctx.genVma(a, traceArg) default: log.Fatalf("unsupported type: %#v", syzType) } return nil } -func genVma(syzType *prog.VmaType, _ parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genVma(syzType *prog.VmaType, _ parser.IrType) prog.Arg { npages := uint64(1) if syzType.RangeBegin != 0 || syzType.RangeEnd != 0 { npages = syzType.RangeEnd @@ -169,12 +186,12 @@ func genVma(syzType *prog.VmaType, _ parser.IrType, ctx *Context) prog.Arg { return prog.MakeVmaPointerArg(syzType, ctx.pg.AllocateVMA(npages), npages) } -func genArray(syzType *prog.ArrayType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genArray(syzType *prog.ArrayType, traceType parser.IrType) prog.Arg { var args []prog.Arg switch a := traceType.(type) { case *parser.GroupType: for i := 0; i < len(a.Elems); i++ { - args = append(args, genArgs(syzType.Type, a.Elems[i], ctx)) + args = append(args, ctx.genArgs(syzType.Type, a.Elems[i])) } default: log.Fatalf("unsupported type for array: %#v", traceType) @@ -182,12 +199,12 @@ func genArray(syzType *prog.ArrayType, traceType parser.IrType, ctx *Context) pr return prog.MakeGroupArg(syzType, args) } -func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genStruct(syzType *prog.StructType, traceType parser.IrType) prog.Arg { var args []prog.Arg switch a := traceType.(type) { case *parser.GroupType: j := 0 - reorderStructFields(syzType, a, ctx) + ctx.reorderStructFields(syzType, a) for i := range syzType.Fields { if prog.IsPad(syzType.Fields[i]) { args = append(args, syzType.Fields[i].DefaultArg()) @@ -199,7 +216,7 @@ func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context) if j >= len(a.Elems) { args = append(args, syzType.Fields[i].DefaultArg()) } else { - args = append(args, genArgs(syzType.Fields[i], a.Elems[j], ctx)) + args = append(args, ctx.genArgs(syzType.Fields[i], a.Elems[j])) } j++ } @@ -214,7 +231,7 @@ func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context) return prog.MakeGroupArg(syzType, args) } -func genUnionArg(syzType *prog.UnionType, straceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genUnionArg(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { if straceType == nil { log.Logf(1, "generating union arg. straceType is nil") return syzType.DefaultArg() @@ -227,16 +244,16 @@ func genUnionArg(syzType *prog.UnionType, straceType parser.IrType, ctx *Context switch syzType.TypeName { case "sockaddr_storage": - return genSockaddrStorage(syzType, straceType, ctx) + return ctx.genSockaddrStorage(syzType, straceType) case "sockaddr_nl": - return genSockaddrNetlink(syzType, straceType, ctx) + return ctx.genSockaddrNetlink(syzType, straceType) case "ifr_ifru": - return genIfrIfru(syzType, straceType, ctx) + return ctx.genIfrIfru(syzType, straceType) } - return prog.MakeUnionArg(syzType, genArgs(syzType.Fields[0], straceType, ctx)) + return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[0], straceType)) } -func genBuffer(syzType *prog.BufferType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genBuffer(syzType *prog.BufferType, traceType parser.IrType) prog.Arg { if syzType.Dir() == prog.DirOut { if !syzType.Varlen() { return prog.MakeOutDataArg(syzType, syzType.Size()) @@ -255,7 +272,7 @@ func genBuffer(syzType *prog.BufferType, traceType parser.IrType, ctx *Context) size := max + int(syzType.RangeBegin) return prog.MakeOutDataArg(syzType, uint64(size)) default: - log.Fatalf("unexpected buffer type kind: %v. call %v arg %#v", syzType.Kind, ctx.CurrentSyzCall, traceType) + log.Fatalf("unexpected buffer type kind: %v. call %v arg %#v", syzType.Kind, ctx.currentSyzCall, traceType) } } } @@ -281,7 +298,7 @@ func genBuffer(syzType *prog.BufferType, traceType parser.IrType, ctx *Context) return prog.MakeDataArg(syzType, bufVal) } -func genPtr(syzType *prog.PtrType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genPtr(syzType *prog.PtrType, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: if a.Val() == 0 { @@ -289,14 +306,14 @@ func genPtr(syzType *prog.PtrType, traceType parser.IrType, ctx *Context) prog.A } // Likely have a type of the form bind(3, 0xfffffffff, [3]); res := syzType.Type.DefaultArg() - return addr(ctx, syzType, res.Size(), res) + return ctx.addr(syzType, res.Size(), res) default: - res := genArgs(syzType.Type, a, ctx) - return addr(ctx, syzType, res.Size(), res) + res := ctx.genArgs(syzType.Type, a) + return ctx.addr(syzType, res.Size(), res) } } -func genConst(syzType prog.Type, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genConst(syzType prog.Type, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: return prog.MakeConstArg(syzType, a.Val()) @@ -310,7 +327,7 @@ func genConst(syzType prog.Type, traceType parser.IrType, ctx *Context) prog.Arg log.Logf(2, "parsing const type, got array type with len 0") return syzType.DefaultArg() } - return genConst(syzType, a.Elems[0], ctx) + return ctx.genConst(syzType, a.Elems[0]) case *parser.BufferType: // The call almost certainly returned an errno return syzType.DefaultArg() @@ -320,17 +337,17 @@ func genConst(syzType prog.Type, traceType parser.IrType, ctx *Context) prog.Arg return nil } -func genResource(syzType *prog.ResourceType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) genResource(syzType *prog.ResourceType, traceType parser.IrType) prog.Arg { if syzType.Dir() == prog.DirOut { log.Logf(2, "resource returned by call argument: %s", traceType.String()) res := prog.MakeResultArg(syzType, nil, syzType.Default()) - ctx.ReturnCache.cache(syzType, traceType, res) + ctx.returnCache.cache(syzType, traceType, res) return res } switch a := traceType.(type) { case parser.Constant: val := a.Val() - if arg := ctx.ReturnCache.get(syzType, traceType); arg != nil { + if arg := ctx.returnCache.get(syzType, traceType); arg != nil { res := prog.MakeResultArg(syzType, arg.(*prog.ResultArg), syzType.Default()) return res } @@ -342,7 +359,7 @@ func genResource(syzType *prog.ResourceType, traceType parser.IrType, ctx *Conte // last argument is a pointer to a resource. Strace will output a pointer to // a number x as [x]. res := prog.MakeResultArg(syzType, nil, syzType.Default()) - ctx.ReturnCache.cache(syzType, a.Elems[0], res) + ctx.returnCache.cache(syzType, a.Elems[0], res) return res } log.Fatalf("generating resource type from GroupType with %d elements", len(a.Elems)) @@ -352,7 +369,7 @@ func genResource(syzType *prog.ResourceType, traceType parser.IrType, ctx *Conte return nil } -func parseProc(syzType *prog.ProcType, traceType parser.IrType, ctx *Context) prog.Arg { +func (ctx *context) parseProc(syzType *prog.ProcType, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: val := a.Val() @@ -371,11 +388,11 @@ func parseProc(syzType *prog.ProcType, traceType parser.IrType, ctx *Context) pr return nil } -func addr(ctx *Context, syzType prog.Type, size uint64, data prog.Arg) prog.Arg { +func (ctx *context) addr(syzType prog.Type, size uint64, data prog.Arg) prog.Arg { return prog.MakePointerArg(syzType, ctx.pg.Allocate(size), data) } -func reorderStructFields(syzType *prog.StructType, traceType *parser.GroupType, ctx *Context) { +func (ctx *context) reorderStructFields(syzType *prog.StructType, traceType *parser.GroupType) { // Sometimes strace reports struct fields out of order compared to our descriptions // Example: 5704 bind(3, {sa_family=AF_INET6, // sin6_port=htons(8888), |
