aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-01-30 19:17:38 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-01-30 21:00:45 +0100
commit1f7f5daef8b9d4665e463bd842c701b4db27f56b (patch)
tree3b8af0bd368cf9ccc1ebc130c50a4f41dabf320c /prog
parent4ee789185bc215d62e9cfa92e23a8de2760789cb (diff)
prog, sys: add tcp packets descriptions
Also embed tcp checksums into packets.
Diffstat (limited to 'prog')
-rw-r--r--prog/checksum.go93
-rw-r--r--prog/checksum_test.go41
2 files changed, 115 insertions, 19 deletions
diff --git a/prog/checksum.go b/prog/checksum.go
index 3806c59e0..9df541dc1 100644
--- a/prog/checksum.go
+++ b/prog/checksum.go
@@ -4,6 +4,7 @@
package prog
import (
+ "encoding/binary"
"fmt"
"unsafe"
@@ -90,7 +91,7 @@ func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
}
}
-func encodeStruct(arg *Arg, pid int) []byte {
+func encodeArg(arg *Arg, pid int) []byte {
bytes := make([]byte, arg.Size())
foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
switch arg.Kind {
@@ -120,38 +121,96 @@ func encodeStruct(arg *Arg, pid int) []byte {
return bytes
}
-func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
- var csumField *Arg
+func getFieldByName(arg *Arg, name string) *Arg {
for _, field := range arg.Inner {
- if _, ok := field.Type.(*sys.CsumType); ok {
- csumField = field
- break
+ if field.Type.FieldName() == name {
+ return field
}
}
- if csumField == nil {
- panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name()))
+ panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
+}
+
+func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
+ csumField := getFieldByName(arg, "csum")
+ if typ, ok := csumField.Type.(*sys.CsumType); !ok {
+ panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
+ } else if typ.Kind != sys.CsumIPv4 {
+ panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
}
if csumField.Value(pid) != 0 {
panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField))
}
- bytes := encodeStruct(arg, pid)
+ bytes := encodeArg(arg, pid)
csum := ipChecksum(bytes)
newCsumField := *csumField
newCsumField.Val = uintptr(csum)
return csumField, &newCsumField
}
+func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg, *Arg) {
+ srcAddr := getFieldByName(arg, "src_ip")
+ if srcAddr.Size() != 4 {
+ panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
+ }
+ dstAddr := getFieldByName(arg, "dst_ip")
+ if dstAddr.Size() != 4 {
+ panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
+ }
+ protocol := getFieldByName(arg, "protocol")
+ if protocol.Size() != 1 {
+ panic(fmt.Sprintf("protocol field in %v must be 1 byte", arg.Type.Name()))
+ }
+ return srcAddr, dstAddr, protocol
+}
+
+func calcChecksumTCP(tcpPacket, srcAddr, dstAddr, protocol *Arg, pid int) (*Arg, *Arg) {
+ tcpHeaderField := getFieldByName(tcpPacket, "header")
+ csumField := getFieldByName(tcpHeaderField, "csum")
+ if typ, ok := csumField.Type.(*sys.CsumType); !ok {
+ panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
+ } else if typ.Kind != sys.CsumTCP {
+ panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
+ }
+
+ var csum IPChecksum
+ csum.Update(encodeArg(srcAddr, pid))
+ csum.Update(encodeArg(dstAddr, pid))
+ csum.Update([]byte{0, byte(protocol.Value(pid))})
+ length := []byte{0, 0}
+ binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size()))
+ csum.Update(length)
+ csum.Update(encodeArg(tcpPacket, pid))
+
+ newCsumField := *csumField
+ newCsumField.Val = uintptr(csum.Digest())
+ return csumField, &newCsumField
+}
+
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
- var m map[*Arg]*Arg
+ var csumMap map[*Arg]*Arg
+ ipv4HeaderParsed := false
+ var ipv4SrcAddr *Arg
+ var ipv4DstAddr *Arg
+ var ipv4Protocol *Arg
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
- // syz_csum_ipv4 struct is used in tests
- if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" {
- if m == nil {
- m = make(map[*Arg]*Arg)
+ // syz_csum_ipv4_header struct is used in tests
+ if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" {
+ if csumMap == nil {
+ csumMap = make(map[*Arg]*Arg)
+ }
+ csumField, newCsumField := calcChecksumIPv4(arg, pid)
+ csumMap[csumField] = newCsumField
+ ipv4SrcAddr, ipv4DstAddr, ipv4Protocol = extractHeaderParamsIPv4(arg)
+ ipv4HeaderParsed = true
+ }
+ // syz_csum_tcp_packet struct is used in tests
+ if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" {
+ if !ipv4HeaderParsed {
+ panic("tcp_packet is being parsed before ipv4_header")
}
- k, v := calcChecksumIPv4(arg, pid)
- m[k] = v
+ csumField, newCsumField := calcChecksumTCP(arg, ipv4SrcAddr, ipv4DstAddr, ipv4Protocol, pid)
+ csumMap[csumField] = newCsumField
}
})
- return m
+ return csumMap
}
diff --git a/prog/checksum_test.go b/prog/checksum_test.go
index bade7f724..56561d4e0 100644
--- a/prog/checksum_test.go
+++ b/prog/checksum_test.go
@@ -6,6 +6,8 @@ package prog
import (
"bytes"
"testing"
+
+ "github.com/google/syzkaller/sys"
)
func TestChecksumIP(t *testing.T) {
@@ -53,6 +55,10 @@ func TestChecksumIP(t *testing.T) {
"\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd",
0xe143,
},
+ {
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\xab\xcd",
+ 0x542e,
+ },
}
for _, test := range tests {
@@ -102,7 +108,7 @@ func TestChecksumEncode(t *testing.T) {
if err != nil {
t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
}
- encoded := encodeStruct(p.Calls[0].Args[0].Res, 0)
+ encoded := encodeArg(p.Calls[0].Args[0].Res, 0)
if !bytes.Equal(encoded, []byte(test.encoded)) {
t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded))
}
@@ -115,7 +121,7 @@ func TestChecksumIPv4Calc(t *testing.T) {
csum uint16
}{
{
- "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})",
+ "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0})",
0xe143,
},
}
@@ -133,6 +139,37 @@ func TestChecksumIPv4Calc(t *testing.T) {
}
}
+func TestChecksumTCPCalc(t *testing.T) {
+ tests := []struct {
+ prog string
+ csum uint16
+ }{
+ {
+ "syz_test$csum_ipv4_tcp(&(0x7f0000000000)={{0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0}, {{0x0}, \"abcd\"}})",
+ 0x542e,
+ },
+ }
+ for i, test := range tests {
+ p, err := Deserialize([]byte(test.prog))
+ if err != nil {
+ t.Fatalf("failed to deserialize prog %v: %v", test.prog, err)
+ }
+ csumMap := calcChecksumsCall(p.Calls[0], i % 32)
+ for oldField, newField := range csumMap {
+ if typ, ok := newField.Type.(*sys.CsumType); ok {
+ if typ.Kind == sys.CsumTCP {
+ csum := newField.Value(i % 32)
+ if csum != uintptr(test.csum) {
+ t.Fatalf("failed to calc tcp checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog)
+ }
+ }
+ } else {
+ t.Fatalf("non csum key %+v in csum map %+v", oldField, csumMap)
+ }
+ }
+ }
+}
+
func TestChecksumCalcRandom(t *testing.T) {
rs, iters := initTest(t)
for i := 0; i < iters; i++ {