diff options
| author | Andrey Konovalov <andreyknvl@google.com> | 2017-02-08 15:49:03 +0100 |
|---|---|---|
| committer | Andrey Konovalov <andreyknvl@google.com> | 2017-02-08 17:11:54 +0100 |
| commit | 0130c7b34e9e4e831c2794f00a0d017040a967a9 (patch) | |
| tree | d004778543ebef9d348e491f428d0de17d332b28 /prog/checksum.go | |
| parent | 8792b9237970ec1e93423de538bbe646a8ecff90 (diff) | |
prog, sys: add icmpv6 packet descriptions and checksums
Also generalize checksums into the two kinds: inet and pseudo.
Inet checksums is just the Internet checksum of a packet.
Pseudo checksum is the Internet checksum of a packet with a pseudo header.
Diffstat (limited to 'prog/checksum.go')
| -rw-r--r-- | prog/checksum.go | 146 |
1 files changed, 80 insertions, 66 deletions
diff --git a/prog/checksum.go b/prog/checksum.go index 03eb81f7b..8ee00bc24 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -184,28 +184,7 @@ func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, p return header } -func findCsumFieldUDP(udpPacket *Arg) *Arg { - csumField := getFieldByName(udpPacket, "csum") - if typ, ok := csumField.Type.(*sys.CsumType); !ok { - panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) - } else if typ.Kind != sys.CsumUDP { - panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) - } - return csumField -} - -func findCsumFieldTCP(tcpPacket *Arg) *Arg { - tcpHeaderField := getFieldByName(tcpPacket, "header") - csumField := getFieldByName(tcpHeaderField, "csum") - if typ, ok := csumField.Type.(*sys.CsumType); !ok { - panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) - } else if typ.Kind != sys.CsumTCP { - panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) - } - return csumField -} - -func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg { +func calcChecksumPseudo(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg { var csum IPChecksum csum.Update(pseudoHeader) csum.Update(encodeArg(packet, pid)) @@ -214,32 +193,75 @@ func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *A return &newCsumField } +func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg { + if typ.Buf == "parent" { + if csummedArg, ok := parentsMap[arg]; ok { + return csummedArg + } + panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name())) + } else { + for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { + if typ.Buf == parent.Type.Name() { + return parent + } + } + } + panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) +} + func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { - var csumMap map[*Arg]*Arg - ipv4HeaderParsed := false - ipv6HeaderParsed := false - var ipSrcAddr *Arg - var ipDstAddr *Arg - tcp := false + var inetCsumFields []*Arg + var pseudoCsumFields []*Arg - // Calculate inet checksums. + // Find all csum fields. + foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { + if typ, ok := arg.Type.(*sys.CsumType); ok { + switch typ.Kind { + case sys.CsumInet: + inetCsumFields = append(inetCsumFields, arg) + case sys.CsumPseudo: + pseudoCsumFields = append(pseudoCsumFields, arg) + default: + panic(fmt.Sprintf("unknown csum kind %v\n", typ.Kind)) + } + } + }) + + // Return if no csum fields found. + if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 { + return nil + } + + // Build map of each field to its parent struct. + parentsMap := make(map[*Arg]*Arg) foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { if _, ok := arg.Type.(*sys.StructType); ok { for _, field := range arg.Inner { - if typ, ok1 := field.Type.(*sys.CsumType); ok1 { - if typ.Kind == sys.CsumInet { - newCsumField := calcChecksumInet(arg, field, pid) - if csumMap == nil { - csumMap = make(map[*Arg]*Arg) - } - csumMap[field] = newCsumField - } - } + parentsMap[field.InnerArg()] = arg } } }) - // Calculate tcp and udp checksums. + csumMap := make(map[*Arg]*Arg) + + // Calculate inet checksums. + for _, arg := range inetCsumFields { + typ, _ := arg.Type.(*sys.CsumType) + csummedArg := findCsummedArg(arg, typ, parentsMap) + newCsumField := calcChecksumInet(csummedArg, arg, pid) + csumMap[arg] = newCsumField + } + + // No need to continue if there are no pseudo csum fields. + if len(pseudoCsumFields) == 0 { + return csumMap + } + + // Extract ipv4 or ipv6 source and destination addresses. + ipv4HeaderParsed := false + ipv6HeaderParsed := false + var ipSrcAddr *Arg + var ipDstAddr *Arg foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { // syz_csum_* structs are used in tests switch arg.Type.Name() { @@ -249,34 +271,26 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { case "ipv6_packet", "syz_csum_ipv6_header": ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg) ipv6HeaderParsed = true - case "tcp_packet", "syz_csum_tcp_packet": - tcp = true - fallthrough - case "udp_packet", "syz_csum_udp_packet": - if !ipv4HeaderParsed && !ipv6HeaderParsed { - panic(fmt.Sprintf("%s is being parsed before ipv4 or ipv6 header", arg.Type.Name())) - } - var csumField *Arg - var protocol uint8 - if tcp { - csumField = findCsumFieldTCP(arg) - protocol = 6 // IPPROTO_TCP - } else { - csumField = findCsumFieldUDP(arg) - protocol = 17 // IPPROTO_UDP - } - var pseudoHeader []byte - if ipv4HeaderParsed { - pseudoHeader = composePseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, protocol, pid) - } else { - pseudoHeader = composePseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, protocol, pid) - } - if csumMap == nil { - csumMap = make(map[*Arg]*Arg) - } - newCsumField := calcChecksumTCPUDP(arg, csumField, pseudoHeader, pid) - csumMap[csumField] = newCsumField } }) + if !ipv4HeaderParsed && !ipv6HeaderParsed { + panic("no ipv4 nor ipv6 header found") + } + + // Calculate pseudo checksums. + for _, arg := range pseudoCsumFields { + typ, _ := arg.Type.(*sys.CsumType) + csummedArg := findCsummedArg(arg, typ, parentsMap) + protocol := uint8(typ.Protocol) + var pseudoHeader []byte + if ipv4HeaderParsed { + pseudoHeader = composePseudoHeaderIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + } else { + pseudoHeader = composePseudoHeaderIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + } + newCsumField := calcChecksumPseudo(csummedArg, arg, pseudoHeader, pid) + csumMap[arg] = newCsumField + } + return csumMap } |
