aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-12-06 18:44:26 +0100
committerDmitry Vyukov <dvyukov@google.com>2018-12-06 18:55:46 +0100
commitceaec61a833bf78e2aa2a1feb964e1915d3f465b (patch)
treecfeaf07cbd501e76eb957d0eca14f0fc297b2eb4
parentf40330afce14e3dab61ec823f106b7fb396f070d (diff)
prog: export Type.DefaultArg
It's effectively exported anyway. So export it the proper way.
-rw-r--r--prog/encoding.go28
-rw-r--r--prog/minimization.go2
-rw-r--r--prog/prog.go2
-rw-r--r--prog/prog_test.go2
-rw-r--r--prog/rand.go6
-rw-r--r--prog/types.go40
-rw-r--r--tools/syz-trace2syz/proggen/proggen.go22
7 files changed, 49 insertions, 53 deletions
diff --git a/prog/encoding.go b/prog/encoding.go
index c968fa0b7..4478bc834 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -250,7 +250,7 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) {
c.Comment = strings.TrimSpace(p.s[p.i+1:])
}
for i := len(c.Args); i < len(meta.Args); i++ {
- c.Args = append(c.Args, meta.Args[i].makeDefaultArg())
+ c.Args = append(c.Args, meta.Args[i].DefaultArg())
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
@@ -292,7 +292,7 @@ func (target *Target) parseArg(typ Type, p *parser, vars map[string]*ResultArg)
}
if arg == nil {
if typ != nil {
- arg = typ.makeDefaultArg()
+ arg = typ.DefaultArg()
} else if r != "" {
return nil, fmt.Errorf("named nil argument")
}
@@ -349,7 +349,7 @@ func (target *Target) parseArgInt(typ Type, p *parser) (Arg, error) {
return MakeSpecialPointerArg(typ, index), nil
default:
eatExcessive(p, true)
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
}
@@ -376,7 +376,7 @@ func (target *Target) parseArgRes(typ Type, p *parser, vars map[string]*ResultAr
}
v := vars[id]
if v == nil {
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
arg := MakeResultArg(typ, v, 0)
arg.OpDiv = div
@@ -392,7 +392,7 @@ func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultA
case *VmaType:
default:
eatExcessive(p, true)
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
p.Parse('&')
addr, vmaSize, err := target.parseAddr(p)
@@ -419,7 +419,7 @@ func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultA
return MakeVmaPointerArg(typ, addr, vmaSize), nil
}
if inner == nil {
- inner = typ1.makeDefaultArg()
+ inner = typ1.DefaultArg()
}
return MakePointerArg(typ, addr, inner), nil
}
@@ -427,7 +427,7 @@ func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultA
func (target *Target) parseArgString(typ Type, p *parser) (Arg, error) {
if _, ok := typ.(*BufferType); !ok {
eatExcessive(p, true)
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
data, err := deserializeData(p)
if err != nil {
@@ -463,7 +463,7 @@ func (target *Target) parseArgStruct(typ Type, p *parser, vars map[string]*Resul
if !ok {
eatExcessive(p, false)
p.Parse('}')
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != '}'; i++ {
@@ -487,7 +487,7 @@ func (target *Target) parseArgStruct(typ Type, p *parser, vars map[string]*Resul
}
p.Parse('}')
for len(inner) < len(t1.Fields) {
- inner = append(inner, t1.Fields[len(inner)].makeDefaultArg())
+ inner = append(inner, t1.Fields[len(inner)].DefaultArg())
}
return MakeGroupArg(typ, inner), nil
}
@@ -498,7 +498,7 @@ func (target *Target) parseArgArray(typ Type, p *parser, vars map[string]*Result
if !ok {
eatExcessive(p, false)
p.Parse(']')
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != ']'; i++ {
@@ -514,7 +514,7 @@ func (target *Target) parseArgArray(typ Type, p *parser, vars map[string]*Result
p.Parse(']')
if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd {
for uint64(len(inner)) < t1.RangeBegin {
- inner = append(inner, t1.Type.makeDefaultArg())
+ inner = append(inner, t1.Type.DefaultArg())
}
inner = inner[:t1.RangeBegin]
}
@@ -525,7 +525,7 @@ func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*Result
t1, ok := typ.(*UnionType)
if !ok {
eatExcessive(p, true)
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
p.Parse('@')
name := p.Ident()
@@ -538,7 +538,7 @@ func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*Result
}
if optType == nil {
eatExcessive(p, true)
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
var opt Arg
if p.Char() == '=' {
@@ -549,7 +549,7 @@ func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*Result
return nil, err
}
} else {
- opt = optType.makeDefaultArg()
+ opt = optType.DefaultArg()
}
return MakeUnionArg(typ, opt), nil
}
diff --git a/prog/minimization.go b/prog/minimization.go
index 1762bbbae..5f3e9f604 100644
--- a/prog/minimization.go
+++ b/prog/minimization.go
@@ -183,7 +183,7 @@ func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool {
return false
}
a := arg.(*ConstArg)
- def := arg.Type().makeDefaultArg().(*ConstArg)
+ def := arg.Type().DefaultArg().(*ConstArg)
if a.Val == def.Val {
return false
}
diff --git a/prog/prog.go b/prog/prog.go
index b7304ba5d..575680882 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -364,7 +364,7 @@ func removeArg(arg0 Arg) {
delete(uses, a)
}
for arg1 := range a.uses {
- arg2 := arg1.Type().makeDefaultArg().(*ResultArg)
+ arg2 := arg1.Type().DefaultArg().(*ResultArg)
replaceResultArg(arg1, arg2)
}
})
diff --git a/prog/prog_test.go b/prog/prog_test.go
index 1bc9eb4d5..9363754d1 100644
--- a/prog/prog_test.go
+++ b/prog/prog_test.go
@@ -23,7 +23,7 @@ func TestDefault(t *testing.T) {
target, _, _ := initTest(t)
for _, meta := range target.Syscalls {
ForeachType(meta, func(typ Type) {
- arg := typ.makeDefaultArg()
+ arg := typ.DefaultArg()
if !isDefault(arg) {
t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
typ, typ, arg)
diff --git a/prog/rand.go b/prog/rand.go
index b479d1e8c..c3b7c2352 100644
--- a/prog/rand.go
+++ b/prog/rand.go
@@ -288,7 +288,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []
metas = append(metas, meta)
}
if len(metas) == 0 {
- return res.makeDefaultArg(), nil
+ return res.DefaultArg(), nil
}
// Now we have a set of candidate calls that can create the necessary resource.
@@ -539,7 +539,7 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A
switch typ.(type) {
case *IntType, *FlagsType, *ConstType, *ProcType,
*VmaType, *ResourceType:
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
}
@@ -548,7 +548,7 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A
v := res.Desc.Values[r.Intn(len(res.Desc.Values))]
return MakeResultArg(typ, nil, v), nil
}
- return typ.makeDefaultArg(), nil
+ return typ.DefaultArg(), nil
}
// Allow infinite recursion for optional pointers.
diff --git a/prog/types.go b/prog/types.go
index 6536056b8..e926d3227 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -60,17 +60,13 @@ type Type interface {
BitfieldLength() uint64
BitfieldMiddle() bool // returns true for all but last bitfield in a group
- makeDefaultArg() Arg
+ DefaultArg() Arg
isDefaultArg(arg Arg) bool
generate(r *randGen, s *state) (arg Arg, calls []*Call)
mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool)
minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool
}
-func DefaultArg(t Type) Arg {
- return t.makeDefaultArg()
-}
-
func IsPad(t Type) bool {
if ct, ok := t.(*ConstType); ok && ct.IsPad {
return true
@@ -147,7 +143,7 @@ func (t *ResourceType) String() string {
return t.Name()
}
-func (t *ResourceType) makeDefaultArg() Arg {
+func (t *ResourceType) DefaultArg() Arg {
return MakeResultArg(t, nil, t.Default())
}
@@ -203,7 +199,7 @@ type ConstType struct {
IsPad bool
}
-func (t *ConstType) makeDefaultArg() Arg {
+func (t *ConstType) DefaultArg() Arg {
return MakeConstArg(t, t.Val)
}
@@ -233,7 +229,7 @@ type IntType struct {
RangeEnd uint64
}
-func (t *IntType) makeDefaultArg() Arg {
+func (t *IntType) DefaultArg() Arg {
return MakeConstArg(t, 0)
}
@@ -247,7 +243,7 @@ type FlagsType struct {
BitMask bool
}
-func (t *FlagsType) makeDefaultArg() Arg {
+func (t *FlagsType) DefaultArg() Arg {
return MakeConstArg(t, 0)
}
@@ -261,7 +257,7 @@ type LenType struct {
Buf string
}
-func (t *LenType) makeDefaultArg() Arg {
+func (t *LenType) DefaultArg() Arg {
return MakeConstArg(t, 0)
}
@@ -280,7 +276,7 @@ const (
procDefaultValue = 0xffffffffffffffff // special value denoting 0 for all procs
)
-func (t *ProcType) makeDefaultArg() Arg {
+func (t *ProcType) DefaultArg() Arg {
return MakeConstArg(t, procDefaultValue)
}
@@ -306,7 +302,7 @@ func (t *CsumType) String() string {
return "csum"
}
-func (t *CsumType) makeDefaultArg() Arg {
+func (t *CsumType) DefaultArg() Arg {
return MakeConstArg(t, 0)
}
@@ -324,7 +320,7 @@ func (t *VmaType) String() string {
return "vma"
}
-func (t *VmaType) makeDefaultArg() Arg {
+func (t *VmaType) DefaultArg() Arg {
return MakeSpecialPointerArg(t, 0)
}
@@ -369,7 +365,7 @@ func (t *BufferType) String() string {
return "buffer"
}
-func (t *BufferType) makeDefaultArg() Arg {
+func (t *BufferType) DefaultArg() Arg {
if t.Dir() == DirOut {
var sz uint64
if !t.Varlen() {
@@ -422,11 +418,11 @@ func (t *ArrayType) String() string {
return fmt.Sprintf("array[%v]", t.Type.String())
}
-func (t *ArrayType) makeDefaultArg() Arg {
+func (t *ArrayType) DefaultArg() Arg {
var elems []Arg
if t.Kind == ArrayRangeLen && t.RangeBegin == t.RangeEnd {
for i := uint64(0); i < t.RangeBegin; i++ {
- elems = append(elems, t.Type.makeDefaultArg())
+ elems = append(elems, t.Type.DefaultArg())
}
}
return MakeGroupArg(t, elems)
@@ -454,11 +450,11 @@ func (t *PtrType) String() string {
return fmt.Sprintf("ptr[%v, %v]", t.Dir(), t.Type.String())
}
-func (t *PtrType) makeDefaultArg() Arg {
+func (t *PtrType) DefaultArg() Arg {
if t.Optional() {
return MakeSpecialPointerArg(t, 0)
}
- return MakePointerArg(t, 0, t.Type.makeDefaultArg())
+ return MakePointerArg(t, 0, t.Type.DefaultArg())
}
func (t *PtrType) isDefaultArg(arg Arg) bool {
@@ -483,10 +479,10 @@ func (t *StructType) FieldName() string {
return t.FldName
}
-func (t *StructType) makeDefaultArg() Arg {
+func (t *StructType) DefaultArg() Arg {
inner := make([]Arg, len(t.Fields))
for i, field := range t.Fields {
- inner[i] = field.makeDefaultArg()
+ inner[i] = field.DefaultArg()
}
return MakeGroupArg(t, inner)
}
@@ -515,8 +511,8 @@ func (t *UnionType) FieldName() string {
return t.FldName
}
-func (t *UnionType) makeDefaultArg() Arg {
- return MakeUnionArg(t, t.Fields[0].makeDefaultArg())
+func (t *UnionType) DefaultArg() Arg {
+ return MakeUnionArg(t, t.Fields[0].DefaultArg())
}
func (t *UnionType) isDefaultArg(arg Arg) bool {
diff --git a/tools/syz-trace2syz/proggen/proggen.go b/tools/syz-trace2syz/proggen/proggen.go
index eec72ca70..4ba81d5b5 100644
--- a/tools/syz-trace2syz/proggen/proggen.go
+++ b/tools/syz-trace2syz/proggen/proggen.go
@@ -79,7 +79,7 @@ func genResult(syzType prog.Type, straceRet int64, ctx *Context) {
func genArgs(syzType prog.Type, traceArg parser.IrType, ctx *Context) prog.Arg {
if traceArg == nil {
log.Logf(3, "parsing syzType: %s, traceArg is nil. generating default arg...", syzType.Name())
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
}
ctx.CurrentStraceArg = traceArg
log.Logf(3, "parsing arg of syz type: %s, ir type: %#v", syzType.Name(), traceArg)
@@ -90,7 +90,7 @@ func genArgs(syzType prog.Type, traceArg parser.IrType, ctx *Context) prog.Arg {
// Resource Types need special care. Pointers, Structs can have resource fields e.g. pipe, socketpair
// Buffer may need special care in out direction
default:
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
}
}
@@ -98,7 +98,7 @@ func genArgs(syzType prog.Type, traceArg parser.IrType, ctx *Context) prog.Arg {
case *prog.IntType, *prog.ConstType, *prog.FlagsType, *prog.CsumType:
return genConst(a, traceArg, ctx)
case *prog.LenType:
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
case *prog.ProcType:
return parseProc(a, traceArg, ctx)
case *prog.ResourceType:
@@ -152,14 +152,14 @@ func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context)
reorderStructFields(syzType, a, ctx)
for i := range syzType.Fields {
if prog.IsPad(syzType.Fields[i]) {
- args = append(args, prog.DefaultArg(syzType.Fields[i]))
+ args = append(args, syzType.Fields[i].DefaultArg())
continue
}
// If the last n fields of a struct are zero or NULL, strace will occasionally omit those values
// this creates a mismatch in the number of elements in the ir type and in
// our descriptions. We generate default values for omitted fields
if j >= len(a.Elems) {
- args = append(args, prog.DefaultArg(syzType.Fields[i]))
+ args = append(args, syzType.Fields[i].DefaultArg())
} else {
args = append(args, genArgs(syzType.Fields[i], a.Elems[j], ctx))
}
@@ -169,7 +169,7 @@ func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context)
// We could have a case like the following:
// ioctl(3, 35111, {ifr_name="\x6c\x6f", ifr_hwaddr=00:00:00:00:00:00}) = 0
// if_hwaddr gets parsed as a BufferType but our syscall descriptions have it as a struct type
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
default:
log.Fatalf("unsupported type for struct: %#v", a)
}
@@ -179,7 +179,7 @@ func genStruct(syzType *prog.StructType, traceType parser.IrType, ctx *Context)
func genUnionArg(syzType *prog.UnionType, straceType parser.IrType, ctx *Context) prog.Arg {
if straceType == nil {
log.Logf(1, "generating union arg. straceType is nil")
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
}
log.Logf(4, "generating union arg: %s %#v", syzType.TypeName, straceType)
@@ -250,7 +250,7 @@ func genPtr(syzType *prog.PtrType, traceType parser.IrType, ctx *Context) prog.A
return prog.MakeSpecialPointerArg(syzType, 0)
}
// Likely have a type of the form bind(3, 0xfffffffff, [3]);
- res := prog.DefaultArg(syzType.Type)
+ res := syzType.Type.DefaultArg()
return addr(ctx, syzType, res.Size(), res)
default:
res := genArgs(syzType.Type, a, ctx)
@@ -270,12 +270,12 @@ func genConst(syzType prog.Type, traceType parser.IrType, ctx *Context) prog.Arg
// For now we choose the first option
if len(a.Elems) == 0 {
log.Logf(2, "parsing const type, got array type with len 0")
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
}
return genConst(syzType, a.Elems[0], ctx)
case *parser.BufferType:
// The call almost certainly returned an errno
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
default:
log.Fatalf("unsupported type for const: %#v", traceType)
}
@@ -326,7 +326,7 @@ func parseProc(syzType *prog.ProcType, traceType parser.IrType, ctx *Context) pr
// Again probably an error case
// Something like the following will trigger this
// bind(3, {sa_family=AF_INET, sa_data="\xac"}, 3) = -1 EINVAL(Invalid argument)
- return prog.DefaultArg(syzType)
+ return syzType.DefaultArg()
default:
log.Fatalf("unsupported type for proc: %#v", traceType)
}