aboutsummaryrefslogtreecommitdiffstats
path: root/prog/checksum.go
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-01-30 19:17:38 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-01-30 21:00:45 +0100
commit1f7f5daef8b9d4665e463bd842c701b4db27f56b (patch)
tree3b8af0bd368cf9ccc1ebc130c50a4f41dabf320c /prog/checksum.go
parent4ee789185bc215d62e9cfa92e23a8de2760789cb (diff)
prog, sys: add tcp packets descriptions
Also embed tcp checksums into packets.
Diffstat (limited to 'prog/checksum.go')
-rw-r--r--prog/checksum.go93
1 files changed, 76 insertions, 17 deletions
diff --git a/prog/checksum.go b/prog/checksum.go
index 3806c59e0..9df541dc1 100644
--- a/prog/checksum.go
+++ b/prog/checksum.go
@@ -4,6 +4,7 @@
package prog
import (
+ "encoding/binary"
"fmt"
"unsafe"
@@ -90,7 +91,7 @@ func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) {
}
}
-func encodeStruct(arg *Arg, pid int) []byte {
+func encodeArg(arg *Arg, pid int) []byte {
bytes := make([]byte, arg.Size())
foreachSubargOffset(arg, func(arg *Arg, offset uintptr) {
switch arg.Kind {
@@ -120,38 +121,96 @@ func encodeStruct(arg *Arg, pid int) []byte {
return bytes
}
-func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
- var csumField *Arg
+func getFieldByName(arg *Arg, name string) *Arg {
for _, field := range arg.Inner {
- if _, ok := field.Type.(*sys.CsumType); ok {
- csumField = field
- break
+ if field.Type.FieldName() == name {
+ return field
}
}
- if csumField == nil {
- panic(fmt.Sprintf("failed to find csum field in %v", arg.Type.Name()))
+ panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
+}
+
+func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
+ csumField := getFieldByName(arg, "csum")
+ if typ, ok := csumField.Type.(*sys.CsumType); !ok {
+ panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
+ } else if typ.Kind != sys.CsumIPv4 {
+ panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
}
if csumField.Value(pid) != 0 {
panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField))
}
- bytes := encodeStruct(arg, pid)
+ bytes := encodeArg(arg, pid)
csum := ipChecksum(bytes)
newCsumField := *csumField
newCsumField.Val = uintptr(csum)
return csumField, &newCsumField
}
+func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg, *Arg) {
+ srcAddr := getFieldByName(arg, "src_ip")
+ if srcAddr.Size() != 4 {
+ panic(fmt.Sprintf("src_ip field in %v must be 4 bytes", arg.Type.Name()))
+ }
+ dstAddr := getFieldByName(arg, "dst_ip")
+ if dstAddr.Size() != 4 {
+ panic(fmt.Sprintf("dst_ip field in %v must be 4 bytes", arg.Type.Name()))
+ }
+ protocol := getFieldByName(arg, "protocol")
+ if protocol.Size() != 1 {
+ panic(fmt.Sprintf("protocol field in %v must be 1 byte", arg.Type.Name()))
+ }
+ return srcAddr, dstAddr, protocol
+}
+
+func calcChecksumTCP(tcpPacket, srcAddr, dstAddr, protocol *Arg, pid int) (*Arg, *Arg) {
+ tcpHeaderField := getFieldByName(tcpPacket, "header")
+ csumField := getFieldByName(tcpHeaderField, "csum")
+ if typ, ok := csumField.Type.(*sys.CsumType); !ok {
+ panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
+ } else if typ.Kind != sys.CsumTCP {
+ panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
+ }
+
+ var csum IPChecksum
+ csum.Update(encodeArg(srcAddr, pid))
+ csum.Update(encodeArg(dstAddr, pid))
+ csum.Update([]byte{0, byte(protocol.Value(pid))})
+ length := []byte{0, 0}
+ binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size()))
+ csum.Update(length)
+ csum.Update(encodeArg(tcpPacket, pid))
+
+ newCsumField := *csumField
+ newCsumField.Val = uintptr(csum.Digest())
+ return csumField, &newCsumField
+}
+
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
- var m map[*Arg]*Arg
+ var csumMap map[*Arg]*Arg
+ ipv4HeaderParsed := false
+ var ipv4SrcAddr *Arg
+ var ipv4DstAddr *Arg
+ var ipv4Protocol *Arg
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
- // syz_csum_ipv4 struct is used in tests
- if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4" {
- if m == nil {
- m = make(map[*Arg]*Arg)
+ // syz_csum_ipv4_header struct is used in tests
+ if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" {
+ if csumMap == nil {
+ csumMap = make(map[*Arg]*Arg)
+ }
+ csumField, newCsumField := calcChecksumIPv4(arg, pid)
+ csumMap[csumField] = newCsumField
+ ipv4SrcAddr, ipv4DstAddr, ipv4Protocol = extractHeaderParamsIPv4(arg)
+ ipv4HeaderParsed = true
+ }
+ // syz_csum_tcp_packet struct is used in tests
+ if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" {
+ if !ipv4HeaderParsed {
+ panic("tcp_packet is being parsed before ipv4_header")
}
- k, v := calcChecksumIPv4(arg, pid)
- m[k] = v
+ csumField, newCsumField := calcChecksumTCP(arg, ipv4SrcAddr, ipv4DstAddr, ipv4Protocol, pid)
+ csumMap[csumField] = newCsumField
}
})
- return m
+ return csumMap
}