aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-07-11 16:49:08 +0200
committerAndrey Konovalov <andreyknvl@google.com>2017-07-17 14:34:09 +0200
commitcfc46d9d0bea72865ba75e0e4063a1a558262df8 (patch)
tree80455a77ab10d09154bb1ae66a12de002b7cd030 /prog
parent8d1e7095528712971312a83e067cdd803aaccc47 (diff)
prog: split Arg into smaller structs
Right now Arg is a huge struct (160 bytes), which has many different fields used for different arg kinds. Since most of the args we see in a typical corpus are ArgConst, this results in a significant memory overuse. This change: - makes Arg an interface instead of a struct - adds a SomethingArg struct for each arg kind we have - converts all *Arg pointers into just Arg, since interface variable by itself contains a pointer to the actual data - removes ArgPageSize, now ConstArg is used instead - consolidates correspondence between arg kinds and types, see comments before each SomethingArg struct definition - now LenType args that denote the length of VmaType args are serialized as "0x1000" instead of "(0x1000)"; to preserve backwards compatibility syzkaller is able to parse the old format for now - multiple small changes all over to make the above work After this change syzkaller uses twice less memory after deserializing a typical corpus.
Diffstat (limited to 'prog')
-rw-r--r--prog/analysis.go166
-rw-r--r--prog/checksum.go64
-rw-r--r--prog/clone.go79
-rw-r--r--prog/encoding.go132
-rw-r--r--prog/encodingexec.go106
-rw-r--r--prog/mutation.go223
-rw-r--r--prog/mutation_test.go32
-rw-r--r--prog/prog.go363
-rw-r--r--prog/prog_test.go20
-rw-r--r--prog/rand.go120
-rw-r--r--prog/size.go58
-rw-r--r--prog/size_test.go8
-rw-r--r--prog/validation.go248
13 files changed, 921 insertions, 698 deletions
diff --git a/prog/analysis.go b/prog/analysis.go
index df9a0599d..a932cc253 100644
--- a/prog/analysis.go
+++ b/prog/analysis.go
@@ -21,7 +21,7 @@ const (
type state struct {
ct *ChoiceTable
files map[string]bool
- resources map[string][]*Arg
+ resources map[string][]Arg
strings map[string]bool
pages [maxPages]bool
}
@@ -42,27 +42,28 @@ func newState(ct *ChoiceTable) *state {
s := &state{
ct: ct,
files: make(map[string]bool),
- resources: make(map[string][]*Arg),
+ resources: make(map[string][]Arg),
strings: make(map[string]bool),
}
return s
}
func (s *state) analyze(c *Call) {
- foreachArgArray(&c.Args, c.Ret, func(arg, base *Arg, _ *[]*Arg) {
- switch typ := arg.Type.(type) {
+ foreachArgArray(&c.Args, c.Ret, func(arg, base Arg, _ *[]Arg) {
+ switch typ := arg.Type().(type) {
case *sys.ResourceType:
- if arg.Type.Dir() != sys.DirIn {
+ if arg.Type().Dir() != sys.DirIn {
s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], arg)
// TODO: negative PIDs and add them as well (that's process groups).
}
case *sys.BufferType:
- if arg.Type.Dir() != sys.DirOut && arg.Kind == ArgData && len(arg.Data) != 0 {
+ a := arg.(*DataArg)
+ if a.Type().Dir() != sys.DirOut && len(a.Data) != 0 {
switch typ.Kind {
case sys.BufferString:
- s.strings[string(arg.Data)] = true
+ s.strings[string(a.Data)] = true
case sys.BufferFilename:
- s.files[string(arg.Data)] = true
+ s.files[string(a.Data)] = true
}
}
}
@@ -70,74 +71,72 @@ func (s *state) analyze(c *Call) {
switch c.Meta.Name {
case "mmap":
// Filter out only very wrong arguments.
- length := c.Args[1]
- if length.AddrPage == 0 && length.AddrOffset == 0 {
+ length := c.Args[1].(*ConstArg)
+ if length.Val == 0 {
break
}
- if flags, fd := c.Args[4], c.Args[3]; flags.Val&sys.MAP_ANONYMOUS == 0 && fd.Kind == ArgConst && fd.Val == sys.InvalidFD {
+ flags := c.Args[3].(*ConstArg)
+ fd := c.Args[4].(*ResultArg)
+ if flags.Val&sys.MAP_ANONYMOUS == 0 && fd.Val == sys.InvalidFD {
break
}
- s.addressable(c.Args[0], length, true)
+ s.addressable(c.Args[0].(*PointerArg), length, true)
case "munmap":
- s.addressable(c.Args[0], c.Args[1], false)
+ s.addressable(c.Args[0].(*PointerArg), c.Args[1].(*ConstArg), false)
case "mremap":
- s.addressable(c.Args[4], c.Args[2], true)
+ s.addressable(c.Args[4].(*PointerArg), c.Args[2].(*ConstArg), true)
case "io_submit":
- if arr := c.Args[2].Res; arr != nil {
- for _, ptr := range arr.Inner {
- if ptr.Kind == ArgPointer {
- if ptr.Res != nil && ptr.Res.Type.Name() == "iocb" {
- s.resources["iocbptr"] = append(s.resources["iocbptr"], ptr)
- }
+ if arr := c.Args[2].(*PointerArg).Res; arr != nil {
+ for _, ptr := range arr.(*GroupArg).Inner {
+ p := ptr.(*PointerArg)
+ if p.Res != nil && p.Res.Type().Name() == "iocb" {
+ s.resources["iocbptr"] = append(s.resources["iocbptr"], ptr)
}
}
}
}
}
-func (s *state) addressable(addr, size *Arg, ok bool) {
- if addr.Kind != ArgPointer || size.Kind != ArgPageSize {
- panic("mmap/munmap/mremap args are not pages")
+func (s *state) addressable(addr *PointerArg, size *ConstArg, ok bool) {
+ sizePages := size.Val / pageSize
+ if addr.PageIndex+sizePages > uintptr(len(s.pages)) {
+ panic(fmt.Sprintf("address is out of bounds: page=%v len=%v bound=%v\naddr: %+v\nsize: %+v",
+ addr.PageIndex, sizePages, len(s.pages), addr, size))
}
- n := size.AddrPage
- if size.AddrOffset != 0 {
- n++
- }
- if addr.AddrPage+n > uintptr(len(s.pages)) {
- panic(fmt.Sprintf("address is out of bounds: page=%v len=%v (%v, %v) bound=%v, addr: %+v, size: %+v",
- addr.AddrPage, n, size.AddrPage, size.AddrOffset, len(s.pages), addr, size))
- }
- for i := uintptr(0); i < n; i++ {
- s.pages[addr.AddrPage+i] = ok
+ for i := uintptr(0); i < sizePages; i++ {
+ s.pages[addr.PageIndex+i] = ok
}
}
-func foreachSubargImpl(arg *Arg, parent *[]*Arg, f func(arg, base *Arg, parent *[]*Arg)) {
- var rec func(arg, base *Arg, parent *[]*Arg)
- rec = func(arg, base *Arg, parent *[]*Arg) {
+func foreachSubargImpl(arg Arg, parent *[]Arg, f func(arg, base Arg, parent *[]Arg)) {
+ var rec func(arg, base Arg, parent *[]Arg)
+ rec = func(arg, base Arg, parent *[]Arg) {
f(arg, base, parent)
- for _, arg1 := range arg.Inner {
- parent1 := parent
- if _, ok := arg.Type.(*sys.StructType); ok {
- parent1 = &arg.Inner
+ switch a := arg.(type) {
+ case *GroupArg:
+ for _, arg1 := range a.Inner {
+ parent1 := parent
+ if _, ok := arg.Type().(*sys.StructType); ok {
+ parent1 = &a.Inner
+ }
+ rec(arg1, base, parent1)
}
- rec(arg1, base, parent1)
- }
- if arg.Kind == ArgPointer && arg.Res != nil {
- rec(arg.Res, arg, parent)
- }
- if arg.Kind == ArgUnion {
- rec(arg.Option, base, parent)
+ case *PointerArg:
+ if a.Res != nil {
+ rec(a.Res, arg, parent)
+ }
+ case *UnionArg:
+ rec(a.Option, base, parent)
}
}
rec(arg, nil, parent)
}
-func foreachSubarg(arg *Arg, f func(arg, base *Arg, parent *[]*Arg)) {
+func foreachSubarg(arg Arg, f func(arg, base Arg, parent *[]Arg)) {
foreachSubargImpl(arg, nil, f)
}
-func foreachArgArray(args *[]*Arg, ret *Arg, f func(arg, base *Arg, parent *[]*Arg)) {
+func foreachArgArray(args *[]Arg, ret Arg, f func(arg, base Arg, parent *[]Arg)) {
for _, arg := range *args {
foreachSubargImpl(arg, args, f)
}
@@ -146,20 +145,20 @@ func foreachArgArray(args *[]*Arg, ret *Arg, f func(arg, base *Arg, parent *[]*A
}
}
-func foreachArg(c *Call, f func(arg, base *Arg, parent *[]*Arg)) {
+func foreachArg(c *Call, f func(arg, base Arg, parent *[]Arg)) {
foreachArgArray(&c.Args, nil, f)
}
-func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) {
- var rec func(*Arg, uintptr) uintptr
- rec = func(arg1 *Arg, offset uintptr) uintptr {
- switch arg1.Kind {
- case ArgGroup:
+func foreachSubargOffset(arg Arg, f func(arg Arg, offset uintptr)) {
+ var rec func(Arg, uintptr) uintptr
+ rec = func(arg1 Arg, offset uintptr) uintptr {
+ switch a := arg1.(type) {
+ case *GroupArg:
f(arg1, offset)
var totalSize uintptr
- for _, arg2 := range arg1.Inner {
+ for _, arg2 := range a.Inner {
size := rec(arg2, offset)
- if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() {
+ if arg2.Type().BitfieldLength() == 0 || arg2.Type().BitfieldLast() {
offset += size
totalSize += size
}
@@ -167,12 +166,12 @@ func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) {
if totalSize > arg1.Size() {
panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1))
}
- case ArgUnion:
+ case *UnionArg:
f(arg1, offset)
- size := rec(arg1.Option, offset)
+ size := rec(a.Option, offset)
offset += size
if size > arg1.Size() {
- panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type))
+ panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type()))
}
default:
f(arg1, offset)
@@ -186,36 +185,36 @@ func sanitizeCall(c *Call) {
switch c.Meta.CallName {
case "mmap":
// Add MAP_FIXED flag, otherwise it produces non-deterministic results.
- addr := c.Args[0]
- if addr.Kind != ArgPointer {
+ _, ok := c.Args[0].(*PointerArg)
+ if !ok {
panic("mmap address is not ArgPointer")
}
- length := c.Args[1]
- if length.Kind != ArgPageSize {
+ _, ok = c.Args[1].(*ConstArg)
+ if !ok {
panic("mmap length is not ArgPageSize")
}
- flags := c.Args[3]
- if flags.Kind != ArgConst {
+ flags, ok := c.Args[3].(*ConstArg)
+ if !ok {
panic("mmap flag arg is not const")
}
flags.Val |= sys.MAP_FIXED
case "mremap":
// Add MREMAP_FIXED flag, otherwise it produces non-deterministic results.
- flags := c.Args[3]
- if flags.Kind != ArgConst {
+ flags, ok := c.Args[3].(*ConstArg)
+ if !ok {
panic("mremap flag arg is not const")
}
if flags.Val&sys.MREMAP_MAYMOVE != 0 {
flags.Val |= sys.MREMAP_FIXED
}
case "mknod", "mknodat":
- mode := c.Args[1]
- dev := c.Args[2]
+ mode, ok1 := c.Args[1].(*ConstArg)
+ dev, ok2 := c.Args[2].(*ConstArg)
if c.Meta.CallName == "mknodat" {
- mode = c.Args[2]
- dev = c.Args[3]
+ mode, ok1 = c.Args[2].(*ConstArg)
+ dev, ok2 = c.Args[3].(*ConstArg)
}
- if mode.Kind != ArgConst || dev.Kind != ArgConst {
+ if !ok1 || !ok2 {
panic("mknod mode is not const")
}
// Char and block devices read/write io ports, kernel memory and do other nasty things.
@@ -233,13 +232,13 @@ func sanitizeCall(c *Call) {
mode.Val |= sys.S_IFREG
}
case "syslog":
- cmd := c.Args[0]
+ cmd := c.Args[0].(*ConstArg)
// These disable console output, but we need it.
if cmd.Val == sys.SYSLOG_ACTION_CONSOLE_OFF || cmd.Val == sys.SYSLOG_ACTION_CONSOLE_ON {
cmd.Val = sys.SYSLOG_ACTION_SIZE_UNREAD
}
case "ioctl":
- cmd := c.Args[1]
+ cmd := c.Args[1].(*ConstArg)
// Freeze kills machine. Though, it is an interesting functions,
// so we need to test it somehow.
// TODO: not required if executor drops privileges.
@@ -247,13 +246,14 @@ func sanitizeCall(c *Call) {
cmd.Val = sys.FITHAW
}
case "ptrace":
+ req := c.Args[0].(*ConstArg)
// PTRACE_TRACEME leads to unkillable processes, see:
// https://groups.google.com/forum/#!topic/syzkaller/uGzwvhlCXAw
- if c.Args[0].Val == sys.PTRACE_TRACEME {
- c.Args[0].Val = ^uintptr(0)
+ if req.Val == sys.PTRACE_TRACEME {
+ req.Val = ^uintptr(0)
}
case "exit", "exit_group":
- code := c.Args[0]
+ code := c.Args[0].(*ConstArg)
// These codes are reserved by executor.
if code.Val%128 == 67 || code.Val%128 == 68 {
code.Val = 1
@@ -264,9 +264,9 @@ func sanitizeCall(c *Call) {
func RequiresBitmasks(p *Prog) bool {
result := false
for _, c := range p.Calls {
- foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
- if arg.Kind == ArgConst {
- if arg.Type.BitfieldOffset() != 0 || arg.Type.BitfieldLength() != 0 {
+ foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
+ if a, ok := arg.(*ConstArg); ok {
+ if a.Type().BitfieldOffset() != 0 || a.Type().BitfieldLength() != 0 {
result = true
}
}
@@ -278,8 +278,8 @@ func RequiresBitmasks(p *Prog) bool {
func RequiresChecksums(p *Prog) bool {
result := false
for _, c := range p.Calls {
- foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
- if _, ok := arg.Type.(*sys.CsumType); ok {
+ foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
+ if _, ok := arg.Type().(*sys.CsumType); ok {
result = true
}
})
diff --git a/prog/checksum.go b/prog/checksum.go
index 5db10d26f..361d83f6f 100644
--- a/prog/checksum.go
+++ b/prog/checksum.go
@@ -29,45 +29,45 @@ type CsumInfo struct {
type CsumChunk struct {
Kind CsumChunkKind
- Arg *Arg // for CsumChunkArg
+ Arg Arg // for CsumChunkArg
Value uintptr // for CsumChunkConst
Size uintptr // for CsumChunkConst
}
-func getFieldByName(arg *Arg, name string) *Arg {
- for _, field := range arg.Inner {
- if field.Type.FieldName() == name {
+func getFieldByName(arg Arg, name string) Arg {
+ for _, field := range arg.(*GroupArg).Inner {
+ if field.Type().FieldName() == name {
return field
}
}
- panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
+ panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name()))
}
-func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) {
+func extractHeaderParamsIPv4(arg Arg) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 4 {
- panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
+ panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 4 {
- panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
+ panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type().Name()))
}
return srcAddr, dstAddr
}
-func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) {
+func extractHeaderParamsIPv6(arg Arg) (Arg, Arg) {
srcAddr := getFieldByName(arg, "src_ip")
if srcAddr.Size() != 16 {
- panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
+ panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type().Name()))
}
dstAddr := getFieldByName(arg, "dst_ip")
if dstAddr.Size() != 16 {
- panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
+ panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type().Name()))
}
return srcAddr, dstAddr
}
-func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo {
+func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8, pid int) CsumInfo {
info := CsumInfo{Kind: CsumInet}
info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
@@ -77,7 +77,7 @@ func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid
return info
}
-func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo {
+func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8, pid int) CsumInfo {
info := CsumInfo{Kind: CsumInet}
info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0})
info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0})
@@ -87,7 +87,7 @@ func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid
return info
}
-func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg {
+func findCsummedArg(arg Arg, typ *sys.CsumType, parentsMap map[Arg]Arg) Arg {
if typ.Buf == "parent" {
if csummedArg, ok := parentsMap[arg]; ok {
return csummedArg
@@ -95,7 +95,7 @@ func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg
panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name()))
} else {
for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
- if typ.Buf == parent.Type.Name() {
+ if typ.Buf == parent.Type().Name() {
return parent
}
}
@@ -103,13 +103,13 @@ func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg
panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf))
}
-func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo {
- var inetCsumFields []*Arg
- var pseudoCsumFields []*Arg
+func calcChecksumsCall(c *Call, pid int) map[Arg]CsumInfo {
+ var inetCsumFields []Arg
+ var pseudoCsumFields []Arg
// Find all csum fields.
- foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
- if typ, ok := arg.Type.(*sys.CsumType); ok {
+ foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
+ if typ, ok := arg.Type().(*sys.CsumType); ok {
switch typ.Kind {
case sys.CsumInet:
inetCsumFields = append(inetCsumFields, arg)
@@ -127,20 +127,20 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo {
}
// Build map of each field to its parent struct.
- parentsMap := make(map[*Arg]*Arg)
- foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
- if _, ok := arg.Type.(*sys.StructType); ok {
- for _, field := range arg.Inner {
- parentsMap[field.InnerArg()] = arg
+ parentsMap := make(map[Arg]Arg)
+ foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
+ if _, ok := arg.Type().(*sys.StructType); ok {
+ for _, field := range arg.(*GroupArg).Inner {
+ parentsMap[InnerArg(field)] = arg
}
}
})
- csumMap := make(map[*Arg]CsumInfo)
+ csumMap := make(map[Arg]CsumInfo)
// Calculate generic inet checksums.
for _, arg := range inetCsumFields {
- typ, _ := arg.Type.(*sys.CsumType)
+ typ, _ := arg.Type().(*sys.CsumType)
csummedArg := findCsummedArg(arg, typ, parentsMap)
chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0}
info := CsumInfo{Kind: CsumInet, Chunks: make([]CsumChunk, 0)}
@@ -156,11 +156,11 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo {
// Extract ipv4 or ipv6 source and destination addresses.
ipv4HeaderParsed := false
ipv6HeaderParsed := false
- var ipSrcAddr *Arg
- var ipDstAddr *Arg
- foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
+ var ipSrcAddr Arg
+ var ipDstAddr Arg
+ foreachArgArray(&c.Args, nil, func(arg, base Arg, _ *[]Arg) {
// syz_csum_* structs are used in tests
- switch arg.Type.Name() {
+ switch arg.Type().Name() {
case "ipv4_header", "syz_csum_ipv4_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg)
ipv4HeaderParsed = true
@@ -175,7 +175,7 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo {
// Calculate pseudo checksums.
for _, arg := range pseudoCsumFields {
- typ, _ := arg.Type.(*sys.CsumType)
+ typ, _ := arg.Type().(*sys.CsumType)
csummedArg := findCsummedArg(arg, typ, parentsMap)
protocol := uint8(typ.Protocol)
var info CsumInfo
diff --git a/prog/clone.go b/prog/clone.go
index c085b765f..fcd651845 100644
--- a/prog/clone.go
+++ b/prog/clone.go
@@ -1,17 +1,17 @@
-// Copyright 2015 syzkaller project authors. All rights reserved.
+// Copyright 2017 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
func (p *Prog) Clone() *Prog {
p1 := new(Prog)
- newargs := make(map[*Arg]*Arg)
+ newargs := make(map[Arg]Arg)
for _, c := range p.Calls {
c1 := new(Call)
c1.Meta = c.Meta
- c1.Ret = c.Ret.clone(c1, newargs)
+ c1.Ret = clone(c.Ret, newargs)
for _, arg := range c.Args {
- c1.Args = append(c1.Args, arg.clone(c1, newargs))
+ c1.Args = append(c1.Args, clone(arg, newargs))
}
p1.Calls = append(p1.Calls, c1)
}
@@ -23,31 +23,60 @@ func (p *Prog) Clone() *Prog {
return p1
}
-func (arg *Arg) clone(c *Call, newargs map[*Arg]*Arg) *Arg {
- arg1 := new(Arg)
- *arg1 = *arg
- arg1.Data = append([]byte{}, arg.Data...)
- switch arg.Kind {
- case ArgPointer:
- if arg.Res != nil {
- arg1.Res = arg.Res.clone(c, newargs)
+func clone(arg Arg, newargs map[Arg]Arg) Arg {
+ var arg1 Arg
+ switch a := arg.(type) {
+ case *ConstArg:
+ a1 := new(ConstArg)
+ *a1 = *a
+ arg1 = a1
+ case *PointerArg:
+ a1 := new(PointerArg)
+ *a1 = *a
+ arg1 = a1
+ if a.Res != nil {
+ a1.Res = clone(a.Res, newargs)
}
- case ArgUnion:
- arg1.Option = arg.Option.clone(c, newargs)
- case ArgResult:
- r := newargs[arg.Res]
- arg1.Res = r
- if r.Uses == nil {
- r.Uses = make(map[*Arg]bool)
+ case *DataArg:
+ a1 := new(DataArg)
+ *a1 = *a
+ a1.Data = append([]byte{}, a.Data...)
+ arg1 = a1
+ case *GroupArg:
+ a1 := new(GroupArg)
+ *a1 = *a
+ arg1 = a1
+ a1.Inner = nil
+ for _, arg2 := range a.Inner {
+ a1.Inner = append(a1.Inner, clone(arg2, newargs))
}
- r.Uses[arg1] = true
+ case *UnionArg:
+ a1 := new(UnionArg)
+ *a1 = *a
+ arg1 = a1
+ a1.Option = clone(a.Option, newargs)
+ case *ResultArg:
+ a1 := new(ResultArg)
+ *a1 = *a
+ arg1 = a1
+ case *ReturnArg:
+ a1 := new(ReturnArg)
+ *a1 = *a
+ arg1 = a1
+ default:
+ panic("bad arg kind")
}
- arg1.Inner = nil
- for _, arg2 := range arg.Inner {
- arg1.Inner = append(arg1.Inner, arg2.clone(c, newargs))
+ if user, ok := arg1.(ArgUser); ok && *user.Uses() != nil {
+ r := newargs[*user.Uses()]
+ *user.Uses() = r
+ used := r.(ArgUsed)
+ if *used.Used() == nil {
+ *used.Used() = make(map[Arg]bool)
+ }
+ (*used.Used())[arg1] = true
}
- if len(arg1.Uses) != 0 {
- arg1.Uses = nil // filled when we clone the referent
+ if used, ok := arg1.(ArgUsed); ok {
+ *used.Used() = nil // filled when we clone the referent
newargs[arg] = arg1
}
return arg1
diff --git a/prog/encoding.go b/prog/encoding.go
index 151ff1935..93af09457 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -33,64 +33,54 @@ func (p *Prog) Serialize() []byte {
}
}
buf := new(bytes.Buffer)
- vars := make(map[*Arg]int)
+ vars := make(map[Arg]int)
varSeq := 0
for _, c := range p.Calls {
- if len(c.Ret.Uses) != 0 {
+ if len(*c.Ret.(ArgUsed).Used()) != 0 {
fmt.Fprintf(buf, "r%v = ", varSeq)
vars[c.Ret] = varSeq
varSeq++
}
fmt.Fprintf(buf, "%v(", c.Meta.Name)
for i, a := range c.Args {
- if sys.IsPad(a.Type) {
+ if sys.IsPad(a.Type()) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
- a.serialize(buf, vars, &varSeq)
+ serialize(a, buf, vars, &varSeq)
}
fmt.Fprintf(buf, ")\n")
}
return buf.Bytes()
}
-func (a *Arg) serialize(buf io.Writer, vars map[*Arg]int, varSeq *int) {
- if a == nil {
+func serialize(arg Arg, buf io.Writer, vars map[Arg]int, varSeq *int) {
+ if arg == nil {
fmt.Fprintf(buf, "nil")
return
}
- if len(a.Uses) != 0 {
+ if used, ok := arg.(ArgUsed); ok && len(*used.Used()) != 0 {
fmt.Fprintf(buf, "<r%v=>", *varSeq)
- vars[a] = *varSeq
+ vars[arg] = *varSeq
*varSeq++
}
- switch a.Kind {
- case ArgConst:
+ switch a := arg.(type) {
+ case *ConstArg:
fmt.Fprintf(buf, "0x%x", a.Val)
- case ArgResult:
- id, ok := vars[a.Res]
- if !ok {
- panic("no result")
- }
- fmt.Fprintf(buf, "r%v", id)
- if a.OpDiv != 0 {
- fmt.Fprintf(buf, "/%v", a.OpDiv)
- }
- if a.OpAdd != 0 {
- fmt.Fprintf(buf, "+%v", a.OpAdd)
- }
- case ArgPointer:
- fmt.Fprintf(buf, "&%v=", serializeAddr(a, true))
- a.Res.serialize(buf, vars, varSeq)
- case ArgPageSize:
- fmt.Fprintf(buf, "%v", serializeAddr(a, false))
- case ArgData:
+ case *PointerArg:
+ if a.Res == nil && a.PagesNum == 0 {
+ fmt.Fprintf(buf, "0x0")
+ break
+ }
+ fmt.Fprintf(buf, "&%v=", serializeAddr(arg))
+ serialize(a.Res, buf, vars, varSeq)
+ case *DataArg:
fmt.Fprintf(buf, "\"%v\"", hex.EncodeToString(a.Data))
- case ArgGroup:
+ case *GroupArg:
var delims []byte
- switch a.Type.(type) {
+ switch arg.Type().(type) {
case *sys.StructType:
delims = []byte{'{', '}'}
case *sys.ArrayType:
@@ -99,19 +89,35 @@ func (a *Arg) serialize(buf io.Writer, vars map[*Arg]int, varSeq *int) {
panic("unknown group type")
}
buf.Write([]byte{delims[0]})
- for i, a1 := range a.Inner {
- if a1 != nil && sys.IsPad(a1.Type) {
+ for i, arg1 := range a.Inner {
+ if arg1 != nil && sys.IsPad(arg1.Type()) {
continue
}
if i != 0 {
fmt.Fprintf(buf, ", ")
}
- a1.serialize(buf, vars, varSeq)
+ serialize(arg1, buf, vars, varSeq)
}
buf.Write([]byte{delims[1]})
- case ArgUnion:
+ case *UnionArg:
fmt.Fprintf(buf, "@%v=", a.OptionType.FieldName())
- a.Option.serialize(buf, vars, varSeq)
+ serialize(a.Option, buf, vars, varSeq)
+ case *ResultArg:
+ if a.Res == nil {
+ fmt.Fprintf(buf, "0x%x", a.Val)
+ break
+ }
+ id, ok := vars[a.Res]
+ if !ok {
+ panic("no result")
+ }
+ fmt.Fprintf(buf, "r%v", id)
+ if a.OpDiv != 0 {
+ fmt.Fprintf(buf, "/%v", a.OpDiv)
+ }
+ if a.OpAdd != 0 {
+ fmt.Fprintf(buf, "+%v", a.OpAdd)
+ }
default:
panic("unknown arg kind")
}
@@ -121,7 +127,7 @@ func Deserialize(data []byte) (prog *Prog, err error) {
prog = new(Prog)
p := &parser{r: bufio.NewScanner(bytes.NewReader(data))}
p.r.Buffer(nil, maxLineLen)
- vars := make(map[string]*Arg)
+ vars := make(map[string]Arg)
for p.Scan() {
if p.EOF() || p.Char() == '#' {
continue
@@ -184,7 +190,7 @@ func Deserialize(data []byte) (prog *Prog, err error) {
return
}
-func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
+func parseArg(typ sys.Type, p *parser, vars map[string]Arg) (Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
@@ -192,7 +198,7 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
p.Parse('=')
p.Parse('>')
}
- var arg *Arg
+ var arg Arg
switch p.Char() {
case '0':
val := p.Ident()
@@ -200,14 +206,25 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
if err != nil {
return nil, fmt.Errorf("wrong arg value '%v': %v", val, err)
}
- arg = constArg(typ, uintptr(v))
+ switch typ.(type) {
+ case *sys.ConstType, *sys.IntType, *sys.FlagsType, *sys.ProcType, *sys.LenType, *sys.CsumType:
+ arg = constArg(typ, uintptr(v))
+ case *sys.ResourceType:
+ arg = resultArg(typ, nil, uintptr(v))
+ case *sys.PtrType:
+ arg = pointerArg(typ, 0, 0, 0, nil)
+ case *sys.VmaType:
+ arg = pointerArg(typ, 0, 0, 0, nil)
+ default:
+ panic(fmt.Sprintf("bad const type %+v", typ))
+ }
case 'r':
id := p.Ident()
v, ok := vars[id]
if !ok || v == nil {
return nil, fmt.Errorf("result %v references unknown variable (vars=%+v)", id, vars)
}
- arg = resultArg(typ, v)
+ arg = resultArg(typ, v, 0)
if p.Char() == '/' {
p.Parse('/')
op := p.Ident()
@@ -215,7 +232,7 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
if err != nil {
return nil, fmt.Errorf("wrong result div op: '%v'", op)
}
- arg.OpDiv = uintptr(v)
+ arg.(*ResultArg).OpDiv = uintptr(v)
}
if p.Char() == '+' {
p.Parse('+')
@@ -224,7 +241,7 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
if err != nil {
return nil, fmt.Errorf("wrong result add op: '%v'", op)
}
- arg.OpAdd = uintptr(v)
+ arg.(*ResultArg).OpAdd = uintptr(v)
}
case '&':
var typ1 sys.Type
@@ -247,11 +264,13 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
}
arg = pointerArg(typ, page, off, size, inner)
case '(':
- page, off, _, err := parseAddr(p, false)
+ // This used to parse length of VmaType and return ArgPageSize, which is now removed.
+ // Leaving this for now for backwards compatibility.
+ pages, _, _, err := parseAddr(p, false)
if err != nil {
return nil, err
}
- arg = pageSizeArg(typ, page, off)
+ arg = constArg(typ, pages*pageSize)
case '"':
p.Parse('"')
val := ""
@@ -270,7 +289,7 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
return nil, fmt.Errorf("'{' arg is not a struct: %#v", typ)
}
p.Parse('{')
- var inner []*Arg
+ var inner []Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
return nil, fmt.Errorf("wrong struct arg count: %v, want %v", i+1, len(t1.Fields))
@@ -300,7 +319,7 @@ func parseArg(typ sys.Type, p *parser, vars map[string]*Arg) (*Arg, error) {
return nil, fmt.Errorf("'[' arg is not an array: %#v", typ)
}
p.Parse('[')
- var inner []*Arg
+ var inner []Arg
for i := 0; p.Char() != ']'; i++ {
arg, err := parseArg(t1.Type, p, vars)
if err != nil {
@@ -358,13 +377,22 @@ const (
maxLineLen = 256 << 10
)
-func serializeAddr(a *Arg, base bool) string {
- page := a.AddrPage * encodingPageSize
- if base {
- page += encodingAddrBase
+func serializeAddr(arg Arg) string {
+ var pageIndex uintptr
+ var pageOffset int
+ var pagesNum uintptr
+ switch a := arg.(type) {
+ case *PointerArg:
+ pageIndex = a.PageIndex
+ pageOffset = a.PageOffset
+ pagesNum = a.PagesNum
+ default:
+ panic("bad addr arg")
}
+ page := pageIndex * encodingPageSize
+ page += encodingAddrBase
soff := ""
- if off := a.AddrOffset; off != 0 {
+ if off := pageOffset; off != 0 {
sign := "+"
if off < 0 {
sign = "-"
@@ -374,7 +402,7 @@ func serializeAddr(a *Arg, base bool) string {
soff = fmt.Sprintf("%v0x%x", sign, off)
}
ssize := ""
- if size := a.AddrPagesNum; size != 0 {
+ if size := pagesNum; size != 0 {
size *= encodingPageSize
ssize = fmt.Sprintf("/0x%x", size)
}
diff --git a/prog/encodingexec.go b/prog/encodingexec.go
index 6695836ca..ba2efcd37 100644
--- a/prog/encodingexec.go
+++ b/prog/encodingexec.go
@@ -43,7 +43,7 @@ const (
dataOffset = 512 << 20
)
-type Args []*Arg
+type Args []Arg
func (s Args) Len() int {
return len(s)
@@ -74,14 +74,14 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
w := &execContext{
buf: buffer,
eof: false,
- args: make(map[*Arg]argInfo),
+ args: make(map[Arg]argInfo),
}
for _, c := range p.Calls {
// Calculate checksums.
csumMap := calcChecksumsCall(c, pid)
- var csumUses map[*Arg]bool
+ var csumUses map[Arg]bool
if csumMap != nil {
- csumUses = make(map[*Arg]bool)
+ csumUses = make(map[Arg]bool)
for arg, info := range csumMap {
csumUses[arg] = true
if info.Kind == CsumInet {
@@ -95,18 +95,23 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
}
// Calculate arg offsets within structs.
// Generate copyin instructions that fill in data into pointer arguments.
- foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
- if arg.Kind == ArgPointer && arg.Res != nil {
- foreachSubargOffset(arg.Res, func(arg1 *Arg, offset uintptr) {
- if len(arg1.Uses) != 0 || csumUses[arg1] {
+ foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
+ if a, ok := arg.(*PointerArg); ok && a.Res != nil {
+ foreachSubargOffset(a.Res, func(arg1 Arg, offset uintptr) {
+ used, ok := arg1.(ArgUsed)
+ if (ok && len(*used.Used()) != 0) || csumUses[arg1] {
w.args[arg1] = argInfo{Addr: physicalAddr(arg) + offset}
}
- if arg1.Kind == ArgGroup || arg1.Kind == ArgUnion {
+ if _, ok := arg1.(*GroupArg); ok {
return
}
- if !sys.IsPad(arg1.Type) &&
- !(arg1.Kind == ArgData && len(arg1.Data) == 0) &&
- arg1.Type.Dir() != sys.DirOut {
+ if _, ok := arg1.(*UnionArg); ok {
+ return
+ }
+ if a1, ok := arg1.(*DataArg); ok && len(a1.Data) == 0 {
+ return
+ }
+ if !sys.IsPad(arg1.Type()) && arg1.Type().Dir() != sys.DirOut {
w.write(ExecInstrCopyin)
w.write(physicalAddr(arg) + offset)
w.writeArg(arg1, pid, csumMap)
@@ -118,14 +123,14 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
// Generate checksum calculation instructions starting from the last one,
// since checksum values can depend on values of the latter ones
if csumMap != nil {
- var csumArgs []*Arg
+ var csumArgs []Arg
for arg, _ := range csumMap {
csumArgs = append(csumArgs, arg)
}
sort.Sort(ByPhysicalAddr{Args: csumArgs, Context: w})
for i := len(csumArgs) - 1; i >= 0; i-- {
arg := csumArgs[i]
- if _, ok := arg.Type.(*sys.CsumType); !ok {
+ if _, ok := arg.Type().(*sys.CsumType); !ok {
panic("csum arg is not csum type")
}
w.write(ExecInstrCopyin)
@@ -162,21 +167,21 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
for _, arg := range c.Args {
w.writeArg(arg, pid, csumMap)
}
- if len(c.Ret.Uses) != 0 {
+ if len(*c.Ret.(ArgUsed).Used()) != 0 {
w.args[c.Ret] = argInfo{Idx: instrSeq}
}
instrSeq++
// Generate copyout instructions that persist interesting return values.
- foreachArg(c, func(arg, base *Arg, _ *[]*Arg) {
- if len(arg.Uses) == 0 {
+ foreachArg(c, func(arg, base Arg, _ *[]Arg) {
+ if used, ok := arg.(ArgUsed); !ok || len(*used.Used()) == 0 {
return
}
- switch arg.Kind {
- case ArgReturn:
+ switch arg.(type) {
+ case *ReturnArg:
// Idx is already assigned above.
- case ArgConst, ArgResult:
+ case *ConstArg, *ResultArg:
// Create a separate copyout instruction that has own Idx.
- if base.Kind != ArgPointer {
+ if _, ok := base.(*PointerArg); !ok {
panic("arg base is not a pointer")
}
info := w.args[arg]
@@ -198,15 +203,16 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
return nil
}
-func physicalAddr(arg *Arg) uintptr {
- if arg.Kind != ArgPointer {
+func physicalAddr(arg Arg) uintptr {
+ a, ok := arg.(*PointerArg)
+ if !ok {
panic("physicalAddr: bad arg kind")
}
- addr := arg.AddrPage*pageSize + dataOffset
- if arg.AddrOffset >= 0 {
- addr += uintptr(arg.AddrOffset)
+ addr := a.PageIndex*pageSize + dataOffset
+ if a.PageOffset >= 0 {
+ addr += uintptr(a.PageOffset)
} else {
- addr += pageSize - uintptr(-arg.AddrOffset)
+ addr += pageSize - uintptr(-a.PageOffset)
}
return addr
}
@@ -214,7 +220,7 @@ func physicalAddr(arg *Arg) uintptr {
type execContext struct {
buf []byte
eof bool
- args map[*Arg]argInfo
+ args map[Arg]argInfo
}
type argInfo struct {
@@ -238,43 +244,37 @@ func (w *execContext) write(v uintptr) {
w.buf = w.buf[8:]
}
-func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]CsumInfo) {
- switch arg.Kind {
- case ArgConst:
+func (w *execContext) writeArg(arg Arg, pid int, csumMap map[Arg]CsumInfo) {
+ switch a := arg.(type) {
+ case *ConstArg:
w.write(ExecArgConst)
- w.write(arg.Size())
- w.write(arg.Value(pid))
- w.write(arg.Type.BitfieldOffset())
- w.write(arg.Type.BitfieldLength())
- case ArgResult:
+ w.write(a.Size())
+ w.write(a.Value(pid))
+ w.write(a.Type().BitfieldOffset())
+ w.write(a.Type().BitfieldLength())
+ case *ResultArg:
w.write(ExecArgResult)
- w.write(arg.Size())
- w.write(w.args[arg.Res].Idx)
- w.write(arg.OpDiv)
- w.write(arg.OpAdd)
- case ArgPointer:
+ w.write(a.Size())
+ w.write(w.args[a.Res].Idx)
+ w.write(a.OpDiv)
+ w.write(a.OpAdd)
+ case *PointerArg:
w.write(ExecArgConst)
- w.write(arg.Size())
+ w.write(a.Size())
w.write(physicalAddr(arg))
w.write(0) // bit field offset
w.write(0) // bit field length
- case ArgPageSize:
- w.write(ExecArgConst)
- w.write(arg.Size())
- w.write(arg.AddrPage * pageSize)
- w.write(0) // bit field offset
- w.write(0) // bit field length
- case ArgData:
+ case *DataArg:
w.write(ExecArgData)
- w.write(uintptr(len(arg.Data)))
- padded := len(arg.Data)
- if pad := 8 - len(arg.Data)%8; pad != 8 {
+ w.write(uintptr(len(a.Data)))
+ padded := len(a.Data)
+ if pad := 8 - len(a.Data)%8; pad != 8 {
padded += pad
}
if len(w.buf) < padded {
w.eof = true
} else {
- copy(w.buf, arg.Data)
+ copy(w.buf, a.Data)
w.buf = w.buf[padded:]
}
default:
diff --git a/prog/mutation.go b/prog/mutation.go
index cbf1f53d5..d6a9380ae 100644
--- a/prog/mutation.go
+++ b/prog/mutation.go
@@ -68,90 +68,84 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro
arg, base := args[idx], bases[idx]
var baseSize uintptr
if base != nil {
- if base.Kind != ArgPointer || base.Res == nil {
+ b, ok := base.(*PointerArg)
+ if !ok || b.Res == nil {
panic("bad base arg")
}
- baseSize = base.Res.Size()
+ baseSize = b.Res.Size()
}
- switch a := arg.Type.(type) {
+ switch t := arg.Type().(type) {
case *sys.IntType, *sys.FlagsType:
+ a := arg.(*ConstArg)
if r.bin() {
- arg1, calls1 := r.generateArg(s, arg.Type)
+ arg1, calls1 := r.generateArg(s, arg.Type())
p.replaceArg(c, arg, arg1, calls1)
} else {
switch {
case r.nOutOf(1, 3):
- arg.Val += uintptr(r.Intn(4)) + 1
+ a.Val += uintptr(r.Intn(4)) + 1
case r.nOutOf(1, 2):
- arg.Val -= uintptr(r.Intn(4)) + 1
+ a.Val -= uintptr(r.Intn(4)) + 1
default:
- arg.Val ^= 1 << uintptr(r.Intn(64))
+ a.Val ^= 1 << uintptr(r.Intn(64))
}
}
case *sys.ResourceType, *sys.VmaType, *sys.ProcType:
- arg1, calls1 := r.generateArg(s, arg.Type)
+ arg1, calls1 := r.generateArg(s, arg.Type())
p.replaceArg(c, arg, arg1, calls1)
case *sys.BufferType:
- switch a.Kind {
+ a := arg.(*DataArg)
+ switch t.Kind {
case sys.BufferBlobRand, sys.BufferBlobRange:
var data []byte
- switch arg.Kind {
- case ArgData:
- data = append([]byte{}, arg.Data...)
- case ArgConst:
- // 0 is OK for optional args.
- if arg.Val != 0 {
- panic(fmt.Sprintf("BufferType has non-zero const value: %v", arg.Val))
- }
- default:
- panic(fmt.Sprintf("bad arg kind for BufferType: %v", arg.Kind))
- }
+ data = append([]byte{}, a.Data...)
minLen := int(0)
maxLen := math.MaxInt32
- if a.Kind == sys.BufferBlobRange {
- minLen = int(a.RangeBegin)
- maxLen = int(a.RangeEnd)
+ if t.Kind == sys.BufferBlobRange {
+ minLen = int(t.RangeBegin)
+ maxLen = int(t.RangeEnd)
}
- arg.Data = mutateData(r, data, minLen, maxLen)
+ a.Data = mutateData(r, data, minLen, maxLen)
case sys.BufferString:
if r.bin() {
minLen := int(0)
maxLen := math.MaxInt32
- if a.Length != 0 {
- minLen = int(a.Length)
- maxLen = int(a.Length)
+ if t.Length != 0 {
+ minLen = int(t.Length)
+ maxLen = int(t.Length)
}
- arg.Data = mutateData(r, append([]byte{}, arg.Data...), minLen, maxLen)
+ a.Data = mutateData(r, append([]byte{}, a.Data...), minLen, maxLen)
} else {
- arg.Data = r.randString(s, a.Values, a.Dir())
+ a.Data = r.randString(s, t.Values, t.Dir())
}
case sys.BufferFilename:
- arg.Data = []byte(r.filename(s))
+ a.Data = []byte(r.filename(s))
case sys.BufferText:
- arg.Data = r.mutateText(a.Text, arg.Data)
+ a.Data = r.mutateText(t.Text, a.Data)
default:
panic("unknown buffer kind")
}
case *sys.ArrayType:
+ a := arg.(*GroupArg)
count := uintptr(0)
- switch a.Kind {
+ switch t.Kind {
case sys.ArrayRandLen:
- for count == uintptr(len(arg.Inner)) {
+ for count == uintptr(len(a.Inner)) {
count = r.randArrayLen()
}
case sys.ArrayRangeLen:
- if a.RangeBegin == a.RangeEnd {
+ if t.RangeBegin == t.RangeEnd {
panic("trying to mutate fixed length array")
}
- for count == uintptr(len(arg.Inner)) {
- count = r.randRange(int(a.RangeBegin), int(a.RangeEnd))
+ for count == uintptr(len(a.Inner)) {
+ count = r.randRange(int(t.RangeBegin), int(t.RangeEnd))
}
}
- if count > uintptr(len(arg.Inner)) {
+ if count > uintptr(len(a.Inner)) {
var calls []*Call
- for count > uintptr(len(arg.Inner)) {
- arg1, calls1 := r.generateArg(s, a.Type)
- arg.Inner = append(arg.Inner, arg1)
+ for count > uintptr(len(a.Inner)) {
+ arg1, calls1 := r.generateArg(s, t.Type)
+ a.Inner = append(a.Inner, arg1)
for _, c1 := range calls1 {
calls = append(calls, c1)
s.analyze(c1)
@@ -162,43 +156,48 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro
}
sanitizeCall(c)
p.insertBefore(c, calls)
- } else if count < uintptr(len(arg.Inner)) {
- for _, arg := range arg.Inner[count:] {
+ } else if count < uintptr(len(a.Inner)) {
+ for _, arg := range a.Inner[count:] {
p.removeArg(c, arg)
}
- arg.Inner = arg.Inner[:count]
+ a.Inner = a.Inner[:count]
}
// TODO: swap elements of the array
case *sys.PtrType:
+ a, ok := arg.(*PointerArg)
+ if !ok {
+ break
+ }
// TODO: we don't know size for out args
size := uintptr(1)
- if arg.Res != nil {
- size = arg.Res.Size()
+ if a.Res != nil {
+ size = a.Res.Size()
}
- arg1, calls1 := r.addr(s, a, size, arg.Res)
+ arg1, calls1 := r.addr(s, t, size, a.Res)
p.replaceArg(c, arg, arg1, calls1)
case *sys.StructType:
- ctor := isSpecialStruct(a)
+ ctor := isSpecialStruct(t)
if ctor == nil {
panic("bad arg returned by mutationArgs: StructType")
}
arg1, calls1 := ctor(r, s)
- for i, f := range arg1.Inner {
- p.replaceArg(c, arg.Inner[i], f, calls1)
+ for i, f := range arg1.(*GroupArg).Inner {
+ p.replaceArg(c, arg.(*GroupArg).Inner[i], f, calls1)
calls1 = nil
}
case *sys.UnionType:
- optType := a.Options[r.Intn(len(a.Options))]
+ a := arg.(*UnionArg)
+ optType := t.Options[r.Intn(len(t.Options))]
maxIters := 1000
- for i := 0; optType.FieldName() == arg.OptionType.FieldName(); i++ {
- optType = a.Options[r.Intn(len(a.Options))]
+ for i := 0; optType.FieldName() == a.OptionType.FieldName(); i++ {
+ optType = t.Options[r.Intn(len(t.Options))]
if i >= maxIters {
- panic(fmt.Sprintf("couldn't generate a different union option after %v iterations, type: %+v", maxIters, a))
+ panic(fmt.Sprintf("couldn't generate a different union option after %v iterations, type: %+v", maxIters, t))
}
}
- p.removeArg(c, arg.Option)
+ p.removeArg(c, a.Option)
opt, calls := r.generateArg(s, optType)
- arg1 := unionArg(a, opt, optType)
+ arg1 := unionArg(t, opt, optType)
p.replaceArg(c, arg, arg1, calls)
case *sys.LenType:
panic("bad arg returned by mutationArgs: LenType")
@@ -207,19 +206,23 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro
case *sys.ConstType:
panic("bad arg returned by mutationArgs: ConstType")
default:
- panic(fmt.Sprintf("bad arg returned by mutationArgs: %#v, type=%#v", *arg, arg.Type))
+ panic(fmt.Sprintf("bad arg returned by mutationArgs: %#v, type=%#v", arg, arg.Type()))
}
// Update base pointer if size has increased.
- if base != nil && baseSize < base.Res.Size() {
- arg1, calls1 := r.addr(s, base.Type, base.Res.Size(), base.Res)
- for _, c1 := range calls1 {
- sanitizeCall(c1)
+ if base != nil {
+ b := base.(*PointerArg)
+ if baseSize < b.Res.Size() {
+ arg1, calls1 := r.addr(s, b.Type(), b.Res.Size(), b.Res)
+ for _, c1 := range calls1 {
+ sanitizeCall(c1)
+ }
+ p.insertBefore(c, calls1)
+ a1 := arg1.(*PointerArg)
+ b.PageIndex = a1.PageIndex
+ b.PageOffset = a1.PageOffset
+ b.PagesNum = a1.PagesNum
}
- p.insertBefore(c, calls1)
- arg.AddrPage = arg1.AddrPage
- arg.AddrOffset = arg1.AddrOffset
- arg.AddrPagesNum = arg1.AddrPagesNum
}
// Update all len fields.
@@ -313,33 +316,41 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
var triedPaths map[string]bool
- var rec func(p *Prog, call *Call, arg *Arg, path string) bool
- rec = func(p *Prog, call *Call, arg *Arg, path string) bool {
- path += fmt.Sprintf("-%v", arg.Type.FieldName())
- switch typ := arg.Type.(type) {
+ var rec func(p *Prog, call *Call, arg Arg, path string) bool
+ rec = func(p *Prog, call *Call, arg Arg, path string) bool {
+ path += fmt.Sprintf("-%v", arg.Type().FieldName())
+ switch typ := arg.Type().(type) {
case *sys.StructType:
- for _, innerArg := range arg.Inner {
+ a := arg.(*GroupArg)
+ for _, innerArg := range a.Inner {
if rec(p, call, innerArg, path) {
return true
}
}
case *sys.UnionType:
- if rec(p, call, arg.Option, path) {
+ a := arg.(*UnionArg)
+ if rec(p, call, a.Option, path) {
return true
}
case *sys.PtrType:
// TODO: try to remove optional ptrs
- if arg.Res != nil {
- return rec(p, call, arg.Res, path)
+ a, ok := arg.(*PointerArg)
+ if !ok {
+ // Can also be *ConstArg.
+ return false
+ }
+ if a.Res != nil {
+ return rec(p, call, a.Res, path)
}
case *sys.ArrayType:
- for i, innerArg := range arg.Inner {
+ a := arg.(*GroupArg)
+ for i, innerArg := range a.Inner {
innerPath := fmt.Sprintf("%v-%v", path, i)
if !triedPaths[innerPath] && !crash {
- if (typ.Kind == sys.ArrayRangeLen && len(arg.Inner) > int(typ.RangeBegin)) ||
+ if (typ.Kind == sys.ArrayRangeLen && len(a.Inner) > int(typ.RangeBegin)) ||
(typ.Kind == sys.ArrayRandLen) {
- copy(arg.Inner[i:], arg.Inner[i+1:])
- arg.Inner = arg.Inner[:len(arg.Inner)-1]
+ copy(a.Inner[i:], a.Inner[i+1:])
+ a.Inner = a.Inner[:len(a.Inner)-1]
p.removeArg(call, innerArg)
assignSizesCall(call)
@@ -356,7 +367,7 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
return true
}
}
- case *sys.IntType, *sys.FlagsType, *sys.ResourceType, *sys.ProcType:
+ case *sys.IntType, *sys.FlagsType, *sys.ProcType:
// TODO: try to reset bits in ints
// TODO: try to set separate flags
if crash {
@@ -366,16 +377,39 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
return false
}
triedPaths[path] = true
- if arg.Val == typ.Default() {
+ a := arg.(*ConstArg)
+ if a.Val == typ.Default() {
+ return false
+ }
+ v0 := a.Val
+ a.Val = typ.Default()
+ if pred(p, callIndex0) {
+ p0 = p
+ return true
+ } else {
+ a.Val = v0
+ }
+ case *sys.ResourceType:
+ if crash {
+ return false
+ }
+ if triedPaths[path] {
+ return false
+ }
+ triedPaths[path] = true
+ a := arg.(*ResultArg)
+ if a.Res == nil {
return false
}
- v0 := arg.Val
- arg.Val = typ.Default()
+ r0 := a.Res
+ a.Res = nil
+ a.Val = typ.Default()
if pred(p, callIndex0) {
p0 = p
return true
} else {
- arg.Val = v0
+ a.Res = r0
+ a.Val = 0
}
case *sys.BufferType:
// TODO: try to set individual bytes to 0
@@ -386,15 +420,16 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
if typ.Kind != sys.BufferBlobRand && typ.Kind != sys.BufferBlobRange {
return false
}
+ a := arg.(*DataArg)
minLen := int(typ.RangeBegin)
- for step := len(arg.Data) - minLen; len(arg.Data) > minLen && step > 0; {
- if len(arg.Data)-step >= minLen {
- arg.Data = arg.Data[:len(arg.Data)-step]
+ for step := len(a.Data) - minLen; len(a.Data) > minLen && step > 0; {
+ if len(a.Data)-step >= minLen {
+ a.Data = a.Data[:len(a.Data)-step]
assignSizesCall(call)
if pred(p, callIndex0) {
continue
}
- arg.Data = arg.Data[:len(arg.Data)+step]
+ a.Data = a.Data[:len(a.Data)+step]
assignSizesCall(call)
}
step /= 2
@@ -440,18 +475,20 @@ func (p *Prog) TrimAfter(idx int) {
}
for i := len(p.Calls) - 1; i > idx; i-- {
c := p.Calls[i]
- foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
- if arg.Kind == ArgResult {
- delete(arg.Res.Uses, arg)
+ foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
+ if a, ok := arg.(*ResultArg); ok && a.Res != nil {
+ if used, ok := a.Res.(ArgUsed); ok {
+ delete(*used.Used(), arg)
+ }
}
})
}
p.Calls = p.Calls[:idx+1]
}
-func mutationArgs(c *Call) (args, bases []*Arg) {
- foreachArg(c, func(arg, base *Arg, _ *[]*Arg) {
- switch typ := arg.Type.(type) {
+func mutationArgs(c *Call) (args, bases []Arg) {
+ foreachArg(c, func(arg, base Arg, _ *[]Arg) {
+ switch typ := arg.Type().(type) {
case *sys.StructType:
if isSpecialStruct(typ) == nil {
// For structs only individual fields are updated.
@@ -477,11 +514,11 @@ func mutationArgs(c *Call) (args, bases []*Arg) {
return // string const
}
}
- if arg.Type.Dir() == sys.DirOut {
+ if arg.Type().Dir() == sys.DirOut {
return
}
if base != nil {
- if _, ok := base.Type.(*sys.StructType); ok && isSpecialStruct(base.Type) != nil {
+ if _, ok := base.Type().(*sys.StructType); ok && isSpecialStruct(base.Type()) != nil {
// These special structs are mutated as a whole.
return
}
diff --git a/prog/mutation_test.go b/prog/mutation_test.go
index 69f2fd974..57b049376 100644
--- a/prog/mutation_test.go
+++ b/prog/mutation_test.go
@@ -50,10 +50,10 @@ func TestMutateTable(t *testing.T) {
tests := [][2]string{
// Insert calls.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"pipe2(&(0x7f0000000000)={0x0, 0x0}, 0x0)\n",
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"sched_yield()\n" +
"pipe2(&(0x7f0000000000)={0x0, 0x0}, 0x0)\n",
},
@@ -119,7 +119,7 @@ func TestMutateTable(t *testing.T) {
"r0 = open(&(0x7f0000001000)=\"2e2f66696c653000\", 0x22c0, 0x1)\n" +
"readv(r0, &(0x7f0000000000)=[{&(0x7f0000001000)=\"00\", 0x1}, {&(0x7f0000002000)=\"00\", 0x2}], 0x2)\n",
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"r0 = open(&(0x7f0000001000)=\"2e2f66696c653000\", 0x22c0, 0x1)\n" +
"readv(r0, &(0x7f0000000000)=[{&(0x7f0000001000)=\"00\", 0x1}, {&(0x7f0000002000)=\"00\", 0x2}, {&(0x7f0000000000)=\"00\", 0x3}], 0x3)\n",
},
@@ -158,7 +158,7 @@ func TestMinimize(t *testing.T) {
}{
// Predicate always returns false, so must get the same program.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"sched_yield()\n" +
"pipe2(&(0x7f0000000000)={0x0, 0x0}, 0x0)\n",
2,
@@ -171,14 +171,14 @@ func TestMinimize(t *testing.T) {
}
return false
},
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"sched_yield()\n" +
"pipe2(&(0x7f0000000000)={0x0, 0x0}, 0x0)\n",
2,
},
// Remove a call.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"sched_yield()\n" +
"pipe2(&(0x7f0000000000)={0xffffffffffffffff, 0xffffffffffffffff}, 0x0)\n",
2,
@@ -186,13 +186,13 @@ func TestMinimize(t *testing.T) {
// Aim at removal of sched_yield.
return len(p.Calls) == 2 && p.Calls[0].Meta.Name == "mmap" && p.Calls[1].Meta.Name == "pipe2"
},
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
"pipe2(&(0x7f0000000000)={0xffffffffffffffff, 0xffffffffffffffff}, 0x0)\n",
1,
},
// Remove two dependent calls.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"pipe2(&(0x7f0000000000)={0x0, 0x0}, 0x0)\n" +
"sched_yield()\n",
2,
@@ -211,7 +211,7 @@ func TestMinimize(t *testing.T) {
},
// Remove a call and replace results.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"pipe2(&(0x7f0000000000)={<r0=>0x0, 0x0}, 0x0)\n" +
"write(r0, &(0x7f0000000000)=\"1155\", 0x2)\n" +
"sched_yield()\n",
@@ -219,14 +219,14 @@ func TestMinimize(t *testing.T) {
func(p *Prog, callIndex int) bool {
return p.String() == "mmap-write-sched_yield"
},
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
"write(0xffffffffffffffff, &(0x7f0000000000)=\"\", 0x0)\n" +
"sched_yield()\n",
2,
},
// Remove a call and replace results.
{
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"r0=open(&(0x7f0000000000)=\"1155\", 0x0, 0x0)\n" +
"write(r0, &(0x7f0000000000)=\"1155\", 0x2)\n" +
"sched_yield()\n",
@@ -234,7 +234,7 @@ func TestMinimize(t *testing.T) {
func(p *Prog, callIndex int) bool {
return p.String() == "mmap-write-sched_yield"
},
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
"write(0xffffffffffffffff, &(0x7f0000000000)=\"\", 0x0)\n" +
"sched_yield()\n",
-1,
@@ -242,15 +242,15 @@ func TestMinimize(t *testing.T) {
// Glue several mmaps together.
{
"sched_yield()\n" +
- "mmap(&(0x7f0000000000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
- "mmap(&(0x7f0000001000/0x1000)=nil, (0x1000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000001000/0x1000)=nil, 0x1000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n" +
"getpid()\n" +
- "mmap(&(0x7f0000005000/0x5000)=nil, (0x2000), 0x3, 0x32, 0xffffffffffffffff, 0x0)\n",
+ "mmap(&(0x7f0000005000/0x5000)=nil, 0x2000, 0x3, 0x32, 0xffffffffffffffff, 0x0)\n",
3,
func(p *Prog, callIndex int) bool {
return p.String() == "mmap-sched_yield-getpid"
},
- "mmap(&(0x7f0000000000/0x7000)=nil, (0x7000), 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
+ "mmap(&(0x7f0000000000/0x7000)=nil, 0x7000, 0x0, 0x0, 0xffffffffffffffff, 0x0)\n" +
"sched_yield()\n" +
"getpid()\n",
2,
diff --git a/prog/prog.go b/prog/prog.go
index 57c5648ca..98121e781 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -15,113 +15,98 @@ type Prog struct {
type Call struct {
Meta *sys.Call
- Args []*Arg
- Ret *Arg
-}
-
-type Arg struct {
- Type sys.Type
- Kind ArgKind
- Val uintptr // value of ArgConst
- AddrPage uintptr // page index for ArgPointer address, page count for ArgPageSize
- AddrOffset int // page offset for ArgPointer address
- AddrPagesNum uintptr // number of available pages for ArgPointer
- Data []byte // data of ArgData
- Inner []*Arg // subargs of ArgGroup
- Res *Arg // target of ArgResult, pointee for ArgPointer
- Uses map[*Arg]bool // this arg is used by those ArgResult args
- OpDiv uintptr // divide result for ArgResult (executed before OpAdd)
- OpAdd uintptr // add to result for ArgResult
-
- // ArgUnion/UnionType
- Option *Arg
- OptionType sys.Type
+ Args []Arg
+ Ret Arg
}
-type ArgKind int
+type Arg interface {
+ Type() sys.Type
+ Size() uintptr
+}
-const (
- ArgConst ArgKind = iota
- ArgResult
- ArgPointer // even if these are always constant (for reproducibility), we use a separate type because they are represented in an abstract (base+page+offset) form
- ArgPageSize // same as ArgPointer but base is not added, so it represents "lengths" in pages
- ArgData
- ArgGroup // logical group of args (struct or array)
- ArgUnion
- ArgReturn // fake value denoting syscall return value
-)
+type ArgCommon struct {
+ typ sys.Type
+}
-// Returns inner arg for PtrType args
-func (a *Arg) InnerArg() *Arg {
- switch typ := a.Type.(type) {
- case *sys.PtrType:
- if a.Res == nil {
- if !typ.Optional() {
- panic(fmt.Sprintf("non-optional pointer is nil\narg: %+v\ntype: %+v", a, typ))
- }
- return nil
- } else {
- return a.Res.InnerArg()
- }
- default:
- return a
- }
+func (arg *ArgCommon) Type() sys.Type {
+ return arg.typ
}
-func encodeValue(value, size uintptr, bigEndian bool) uintptr {
- if !bigEndian {
- return value
- }
- switch size {
- case 2:
- return uintptr(swap16(uint16(value)))
- case 4:
- return uintptr(swap32(uint32(value)))
- case 8:
- return uintptr(swap64(uint64(value)))
- default:
- panic(fmt.Sprintf("bad size %v for value %v", size, value))
- }
+// Used for ConstType, IntType, FlagsType, LenType, ProcType and CsumType.
+type ConstArg struct {
+ ArgCommon
+ Val uintptr
}
-// Returns value taking endianness into consideration.
-func (a *Arg) Value(pid int) uintptr {
- switch typ := a.Type.(type) {
+func (arg *ConstArg) Size() uintptr {
+ return arg.typ.Size()
+}
+
+// Returns value taking endianness and executor pid into consideration.
+func (arg *ConstArg) Value(pid int) uintptr {
+ switch typ := (*arg).Type().(type) {
case *sys.IntType:
- return encodeValue(a.Val, typ.Size(), typ.BigEndian)
+ return encodeValue(arg.Val, typ.Size(), typ.BigEndian)
case *sys.ConstType:
- return encodeValue(a.Val, typ.Size(), typ.BigEndian)
+ return encodeValue(arg.Val, typ.Size(), typ.BigEndian)
case *sys.FlagsType:
- return encodeValue(a.Val, typ.Size(), typ.BigEndian)
+ return encodeValue(arg.Val, typ.Size(), typ.BigEndian)
case *sys.LenType:
- return encodeValue(a.Val, typ.Size(), typ.BigEndian)
+ return encodeValue(arg.Val, typ.Size(), typ.BigEndian)
case *sys.CsumType:
// Checksums are computed dynamically in executor.
return 0
case *sys.ResourceType:
if t, ok := typ.Desc.Type.(*sys.IntType); ok {
- return encodeValue(a.Val, t.Size(), t.BigEndian)
+ return encodeValue(arg.Val, t.Size(), t.BigEndian)
} else {
panic(fmt.Sprintf("bad base type for a resource: %v", t))
}
case *sys.ProcType:
- val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + a.Val
+ val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + arg.Val
return encodeValue(val, typ.Size(), typ.BigEndian)
}
- return a.Val
+ return arg.Val
+}
+
+// Used for PtrType and VmaType.
+// Even if these are always constant (for reproducibility), we use a separate
+// type because they are represented in an abstract (base+page+offset) form.
+type PointerArg struct {
+ ArgCommon
+ PageIndex uintptr
+ PageOffset int // offset within a page
+ PagesNum uintptr // number of available pages
+ Res Arg // pointee
+}
+
+func (arg *PointerArg) Size() uintptr {
+ return arg.typ.Size()
+}
+
+// Used for BufferType.
+type DataArg struct {
+ ArgCommon
+ Data []byte
+}
+
+func (arg *DataArg) Size() uintptr {
+ return uintptr(len(arg.Data))
+}
+
+// Used for StructType and ArrayType.
+// Logical group of args (struct or array).
+type GroupArg struct {
+ ArgCommon
+ Inner []Arg
}
-func (a *Arg) Size() uintptr {
- switch typ := a.Type.(type) {
- case *sys.IntType, *sys.LenType, *sys.FlagsType, *sys.ConstType,
- *sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType, *sys.CsumType:
- return typ.Size()
- case *sys.BufferType:
- return uintptr(len(a.Data))
+func (arg *GroupArg) Size() uintptr {
+ switch typ := (*arg).Type().(type) {
case *sys.StructType:
var size uintptr
- for _, fld := range a.Inner {
- if fld.Type.BitfieldLength() == 0 || fld.Type.BitfieldLast() {
+ for _, fld := range arg.Inner {
+ if fld.Type().BitfieldLength() == 0 || fld.Type().BitfieldLast() {
size += fld.Size()
}
}
@@ -130,68 +115,155 @@ func (a *Arg) Size() uintptr {
if typ.Varlen() {
size += align - size%align
} else {
- panic(fmt.Sprintf("struct %+v with type %+v has static size %v, which isn't aligned to %v", a, typ, size, align))
+ panic(fmt.Sprintf("struct %+v with type %+v has static size %v, which isn't aligned to %v", arg, typ, size, align))
}
}
return size
- case *sys.UnionType:
- if !typ.Varlen() {
- return typ.Size()
- } else {
- return a.Option.Size()
- }
case *sys.ArrayType:
var size uintptr
- for _, in := range a.Inner {
+ for _, in := range arg.Inner {
size += in.Size()
}
return size
default:
- panic("unknown arg type")
+ panic(fmt.Sprintf("bad group arg type %v", typ))
}
}
-func constArg(t sys.Type, v uintptr) *Arg {
- return &Arg{Type: t, Kind: ArgConst, Val: v}
+// Used for UnionType.
+type UnionArg struct {
+ ArgCommon
+ Option Arg
+ OptionType sys.Type
}
-func resultArg(t sys.Type, r *Arg) *Arg {
- arg := &Arg{Type: t, Kind: ArgResult, Res: r}
- if r.Uses == nil {
- r.Uses = make(map[*Arg]bool)
+func (arg *UnionArg) Size() uintptr {
+ if !arg.Type().Varlen() {
+ return arg.Type().Size()
+ } else {
+ return arg.Option.Size()
}
- if r.Uses[arg] {
- panic("already used")
- }
- r.Uses[arg] = true
- return arg
}
-func dataArg(t sys.Type, data []byte) *Arg {
- return &Arg{Type: t, Kind: ArgData, Data: append([]byte{}, data...)}
+// Used for ResourceType.
+// Either holds constant value or reference another ResultArg or ReturnArg.
+type ResultArg struct {
+ ArgCommon
+ Res Arg // reference to arg which we use
+ OpDiv uintptr // divide result (executed before OpAdd)
+ OpAdd uintptr // add to result
+ Val uintptr // value used if Res is nil
+ uses map[Arg]bool // ArgResult args that use this arg
+}
+
+func (arg *ResultArg) Size() uintptr {
+ return arg.typ.Size()
+}
+
+// Used for ResourceType and VmaType.
+// This argument denotes syscall return value.
+type ReturnArg struct {
+ ArgCommon
+ uses map[Arg]bool // ArgResult args that use this arg
+}
+
+func (arg *ReturnArg) Size() uintptr {
+ panic("not called")
}
-func pointerArg(t sys.Type, page uintptr, off int, npages uintptr, obj *Arg) *Arg {
- return &Arg{Type: t, Kind: ArgPointer, AddrPage: page, AddrOffset: off, AddrPagesNum: npages, Res: obj}
+type ArgUsed interface {
+ Used() *map[Arg]bool
}
-func pageSizeArg(t sys.Type, npages uintptr, off int) *Arg {
- return &Arg{Type: t, Kind: ArgPageSize, AddrPage: npages, AddrOffset: off}
+func (arg *ResultArg) Used() *map[Arg]bool {
+ return &arg.uses
}
-func groupArg(t sys.Type, inner []*Arg) *Arg {
- return &Arg{Type: t, Kind: ArgGroup, Inner: inner}
+func (arg *ReturnArg) Used() *map[Arg]bool {
+ return &arg.uses
}
-func unionArg(t sys.Type, opt *Arg, typ sys.Type) *Arg {
- return &Arg{Type: t, Kind: ArgUnion, Option: opt, OptionType: typ}
+type ArgUser interface {
+ Uses() *Arg
}
-func returnArg(t sys.Type) *Arg {
- if t != nil {
- return &Arg{Type: t, Kind: ArgReturn, Val: t.Default()}
+func (arg *ResultArg) Uses() *Arg {
+ return &arg.Res
+}
+
+// Returns inner arg for pointer args.
+func InnerArg(arg Arg) Arg {
+ if t, ok := arg.Type().(*sys.PtrType); ok {
+ if a, ok := arg.(*PointerArg); ok {
+ if a.Res == nil {
+ if !t.Optional() {
+ panic(fmt.Sprintf("non-optional pointer is nil\narg: %+v\ntype: %+v", a, t))
+ }
+ return nil
+ } else {
+ return InnerArg(a.Res)
+ }
+ }
+ return nil // *ConstArg.
+ }
+ return arg // Not a pointer.
+}
+
+func encodeValue(value, size uintptr, bigEndian bool) uintptr {
+ if !bigEndian {
+ return value
+ }
+ switch size {
+ case 2:
+ return uintptr(swap16(uint16(value)))
+ case 4:
+ return uintptr(swap32(uint32(value)))
+ case 8:
+ return uintptr(swap64(uint64(value)))
+ default:
+ panic(fmt.Sprintf("bad size %v for value %v", size, value))
+ }
+}
+
+func constArg(t sys.Type, v uintptr) Arg {
+ return &ConstArg{ArgCommon: ArgCommon{typ: t}, Val: v}
+}
+
+func resultArg(t sys.Type, r Arg, v uintptr) Arg {
+ arg := &ResultArg{ArgCommon: ArgCommon{typ: t}, Res: r, Val: v}
+ if r == nil {
+ return arg
+ }
+ if used, ok := r.(ArgUsed); ok {
+ if *used.Used() == nil {
+ *used.Used() = make(map[Arg]bool)
+ }
+ if (*used.Used())[arg] {
+ panic("already used")
+ }
+ (*used.Used())[arg] = true
}
- return &Arg{Type: t, Kind: ArgReturn}
+ return arg
+}
+
+func dataArg(t sys.Type, data []byte) Arg {
+ return &DataArg{ArgCommon: ArgCommon{typ: t}, Data: append([]byte{}, data...)}
+}
+
+func pointerArg(t sys.Type, page uintptr, off int, npages uintptr, obj Arg) Arg {
+ return &PointerArg{ArgCommon: ArgCommon{typ: t}, PageIndex: page, PageOffset: off, PagesNum: npages, Res: obj}
+}
+
+func groupArg(t sys.Type, inner []Arg) Arg {
+ return &GroupArg{ArgCommon: ArgCommon{typ: t}, Inner: inner}
+}
+
+func unionArg(t sys.Type, opt Arg, typ sys.Type) Arg {
+ return &UnionArg{ArgCommon: ArgCommon{typ: t}, Option: opt, OptionType: typ}
+}
+
+func returnArg(t sys.Type) Arg {
+ return &ReturnArg{ArgCommon: ArgCommon{typ: t}}
}
func (p *Prog) insertBefore(c *Call, calls []*Call) {
@@ -212,46 +284,55 @@ func (p *Prog) insertBefore(c *Call, calls []*Call) {
}
// replaceArg replaces arg with arg1 in call c in program p, and inserts calls before arg call.
-func (p *Prog) replaceArg(c *Call, arg, arg1 *Arg, calls []*Call) {
- if arg.Kind != ArgConst && arg.Kind != ArgResult && arg.Kind != ArgPointer && arg.Kind != ArgUnion {
- panic(fmt.Sprintf("replaceArg: bad arg kind %v", arg.Kind))
- }
- if arg1.Kind != ArgConst && arg1.Kind != ArgResult && arg1.Kind != ArgPointer && arg.Kind != ArgUnion {
- panic(fmt.Sprintf("replaceArg: bad arg1 kind %v", arg1.Kind))
- }
- if arg.Kind == ArgResult {
- delete(arg.Res.Uses, arg)
- }
+func (p *Prog) replaceArg(c *Call, arg, arg1 Arg, calls []*Call) {
for _, c := range calls {
sanitizeCall(c)
}
p.insertBefore(c, calls)
- // Somewhat hacky, but safe and preserves references to arg.
- uses := arg.Uses
- *arg = *arg1
- arg.Uses = uses
- if arg.Kind == ArgResult {
- delete(arg.Res.Uses, arg1)
- arg.Res.Uses[arg] = true
+ switch a := arg.(type) {
+ case *ConstArg:
+ *a = *arg1.(*ConstArg)
+ case *ResultArg:
+ // Remove link from `a.Res` to `arg`.
+ if a.Res != nil {
+ delete(*a.Res.(ArgUsed).Used(), arg)
+ }
+ // Copy all fields from `arg1` to `arg` except for the list of args that use `arg`.
+ used := *arg.(ArgUsed).Used()
+ *a = *arg1.(*ResultArg)
+ *arg.(ArgUsed).Used() = used
+ // Make the link in `a.Res` (which is now `Res` of `arg1`) to point to `arg` instead of `arg1`.
+ if a.Res != nil {
+ delete(*a.Res.(ArgUsed).Used(), arg1)
+ (*a.Res.(ArgUsed).Used())[arg] = true
+ }
+ case *PointerArg:
+ *a = *arg1.(*PointerArg)
+ case *UnionArg:
+ *a = *arg1.(*UnionArg)
+ default:
+ panic(fmt.Sprintf("replaceArg: bad arg kind %v", arg))
}
sanitizeCall(c)
}
// removeArg removes all references to/from arg0 of call c from p.
-func (p *Prog) removeArg(c *Call, arg0 *Arg) {
- foreachSubarg(arg0, func(arg, _ *Arg, _ *[]*Arg) {
- if arg.Kind == ArgResult {
- if _, ok := arg.Res.Uses[arg]; !ok {
+func (p *Prog) removeArg(c *Call, arg0 Arg) {
+ foreachSubarg(arg0, func(arg, _ Arg, _ *[]Arg) {
+ if a, ok := arg.(*ResultArg); ok && a.Res != nil {
+ if _, ok := (*a.Res.(ArgUsed).Used())[arg]; !ok {
panic("broken tree")
}
- delete(arg.Res.Uses, arg)
+ delete(*a.Res.(ArgUsed).Used(), arg)
}
- for arg1 := range arg.Uses {
- if arg1.Kind != ArgResult {
- panic("use references not ArgResult")
+ if used, ok := arg.(ArgUsed); ok {
+ for arg1 := range *used.Used() {
+ if _, ok := arg1.(*ResultArg); !ok {
+ panic("use references not ArgResult")
+ }
+ arg2 := resultArg(arg1.Type(), nil, arg1.Type().Default())
+ p.replaceArg(c, arg1, arg2, nil)
}
- arg2 := constArg(arg1.Type, arg1.Type.Default())
- p.replaceArg(c, arg1, arg2, nil)
}
})
}
diff --git a/prog/prog_test.go b/prog/prog_test.go
index b2655ff25..707c7133a 100644
--- a/prog/prog_test.go
+++ b/prog/prog_test.go
@@ -71,18 +71,20 @@ func TestVmaType(t *testing.T) {
if len(c.Args) != 6 {
t.Fatalf("generated wrong number of args %v", len(c.Args))
}
- check := func(v, l *Arg, min, max uintptr) {
- if v.Kind != ArgPointer {
- t.Fatalf("vma has bad type: %v, want %v", v.Kind, ArgPointer)
+ check := func(v, l Arg, min, max uintptr) {
+ va, ok := v.(*PointerArg)
+ if !ok {
+ t.Fatalf("vma has bad type: %v", v)
}
- if l.Kind != ArgPageSize {
- t.Fatalf("len has bad type: %v, want %v", l.Kind, ArgPageSize)
+ la, ok := l.(*ConstArg)
+ if !ok {
+ t.Fatalf("len has bad type: %v", l)
}
- if v.AddrPagesNum < min || v.AddrPagesNum > max {
- t.Fatalf("vma has bad number of pages: %v, want [%v-%v]", v.AddrPagesNum, min, max)
+ if va.PagesNum < min || va.PagesNum > max {
+ t.Fatalf("vma has bad number of pages: %v, want [%v-%v]", va.PagesNum, min, max)
}
- if l.AddrPage < min || l.AddrPage > max {
- t.Fatalf("len has bad number of pages: %v, want [%v-%v]", l.AddrPage, min, max)
+ if la.Val/pageSize < min || la.Val/pageSize > max {
+ t.Fatalf("len has bad number of pages: %v, want [%v-%v]", la.Val/pageSize, min, max)
}
}
check(c.Args[0], c.Args[1], 1, 1e5)
diff --git a/prog/rand.go b/prog/rand.go
index ebebc79f9..d6bcecb70 100644
--- a/prog/rand.go
+++ b/prog/rand.go
@@ -223,34 +223,34 @@ func (r *randGen) randStringImpl(s *state, vals []string) []byte {
return buf.Bytes()
}
-func isSpecialStruct(typ sys.Type) func(r *randGen, s *state) (*Arg, []*Call) {
+func isSpecialStruct(typ sys.Type) func(r *randGen, s *state) (Arg, []*Call) {
a, ok := typ.(*sys.StructType)
if !ok {
panic("must be a struct")
}
switch typ.Name() {
case "timespec":
- return func(r *randGen, s *state) (*Arg, []*Call) {
+ return func(r *randGen, s *state) (Arg, []*Call) {
return r.timespec(s, a, false)
}
case "timeval":
- return func(r *randGen, s *state) (*Arg, []*Call) {
+ return func(r *randGen, s *state) (Arg, []*Call) {
return r.timespec(s, a, true)
}
}
return nil
}
-func (r *randGen) timespec(s *state, typ *sys.StructType, usec bool) (arg *Arg, calls []*Call) {
+func (r *randGen) timespec(s *state, typ *sys.StructType, usec bool) (arg Arg, calls []*Call) {
// We need to generate timespec/timeval that are either (1) definitely in the past,
// or (2) definitely in unreachable fututre, or (3) few ms ahead of now.
// Note timespec/timeval can be absolute or relative to now.
switch {
case r.nOutOf(1, 4):
// now for relative, past for absolute
- arg = groupArg(typ, []*Arg{
- constArg(typ.Fields[0], 0),
- constArg(typ.Fields[1], 0),
+ arg = groupArg(typ, []Arg{
+ resultArg(typ.Fields[0], nil, 0),
+ resultArg(typ.Fields[1], nil, 0),
})
case r.nOutOf(1, 3):
// few ms ahead for relative, past for absolute
@@ -258,45 +258,45 @@ func (r *randGen) timespec(s *state, typ *sys.StructType, usec bool) (arg *Arg,
if usec {
nsec /= 1e3
}
- arg = groupArg(typ, []*Arg{
- constArg(typ.Fields[0], 0),
- constArg(typ.Fields[1], nsec),
+ arg = groupArg(typ, []Arg{
+ resultArg(typ.Fields[0], nil, 0),
+ resultArg(typ.Fields[1], nil, nsec),
})
case r.nOutOf(1, 2):
// unreachable fututre for both relative and absolute
- arg = groupArg(typ, []*Arg{
- constArg(typ.Fields[0], 2e9),
- constArg(typ.Fields[1], 0),
+ arg = groupArg(typ, []Arg{
+ resultArg(typ.Fields[0], nil, 2e9),
+ resultArg(typ.Fields[1], nil, 0),
})
default:
// few ms ahead for absolute
meta := sys.CallMap["clock_gettime"]
ptrArgType := meta.Args[1].(*sys.PtrType)
argType := ptrArgType.Type.(*sys.StructType)
- tp := groupArg(argType, []*Arg{
- constArg(argType.Fields[0], 0),
- constArg(argType.Fields[1], 0),
+ tp := groupArg(argType, []Arg{
+ resultArg(argType.Fields[0], nil, 0),
+ resultArg(argType.Fields[1], nil, 0),
})
- var tpaddr *Arg
+ var tpaddr Arg
tpaddr, calls = r.addr(s, ptrArgType, 2*ptrSize, tp)
gettime := &Call{
Meta: meta,
- Args: []*Arg{
+ Args: []Arg{
constArg(meta.Args[0], sys.CLOCK_REALTIME),
tpaddr,
},
Ret: returnArg(meta.Ret),
}
calls = append(calls, gettime)
- sec := resultArg(typ.Fields[0], tp.Inner[0])
- nsec := resultArg(typ.Fields[1], tp.Inner[1])
+ sec := resultArg(typ.Fields[0], tp.(*GroupArg).Inner[0], 0)
+ nsec := resultArg(typ.Fields[1], tp.(*GroupArg).Inner[1], 0)
if usec {
- nsec.OpDiv = 1e3
- nsec.OpAdd = 10 * 1e3
+ nsec.(*ResultArg).OpDiv = 1e3
+ nsec.(*ResultArg).OpAdd = 10 * 1e3
} else {
- nsec.OpAdd = 10 * 1e6
+ nsec.(*ResultArg).OpAdd = 10 * 1e6
}
- arg = groupArg(typ, []*Arg{sec, nsec})
+ arg = groupArg(typ, []Arg{sec, nsec})
}
return
}
@@ -306,12 +306,12 @@ func createMmapCall(start, npages uintptr) *Call {
meta := sys.CallMap["mmap"]
mmap := &Call{
Meta: meta,
- Args: []*Arg{
+ Args: []Arg{
pointerArg(meta.Args[0], start, 0, npages, nil),
- pageSizeArg(meta.Args[1], npages, 0),
+ constArg(meta.Args[1], npages*pageSize),
constArg(meta.Args[2], sys.PROT_READ|sys.PROT_WRITE),
constArg(meta.Args[3], sys.MAP_ANONYMOUS|sys.MAP_PRIVATE|sys.MAP_FIXED),
- constArg(meta.Args[4], sys.InvalidFD),
+ resultArg(meta.Args[4], nil, sys.InvalidFD),
constArg(meta.Args[5], 0),
},
Ret: returnArg(meta.Ret),
@@ -319,7 +319,7 @@ func createMmapCall(start, npages uintptr) *Call {
return mmap
}
-func (r *randGen) addr1(s *state, typ sys.Type, size uintptr, data *Arg) (*Arg, []*Call) {
+func (r *randGen) addr1(s *state, typ sys.Type, size uintptr, data Arg) (Arg, []*Call) {
npages := (size + pageSize - 1) / pageSize
if npages == 0 {
npages = 1
@@ -344,27 +344,28 @@ func (r *randGen) addr1(s *state, typ sys.Type, size uintptr, data *Arg) (*Arg,
return r.randPageAddr(s, typ, npages, data, false), nil
}
-func (r *randGen) addr(s *state, typ sys.Type, size uintptr, data *Arg) (*Arg, []*Call) {
+func (r *randGen) addr(s *state, typ sys.Type, size uintptr, data Arg) (Arg, []*Call) {
arg, calls := r.addr1(s, typ, size, data)
- if arg.Kind != ArgPointer {
+ a, ok := arg.(*PointerArg)
+ if !ok {
panic("bad")
}
// Patch offset of the address.
switch {
case r.nOutOf(50, 102):
case r.nOutOf(50, 52):
- arg.AddrOffset = -int(size)
+ a.PageOffset = -int(size)
case r.nOutOf(1, 2):
- arg.AddrOffset = r.Intn(pageSize)
+ a.PageOffset = r.Intn(pageSize)
default:
if size > 0 {
- arg.AddrOffset = -r.Intn(int(size))
+ a.PageOffset = -r.Intn(int(size))
}
}
return arg, calls
}
-func (r *randGen) randPageAddr(s *state, typ sys.Type, npages uintptr, data *Arg, vma bool) *Arg {
+func (r *randGen) randPageAddr(s *state, typ sys.Type, npages uintptr, data Arg, vma bool) Arg {
poolPtr := pageStartPool.Get().(*[]uintptr)
starts := (*poolPtr)[:0]
for i := uintptr(0); i < maxPages-npages; i++ {
@@ -396,10 +397,10 @@ func (r *randGen) randPageAddr(s *state, typ sys.Type, npages uintptr, data *Arg
return pointerArg(typ, page, 0, npages, data)
}
-func (r *randGen) createResource(s *state, res *sys.ResourceType) (arg *Arg, calls []*Call) {
+func (r *randGen) createResource(s *state, res *sys.ResourceType) (arg Arg, calls []*Call) {
if r.inCreateResource {
special := res.SpecialValues()
- return constArg(res, special[r.Intn(len(special))]), nil
+ return resultArg(res, nil, special[r.Intn(len(special))]), nil
}
r.inCreateResource = true
defer func() { r.inCreateResource = false }()
@@ -426,7 +427,7 @@ func (r *randGen) createResource(s *state, res *sys.ResourceType) (arg *Arg, cal
metas = append(metas, meta)
}
if len(metas) == 0 {
- return constArg(res, res.Default()), nil
+ return resultArg(res, nil, res.Default()), nil
}
// Now we have a set of candidate calls that can create the necessary resource.
@@ -437,7 +438,7 @@ func (r *randGen) createResource(s *state, res *sys.ResourceType) (arg *Arg, cal
s1 := newState(s.ct)
s1.analyze(calls[len(calls)-1])
// Now see if we have what we want.
- var allres []*Arg
+ var allres []Arg
for kind1, res1 := range s1.resources {
if sys.IsCompatibleResource(kind, kind1) {
allres = append(allres, res1...)
@@ -445,14 +446,14 @@ func (r *randGen) createResource(s *state, res *sys.ResourceType) (arg *Arg, cal
}
if len(allres) != 0 {
// Bingo!
- arg := resultArg(res, allres[r.Intn(len(allres))])
+ arg := resultArg(res, allres[r.Intn(len(allres))], 0)
return arg, calls
}
// Discard unsuccessful calls.
for _, c := range calls {
- foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
- if arg.Kind == ArgResult {
- delete(arg.Res.Uses, arg)
+ foreachArg(c, func(arg, _ Arg, _ *[]Arg) {
+ if a, ok := arg.(*ResultArg); ok && a.Res != nil {
+ delete(*a.Res.(ArgUsed).Used(), arg)
}
})
}
@@ -582,9 +583,9 @@ func GenerateAllSyzProg(rs rand.Source) *Prog {
return p
}
-func (r *randGen) generateArgs(s *state, types []sys.Type) ([]*Arg, []*Call) {
+func (r *randGen) generateArgs(s *state, types []sys.Type) ([]Arg, []*Call) {
var calls []*Call
- args := make([]*Arg, len(types))
+ args := make([]Arg, len(types))
// Generate all args. Size args have the default value 0 for now.
for i, typ := range types {
@@ -599,24 +600,33 @@ func (r *randGen) generateArgs(s *state, types []sys.Type) ([]*Arg, []*Call) {
return args, calls
}
-func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) {
+func (r *randGen) generateArg(s *state, typ sys.Type) (arg Arg, calls []*Call) {
if typ.Dir() == sys.DirOut {
// No need to generate something interesting for output scalar arguments.
// But we still need to generate the argument itself so that it can be referenced
// in subsequent calls. For the same reason we do generate pointer/array/struct
// output arguments (their elements can be referenced in subsequent calls).
switch typ.(type) {
- case *sys.IntType, *sys.FlagsType, *sys.ConstType,
- *sys.ResourceType, *sys.VmaType, *sys.ProcType:
+ case *sys.IntType, *sys.FlagsType, *sys.ConstType, *sys.ProcType:
return constArg(typ, typ.Default()), nil
+ case *sys.VmaType:
+ return pointerArg(typ, 0, 0, 0, nil), nil
+ case *sys.ResourceType:
+ return resultArg(typ, nil, typ.Default()), nil
}
}
if typ.Optional() && r.oneOf(5) {
- if _, ok := typ.(*sys.BufferType); ok {
+ switch typ.(type) {
+ case *sys.PtrType:
+ return pointerArg(typ, 0, 0, 0, nil), nil
+ case *sys.BufferType:
panic("impossible") // parent PtrType must be Optional instead
+ case *sys.VmaType:
+ return pointerArg(typ, 0, 0, 0, nil), nil
+ default:
+ return constArg(typ, typ.Default()), nil
}
- return constArg(typ, typ.Default()), nil
}
switch a := typ.(type) {
@@ -624,7 +634,7 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
switch {
case r.nOutOf(1000, 1011):
// Get an existing resource.
- var allres []*Arg
+ var allres []Arg
for name1, res1 := range s.resources {
if sys.IsCompatibleResource(a.Desc.Name, name1) ||
r.oneOf(20) && sys.IsCompatibleResource(a.Desc.Kind[0], name1) {
@@ -632,7 +642,7 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
}
}
if len(allres) != 0 {
- arg = resultArg(a, allres[r.Intn(len(allres))])
+ arg = resultArg(a, allres[r.Intn(len(allres))], 0)
} else {
arg, calls = r.createResource(s, a)
}
@@ -641,7 +651,7 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
arg, calls = r.createResource(s, a)
default:
special := a.SpecialValues()
- arg = constArg(a, special[r.Intn(len(special))])
+ arg = resultArg(a, nil, special[r.Intn(len(special))])
}
return arg, calls
case *sys.BufferType:
@@ -720,7 +730,7 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
case sys.ArrayRangeLen:
count = r.randRange(int(a.RangeBegin), int(a.RangeEnd))
}
- var inner []*Arg
+ var inner []Arg
var calls []*Call
for i := uintptr(0); i < count; i++ {
arg1, calls1 := r.generateArg(s, a.Type)
@@ -746,8 +756,8 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
// It is weird, but these are actually identified by kernel by address.
// So try to reuse a previously used address.
addrs := s.resources["iocbptr"]
- addr := addrs[r.Intn(len(addrs))]
- arg = pointerArg(a, addr.AddrPage, addr.AddrOffset, addr.AddrPagesNum, inner)
+ addr := addrs[r.Intn(len(addrs))].(*PointerArg)
+ arg = pointerArg(a, addr.PageIndex, addr.PageOffset, addr.PagesNum, inner)
return arg, calls
}
arg, calls1 := r.addr(s, a, inner.Size(), inner)
diff --git a/prog/size.go b/prog/size.go
index 60cc7bd57..f00fd56b1 100644
--- a/prog/size.go
+++ b/prog/size.go
@@ -9,20 +9,22 @@ import (
"github.com/google/syzkaller/sys"
)
-func generateSize(arg *Arg, lenType *sys.LenType) *Arg {
+func generateSize(arg Arg, lenType *sys.LenType) Arg {
if arg == nil {
// Arg is an optional pointer, set size to 0.
return constArg(lenType, 0)
}
- switch arg.Type.(type) {
+ switch arg.Type().(type) {
case *sys.VmaType:
- return pageSizeArg(lenType, arg.AddrPagesNum, 0)
+ a := arg.(*PointerArg)
+ return constArg(lenType, a.PagesNum*pageSize)
case *sys.ArrayType:
+ a := arg.(*GroupArg)
if lenType.ByteSize != 0 {
- return constArg(lenType, arg.Size()/lenType.ByteSize)
+ return constArg(lenType, a.Size()/lenType.ByteSize)
} else {
- return constArg(lenType, uintptr(len(arg.Inner)))
+ return constArg(lenType, uintptr(len(a.Inner)))
}
default:
if lenType.ByteSize != 0 {
@@ -33,42 +35,44 @@ func generateSize(arg *Arg, lenType *sys.LenType) *Arg {
}
}
-func assignSizes(args []*Arg, parentsMap map[*Arg]*Arg) {
- // Create a map of args and calculate size of the whole struct.
- argsMap := make(map[string]*Arg)
+func assignSizes(args []Arg, parentsMap map[Arg]Arg) {
+ // Create a map from field names to args.
+ argsMap := make(map[string]Arg)
for _, arg := range args {
- if sys.IsPad(arg.Type) {
+ if sys.IsPad(arg.Type()) {
continue
}
- argsMap[arg.Type.FieldName()] = arg
+ argsMap[arg.Type().FieldName()] = arg
}
// Fill in size arguments.
for _, arg := range args {
- if arg = arg.InnerArg(); arg == nil {
+ if arg = InnerArg(arg); arg == nil {
continue // Pointer to optional len field, no need to fill in value.
}
- if typ, ok := arg.Type.(*sys.LenType); ok {
+ if typ, ok := arg.Type().(*sys.LenType); ok {
+ a := arg.(*ConstArg)
+
buf, ok := argsMap[typ.Buf]
if ok {
- *arg = *generateSize(buf.InnerArg(), typ)
+ *a = *generateSize(InnerArg(buf), typ).(*ConstArg)
continue
}
if typ.Buf == "parent" {
- arg.Val = parentsMap[arg].Size()
+ a.Val = parentsMap[arg].Size()
if typ.ByteSize != 0 {
- arg.Val /= typ.ByteSize
+ a.Val /= typ.ByteSize
}
continue
}
sizeAssigned := false
for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
- if typ.Buf == parent.Type.Name() {
- arg.Val = parent.Size()
+ if typ.Buf == parent.Type().Name() {
+ a.Val = parent.Size()
if typ.ByteSize != 0 {
- arg.Val /= typ.ByteSize
+ a.Val /= typ.ByteSize
}
sizeAssigned = true
break
@@ -84,19 +88,19 @@ func assignSizes(args []*Arg, parentsMap map[*Arg]*Arg) {
}
}
-func assignSizesArray(args []*Arg) {
- parentsMap := make(map[*Arg]*Arg)
- foreachArgArray(&args, nil, func(arg, base *Arg, _ *[]*Arg) {
- if _, ok := arg.Type.(*sys.StructType); ok {
- for _, field := range arg.Inner {
- parentsMap[field.InnerArg()] = arg
+func assignSizesArray(args []Arg) {
+ parentsMap := make(map[Arg]Arg)
+ foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) {
+ if _, ok := arg.Type().(*sys.StructType); ok {
+ for _, field := range arg.(*GroupArg).Inner {
+ parentsMap[InnerArg(field)] = arg
}
}
})
assignSizes(args, parentsMap)
- foreachArgArray(&args, nil, func(arg, base *Arg, _ *[]*Arg) {
- if _, ok := arg.Type.(*sys.StructType); ok {
- assignSizes(arg.Inner, parentsMap)
+ foreachArgArray(&args, nil, func(arg, base Arg, _ *[]Arg) {
+ if _, ok := arg.Type().(*sys.StructType); ok {
+ assignSizes(arg.(*GroupArg).Inner, parentsMap)
}
})
}
diff --git a/prog/size_test.go b/prog/size_test.go
index f6e4ab17c..0f51edc2b 100644
--- a/prog/size_test.go
+++ b/prog/size_test.go
@@ -75,12 +75,12 @@ func TestAssignSize(t *testing.T) {
"syz_test$length8(&(0x7f000001f000)={0x38, {0xff, 0x1, 0x10, [0xff, 0xff, 0xff]}, [{0xff, 0x1, 0x10, [0xff, 0xff, 0xff]}], 0x10, 0x1, [0xff, 0xff]})",
},
{
- "syz_test$length9(&(0x7f000001f000)={&(0x7f0000000000/0x5000)=nil, (0x0000)})",
- "syz_test$length9(&(0x7f000001f000)={&(0x7f0000000000/0x5000)=nil, (0x5000)})",
+ "syz_test$length9(&(0x7f000001f000)={&(0x7f0000000000/0x5000)=nil, 0x0000})",
+ "syz_test$length9(&(0x7f000001f000)={&(0x7f0000000000/0x5000)=nil, 0x5000})",
},
{
- "syz_test$length10(&(0x7f0000000000/0x5000)=nil, (0x0000))",
- "syz_test$length10(&(0x7f0000000000/0x5000)=nil, (0x5000))",
+ "syz_test$length10(&(0x7f0000000000/0x5000)=nil, 0x0000)",
+ "syz_test$length10(&(0x7f0000000000/0x5000)=nil, 0x5000)",
},
{
"syz_test$length11(&(0x7f0000000000)={0xff, 0xff, [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]}, 0x00)",
diff --git a/prog/validation.go b/prog/validation.go
index e806ea4ce..162957068 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -12,12 +12,12 @@ import (
var debug = false // enabled in tests
type validCtx struct {
- args map[*Arg]bool
- uses map[*Arg]*Arg
+ args map[Arg]bool
+ uses map[Arg]Arg
}
func (p *Prog) validate() error {
- ctx := &validCtx{make(map[*Arg]bool), make(map[*Arg]*Arg)}
+ ctx := &validCtx{make(map[Arg]bool), make(map[Arg]Arg)}
for _, c := range p.Calls {
if err := c.validate(ctx); err != nil {
return err
@@ -25,7 +25,7 @@ func (p *Prog) validate() error {
}
for u, orig := range ctx.uses {
if !ctx.args[u] {
- return fmt.Errorf("use of %+v referes to an out-of-tree arg\narg: %#v", *orig, u)
+ return fmt.Errorf("use of %+v referes to an out-of-tree arg\narg: %#v", orig, u)
}
}
return nil
@@ -38,8 +38,8 @@ func (c *Call) validate(ctx *validCtx) error {
if len(c.Args) != len(c.Meta.Args) {
return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", c.Meta.Name, len(c.Meta.Args), len(c.Args))
}
- var checkArg func(arg *Arg) error
- checkArg = func(arg *Arg) error {
+ var checkArg func(arg Arg) error
+ checkArg = func(arg Arg) error {
if arg == nil {
return fmt.Errorf("syscall %v: nil arg", c.Meta.Name)
}
@@ -47,179 +47,211 @@ func (c *Call) validate(ctx *validCtx) error {
return fmt.Errorf("syscall %v: arg is referenced several times in the tree", c.Meta.Name)
}
ctx.args[arg] = true
- for u := range arg.Uses {
- ctx.uses[u] = arg
+ if used, ok := arg.(ArgUsed); ok {
+ for u := range *used.Used() {
+ if u == nil {
+ return fmt.Errorf("syscall %v: nil reference in uses for arg %+v", c.Meta.Name, arg)
+ }
+ ctx.uses[u] = arg
+ }
}
- if arg.Type == nil {
+ if arg.Type() == nil {
return fmt.Errorf("syscall %v: no type", c.Meta.Name)
}
- if arg.Type.Dir() == sys.DirOut {
- if (arg.Val != 0 && arg.Val != arg.Type.Default()) || arg.AddrPage != 0 || arg.AddrOffset != 0 {
- // We generate output len arguments, which makes sense
- // since it can be a length of a variable-length array
- // which is not known otherwise.
- if _, ok := arg.Type.(*sys.LenType); !ok {
- return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'", c.Meta.Name, arg.Type.FieldName(), arg.Type.Name(), *arg)
+ if arg.Type().Dir() == sys.DirOut {
+ switch a := arg.(type) {
+ case *ConstArg:
+ // We generate output len arguments, which makes sense since it can be
+ // a length of a variable-length array which is not known otherwise.
+ if _, ok := a.Type().(*sys.LenType); ok {
+ break
}
- }
- for _, v := range arg.Data {
- if v != 0 {
- return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, arg.Type.Name())
+ if a.Val != 0 && a.Val != a.Type().Default() {
+ return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'", c.Meta.Name, a.Type().FieldName(), a.Type().Name(), a)
+ }
+ case *DataArg:
+ for _, v := range a.Data {
+ if v != 0 {
+ return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, a.Type().Name())
+ }
}
}
}
- switch typ1 := arg.Type.(type) {
+ switch typ1 := arg.Type().(type) {
case *sys.IntType:
- switch arg.Kind {
- case ArgConst:
- case ArgResult:
- case ArgReturn:
- if arg.Type.Dir() == sys.DirOut && (arg.Val != 0 && arg.Val != arg.Type.Default()) {
- return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v", c.Meta.Name, arg.Type.Name(), arg.Val)
+ switch a := arg.(type) {
+ case *ConstArg:
+ if a.Type().Dir() == sys.DirOut && (a.Val != 0 && a.Val != a.Type().Default()) {
+ return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v", c.Meta.Name, a.Type().Name(), a.Val)
}
+ case *ReturnArg:
default:
- return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.ResourceType:
- switch arg.Kind {
- case ArgResult:
- case ArgReturn:
- case ArgConst:
- if arg.Type.Dir() == sys.DirOut && (arg.Val != 0 && arg.Val != arg.Type.Default()) {
- return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, arg.Type.Name(), arg.Val)
+ switch a := arg.(type) {
+ case *ResultArg:
+ if a.Type().Dir() == sys.DirOut && (a.Val != 0 && a.Val != a.Type().Default()) {
+ return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, a.Type().Name(), a.Val)
}
+ case *ReturnArg:
default:
- return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.StructType, *sys.ArrayType:
- switch arg.Kind {
- case ArgGroup:
+ switch arg.(type) {
+ case *GroupArg:
default:
- return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.UnionType:
- switch arg.Kind {
- case ArgUnion:
+ switch arg.(type) {
+ case *UnionArg:
default:
- return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.ProcType:
- if arg.Val >= uintptr(typ1.ValuesPerProc) {
- return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, arg.Type.Name(), arg.Val)
- }
- case *sys.BufferType:
- switch arg.Kind {
- case ArgData:
+ switch a := arg.(type) {
+ case *ConstArg:
+ if a.Val >= uintptr(typ1.ValuesPerProc) {
+ return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, a.Type().Name(), a.Val)
+ }
default:
- return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: proc arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
- switch typ1.Kind {
- case sys.BufferString:
- if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) {
- return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, arg.Type.Name(), len(arg.Data), typ1.Length)
+ case *sys.BufferType:
+ switch a := arg.(type) {
+ case *DataArg:
+ switch typ1.Kind {
+ case sys.BufferString:
+ if typ1.Length != 0 && len(a.Data) != int(typ1.Length) {
+ return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, a.Type().Name(), len(a.Data), typ1.Length)
+ }
}
+ default:
+ return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.CsumType:
- if arg.Val != 0 {
- return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, arg.Type.Name(), arg.Val)
+ switch a := arg.(type) {
+ case *ConstArg:
+ if a.Val != 0 {
+ return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, a.Type().Name(), a.Val)
+ }
+ default:
+ return fmt.Errorf("syscall %v: csum arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
case *sys.PtrType:
- if arg.Type.Dir() == sys.DirOut {
- return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, arg.Type.Name())
- }
- if arg.Res == nil && !arg.Type.Optional() {
- return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, arg.Type.Name())
+ switch a := arg.(type) {
+ case *PointerArg:
+ if a.Type().Dir() == sys.DirOut {
+ return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, a.Type().Name())
+ }
+ if a.Res == nil && !a.Type().Optional() {
+ return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, a.Type().Name())
+ }
+ default:
+ return fmt.Errorf("syscall %v: ptr arg '%v' has bad kind %v", c.Meta.Name, arg.Type().Name(), arg)
}
}
- switch arg.Kind {
- case ArgConst:
- case ArgResult:
- if arg.Res == nil {
- return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, arg.Type.Name())
- }
- if !ctx.args[arg.Res] {
- return fmt.Errorf("syscall %v: result arg '%v' references out-of-tree result: %p%+v -> %p%+v",
- c.Meta.Name, arg.Type.Name(), arg, arg, arg.Res, arg.Res)
- }
- if _, ok := arg.Res.Uses[arg]; !ok {
- return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, arg.Type.Name(), arg.Res.Uses)
- }
- case ArgPointer:
- switch arg.Type.(type) {
+ switch a := arg.(type) {
+ case *ConstArg:
+ case *PointerArg:
+ switch t := a.Type().(type) {
case *sys.VmaType:
- if arg.Res != nil {
- return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, arg.Type.Name())
+ if a.Res != nil {
+ return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, a.Type().Name())
}
- if arg.AddrPagesNum == 0 {
- return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, arg.Type.Name())
+ if a.PagesNum == 0 && t.Dir() != sys.DirOut && !t.Optional() {
+ return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, a.Type().Name())
}
case *sys.PtrType:
- if arg.Res != nil {
- if err := checkArg(arg.Res); err != nil {
+ if a.Res != nil {
+ if err := checkArg(a.Res); err != nil {
return err
}
}
- if arg.AddrPagesNum != 0 {
- return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, arg.Type.Name())
+ if a.PagesNum != 0 {
+ return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, a.Type().Name())
}
default:
- return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type.Name(), arg.Type)
+ return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type().Name(), arg.Type())
}
- case ArgPageSize:
- case ArgData:
- switch typ1 := arg.Type.(type) {
+ case *DataArg:
+ switch typ1 := a.Type().(type) {
case *sys.ArrayType:
if typ2, ok := typ1.Type.(*sys.IntType); !ok || typ2.Size() != 1 {
- return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, arg.Type.Name())
+ return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, a.Type().Name())
}
}
- case ArgGroup:
- switch typ1 := arg.Type.(type) {
+ case *GroupArg:
+ switch typ1 := a.Type().(type) {
case *sys.StructType:
- if len(arg.Inner) != len(typ1.Fields) {
- return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, arg.Type.Name(), len(typ1.Fields), len(arg.Inner))
+ if len(a.Inner) != len(typ1.Fields) {
+ return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, a.Type().Name(), len(typ1.Fields), len(a.Inner))
}
- for _, arg1 := range arg.Inner {
+ for _, arg1 := range a.Inner {
if err := checkArg(arg1); err != nil {
return err
}
}
case *sys.ArrayType:
- for _, arg1 := range arg.Inner {
+ for _, arg1 := range a.Inner {
if err := checkArg(arg1); err != nil {
return err
}
}
default:
- return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, arg.Type.Name(), arg.Type)
+ return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, arg.Type().Name(), arg.Type())
}
- case ArgUnion:
- typ1, ok := arg.Type.(*sys.UnionType)
+ case *UnionArg:
+ typ1, ok := a.Type().(*sys.UnionType)
if !ok {
- return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, arg.Type.Name())
+ return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, a.Type().Name())
}
found := false
for _, typ2 := range typ1.Options {
- if arg.OptionType.Name() == typ2.Name() {
+ if a.OptionType.Name() == typ2.Name() {
found = true
break
}
}
if !found {
- return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, arg.Type.Name())
+ return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, a.Type().Name())
}
- if err := checkArg(arg.Option); err != nil {
+ if err := checkArg(a.Option); err != nil {
return err
}
- case ArgReturn:
+ case *ResultArg:
+ switch a.Type().(type) {
+ case *sys.ResourceType:
+ default:
+ return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type().Name(), arg.Type())
+ }
+ if a.Res == nil {
+ break
+ }
+ if !ctx.args[a.Res] {
+ return fmt.Errorf("syscall %v: result arg '%v' references out-of-tree result: %p%+v -> %p%+v",
+ c.Meta.Name, a.Type().Name(), arg, arg, a.Res, a.Res)
+ }
+ if _, ok := (*a.Res.(ArgUsed).Used())[arg]; !ok {
+ return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, a.Type().Name(), *a.Res.(ArgUsed).Used())
+ }
+ case *ReturnArg:
+ switch a.Type().(type) {
+ case *sys.ResourceType:
+ case *sys.VmaType:
+ default:
+ return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type().Name(), arg.Type())
+ }
default:
- return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, arg.Type.Name())
+ return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, arg.Type().Name())
}
return nil
}
for _, arg := range c.Args {
- if arg.Kind == ArgReturn {
- return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", c.Meta.Name, arg.Type.Name())
+ if _, ok := arg.(*ReturnArg); ok {
+ return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", c.Meta.Name, arg.Type().Name())
}
if err := checkArg(arg); err != nil {
return err
@@ -228,15 +260,15 @@ func (c *Call) validate(ctx *validCtx) error {
if c.Ret == nil {
return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name)
}
- if c.Ret.Kind != ArgReturn {
- return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret.Kind)
+ if _, ok := c.Ret.(*ReturnArg); !ok {
+ return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret)
}
if c.Meta.Ret != nil {
if err := checkArg(c.Ret); err != nil {
return err
}
- } else if c.Ret.Type != nil {
- return fmt.Errorf("syscall %v: return value has spurious type: %+v", c.Meta.Name, c.Ret.Type)
+ } else if c.Ret.Type() != nil {
+ return fmt.Errorf("syscall %v: return value has spurious type: %+v", c.Meta.Name, c.Ret.Type())
}
return nil
}