diff options
| author | Andrey Konovalov <andreyknvl@google.com> | 2017-02-02 19:19:32 +0100 |
|---|---|---|
| committer | Andrey Konovalov <andreyknvl@google.com> | 2017-02-02 19:19:32 +0100 |
| commit | 13266cc0b604fed3d5f9fc73e9f804091e5c1ac6 (patch) | |
| tree | 676de3f880ebdc151643c8236801dce532b1f7e7 /prog/checksum.go | |
| parent | 1a85b51165c0be64dc8651245c2b235a4c1928ec (diff) | |
prog, sys: add udp description and checksum
Diffstat (limited to 'prog/checksum.go')
| -rw-r--r-- | prog/checksum.go | 83 |
1 files changed, 55 insertions, 28 deletions
diff --git a/prog/checksum.go b/prog/checksum.go index 94df4630a..fdc14a513 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -130,8 +130,8 @@ func getFieldByName(arg *Arg, name string) *Arg { panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name())) } -func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { - csumField := getFieldByName(arg, "csum") +func findCsumFieldIPv4(packet *Arg, pid int) *Arg { + csumField := getFieldByName(packet, "csum") if typ, ok := csumField.Type.(*sys.CsumType); !ok { panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) } else if typ.Kind != sys.CsumIPv4 { @@ -140,11 +140,15 @@ func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) { if csumField.Value(pid) != 0 { panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField)) } - bytes := encodeArg(arg, pid) + return csumField +} + +func calcChecksumIPv4(packet, csumField *Arg, pid int) *Arg { + bytes := encodeArg(packet, pid) csum := ipChecksum(bytes) newCsumField := *csumField newCsumField.Val = uintptr(csum) - return csumField, &newCsumField + return &newCsumField } func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) { @@ -171,29 +175,39 @@ func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) { return srcAddr, dstAddr } -func composeTCPPseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, pid int) []byte { +func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { header := []byte{} header = append(header, encodeArg(srcAddr, pid)...) header = append(header, encodeArg(dstAddr, pid)...) - header = append(header, []byte{0, 6}...) // IPPROTO_TCP == 6 + header = append(header, []byte{0, protocol}...) length := []byte{0, 0} binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size())) header = append(header, length...) return header } -func composeTCPPseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, pid int) []byte { +func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { header := []byte{} header = append(header, encodeArg(srcAddr, pid)...) header = append(header, encodeArg(dstAddr, pid)...) length := []byte{0, 0, 0, 0} binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size())) header = append(header, length...) - header = append(header, []byte{0, 0, 0, 6}...) // IPPROTO_TCP == 6 + header = append(header, []byte{0, 0, 0, protocol}...) return header } -func calcChecksumTCP(tcpPacket *Arg, pseudoHeader []byte, pid int) (*Arg, *Arg) { +func findCsumFieldUDP(udpPacket *Arg) *Arg { + csumField := getFieldByName(udpPacket, "csum") + if typ, ok := csumField.Type.(*sys.CsumType); !ok { + panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField)) + } else if typ.Kind != sys.CsumUDP { + panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) + } + return csumField +} + +func findCsumFieldTCP(tcpPacket *Arg) *Arg { tcpHeaderField := getFieldByName(tcpPacket, "header") csumField := getFieldByName(tcpHeaderField, "csum") if typ, ok := csumField.Type.(*sys.CsumType); !ok { @@ -201,14 +215,16 @@ func calcChecksumTCP(tcpPacket *Arg, pseudoHeader []byte, pid int) (*Arg, *Arg) } else if typ.Kind != sys.CsumTCP { panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField)) } + return csumField +} +func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg { var csum IPChecksum csum.Update(pseudoHeader) - csum.Update(encodeArg(tcpPacket, pid)) - + csum.Update(encodeArg(packet, pid)) newCsumField := *csumField newCsumField.Val = uintptr(csum.Digest()) - return csumField, &newCsumField + return &newCsumField } func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { @@ -217,37 +233,48 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { ipv6HeaderParsed := false var ipSrcAddr *Arg var ipDstAddr *Arg + tcp := false foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) { - // syz_csum_ipv4_header struct is used in tests - if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" { + // syz_csum_* structs are used in tests + switch arg.Type.Name() { + case "ipv4_header", "syz_csum_ipv4_header": if csumMap == nil { csumMap = make(map[*Arg]*Arg) } - csumField, newCsumField := calcChecksumIPv4(arg, pid) + csumField := findCsumFieldIPv4(arg, pid) + newCsumField := calcChecksumIPv4(arg, csumField, pid) csumMap[csumField] = newCsumField ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg) ipv4HeaderParsed = true - } - // syz_csum_ipv6_header struct is used in tests - if arg.Type.Name() == "ipv6_packet" || arg.Type.Name() == "syz_csum_ipv6_header" { + case "ipv6_packet", "syz_csum_ipv6_header": ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg) ipv6HeaderParsed = true - } - // syz_csum_tcp_packet struct is used in tests - if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" { - if csumMap == nil { - csumMap = make(map[*Arg]*Arg) - } + case "tcp_packet", "syz_csum_tcp_packet": + tcp = true + fallthrough + case "udp_packet", "syz_csum_udp_packet": if !ipv4HeaderParsed && !ipv6HeaderParsed { - panic("tcp packet is being parsed before ipv4 or ipv6 header") + panic(fmt.Sprintf("%s is being parsed before ipv4 or ipv6 header", arg.Type.Name())) + } + var csumField *Arg + var protocol uint8 + if tcp { + csumField = findCsumFieldTCP(arg) + protocol = 6 // IPPROTO_TCP + } else { + csumField = findCsumFieldUDP(arg) + protocol = 17 // IPPROTO_UDP } var pseudoHeader []byte if ipv4HeaderParsed { - pseudoHeader = composeTCPPseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, pid) + pseudoHeader = composePseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, protocol, pid) } else { - pseudoHeader = composeTCPPseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, pid) + pseudoHeader = composePseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, protocol, pid) + } + if csumMap == nil { + csumMap = make(map[*Arg]*Arg) } - csumField, newCsumField := calcChecksumTCP(arg, pseudoHeader, pid) + newCsumField := calcChecksumTCPUDP(arg, csumField, pseudoHeader, pid) csumMap[csumField] = newCsumField } }) |
