aboutsummaryrefslogtreecommitdiffstats
path: root/prog/checksum.go
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-02-02 19:19:32 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-02-02 19:19:32 +0100
commit13266cc0b604fed3d5f9fc73e9f804091e5c1ac6 (patch)
tree676de3f880ebdc151643c8236801dce532b1f7e7 /prog/checksum.go
parent1a85b51165c0be64dc8651245c2b235a4c1928ec (diff)
prog, sys: add udp description and checksum
Diffstat (limited to 'prog/checksum.go')
-rw-r--r--prog/checksum.go83
1 files changed, 55 insertions, 28 deletions
diff --git a/prog/checksum.go b/prog/checksum.go
index 94df4630a..fdc14a513 100644
--- a/prog/checksum.go
+++ b/prog/checksum.go
@@ -130,8 +130,8 @@ func getFieldByName(arg *Arg, name string) *Arg {
panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name()))
}
-func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
- csumField := getFieldByName(arg, "csum")
+func findCsumFieldIPv4(packet *Arg, pid int) *Arg {
+ csumField := getFieldByName(packet, "csum")
if typ, ok := csumField.Type.(*sys.CsumType); !ok {
panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
} else if typ.Kind != sys.CsumIPv4 {
@@ -140,11 +140,15 @@ func calcChecksumIPv4(arg *Arg, pid int) (*Arg, *Arg) {
if csumField.Value(pid) != 0 {
panic(fmt.Sprintf("checksum field has nonzero value %v, arg: %+v", csumField.Value(pid), csumField))
}
- bytes := encodeArg(arg, pid)
+ return csumField
+}
+
+func calcChecksumIPv4(packet, csumField *Arg, pid int) *Arg {
+ bytes := encodeArg(packet, pid)
csum := ipChecksum(bytes)
newCsumField := *csumField
newCsumField.Val = uintptr(csum)
- return csumField, &newCsumField
+ return &newCsumField
}
func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) {
@@ -171,29 +175,39 @@ func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) {
return srcAddr, dstAddr
}
-func composeTCPPseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, pid int) []byte {
+func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
header := []byte{}
header = append(header, encodeArg(srcAddr, pid)...)
header = append(header, encodeArg(dstAddr, pid)...)
- header = append(header, []byte{0, 6}...) // IPPROTO_TCP == 6
+ header = append(header, []byte{0, protocol}...)
length := []byte{0, 0}
binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size()))
header = append(header, length...)
return header
}
-func composeTCPPseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, pid int) []byte {
+func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte {
header := []byte{}
header = append(header, encodeArg(srcAddr, pid)...)
header = append(header, encodeArg(dstAddr, pid)...)
length := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size()))
header = append(header, length...)
- header = append(header, []byte{0, 0, 0, 6}...) // IPPROTO_TCP == 6
+ header = append(header, []byte{0, 0, 0, protocol}...)
return header
}
-func calcChecksumTCP(tcpPacket *Arg, pseudoHeader []byte, pid int) (*Arg, *Arg) {
+func findCsumFieldUDP(udpPacket *Arg) *Arg {
+ csumField := getFieldByName(udpPacket, "csum")
+ if typ, ok := csumField.Type.(*sys.CsumType); !ok {
+ panic(fmt.Sprintf("checksum field has bad type %v, arg: %+v", csumField.Type, csumField))
+ } else if typ.Kind != sys.CsumUDP {
+ panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
+ }
+ return csumField
+}
+
+func findCsumFieldTCP(tcpPacket *Arg) *Arg {
tcpHeaderField := getFieldByName(tcpPacket, "header")
csumField := getFieldByName(tcpHeaderField, "csum")
if typ, ok := csumField.Type.(*sys.CsumType); !ok {
@@ -201,14 +215,16 @@ func calcChecksumTCP(tcpPacket *Arg, pseudoHeader []byte, pid int) (*Arg, *Arg)
} else if typ.Kind != sys.CsumTCP {
panic(fmt.Sprintf("checksum field has bad kind %v, arg: %+v", typ.Kind, csumField))
}
+ return csumField
+}
+func calcChecksumTCPUDP(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg {
var csum IPChecksum
csum.Update(pseudoHeader)
- csum.Update(encodeArg(tcpPacket, pid))
-
+ csum.Update(encodeArg(packet, pid))
newCsumField := *csumField
newCsumField.Val = uintptr(csum.Digest())
- return csumField, &newCsumField
+ return &newCsumField
}
func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
@@ -217,37 +233,48 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg {
ipv6HeaderParsed := false
var ipSrcAddr *Arg
var ipDstAddr *Arg
+ tcp := false
foreachArgArray(&c.Args, nil, func(arg, base *Arg, _ *[]*Arg) {
- // syz_csum_ipv4_header struct is used in tests
- if arg.Type.Name() == "ipv4_header" || arg.Type.Name() == "syz_csum_ipv4_header" {
+ // syz_csum_* structs are used in tests
+ switch arg.Type.Name() {
+ case "ipv4_header", "syz_csum_ipv4_header":
if csumMap == nil {
csumMap = make(map[*Arg]*Arg)
}
- csumField, newCsumField := calcChecksumIPv4(arg, pid)
+ csumField := findCsumFieldIPv4(arg, pid)
+ newCsumField := calcChecksumIPv4(arg, csumField, pid)
csumMap[csumField] = newCsumField
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv4(arg)
ipv4HeaderParsed = true
- }
- // syz_csum_ipv6_header struct is used in tests
- if arg.Type.Name() == "ipv6_packet" || arg.Type.Name() == "syz_csum_ipv6_header" {
+ case "ipv6_packet", "syz_csum_ipv6_header":
ipSrcAddr, ipDstAddr = extractHeaderParamsIPv6(arg)
ipv6HeaderParsed = true
- }
- // syz_csum_tcp_packet struct is used in tests
- if arg.Type.Name() == "tcp_packet" || arg.Type.Name() == "syz_csum_tcp_packet" {
- if csumMap == nil {
- csumMap = make(map[*Arg]*Arg)
- }
+ case "tcp_packet", "syz_csum_tcp_packet":
+ tcp = true
+ fallthrough
+ case "udp_packet", "syz_csum_udp_packet":
if !ipv4HeaderParsed && !ipv6HeaderParsed {
- panic("tcp packet is being parsed before ipv4 or ipv6 header")
+ panic(fmt.Sprintf("%s is being parsed before ipv4 or ipv6 header", arg.Type.Name()))
+ }
+ var csumField *Arg
+ var protocol uint8
+ if tcp {
+ csumField = findCsumFieldTCP(arg)
+ protocol = 6 // IPPROTO_TCP
+ } else {
+ csumField = findCsumFieldUDP(arg)
+ protocol = 17 // IPPROTO_UDP
}
var pseudoHeader []byte
if ipv4HeaderParsed {
- pseudoHeader = composeTCPPseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, pid)
+ pseudoHeader = composePseudoHeaderIPv4(arg, ipSrcAddr, ipDstAddr, protocol, pid)
} else {
- pseudoHeader = composeTCPPseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, pid)
+ pseudoHeader = composePseudoHeaderIPv6(arg, ipSrcAddr, ipDstAddr, protocol, pid)
+ }
+ if csumMap == nil {
+ csumMap = make(map[*Arg]*Arg)
}
- csumField, newCsumField := calcChecksumTCP(arg, pseudoHeader, pid)
+ newCsumField := calcChecksumTCPUDP(arg, csumField, pseudoHeader, pid)
csumMap[csumField] = newCsumField
}
})