aboutsummaryrefslogtreecommitdiffstats
path: root/prog/checksum.go
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-02-08 15:49:03 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-02-08 17:11:54 +0100
commit0130c7b34e9e4e831c2794f00a0d017040a967a9 (patch)
treed004778543ebef9d348e491f428d0de17d332b28 /prog/checksum.go
parent8792b9237970ec1e93423de538bbe646a8ecff90 (diff)
prog, sys: add icmpv6 packet descriptions and checksums
Also generalize checksums into the two kinds: inet and pseudo. Inet checksums is just the Internet checksum of a packet. Pseudo checksum is the Internet checksum of a packet with a pseudo header.
Diffstat (limited to 'prog/checksum.go')
-rw-r--r--prog/checksum.go146
1 files changed, 80 insertions, 66 deletions
diff --git a/prog/checksum.go b/prog/checksum.go
index 03eb81f7b..8ee00bc24 100644
--- a/prog/checksum.go
+++ b/prog/checksum.go
@@ -184,28 +184,7 @@ func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, p
return header
}
-func findCsumFieldUDP(udpPacket *Arg) *Arg {
- csumField := getFieldByName(udpPacket, "csum")
- if typ, ok := csumField.Type.(*sys.CsumType); !ok {
- panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
- } else if typ.Kind != sys.CsumUDP {
- panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
- }
- return csumField
-}
-
-func findCsumFieldTCP(tcpPacket *Arg) *Arg {
- tcpHeaderField := getFieldByName(tcpPacket, "header")
- csumField := getFieldByName(tcpHeaderField, "csum")
- if typ, ok := csumField.Type.(*sys.CsumType); !ok {
- panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
- } else if typ.Kind != sys.CsumTCP {
- panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
- }
- return csumField
-}
-
-func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg {
+func calcChecksumPseudo(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg {
var csum IPChecksum
csum.Update(pseudoHeader)
csum.Update(encodeArg(packet, pid))
@@ -214,32 +193,75 @@ func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *A
return &newCsumField
}
+func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg {
+ if typ.Buf == "parent" {
+ if csummedArg, ok := parentsMap[arg]; ok {
+ return csummedArg
+ }
+ panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name()))
+ } else {
+ for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] {
+ if typ.Buf == parent.Type.Name() {
+ return parent
+ }
+ }
+ }
+ panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf))
+}
+
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
- var csumMap map[*Arg]*Arg
- ipv4HeaderParsed := false
- ipv6HeaderParsed := false
- var ipSrcAddr *Arg
- var ipDstAddr *Arg
- tcp := false
+ var inetCsumFields []*Arg
+ var pseudoCsumFields []*Arg
- // Calculate inet checksums.
+ // Find all csum fields.
+ foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
+ if typ, ok := arg.Type.(*sys.CsumType); ok {
+ switch typ.Kind {
+ case sys.CsumInet:
+ inetCsumFields = append(inetCsumFields, arg)
+ case sys.CsumPseudo:
+ pseudoCsumFields = append(pseudoCsumFields, arg)
+ default:
+ panic(fmt.Sprintf("unknown csum kind %v\n", typ.Kind))
+ }
+ }
+ })
+
+ // Return if no csum fields found.
+ if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 {
+ return nil
+ }
+
+ // Build map of each field to its parent struct.
+ parentsMap := make(map[*Arg]*Arg)
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
if _, ok := arg.Type.(*sys.StructType); ok {
for _, field := range arg.Inner {
- if typ, ok1 := field.Type.(*sys.CsumType); ok1 {
- if typ.Kind == sys.CsumInet {
- newCsumField := calcChecksumInet(arg, field, pid)
- if csumMap == nil {
- csumMap = make(map[*Arg]*Arg)
- }
- csumMap[field] = newCsumField
- }
- }
+ parentsMap[field.InnerArg()] = arg
}
}
})
- // Calculate tcp and udp checksums.
+ csumMap := make(map[*Arg]*Arg)
+
+ // Calculate inet checksums.
+ for _, arg := range inetCsumFields {
+ typ, _ := arg.Type.(*sys.CsumType)
+ csummedArg := findCsummedArg(arg, typ, parentsMap)
+ newCsumField := calcChecksumInet(csummedArg, arg, pid)
+ csumMap[arg] = newCsumField
+ }
+
+ // No need to continue if there are no pseudo csum fields.
+ if len(pseudoCsumFields) == 0 {
+ return csumMap
+ }
+
+ // Extract ipv4 or ipv6 source and destination addresses.
+ ipv4HeaderParsed := false
+ ipv6HeaderParsed := false
+ var ipSrcAddr *Arg
+ var ipDstAddr *Arg
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
// syz_csum_* structs are used in tests
switch arg.Type.Name() {
@@ -249,34 +271,26 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
case "ipv6_packet", "syz_csum_ipv6_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg)
ipv6HeaderParsed = true
- case "tcp_packet", "syz_csum_tcp_packet":
- tcp = true
- fallthrough
- case "udp_packet", "syz_csum_udp_packet":
- if !ipv4HeaderParsed && !ipv6HeaderParsed {
- panic(fmt.Sprintf("%s is being parsed before ipv4 or ipv6 header", arg.Type.Name()))
- }
- var csumField *Arg
- var protocol uint8
- if tcp {
- csumField = findCsumFieldTCP(arg)
- protocol = 6 // IPPROTO_TCP
- } else {
- csumField = findCsumFieldUDP(arg)
- protocol = 17 // IPPROTO_UDP
- }
- var pseudoHeader []byte
- if ipv4HeaderParsed {
- pseudoHeader = composePseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, protocol, pid)
- } else {
- pseudoHeader = composePseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, protocol, pid)
- }
- if csumMap == nil {
- csumMap = make(map[*Arg]*Arg)
- }
- newCsumField := calcChecksumTCPUDP(arg, csumField, pseudoHeader, pid)
- csumMap[csumField] = newCsumField
}
})
+ if !ipv4HeaderParsed && !ipv6HeaderParsed {
+ panic("no ipv4 nor ipv6 header found")
+ }
+
+ // Calculate pseudo checksums.
+ for _, arg := range pseudoCsumFields {
+ typ, _ := arg.Type.(*sys.CsumType)
+ csummedArg := findCsummedArg(arg, typ, parentsMap)
+ protocol := uint8(typ.Protocol)
+ var pseudoHeader []byte
+ if ipv4HeaderParsed {
+ pseudoHeader = composePseudoHeaderIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid)
+ } else {
+ pseudoHeader = composePseudoHeaderIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid)
+ }
+ newCsumField := calcChecksumPseudo(csummedArg, arg, pseudoHeader, pid)
+ csumMap[arg] = newCsumField
+ }
+
return csumMap
}