diff options
| author | Andrey Konovalov <andreyknvl@google.com> | 2017-01-30 19:17:38 +0100 |
|---|---|---|
| committer | Andrey Konovalov <andreyknvl@google.com> | 2017-01-30 21:00:45 +0100 |
| commit | 1f7f5daef8b9d4665e463bd842c701b4db27f56b (patch) | |
| tree | 3b8af0bd368cf9ccc1ebc130c50a4f41dabf320c /prog | |
| parent | 4ee789185bc215d62e9cfa92e23a8de2760789cb (diff) | |
prog, sys: add tcp packets descriptions
Also embed tcp checksums into packets.
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/checksum.go | 93 | ||||
| -rw-r--r-- | prog/checksum_test.go | 41 |
2 files changed, 115 insertions, 19 deletions
diff --git a/prog/checksum.go b/prog/checksum.go index 3806c59e0..9df541dc1 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -4,6 +4,7 @@ package prog import ( + "encoding/binary" "fmt" "unsafe" @@ -90,7 +91,7 @@ func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) { } } -func encodeStruct(arg *Arg, pid int) []byte { +func encodeArg(arg *Arg, pid int) []byte { bytes := make([]byte, arg.Size()) foreachSubargOffset(arg, func(arg *Arg, offset uintptr) { switch arg.Kind { @@ -120,38 +121,96 @@ func encodeStruct(arg *Arg, pid int) []byte { return bytes } -func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { - var csumField *Arg +func getFieldByName(arg *Arg, name string) *Arg { for _, field := range arg.Inner { - if _, ok := field.Type.(*sys.CsumType); ok { - csumField = field - break + if field.Type.FieldName() == name { + return field } } - if csumField == nil { - panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name())) + panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name())) +} + +func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { + csumField := getFieldByName(arg, "csum") + if typ, ok := csumField.Type.(*sys.CsumType); !ok { + panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) + } else if typ.Kind != sys.CsumIPv4 { + panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) } if csumField.Value(pid) != 0 { panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField)) } - bytes := encodeStruct(arg, pid) + bytes := encodeArg(arg, pid) csum := ipChecksum(bytes) newCsumField := *csumField newCsumField.Val = uintptr(csum) return csumField, &newCsumField } +func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg, *Arg) { + srcAddr := getFieldByName(arg, "src_ip") + if srcAddr.Size() != 4 { + panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name())) + } + dstAddr := getFieldByName(arg, "dst_ip") + if dstAddr.Size() != 4 { + panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name())) + } + protocol := getFieldByName(arg, "protocol") + if protocol.Size() != 1 { + panic(fmt.Sprintf("protocol field in %v must be 1 byte", arg.Type.Name())) + } + return srcAddr, dstAddr, protocol +} + +func calcChecksumTCP(tcpPacket, srcAddr, dstAddr, protocol *Arg, pid int) (*Arg, *Arg) { + tcpHeaderField := getFieldByName(tcpPacket, "header") + csumField := getFieldByName(tcpHeaderField, "csum") + if typ, ok := csumField.Type.(*sys.CsumType); !ok { + panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) + } else if typ.Kind != sys.CsumTCP { + panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) + } + + var csum IPChecksum + csum.Update(encodeArg(srcAddr, pid)) + csum.Update(encodeArg(dstAddr, pid)) + csum.Update([]byte{0, byte(protocol.Value(pid))}) + length := []byte{0, 0} + binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size())) + csum.Update(length) + csum.Update(encodeArg(tcpPacket, pid)) + + newCsumField := *csumField + newCsumField.Val = uintptr(csum.Digest()) + return csumField, &newCsumField +} + func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { - var m map[*Arg]*Arg + var csumMap map[*Arg]*Arg + ipv4HeaderParsed := false + var ipv4SrcAddr *Arg + var ipv4DstAddr *Arg + var ipv4Protocol *Arg foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { - // syz_csum_ipv4 struct is used in tests - if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" { - if m == nil { - m = make(map[*Arg]*Arg) + // syz_csum_ipv4_header struct is used in tests + if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" { + if csumMap == nil { + csumMap = make(map[*Arg]*Arg) + } + csumField, newCsumField := calcChecksumIPv4(arg, pid) + csumMap[csumField] = newCsumField + ipv4SrcAddr, ipv4DstAddr, ipv4Protocol = extractHeaderParamsIPv4(arg) + ipv4HeaderParsed = true + } + // syz_csum_tcp_packet struct is used in tests + if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" { + if !ipv4HeaderParsed { + panic("tcp_packet is being parsed before ipv4_header") } - k, v := calcChecksumIPv4(arg, pid) - m[k] = v + csumField, newCsumField := calcChecksumTCP(arg, ipv4SrcAddr, ipv4DstAddr, ipv4Protocol, pid) + csumMap[csumField] = newCsumField } }) - return m + return csumMap } diff --git a/prog/checksum_test.go b/prog/checksum_test.go index bade7f724..56561d4e0 100644 --- a/prog/checksum_test.go +++ b/prog/checksum_test.go @@ -6,6 +6,8 @@ package prog import ( "bytes" "testing" + + "github.com/google/syzkaller/sys" ) func TestChecksumIP(t *testing.T) { @@ -53,6 +55,10 @@ func TestChecksumIP(t *testing.T) { "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", 0xe143, }, + { + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\xab\xcd", + 0x542e, + }, } for _, test := range tests { @@ -102,7 +108,7 @@ func TestChecksumEncode(t *testing.T) { if err != nil { t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) } - encoded := encodeStruct(p.Calls[0].Args[0].Res, 0) + encoded := encodeArg(p.Calls[0].Args[0].Res, 0) if !bytes.Equal(encoded, []byte(test.encoded)) { t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded)) } @@ -115,7 +121,7 @@ func TestChecksumIPv4Calc(t *testing.T) { csum uint16 }{ { - "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}})", + "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0})", 0xe143, }, } @@ -133,6 +139,37 @@ func TestChecksumIPv4Calc(t *testing.T) { } } +func TestChecksumTCPCalc(t *testing.T) { + tests := []struct { + prog string + csum uint16 + }{ + { + "syz_test$csum_ipv4_tcp(&(0x7f0000000000)={{0x0, {0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"}, 0x0, 0x0, 0x0}, {{0x0}, \"abcd\"}})", + 0x542e, + }, + } + for i, test := range tests { + p, err := Deserialize([]byte(test.prog)) + if err != nil { + t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) + } + csumMap := calcChecksumsCall(p.Calls[0], i % 32) + for oldField, newField := range csumMap { + if typ, ok := newField.Type.(*sys.CsumType); ok { + if typ.Kind == sys.CsumTCP { + csum := newField.Value(i % 32) + if csum != uintptr(test.csum) { + t.Fatalf("failed to calc tcp checksum, got %x, want %x, prog: '%v'", csum, test.csum, test.prog) + } + } + } else { + t.Fatalf("non csum key %+v in csum map %+v", oldField, csumMap) + } + } + } +} + func TestChecksumCalcRandom(t *testing.T) { rs, iters := initTest(t) for i := 0; i < iters; i++ { |
