diff options
| author | Andrey Konovalov <andreyknvl@google.com> | 2017-04-27 20:31:00 +0200 |
|---|---|---|
| committer | Andrey Konovalov <andreyknvl@google.com> | 2017-05-12 15:47:59 +0200 |
| commit | ac0c70f74a5badbebec721c2be0602ea98c0437b (patch) | |
| tree | 69278d03cdaae547c1c87a84deb44969081837bd /prog/checksum.go | |
| parent | b2dbb4f4d10f436088ba8a1d2d18437911a83887 (diff) | |
prog, executor: move checksum computation to executor
This commit moves checksum computation to executor. This will allow to embed
dynamically generated values (like TCP sequence numbers) into packets.
Diffstat (limited to 'prog/checksum.go')
| -rw-r--r-- | prog/checksum.go | 193 |
1 files changed, 44 insertions, 149 deletions
diff --git a/prog/checksum.go b/prog/checksum.go index 8ee00bc24..5db10d26f 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -4,121 +4,34 @@ package prog import ( - "encoding/binary" "fmt" - "unsafe" "github.com/google/syzkaller/sys" ) -type IPChecksum struct { - acc uint32 -} - -func (csum *IPChecksum) Update(data []byte) { - length := len(data) - 1 - for i := 0; i < length; i += 2 { - csum.acc += uint32(data[i]) << 8 - csum.acc += uint32(data[i+1]) - } - if len(data)%2 == 1 { - csum.acc += uint32(data[length]) << 8 - } - for csum.acc > 0xffff { - csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff) - } -} - -func (csum *IPChecksum) Digest() uint16 { - return ^uint16(csum.acc) -} - -func ipChecksum(data []byte) uint16 { - var csum IPChecksum - csum.Update(data) - return csum.Digest() -} +type CsumKind int -func bitmaskLen(bfLen uint64) uint64 { - return (1 << bfLen) - 1 -} +const ( + CsumInet CsumKind = iota +) -func bitmaskLenOff(bfOff, bfLen uint64) uint64 { - return bitmaskLen(bfLen) << bfOff -} +type CsumChunkKind int -func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} - -func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} +const ( + CsumChunkArg CsumChunkKind = iota + CsumChunkConst +) -func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } +type CsumInfo struct { + Kind CsumKind + Chunks []CsumChunk } -func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} - -func encodeArg(arg *Arg, pid int) []byte { - bytes := make([]byte, arg.Size()) - foreachSubargOffset(arg, func(arg *Arg, offset uintptr) { - switch arg.Kind { - case ArgConst: - addr := unsafe.Pointer(&bytes[offset]) - val := arg.Value(pid) - bfOff := uint64(arg.Type.BitfieldOffset()) - bfLen := uint64(arg.Type.BitfieldLength()) - switch arg.Size() { - case 1: - storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen) - case 2: - storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen) - case 4: - storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen) - case 8: - storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen) - default: - panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg)) - } - case ArgData: - copy(bytes[offset:], arg.Data) - default: - panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type)) - } - }) - return bytes +type CsumChunk struct { + Kind CsumChunkKind + Arg *Arg // for CsumChunkArg + Value uintptr // for CsumChunkConst + Size uintptr // for CsumChunkConst } func getFieldByName(arg *Arg, name string) *Arg { @@ -130,14 +43,6 @@ func getFieldByName(arg *Arg, name string) *Arg { panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name())) } -func calcChecksumInet(packet, csumField *Arg, pid int) *Arg { - bytes := encodeArg(packet, pid) - csum := ipChecksum(bytes) - newCsumField := *csumField - newCsumField.Val = uintptr(csum) - return &newCsumField -} - func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) { srcAddr := getFieldByName(arg, "src_ip") if srcAddr.Size() != 4 { @@ -162,35 +67,24 @@ func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) { return srcAddr, dstAddr } -func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { - header := []byte{} - header = append(header, encodeArg(srcAddr, pid)...) - header = append(header, encodeArg(dstAddr, pid)...) - header = append(header, []byte{0, protocol}...) - length := []byte{0, 0} - binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size())) - header = append(header, length...) - return header -} - -func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { - header := []byte{} - header = append(header, encodeArg(srcAddr, pid)...) - header = append(header, encodeArg(dstAddr, pid)...) - length := []byte{0, 0, 0, 0} - binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size())) - header = append(header, length...) - header = append(header, []byte{0, 0, 0, protocol}...) - return header +func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo { + info := CsumInfo{Kind: CsumInet} + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap16(uint16(protocol))), 2}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap16(uint16(tcpPacket.Size()))), 2}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) + return info } -func calcChecksumPseudo(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg { - var csum IPChecksum - csum.Update(pseudoHeader) - csum.Update(encodeArg(packet, pid)) - newCsumField := *csumField - newCsumField.Val = uintptr(csum.Digest()) - return &newCsumField +func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo { + info := CsumInfo{Kind: CsumInet} + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap32(uint32(tcpPacket.Size()))), 4}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap32(uint32(protocol))), 4}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) + return info } func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg { @@ -209,7 +103,7 @@ func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) } -func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { +func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo { var inetCsumFields []*Arg var pseudoCsumFields []*Arg @@ -242,14 +136,16 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { } }) - csumMap := make(map[*Arg]*Arg) + csumMap := make(map[*Arg]CsumInfo) - // Calculate inet checksums. + // Calculate generic inet checksums. for _, arg := range inetCsumFields { typ, _ := arg.Type.(*sys.CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) - newCsumField := calcChecksumInet(csummedArg, arg, pid) - csumMap[arg] = newCsumField + chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0} + info := CsumInfo{Kind: CsumInet, Chunks: make([]CsumChunk, 0)} + info.Chunks = append(info.Chunks, chunk) + csumMap[arg] = info } // No need to continue if there are no pseudo csum fields. @@ -282,14 +178,13 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { typ, _ := arg.Type.(*sys.CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) protocol := uint8(typ.Protocol) - var pseudoHeader []byte + var info CsumInfo if ipv4HeaderParsed { - pseudoHeader = composePseudoHeaderIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) } else { - pseudoHeader = composePseudoHeaderIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) } - newCsumField := calcChecksumPseudo(csummedArg, arg, pseudoHeader, pid) - csumMap[arg] = newCsumField + csumMap[arg] = info } return csumMap |
