diff options
| author | Andrey Konovalov <andreyknvl@gmail.com> | 2017-05-17 01:56:40 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-05-17 01:56:40 +0200 |
| commit | 64cd235dcf60d3ab0dec3d4b0f168297782417a7 (patch) | |
| tree | 67c41279da959d253aab080b873284e6e2e5186a | |
| parent | ffb66e506d9083140196d9b4b2a46e15bc4a5d1e (diff) | |
| parent | ac0c70f74a5badbebec721c2be0602ea98c0437b (diff) | |
Merge pull request #169 from xairy/executor-csum
prog, executor: move checksum computation to executor
| -rw-r--r-- | csource/common.go | 30 | ||||
| -rw-r--r-- | csource/csource.go | 27 | ||||
| -rw-r--r-- | executor/common.h | 30 | ||||
| -rw-r--r-- | executor/executor.cc | 54 | ||||
| -rw-r--r-- | executor/test.go | 10 | ||||
| -rw-r--r-- | executor/test_executor.cc (renamed from executor/test_kvm.cc) | 168 | ||||
| -rw-r--r-- | executor/test_test.go | 24 | ||||
| -rw-r--r-- | prog/analysis.go | 2 | ||||
| -rw-r--r-- | prog/checksum.go | 193 | ||||
| -rw-r--r-- | prog/checksum_test.go | 195 | ||||
| -rw-r--r-- | prog/encodingexec.go | 105 | ||||
| -rw-r--r-- | prog/prog.go | 3 | ||||
| -rw-r--r-- | prog/rand.go | 6 |
13 files changed, 479 insertions, 368 deletions
diff --git a/csource/common.go b/csource/common.go index f09cbf24d..2551cbbb9 100644 --- a/csource/common.go +++ b/csource/common.go @@ -275,6 +275,36 @@ static uintptr_t syz_emit_ethernet(uintptr_t a0, uintptr_t a1) } #endif +struct csum_inet { + uint32_t acc; +}; + +void csum_inet_init(struct csum_inet* csum) +{ + csum->acc = 0; +} + +void csum_inet_update(struct csum_inet* csum, const uint8_t* data, size_t length) +{ + if (length == 0) + return; + + size_t i; + for (i = 0; i < length - 1; i += 2) + csum->acc += *(uint16_t*)&data[i]; + + if (length & 1) + csum->acc += (uint16_t)data[length - 1]; + + while (csum->acc > 0xffff) + csum->acc = (csum->acc & 0xffff) + (csum->acc >> 16); +} + +uint16_t csum_inet_digest(struct csum_inet* csum) +{ + return ~csum->acc; +} + #ifdef __NR_syz_open_dev static uintptr_t syz_open_dev(uintptr_t a0, uintptr_t a1, uintptr_t a2) { diff --git a/csource/csource.go b/csource/csource.go index 1bf3fc7a1..6a21f54e6 100644 --- a/csource/csource.go +++ b/csource/csource.go @@ -226,8 +226,33 @@ loop: esc = append(esc, '\\', 'x', hex(v>>4), hex(v<<4>>4)) } fmt.Fprintf(w, "\tNONFAILING(memcpy((void*)0x%x, \"%s\", %v));\n", addr, esc, size) + case prog.ExecArgCsum: + csum_kind := read() + switch csum_kind { + case prog.ExecArgCsumInet: + fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", n) + fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", n) + csum_chunks_num := read() + for i := uintptr(0); i < csum_chunks_num; i++ { + chunk_kind := read() + chunk_value := read() + chunk_size := read() + switch chunk_kind { + case prog.ExecArgCsumChunkData: + fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8_t*)0x%x, %d));\n", n, chunk_value, chunk_size) + case prog.ExecArgCsumChunkConst: + fmt.Fprintf(w, "\tuint%d_t csum_%d_chunk_%d = 0x%x;\n", chunk_size*8, n, i, chunk_value) + fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8_t*)&csum_%d_chunk_%d, %d);\n", n, n, i, chunk_size) + default: + panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk_kind)) + } + } + fmt.Fprintf(w, "\tNONFAILING(*(uint16_t*)0x%x = csum_inet_digest(&csum_%d));\n", addr, n) + default: + panic(fmt.Sprintf("unknown csum kind %v", csum_kind)) + } default: - panic("bad argument type") + panic(fmt.Sprintf("bad argument type %v", instr)) } case prog.ExecInstrCopyout: addr := read() diff --git a/executor/common.h b/executor/common.h index 4983802f2..eb56c8ec1 100644 --- a/executor/common.h +++ b/executor/common.h @@ -298,6 +298,36 @@ static uintptr_t syz_emit_ethernet(uintptr_t a0, uintptr_t a1) } #endif // __NR_syz_emit_ethernet +struct csum_inet { + uint32_t acc; +}; + +void csum_inet_init(struct csum_inet* csum) +{ + csum->acc = 0; +} + +void csum_inet_update(struct csum_inet* csum, const uint8_t* data, size_t length) +{ + if (length == 0) + return; + + size_t i; + for (i = 0; i < length - 1; i += 2) + csum->acc += *(uint16_t*)&data[i]; + + if (length & 1) + csum->acc += (uint16_t)data[length - 1]; + + while (csum->acc > 0xffff) + csum->acc = (csum->acc & 0xffff) + (csum->acc >> 16); +} + +uint16_t csum_inet_digest(struct csum_inet* csum) +{ + return ~csum->acc; +} + #ifdef __NR_syz_open_dev static uintptr_t syz_open_dev(uintptr_t a0, uintptr_t a1, uintptr_t a2) { diff --git a/executor/executor.cc b/executor/executor.cc index 63d66d435..52da3a16b 100644 --- a/executor/executor.cc +++ b/executor/executor.cc @@ -57,6 +57,7 @@ const uint64_t instr_copyout = -3; const uint64_t arg_const = 0; const uint64_t arg_result = 1; const uint64_t arg_data = 2; +const uint64_t arg_csum = 3; // We use the default value instead of results of failed syscalls. // -1 is an invalid fd and an invalid address and deterministic, @@ -115,6 +116,13 @@ struct thread_t { thread_t threads[kMaxThreads]; +// Checksum kinds. +const uint64_t arg_csum_inet = 0; + +// Checksum chunk kinds. +const uint64_t arg_csum_chunk_data = 0; +const uint64_t arg_csum_chunk_const = 1; + void execute_one(); uint64_t read_input(uint64_t** input_posp, bool peek = false); uint64_t read_arg(uint64_t** input_posp); @@ -354,6 +362,52 @@ retry: read_input(&input_pos); break; } + case arg_csum: { + debug("checksum found at %llx\n", addr); + char* csum_addr = addr; + uint64_t csum_size = size; + uint64_t csum_kind = read_input(&input_pos); + switch (csum_kind) { + case arg_csum_inet: { + if (csum_size != 2) { + fail("inet checksum must be 2 bytes, not %lu", size); + } + debug("calculating checksum for %llx\n", csum_addr); + struct csum_inet csum; + csum_inet_init(&csum); + uint64_t chunks_num = read_input(&input_pos); + uint64_t chunk; + for (chunk = 0; chunk < chunks_num; chunk++) { + uint64_t chunk_kind = read_input(&input_pos); + uint64_t chunk_value = read_input(&input_pos); + uint64_t chunk_size = read_input(&input_pos); + switch (chunk_kind) { + case arg_csum_chunk_data: + debug("#%d: data chunk, addr: %llx, size: %llu\n", chunk, chunk_value, chunk_size); + NONFAILING(csum_inet_update(&csum, (const uint8_t*)chunk_value, chunk_size)); + break; + case arg_csum_chunk_const: + if (chunk_size != 2 && chunk_size != 4 && chunk_size != 8) { + fail("bad checksum const chunk size %lld\n", chunk_size); + } + // Here we assume that const values come to us big endian. + debug("#%d: const chunk, value: %llx, size: %llu\n", chunk, chunk_value, chunk_size); + csum_inet_update(&csum, (const uint8_t*)&chunk_value, chunk_size); + break; + default: + fail("bad checksum chunk kind %lu", chunk_kind); + } + } + int16_t csum_value = csum_inet_digest(&csum); + debug("writing inet checksum %hx to %llx\n", csum_value, csum_addr); + NONFAILING(copyin(csum_addr, csum_value, 2, 0, 0)); + break; + } + default: + fail("bad checksum kind %lu", csum_kind); + } + break; + } default: fail("bad argument type %lu", typ); } diff --git a/executor/test.go b/executor/test.go index e9709b8d0..47a4e0388 100644 --- a/executor/test.go +++ b/executor/test.go @@ -6,6 +6,8 @@ package executor // int test_copyin(); +// int test_csum_inet(); +// int test_csum_inet_acc(); // int test_kvm(); import "C" @@ -13,6 +15,14 @@ func testCopyin() int { return int(C.test_copyin()) } +func testCsumInet() int { + return int(C.test_csum_inet()) +} + +func testCsumInetAcc() int { + return int(C.test_csum_inet_acc()) +} + func testKVM() int { return int(C.test_kvm()) } diff --git a/executor/test_kvm.cc b/executor/test_executor.cc index 79b951f62..5cf404675 100644 --- a/executor/test_kvm.cc +++ b/executor/test_executor.cc @@ -24,6 +24,174 @@ extern "C" int test_copyin() return 0; } +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +struct csum_inet_test { + const char* data; + size_t length; + uint16_t csum; +}; + +extern "C" int test_csum_inet() +{ + struct csum_inet_test tests[] = { + {// 0 + "", + 0, + 0xffff}, + { + // 1 + "\x00", + 1, + 0xffff, + }, + { + // 2 + "\x00\x00", + 2, + 0xffff, + }, + { + // 3 + "\x00\x00\xff\xff", + 4, + 0x0000, + }, + { + // 4 + "\xfc", + 1, + 0xff03, + }, + { + // 5 + "\xfc\x12", + 2, + 0xed03, + }, + { + // 6 + "\xfc\x12\x3e", + 3, + 0xecc5, + }, + { + // 7 + "\xfc\x12\x3e\x00\xc5\xec", + 6, + 0x0000, + }, + { + // 8 + "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", + 17, + 0x43e1, + }, + { + // 9 + "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd\x00", + 18, + 0x43e1, + }, + { + // 10 + "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", + 19, + 0x43e1, + }, + { + // 11 + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\xab\xcd", + 15, + 0x5032, + }, + { + // 12 + "\x00\x00\x12\x34\x56\x78", + 6, + 0x5397, + }, + { + // 13 + "\x00\x00\x12\x34\x00\x00\x56\x78\x00\x06\x00\x04\xab\xcd", + 14, + 0x7beb, + }, + { + // 14 + "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x06\x00\x00\xab\xcd", + 44, + 0x2854, + }, + { + // 15 + "\x00\x00\x12\x34\x00\x00\x56\x78\x00\x11\x00\x04\xab\xcd", + 14, + 0x70eb, + }, + { + // 16 + "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x11\x00\x00\xab\xcd", + 44, + 0x1d54, + }, + { + // 17 + "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x3a\x00\x00\xab\xcd", + 44, + 0xf453, + }}; + + int i; + for (i = 0; i < ARRAY_SIZE(tests); i++) { + struct csum_inet csum; + csum_inet_init(&csum); + csum_inet_update(&csum, (const uint8_t*)tests[i].data, tests[i].length); + if (csum_inet_digest(&csum) != tests[i].csum) { + fprintf(stderr, "bad checksum in test #%d, want: %hx, got: %hx\n", i, tests[i].csum, csum_inet_digest(&csum)); + return 1; + } + } + + return 0; +} + +int randInt(int start, int end) +{ + return rand() % (end + 1 - start) + start; +} + +extern "C" int test_csum_inet_acc() +{ + uint8_t buffer[128]; + + int test; + for (test = 0; test < 256; test++) { + int size = randInt(1, 128); + int step = randInt(1, 8) * 2; + + int i; + for (i = 0; i < size; i++) + buffer[i] = randInt(0, 255); + + struct csum_inet csum_acc; + csum_inet_init(&csum_acc); + + for (i = 0; i < size / step; i++) + csum_inet_update(&csum_acc, &buffer[i * step], step); + if (size % step != 0) + csum_inet_update(&csum_acc, &buffer[size - size % step], size % step); + + struct csum_inet csum; + csum_inet_init(&csum); + csum_inet_update(&csum, &buffer[0], size); + + if (csum_inet_digest(&csum_acc) != csum_inet_digest(&csum)) + return 1; + return 0; + } +} + static unsigned host_kernel_version(); static void dump_cpu_state(int cpufd, char* vm_mem); diff --git a/executor/test_test.go b/executor/test_test.go index f197dbf51..368454ac3 100644 --- a/executor/test_test.go +++ b/executor/test_test.go @@ -5,8 +5,8 @@ package executor import "testing" -func TestCopyin(t *testing.T) { - switch res := testCopyin(); { +func testWrapper(t *testing.T, f func() int) { + switch res := f(); { case res < 0: t.Skip() case res > 0: @@ -15,12 +15,18 @@ func TestCopyin(t *testing.T) { } } +func TestCopyin(t *testing.T) { + testWrapper(t, testCopyin) +} + +func TestCsumInet(t *testing.T) { + testWrapper(t, testCsumInet) +} + +func TestCsumInetAcc(t *testing.T) { + testWrapper(t, testCsumInetAcc) +} + func TestKVM(t *testing.T) { - switch res := testKVM(); { - case res < 0: - t.Skip() - case res > 0: - t.Fail() - default: - } + testWrapper(t, testKVM) } diff --git a/prog/analysis.go b/prog/analysis.go index 83419006a..5b786c753 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -155,6 +155,7 @@ func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) { rec = func(arg1 *Arg, offset uintptr) uintptr { switch arg1.Kind { case ArgGroup: + f(arg1, offset) var totalSize uintptr for _, arg2 := range arg1.Inner { size := rec(arg2, offset) @@ -167,6 +168,7 @@ func foreachSubargOffset(arg *Arg, f func(arg *Arg, offset uintptr)) { panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %+v", totalSize, arg1.Size(), arg1)) } case ArgUnion: + f(arg1, offset) size := rec(arg1.Option, offset) offset += size if size > arg1.Size() { diff --git a/prog/checksum.go b/prog/checksum.go index 8ee00bc24..5db10d26f 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -4,121 +4,34 @@ package prog import ( - "encoding/binary" "fmt" - "unsafe" "github.com/google/syzkaller/sys" ) -type IPChecksum struct { - acc uint32 -} - -func (csum *IPChecksum) Update(data []byte) { - length := len(data) - 1 - for i := 0; i < length; i += 2 { - csum.acc += uint32(data[i]) << 8 - csum.acc += uint32(data[i+1]) - } - if len(data)%2 == 1 { - csum.acc += uint32(data[length]) << 8 - } - for csum.acc > 0xffff { - csum.acc = (csum.acc >> 16) + (csum.acc & 0xffff) - } -} - -func (csum *IPChecksum) Digest() uint16 { - return ^uint16(csum.acc) -} - -func ipChecksum(data []byte) uint16 { - var csum IPChecksum - csum.Update(data) - return csum.Digest() -} +type CsumKind int -func bitmaskLen(bfLen uint64) uint64 { - return (1 << bfLen) - 1 -} +const ( + CsumInet CsumKind = iota +) -func bitmaskLenOff(bfOff, bfLen uint64) uint64 { - return bitmaskLen(bfLen) << bfOff -} +type CsumChunkKind int -func storeByBitmask8(addr *uint8, value uint8, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint8(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint8(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} - -func storeByBitmask16(addr *uint16, value uint16, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint16(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint16(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} +const ( + CsumChunkArg CsumChunkKind = iota + CsumChunkConst +) -func storeByBitmask32(addr *uint32, value uint32, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint32(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint32(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } +type CsumInfo struct { + Kind CsumKind + Chunks []CsumChunk } -func storeByBitmask64(addr *uint64, value uint64, bfOff uint64, bfLen uint64) { - if bfOff == 0 && bfLen == 0 { - *addr = value - } else { - newValue := *addr - newValue &= ^uint64(bitmaskLenOff(bfOff, bfLen)) - newValue |= (value & uint64(bitmaskLen(bfLen))) << bfOff - *addr = newValue - } -} - -func encodeArg(arg *Arg, pid int) []byte { - bytes := make([]byte, arg.Size()) - foreachSubargOffset(arg, func(arg *Arg, offset uintptr) { - switch arg.Kind { - case ArgConst: - addr := unsafe.Pointer(&bytes[offset]) - val := arg.Value(pid) - bfOff := uint64(arg.Type.BitfieldOffset()) - bfLen := uint64(arg.Type.BitfieldLength()) - switch arg.Size() { - case 1: - storeByBitmask8((*uint8)(addr), uint8(val), bfOff, bfLen) - case 2: - storeByBitmask16((*uint16)(addr), uint16(val), bfOff, bfLen) - case 4: - storeByBitmask32((*uint32)(addr), uint32(val), bfOff, bfLen) - case 8: - storeByBitmask64((*uint64)(addr), uint64(val), bfOff, bfLen) - default: - panic(fmt.Sprintf("bad arg size %v, arg: %+v\n", arg.Size(), arg)) - } - case ArgData: - copy(bytes[offset:], arg.Data) - default: - panic(fmt.Sprintf("bad arg kind %v, arg: %+v, type: %+v", arg.Kind, arg, arg.Type)) - } - }) - return bytes +type CsumChunk struct { + Kind CsumChunkKind + Arg *Arg // for CsumChunkArg + Value uintptr // for CsumChunkConst + Size uintptr // for CsumChunkConst } func getFieldByName(arg *Arg, name string) *Arg { @@ -130,14 +43,6 @@ func getFieldByName(arg *Arg, name string) *Arg { panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type.Name())) } -func calcChecksumInet(packet, csumField *Arg, pid int) *Arg { - bytes := encodeArg(packet, pid) - csum := ipChecksum(bytes) - newCsumField := *csumField - newCsumField.Val = uintptr(csum) - return &newCsumField -} - func extractHeaderParamsIPv4(arg *Arg) (*Arg, *Arg) { srcAddr := getFieldByName(arg, "src_ip") if srcAddr.Size() != 4 { @@ -162,35 +67,24 @@ func extractHeaderParamsIPv6(arg *Arg) (*Arg, *Arg) { return srcAddr, dstAddr } -func composePseudoHeaderIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { - header := []byte{} - header = append(header, encodeArg(srcAddr, pid)...) - header = append(header, encodeArg(dstAddr, pid)...) - header = append(header, []byte{0, protocol}...) - length := []byte{0, 0} - binary.BigEndian.PutUint16(length, uint16(tcpPacket.Size())) - header = append(header, length...) - return header -} - -func composePseudoHeaderIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) []byte { - header := []byte{} - header = append(header, encodeArg(srcAddr, pid)...) - header = append(header, encodeArg(dstAddr, pid)...) - length := []byte{0, 0, 0, 0} - binary.BigEndian.PutUint32(length, uint32(tcpPacket.Size())) - header = append(header, length...) - header = append(header, []byte{0, 0, 0, protocol}...) - return header +func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo { + info := CsumInfo{Kind: CsumInet} + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap16(uint16(protocol))), 2}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap16(uint16(tcpPacket.Size()))), 2}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) + return info } -func calcChecksumPseudo(packet, csumField *Arg, pseudoHeader []byte, pid int) *Arg { - var csum IPChecksum - csum.Update(pseudoHeader) - csum.Update(encodeArg(packet, pid)) - newCsumField := *csumField - newCsumField.Val = uintptr(csum.Digest()) - return &newCsumField +func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr *Arg, protocol uint8, pid int) CsumInfo { + info := CsumInfo{Kind: CsumInet} + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap32(uint32(tcpPacket.Size()))), 4}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uintptr(swap32(uint32(protocol))), 4}) + info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) + return info } func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg { @@ -209,7 +103,7 @@ func findCsummedArg(arg *Arg, typ *sys.CsumType, parentsMap map[*Arg]*Arg) *Arg panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) } -func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { +func calcChecksumsCall(c *Call, pid int) map[*Arg]CsumInfo { var inetCsumFields []*Arg var pseudoCsumFields []*Arg @@ -242,14 +136,16 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { } }) - csumMap := make(map[*Arg]*Arg) + csumMap := make(map[*Arg]CsumInfo) - // Calculate inet checksums. + // Calculate generic inet checksums. for _, arg := range inetCsumFields { typ, _ := arg.Type.(*sys.CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) - newCsumField := calcChecksumInet(csummedArg, arg, pid) - csumMap[arg] = newCsumField + chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0} + info := CsumInfo{Kind: CsumInet, Chunks: make([]CsumChunk, 0)} + info.Chunks = append(info.Chunks, chunk) + csumMap[arg] = info } // No need to continue if there are no pseudo csum fields. @@ -282,14 +178,13 @@ func calcChecksumsCall(c *Call, pid int) map[*Arg]*Arg { typ, _ := arg.Type.(*sys.CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) protocol := uint8(typ.Protocol) - var pseudoHeader []byte + var info CsumInfo if ipv4HeaderParsed { - pseudoHeader = composePseudoHeaderIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) } else { - pseudoHeader = composePseudoHeaderIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) + info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol, pid) } - newCsumField := calcChecksumPseudo(csummedArg, arg, pseudoHeader, pid) - csumMap[arg] = newCsumField + csumMap[arg] = info } return csumMap diff --git a/prog/checksum_test.go b/prog/checksum_test.go index dfb6e04a5..205450255 100644 --- a/prog/checksum_test.go +++ b/prog/checksum_test.go @@ -4,204 +4,9 @@ package prog import ( - "bytes" "testing" - - "github.com/google/syzkaller/sys" ) -func TestChecksumIP(t *testing.T) { - tests := []struct { - data string - csum uint16 - }{ - { - "", - 0xffff, - }, - { - "\x00", - 0xffff, - }, - { - "\x00\x00", - 0xffff, - }, - { - "\x00\x00\xff\xff", - 0x0000, - }, - { - "\xfc", - 0x03ff, - }, - { - "\xfc\x12", - 0x03ed, - }, - { - "\xfc\x12\x3e", - 0xc5ec, - }, - { - "\xfc\x12\x3e\x00\xc5\xec", - 0x0000, - }, - { - "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", - 0xe143, - }, - { - "\x00\x00\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", - 0xe143, - }, - { - "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\xab\xcd", - 0x3250, - }, - { - "\x00\x00\x12\x34\x56\x78", - 0x9753, - }, - { - "\x00\x00\x12\x34\x00\x00\x56\x78\x00\x06\x00\x04\xab\xcd", - 0xeb7b, - }, - { - "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x06\x00\x00\xab\xcd", - 0x5428, - }, - { - "\x00\x00\x12\x34\x00\x00\x56\x78\x00\x11\x00\x04\xab\xcd", - 0xeb70, - }, - { - "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x11\x00\x00\xab\xcd", - 0x541d, - }, - { - "\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\xff\xee\xdd\xcc\xbb\xaa\x99\x88\x77\x66\x55\x44\x33\x22\x11\x00\x00\x00\x00\x04\x00\x00\x00\x3a\x00\x00\xab\xcd", - 0x53f4, - }, - } - - for _, test := range tests { - csum := ipChecksum([]byte(test.data)) - if csum != test.csum { - t.Fatalf("incorrect ip checksum, got: %x, want: %x, data: %+v", csum, test.csum, []byte(test.data)) - } - } -} - -func TestChecksumIPAcc(t *testing.T) { - rs, iters := initTest(t) - r := newRand(rs) - - for i := 0; i < iters; i++ { - bytes := make([]byte, r.Intn(256)) - for i := 0; i < len(bytes); i++ { - bytes[i] = byte(r.Intn(256)) - } - step := int(r.randRange(1, 8)) * 2 - var csumAcc IPChecksum - for i := 0; i < len(bytes)/step; i++ { - csumAcc.Update(bytes[i*step : (i+1)*step]) - } - if len(bytes)%step != 0 { - csumAcc.Update(bytes[len(bytes)-(len(bytes)%step) : len(bytes)]) - } - csum := ipChecksum(bytes) - if csum != csumAcc.Digest() { - t.Fatalf("inconsistent ip checksum: %x vs %x, step: %v, data: %+v", csum, csumAcc.Digest(), step, bytes) - } - } -} - -func TestChecksumEncode(t *testing.T) { - tests := []struct { - prog string - encoded string - }{ - { - "syz_test$csum_encode(&(0x7f0000000000)={0x42, 0x43, [0x44, 0x45], 0xa, 0xb, \"aabbccdd\"})", - "\x42\x00\x00\x43\x44\x00\x00\x00\x45\x00\x00\x00\xba\xaa\xbb\xcc\xdd", - }, - } - for i, test := range tests { - p, err := Deserialize([]byte(test.prog)) - if err != nil { - t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) - } - encoded := encodeArg(p.Calls[0].Args[0].Res, 0) - if !bytes.Equal(encoded, []byte(test.encoded)) { - t.Fatalf("incorrect encoding for prog #%v, got: %+v, want: %+v", i, encoded, []byte(test.encoded)) - } - } -} - -func TestChecksumCalc(t *testing.T) { - tests := []struct { - prog string - kind sys.CsumKind - csum uint16 - }{ - { - "syz_test$csum_ipv4(&(0x7f0000000000)={0x0, 0x1234, 0x5678})", - sys.CsumInet, - 0x9753, - }, - { - "syz_test$csum_ipv4_tcp(&(0x7f0000000000)={{0x0, 0x1234, 0x5678}, {{0x0}, \"abcd\"}})", - sys.CsumPseudo, - 0xeb7b, - }, - { - "syz_test$csum_ipv6_tcp(&(0x7f0000000000)={{\"00112233445566778899aabbccddeeff\", \"ffeeddccbbaa99887766554433221100\"}, {{0x0}, \"abcd\"}})", - sys.CsumPseudo, - 0x5428, - }, - { - "syz_test$csum_ipv4_udp(&(0x7f0000000000)={{0x0, 0x1234, 0x5678}, {0x0, \"abcd\"}})", - sys.CsumPseudo, - 0xeb70, - }, - { - "syz_test$csum_ipv6_udp(&(0x7f0000000000)={{\"00112233445566778899aabbccddeeff\", \"ffeeddccbbaa99887766554433221100\"}, {0x0, \"abcd\"}})", - sys.CsumPseudo, - 0x541d, - }, - { - "syz_test$csum_ipv6_icmp(&(0x7f0000000000)={{\"00112233445566778899aabbccddeeff\", \"ffeeddccbbaa99887766554433221100\"}, {0x0, \"abcd\"}})", - sys.CsumPseudo, - 0x53f4, - }, - } - for i, test := range tests { - p, err := Deserialize([]byte(test.prog)) - if err != nil { - t.Fatalf("failed to deserialize prog %v: %v", test.prog, err) - } - csumMap := calcChecksumsCall(p.Calls[0], i%32) - found := false - for oldField, newField := range csumMap { - if typ, ok := newField.Type.(*sys.CsumType); ok { - if typ.Kind == test.kind { - found = true - csum := newField.Value(i % 32) - if csum != uintptr(test.csum) { - t.Fatalf("failed to calc checksum, got %x, want %x, kind %v, prog '%v'", csum, test.csum, test.kind, test.prog) - } - } - } else { - t.Fatalf("non csum key %+v in csum map %+v", oldField, csumMap) - } - } - if !found { - t.Fatalf("csum field not found, kind %v, prog '%v'", test.kind, test.prog) - } - } -} - func TestChecksumCalcRandom(t *testing.T) { rs, iters := initTest(t) for i := 0; i < iters; i++ { diff --git a/prog/encodingexec.go b/prog/encodingexec.go index 9a9cc4b48..6695836ca 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -8,6 +8,7 @@ package prog import ( "fmt" + "sort" "github.com/google/syzkaller/sys" ) @@ -22,6 +23,16 @@ const ( ExecArgConst = uintptr(iota) ExecArgResult ExecArgData + ExecArgCsum +) + +const ( + ExecArgCsumInet = uintptr(iota) +) + +const ( + ExecArgCsumChunkData = uintptr(iota) + ExecArgCsumChunkConst ) const ( @@ -32,6 +43,25 @@ const ( dataOffset = 512 << 20 ) +type Args []*Arg + +func (s Args) Len() int { + return len(s) +} + +func (s Args) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +type ByPhysicalAddr struct { + Args + Context *execContext +} + +func (s ByPhysicalAddr) Less(i, j int) bool { + return s.Context.args[s.Args[i]].Addr < s.Context.args[s.Args[j]].Addr +} + // SerializeForExec serializes program p for execution by process pid into the provided buffer. // If the provided buffer is too small for the program an error is returned. func (p *Prog) SerializeForExec(buffer []byte, pid int) error { @@ -49,13 +79,30 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error { for _, c := range p.Calls { // Calculate checksums. csumMap := calcChecksumsCall(c, pid) + var csumUses map[*Arg]bool + if csumMap != nil { + csumUses = make(map[*Arg]bool) + for arg, info := range csumMap { + csumUses[arg] = true + if info.Kind == CsumInet { + for _, chunk := range info.Chunks { + if chunk.Kind == CsumChunkArg { + csumUses[chunk.Arg] = true + } + } + } + } + } // Calculate arg offsets within structs. // Generate copyin instructions that fill in data into pointer arguments. foreachArg(c, func(arg, _ *Arg, _ *[]*Arg) { if arg.Kind == ArgPointer && arg.Res != nil { foreachSubargOffset(arg.Res, func(arg1 *Arg, offset uintptr) { - if len(arg1.Uses) != 0 { - w.args[arg1] = argInfo{Offset: offset} + if len(arg1.Uses) != 0 || csumUses[arg1] { + w.args[arg1] = argInfo{Addr: physicalAddr(arg) + offset} + } + if arg1.Kind == ArgGroup || arg1.Kind == ArgUnion { + return } if !sys.IsPad(arg1.Type) && !(arg1.Kind == ArgData && len(arg1.Data) == 0) && @@ -68,6 +115,47 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error { }) } }) + // Generate checksum calculation instructions starting from the last one, + // since checksum values can depend on values of the latter ones + if csumMap != nil { + var csumArgs []*Arg + for arg, _ := range csumMap { + csumArgs = append(csumArgs, arg) + } + sort.Sort(ByPhysicalAddr{Args: csumArgs, Context: w}) + for i := len(csumArgs) - 1; i >= 0; i-- { + arg := csumArgs[i] + if _, ok := arg.Type.(*sys.CsumType); !ok { + panic("csum arg is not csum type") + } + w.write(ExecInstrCopyin) + w.write(w.args[arg].Addr) + w.write(ExecArgCsum) + w.write(arg.Size()) + switch csumMap[arg].Kind { + case CsumInet: + w.write(ExecArgCsumInet) + w.write(uintptr(len(csumMap[arg].Chunks))) + for _, chunk := range csumMap[arg].Chunks { + switch chunk.Kind { + case CsumChunkArg: + w.write(ExecArgCsumChunkData) + w.write(w.args[chunk.Arg].Addr) + w.write(chunk.Arg.Size()) + case CsumChunkConst: + w.write(ExecArgCsumChunkConst) + w.write(chunk.Value) + w.write(chunk.Size) + default: + panic(fmt.Sprintf("csum chunk has unknown kind %v", chunk.Kind)) + } + } + default: + panic(fmt.Sprintf("csum arg has unknown kind %v", csumMap[arg].Kind)) + } + instrSeq++ + } + } // Generate the call itself. w.write(uintptr(c.Meta.ID)) w.write(uintptr(len(c.Args))) @@ -96,7 +184,7 @@ func (p *Prog) SerializeForExec(buffer []byte, pid int) error { instrSeq++ w.args[arg] = info w.write(ExecInstrCopyout) - w.write(physicalAddr(base) + info.Offset) + w.write(info.Addr) w.write(arg.Size()) default: panic("bad arg kind in copyout") @@ -130,8 +218,8 @@ type execContext struct { } type argInfo struct { - Offset uintptr // from base pointer - Idx uintptr // instruction index + Addr uintptr // physical addr + Idx uintptr // instruction index } func (w *execContext) write(v uintptr) { @@ -150,14 +238,9 @@ func (w *execContext) write(v uintptr) { w.buf = w.buf[8:] } -func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]*Arg) { +func (w *execContext) writeArg(arg *Arg, pid int, csumMap map[*Arg]CsumInfo) { switch arg.Kind { case ArgConst: - if _, ok := arg.Type.(*sys.CsumType); ok { - if arg, ok = csumMap[arg]; !ok { - panic("csum arg is not in csum map") - } - } w.write(ExecArgConst) w.write(arg.Size()) w.write(arg.Value(pid)) diff --git a/prog/prog.go b/prog/prog.go index fbd8507c6..db696ec65 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -96,7 +96,8 @@ func (a *Arg) Value(pid int) uintptr { case *sys.LenType: return encodeValue(a.Val, typ.Size(), typ.BigEndian) case *sys.CsumType: - return encodeValue(a.Val, typ.Size(), typ.BigEndian) + // Checksums are computed dynamically in executor. + return 0 case *sys.ProcType: val := uintptr(typ.ValuesStart) + uintptr(typ.ValuesPerProc)*uintptr(pid) + a.Val return encodeValue(val, typ.Size(), typ.BigEndian) diff --git a/prog/rand.go b/prog/rand.go index 3eebe4b8d..0aaf9ab06 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -765,8 +765,10 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) arg, calls1 := r.addr(s, a, inner.Size(), inner) calls = append(calls, calls1...) return arg, calls - case *sys.LenType, *sys.CsumType: - // Return placeholder value of 0 while generating len and csum args. + case *sys.LenType: + // Return placeholder value of 0 while generating len arg. + return constArg(a, 0), nil + case *sys.CsumType: return constArg(a, 0), nil default: panic("unknown argument type") |
