aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-01-25 16:18:05 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-01-25 20:31:13 +0100
commit63b16a5d5cfd3b41f596daccd56d32b2548ec119 (patch)
tree3dfa93b07083b7ee4c21aa430aeedc92b9a16bb4 /prog
parentc8d03a05f3acd375badcde94264909d149784778 (diff)
prog, sys: add csum type, embed checksums for ipv4 packets
This change adds a `csum[kind, type]` type. The only available kind right now is `ipv4`. Using `csum[ipv4, int16be]` in `ipv4_header` makes syzkaller calculate and embed correct checksums into ipv4 packets.
Diffstat (limited to 'prog')
-rw-r--r--prog/analysis.go30
-rw-r--r--prog/checksum.go157
-rw-r--r--prog/checksum_test.go150
-rw-r--r--prog/encodingexec.go42
-rw-r--r--prog/mutation.go7
-rw-r--r--prog/prog.go4
-rw-r--r--prog/rand.go4
-rw-r--r--prog/validation.go6
8 files changed, 365 insertions, 35 deletions
diff --git a/prog/analysis.go b/prog/analysis.go
index d008f9c48..a267f7d15 100644
--- a/prog/analysis.go
+++ b/prog/analysis.go
@@ -150,6 +150,36 @@ func foreachArg(c *Call, f func(arg, base *Arg, parent *[]*Arg)) {
foreachArgArray(&c.Args, nil, f)
}
+func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) {
+ var rec func(*Arg, uintptr) uintptr
+ rec = func(arg1 *Arg, offset uintptr) uintptr {
+ switch arg1.Kind {
+ case ArgGroup:
+ var totalSize uintptr
+ for _, arg2 := range arg1.Inner {
+ size := rec(arg2, offset)
+ if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() {
+ offset += size
+ totalSize += size
+ }
+ }
+ if totalSize > arg1.Size() {
+ panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1))
+ }
+ case ArgUnion:
+ size := rec(arg1.Option, offset)
+ offset += size
+ if size > arg1.Size() {
+ panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type))
+ }
+ default:
+ f(arg1, offset)
+ }
+ return arg1.Size()
+ }
+ rec(arg, 0)
+}
+
func sanitizeCall(c *Call) {
switch c.Meta.CallName {
case "mmap":
diff --git a/prog/checksum.go b/prog/checksum.go
new file mode 100644
index 000000000..3806c59e0
--- /dev/null
+++ b/prog/checksum.go
@@ -0,0 +1,157 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package prog
+
+import (
+ "fmt"
+ "unsafe"
+
+ "github.com/google/syzkaller/sys"
+)
+
+type IPChecksum struct {
+ acc uint32
+}
+
+func (csum *IPChecksum) Update(data []byte) {
+ length := len(data) - 1
+ for i := 0; i < length; i += 2 {
+ csum.acc += uint32(data[i]) << 8
+ csum.acc += uint32(data[i+1])
+ }
+ if len(data)%2 == 1 {
+ csum.acc += uint32(data[length]) << 8
+ }
+ for csum.acc > 0xffff {
+ csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff)
+ }
+}
+
+func (csum *IPChecksum) Digest() uint16 {
+ return ^uint16(csum.acc)
+}
+
+func ipChecksum(data []byte) uint16 {
+ var csum IPChecksum
+ csum.Update(data)
+ return csum.Digest()
+}
+
+func bitmaskLen(bfLen uint64) uint64 {
+ return (1 << bfLen) - 1
+}
+
+func bitmaskLenOff(bfOff, bfLen uint64) uint64 {
+ return bitmaskLen(bfLen) << bfOff
+}
+
+func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) {
+ if bfOff == 0 && bfLen == 0 {
+ *addr = value
+ } else {
+ newValue := *addr
+ newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen))
+ newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff
+ *addr = newValue
+ }
+}
+
+func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) {
+ if bfOff == 0 && bfLen == 0 {
+ *addr = value
+ } else {
+ newValue := *addr
+ newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen))
+ newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff
+ *addr = newValue
+ }
+}
+
+func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) {
+ if bfOff == 0 && bfLen == 0 {
+ *addr = value
+ } else {
+ newValue := *addr
+ newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen))
+ newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff
+ *addr = newValue
+ }
+}
+
+func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
+ if bfOff == 0 && bfLen == 0 {
+ *addr = value
+ } else {
+ newValue := *addr
+ newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen))
+ newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff
+ *addr = newValue
+ }
+}
+
+func encodeStruct(arg *Arg, pid int) []byte {
+ bytes := make([]byte, arg.Size())
+ foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
+ switch arg.Kind {
+ case ArgConst:
+ addr := unsafe.Pointer(&bytes[offset])
+ val := arg.Value(pid)
+ bfOff := uint64(arg.Type.BitfieldOffset())
+ bfLen := uint64(arg.Type.BitfieldLength())
+ switch arg.Size() {
+ case 1:
+ storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen)
+ case 2:
+ storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen)
+ case 4:
+ storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen)
+ case 8:
+ storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen)
+ default:
+ panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg))
+ }
+ case ArgData:
+ copy(bytes[offset:], arg.Data)
+ default:
+ panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type))
+ }
+ })
+ return bytes
+}
+
+func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
+ var csumField *Arg
+ for _, field := range arg.Inner {
+ if _, ok := field.Type.(*sys.CsumType); ok {
+ csumField = field
+ break
+ }
+ }
+ if csumField == nil {
+ panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name()))
+ }
+ if csumField.Value(pid) != 0 {
+ panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField))
+ }
+ bytes := encodeStruct(arg, pid)
+ csum := ipChecksum(bytes)
+ newCsumField := *csumField
+ newCsumField.Val = uintptr(csum)
+ return csumField, &newCsumField
+}
+
+func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
+ var m map[*Arg]*Arg
+ foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
+ // syz_csum_ipv4 struct is used in tests
+ if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" {
+ if m == nil {
+ m = make(map[*Arg]*Arg)
+ }
+ k, v := calcChecksumIPv4(arg, pid)
+ m[k] = v
+ }
+ })
+ return m
+}
diff --git a/prog/checksum_test.go b/prog/checksum_test.go
new file mode 100644
index 000000000..bade7f724
--- /dev/null
+++ b/prog/checksum_test.go
@@ -0,0 +1,150 @@
+// Copyright 2016 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package prog
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestChecksumIP(t *testing.T) {
+ tests := []struct {
+ data string
+ csum uint16
+ }{
+ {
+ "",
+ 0xffff,
+ },
+ {
+ "\x00",
+ 0xffff,
+ },
+ {
+ "\x00\x00",
+ 0xffff,
+ },
+ {
+ "\x00\x00\xff\xff",
+ 0x0000,
+ },
+ {
+ "\xfc",
+ 0x03ff,
+ },
+ {
+ "\xfc\x12",
+ 0x03ed,
+ },
+ {
+ "\xfc\x12\x3e",
+ 0xc5ec,
+ },
+ {
+ "\xfc\x12\x3e\x00\xc5\xec",
+ 0x0000,
+ },
+ {
+ "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
+ 0xe143,
+ },
+ {
+ "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
+ 0xe143,
+ },
+ }
+
+ for _, test := range tests {
+ csum := ipChecksum([]byte(test.data))
+ if csum != test.csum {
+ t.Fatalf("incorrect ip checksum, got: %x, want: %x, data: %+v", csum, test.csum, []byte(test.data))
+ }
+ }
+}
+
+func TestChecksumIPAcc(t *testing.T) {
+ rs, iters := initTest(t)
+ r := newRand(rs)
+
+ for i := 0; i < iters; i++ {
+ bytes := make([]byte, r.Intn(256))
+ for i := 0; i < len(bytes); i++ {
+ bytes[i] = byte(r.Intn(256))
+ }
+ step := int(r.randRange(1, 8)) * 2
+ var csumAcc IPChecksum
+ for i := 0; i < len(bytes)/step; i++ {
+ csumAcc.Update(bytes[i*step : (i+1)*step])
+ }
+ if len(bytes)%step != 0 {
+ csumAcc.Update(bytes[len(bytes)-(len(bytes)%step) : len(bytes)])
+ }
+ csum := ipChecksum(bytes)
+ if csum != csumAcc.Digest() {
+ t.Fatalf("inconsistent ip checksum: %x vs %x, step: %v, data: %+v", csum, csumAcc.Digest(), step, bytes)
+ }
+ }
+}
+
+func TestChecksumEncode(t *testing.T) {
+ tests := []struct {
+ prog string
+ encoded string
+ }{
+ {
+ "syz_test$csum_encode(&(0x7f0000000000)={0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"})",
+ "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
+ },
+ }
+ for i, test := range tests {
+ p, err := Deserialize([]byte(test.prog))
+ if err != nil {
+ t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
+ }
+ encoded := encodeStruct(p.Calls[0].Args[0].Res, 0)
+ if !bytes.Equal(encoded, []byte(test.encoded)) {
+ t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded))
+ }
+ }
+}
+
+func TestChecksumIPv4Calc(t *testing.T) {
+ tests := []struct {
+ prog string
+ csum uint16
+ }{
+ {
+ "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})",
+ 0xe143,
+ },
+ }
+ for i, test := range tests {
+ p, err := Deserialize([]byte(test.prog))
+ if err != nil {
+ t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
+ }
+ _, csumField := calcChecksumIPv4(p.Calls[0].Args[0].Res, i%32)
+ // Can't compare serialized progs, since checksums are zerod on serialization.
+ csum := csumField.Value(i % 32)
+ if csum != uintptr(test.csum) {
+ t.Fatalf("failed to calc ipv4 checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog)
+ }
+ }
+}
+
+func TestChecksumCalcRandom(t *testing.T) {
+ rs, iters := initTest(t)
+ for i := 0; i < iters; i++ {
+ p := Generate(rs, 10, nil)
+ for _, call := range p.Calls {
+ calcChecksumsCall(call, i%32)
+ }
+ for try := 0; try <= 10; try++ {
+ p.Mutate(rs, 10, nil, nil)
+ for _, call := range p.Calls {
+ calcChecksumsCall(call, i%32)
+ }
+ }
+ }
+}
diff --git a/prog/encodingexec.go b/prog/encodingexec.go
index 304440d0e..9a9cc4b48 100644
--- a/prog/encodingexec.go
+++ b/prog/encodingexec.go
@@ -47,55 +47,32 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error {
args: make(map[*Arg]argInfo),
}
for _, c := range p.Calls {
+ // Calculate checksums.
+ csumMap := calcChecksumsCall(c, pid)
// Calculate arg offsets within structs.
// Generate copyin instructions that fill in data into pointer arguments.
foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) {
if arg.Kind == ArgPointer && arg.Res != nil {
- var rec func(*Arg, uintptr) uintptr
- rec = func(arg1 *Arg, offset uintptr) uintptr {
+ foreachSubargOffset(arg.Res, func(arg1 *Arg, offset uintptr) {
if len(arg1.Uses) != 0 {
w.args[arg1] = argInfo{Offset: offset}
}
- if arg1.Kind == ArgGroup {
- var totalSize uintptr
- for _, arg2 := range arg1.Inner {
- size := rec(arg2, offset)
- if arg2.Type.BitfieldLength() == 0 || arg2.Type.BitfieldLast() {
- offset += size
- totalSize += size
- }
- }
- if totalSize > arg1.Size() {
- panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1))
- }
- return arg1.Size()
- }
- if arg1.Kind == ArgUnion {
- size := rec(arg1.Option, offset)
- offset += size
- if size > arg1.Size() {
- panic(fmt.Sprintf("bad union arg size %v, should be <= %v for arg %+v with type %+v", size, arg1.Size(), arg1, arg1.Type))
- }
- return arg1.Size()
- }
if !sys.IsPad(arg1.Type) &&
!(arg1.Kind == ArgData && len(arg1.Data) == 0) &&
arg1.Type.Dir() != sys.DirOut {
w.write(ExecInstrCopyin)
w.write(physicalAddr(arg) + offset)
- w.writeArg(arg1, pid)
+ w.writeArg(arg1, pid, csumMap)
instrSeq++
}
- return arg1.Size()
- }
- rec(arg.Res, 0)
+ })
}
})
// Generate the call itself.
w.write(uintptr(c.Meta.ID))
w.write(uintptr(len(c.Args)))
for _, arg := range c.Args {
- w.writeArg(arg, pid)
+ w.writeArg(arg, pid, csumMap)
}
if len(c.Ret.Uses) != 0 {
w.args[c.Ret] = argInfo{Idx: instrSeq}
@@ -173,9 +150,14 @@ func (w *execContext) write(v uintptr) {
w.buf = w.buf[8:]
}
-func (w *execContext) writeArg(arg *Arg, pid int) {
+func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]*Arg) {
switch arg.Kind {
case ArgConst:
+ if _, ok := arg.Type.(*sys.CsumType); ok {
+ if arg, ok = csumMap[arg]; !ok {
+ panic("csum arg is not in csum map")
+ }
+ }
w.write(ExecArgConst)
w.write(arg.Size())
w.write(arg.Value(pid))
diff --git a/prog/mutation.go b/prog/mutation.go
index 04465a424..358a2b104 100644
--- a/prog/mutation.go
+++ b/prog/mutation.go
@@ -197,6 +197,8 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro
p.replaceArg(c, arg, arg1, calls)
case *sys.LenType:
panic("bad arg returned by mutationArgs: LenType")
+ case *sys.CsumType:
+ panic("bad arg returned by mutationArgs: CsumType")
case *sys.ConstType:
panic("bad arg returned by mutationArgs: ConstType")
default:
@@ -397,7 +399,7 @@ func Minimize(p0 *Prog, callIndex0 int, pred func(*Prog, int) bool, crash bool)
}
}
p0 = p
- case *sys.VmaType, *sys.LenType, *sys.ConstType:
+ case *sys.VmaType, *sys.LenType, *sys.CsumType, *sys.ConstType:
// TODO: try to remove offset from vma
return false
default:
@@ -460,6 +462,9 @@ func mutationArgs(c *Call) (args, bases []*Arg) {
case *sys.LenType:
// Size is updated when the size-of arg change.
return
+ case *sys.CsumType:
+ // Checksum is updated when the checksummed data changes.
+ return
case *sys.ConstType:
// Well, this is const.
return
diff --git a/prog/prog.go b/prog/prog.go
index 13265e44c..fbd8507c6 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -95,6 +95,8 @@ func (a *Arg) Value(pid int) uintptr {
return encodeValue(a.Val, typ.Size(), typ.BigEndian)
case *sys.LenType:
return encodeValue(a.Val, typ.Size(), typ.BigEndian)
+ case *sys.CsumType:
+ return encodeValue(a.Val, typ.Size(), typ.BigEndian)
case *sys.ProcType:
val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + a.Val
return encodeValue(val, typ.Size(), typ.BigEndian)
@@ -105,7 +107,7 @@ func (a *Arg) Value(pid int) uintptr {
func (a *Arg) Size() uintptr {
switch typ := a.Type.(type) {
case *sys.IntType, *sys.LenType, *sys.FlagsType, *sys.ConstType,
- *sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType:
+ *sys.ResourceType, *sys.VmaType, *sys.PtrType, *sys.ProcType, *sys.CsumType:
return typ.Size()
case *sys.BufferType:
return uintptr(len(a.Data))
diff --git a/prog/rand.go b/prog/rand.go
index 4a7bbe08d..3eebe4b8d 100644
--- a/prog/rand.go
+++ b/prog/rand.go
@@ -765,8 +765,8 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call)
arg, calls1 := r.addr(s, a, inner.Size(), inner)
calls = append(calls, calls1...)
return arg, calls
- case *sys.LenType:
- // Return placeholder value of 0 while generating len args.
+ case *sys.LenType, *sys.CsumType:
+ // Return placeholder value of 0 while generating len and csum args.
return constArg(a, 0), nil
default:
panic("unknown argument type")
diff --git a/prog/validation.go b/prog/validation.go
index 564c0b060..28c619802 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -106,9 +106,13 @@ func (c *Call) validate(ctx *validCtx) error {
switch typ1.Kind {
case sys.BufferString:
if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) {
- return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, len(arg.Data), typ1.Length)
+ return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, typ.Name(), len(arg.Data), typ1.Length)
}
}
+ case *sys.CsumType:
+ if arg.Val != 0 {
+ return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, typ.Name(), arg.Val)
+ }
}
switch arg.Kind {
case ArgConst: