diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2023-12-28 21:31:00 +0100 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2024-02-19 11:54:01 +0000 |
| commit | 31179bc75602cbe8f0421b44f19ff1b960039644 (patch) | |
| tree | 9804abe8e2ca0218da0e2c71b61a8b411f26e800 | |
| parent | ed571339c6ff5ed764283737a0aa68451085e84d (diff) | |
prog: support conditional fields
pkg/compiler restructures conditional fields in structures into unions,
so we only have to implement the support for unions.
Semantics is as follows:
If a union has conditions, syzkaller picks the first field whose
condition matches. Since we require the last union field to have no
conditions, we can always construct an object.
Changes from this commit aim at ensuring that the selected union fields
always follow the rule above.
| -rw-r--r-- | prog/expr.go | 209 | ||||
| -rw-r--r-- | prog/expr_test.go | 226 | ||||
| -rw-r--r-- | prog/hints.go | 7 | ||||
| -rw-r--r-- | prog/minimization.go | 10 | ||||
| -rw-r--r-- | prog/mutation.go | 4 | ||||
| -rw-r--r-- | prog/rand.go | 9 | ||||
| -rw-r--r-- | prog/size.go | 64 | ||||
| -rw-r--r-- | prog/types.go | 28 | ||||
| -rw-r--r-- | prog/validation.go | 3 | ||||
| -rw-r--r-- | sys/test/exec.txt | 3 | ||||
| -rw-r--r-- | sys/test/expressions.txt | 58 | ||||
| -rw-r--r-- | sys/test/expressions.txt.const | 3 | ||||
| -rw-r--r-- | sys/test/test/expressions | 10 | ||||
| -rw-r--r-- | sys/test/test/expressions_be | 6 |
14 files changed, 611 insertions, 29 deletions
diff --git a/prog/expr.go b/prog/expr.go new file mode 100644 index 000000000..dc2279cac --- /dev/null +++ b/prog/expr.go @@ -0,0 +1,209 @@ +// Copyright 2023 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package prog + +import ( + "errors" + "fmt" +) + +func (bo *BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) { + left, ok := bo.Left.Evaluate(finder) + if !ok { + return 0, false + } + right, ok := bo.Right.Evaluate(finder) + if !ok { + return 0, false + } + switch bo.Operator { + case OperatorCompareEq: + if left == right { + return 1, true + } + return 0, true + case OperatorCompareNeq: + if left != right { + return 1, true + } + return 0, true + case OperatorBinaryAnd: + return left & right, true + } + panic(fmt.Sprintf("unknown operator %q", bo.Operator)) +} + +func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) { + if len(v.Path) == 0 { + return v.Value, true + } + found := finder(v.Path) + if found == SquashedArgFound { + // This is expectable. + return 0, false + } + if found == nil { + panic(fmt.Sprintf("no argument was found by %v", v.Path)) + } + constArg, ok := found.(*ConstArg) + if !ok { + panic("value expressions must only rely on int fields") + } + return constArg.Val, true +} + +func argFinderConstructor(t *Target, c *Call) func(*UnionArg) ArgFinder { + parentsMap := callParentsMap(c) + return func(unionArg *UnionArg) ArgFinder { + return func(path []string) Arg { + f := t.findArg(unionArg.Option, path, nil, nil, parentsMap, 0) + if f == nil { + return nil + } + if f.isAnyPtr { + return SquashedArgFound + } + return f.arg + } + } +} + +func callParentsMap(c *Call) map[Arg]Arg { + parentsMap := map[Arg]Arg{} + ForeachArg(c, func(arg Arg, _ *ArgCtx) { + saveToParentsMap(arg, parentsMap) + }) + return parentsMap +} + +func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) { + if r.inPatchConditional { + return nil, false + } + r.inPatchConditional = true + defer func() { r.inPatchConditional = false }() + + var extraCalls []*Call + var anyPatched bool + for { + replace := map[Arg]Arg{} + makeArgFinder := argFinderConstructor(r.target, c) + forEachStaleUnion(r.target, c, makeArgFinder, + func(unionArg *UnionArg, unionType *UnionType, needIdx int) { + newType, newDir := unionType.Fields[needIdx].Type, + unionType.Fields[needIdx].Dir(unionArg.Dir()) + newTypeArg, newCalls := r.generateArg(s, newType, newDir) + replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, needIdx) + extraCalls = append(extraCalls, newCalls...) + anyPatched = true + }) + for old, new := range replace { + replaceArg(old, new) + } + // The newly inserted argument might contain more arguments we need + // to patch. + // Repeat until we have to change nothing. + if len(replace) == 0 { + break + } + } + return extraCalls, anyPatched +} + +func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) ArgFinder, + cb func(*UnionArg, *UnionType, int)) { + ForeachArg(c, func(arg Arg, argCtx *ArgCtx) { + if target.isAnyPtr(arg.Type()) { + argCtx.Stop = true + return + } + unionArg, ok := arg.(*UnionArg) + if !ok { + return + } + unionType, ok := arg.Type().(*UnionType) + if !ok || !unionType.isConditional() { + return + } + argFinder := makeArgFinder(unionArg) + needIdx, ok := calculateUnionArg(unionArg, unionType, argFinder) + if !ok { + // Let it stay as is. + return + } + if unionArg.Index == needIdx { + // No changes are needed. + return + } + cb(unionArg, unionType, needIdx) + argCtx.Stop = true + }) +} + +func calculateUnionArg(arg *UnionArg, typ *UnionType, finder ArgFinder) (int, bool) { + defaultIdx := typ.defaultField() + for i, field := range typ.Fields { + if field.Condition == nil { + continue + } + val, ok := field.Condition.Evaluate(finder) + if !ok { + // We could not calculate the expression. + // Let the union stay as it was. + return defaultIdx, false + } + if val != 0 { + return i, true + } + } + return defaultIdx, true +} + +func (p *Prog) checkConditions() error { + for _, c := range p.Calls { + err := c.checkConditions(p.Target) + if err != nil { + return err + } + } + return nil +} + +var ErrViolatedConditions = errors.New("conditional fields rules violation") + +func (c *Call) checkConditions(target *Target) error { + var ret error + makeArgFinder := argFinderConstructor(target, c) + forEachStaleUnion(target, c, makeArgFinder, + func(a *UnionArg, t *UnionType, need int) { + ret = fmt.Errorf("%w union %s field is %s, but %s satisfies conditions", + ErrViolatedConditions, t.Name(), t.Fields[a.Index].Name, t.Fields[need].Name) + }) + return ret +} + +func (c *Call) setDefaultConditions(target *Target) bool { + var anyReplaced bool + // Replace stale conditions with the default values of their correct types. + for { + replace := map[Arg]Arg{} + makeArgFinder := argFinderConstructor(target, c) + forEachStaleUnion(target, c, makeArgFinder, + func(unionArg *UnionArg, unionType *UnionType, needIdx int) { + field := unionType.Fields[needIdx] + replace[unionArg] = MakeUnionArg(unionType, + unionArg.Dir(), + field.DefaultArg(field.Dir(unionArg.Dir())), + needIdx) + }) + for old, new := range replace { + anyReplaced = true + replaceArg(old, new) + } + if len(replace) == 0 { + break + } + } + return anyReplaced +} diff --git a/prog/expr_test.go b/prog/expr_test.go new file mode 100644 index 000000000..69bd790db --- /dev/null +++ b/prog/expr_test.go @@ -0,0 +1,226 @@ +// Copyright 2023 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package prog + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateConditionalFields(t *testing.T) { + // Ensure that we reach different combinations of conditional fields. + target, rs, _ := initRandomTargetTest(t, "test", "64") + ct := target.DefaultChoiceTable() + r := newRand(target, rs) + + combinations := [][]bool{ + {false, false}, + {false, false}, + } + b2i := func(b bool) int { + if b { + return 1 + } + return 0 + } + for i := 0; i < 150; i++ { + p := genConditionalFieldProg(target, ct, r) + f1, f2 := parseConditionalStructCall(t, p.Calls[len(p.Calls)-1]) + combinations[b2i(f1)][b2i(f2)] = true + } + for _, first := range []int{0, 1} { + for _, second := range []int{0, 1} { + if !combinations[first][second] { + t.Fatalf("Did not generate a combination f1=%v f2=%v", first, second) + } + } + } +} + +func TestMutateConditionalFields(t *testing.T) { + target, rs, _ := initRandomTargetTest(t, "test", "64") + ct := target.DefaultChoiceTable() + r := newRand(target, rs) + iters := 500 + if testing.Short() { + iters /= 10 + } + nonAny := 0 + for i := 0; i < iters; i++ { + prog := genConditionalFieldProg(target, ct, r) + for j := 0; j < 5; j++ { + prog.Mutate(rs, 10, ct, nil, nil) + hasAny := bytes.Contains(prog.Serialize(), []byte("ANY=")) + if hasAny { + // No sense to verify these. + break + } + nonAny++ + validateConditionalProg(t, prog) + } + } + assert.Greater(t, nonAny, 10) // Just in case. +} + +func TestEvaluateConditionalFields(t *testing.T) { + target := InitTargetTest(t, "test", "64") + tests := []struct { + good []string + bad []string + }{ + { + good: []string{ + `test$conditional_struct(&AUTO={0x0, @void, @void})`, + `test$conditional_struct(&AUTO={0x4, @void, @value=0x123})`, + `test$conditional_struct(&AUTO={0x6, @value={AUTO}, @value=0x123})`, + }, + bad: []string{ + `test$conditional_struct(&AUTO={0x0, @void, @value=0x123})`, + `test$conditional_struct(&AUTO={0x0, @value={AUTO}, @value=0x123})`, + }, + }, + { + good: []string{ + `test$parent_conditions(&AUTO={0x0, @without_flag1=0x123, {0x0, @void}})`, + `test$parent_conditions(&AUTO={0x2, @with_flag1=0x123, {0x0, @void}})`, + `test$parent_conditions(&AUTO={0x4, @without_flag1=0x123, {0x0, @value=0x0}})`, + `test$parent_conditions(&AUTO={0x6, @with_flag1=0x123, {0x0, @value=0x0}})`, + }, + bad: []string{ + `test$parent_conditions(&AUTO={0x0, @with_flag1=0x123, {0x0, @void}})`, + `test$parent_conditions(&AUTO={0x2, @without_flag1=0x123, {0x0, @void}})`, + `test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @void}})`, + `test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @value=0x0}})`, + }, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(tt *testing.T) { + for _, good := range test.good { + _, err := target.Deserialize([]byte(good), Strict) + assert.NoError(tt, err) + } + for _, bad := range test.bad { + _, err := target.Deserialize([]byte(bad), Strict) + assert.ErrorIs(tt, err, ErrViolatedConditions) + } + }) + } +} + +func TestConditionalMinimize(t *testing.T) { + tests := []struct { + input string + pred func(*Prog, int) bool + output string + }{ + { + input: `test$conditional_struct(&AUTO={0x6, @value={AUTO}, @value=0x123})`, + pred: func(p *Prog, _ int) bool { + return len(p.Calls) == 1 && p.Calls[0].Meta.Name == `test$conditional_struct` + }, + output: `test$conditional_struct(0x0)`, + }, + { + input: `test$conditional_struct(&(0x7f0000000040)={0x6, @value, @value=0x123})`, + pred: func(p *Prog, _ int) bool { + return bytes.Contains(p.Serialize(), []byte("0x123")) + }, + // We don't drop individual bits from integers, so there's no chance + // to turn 0x6 into 0x4. + output: `test$conditional_struct(&(0x7f0000000040)={0x6, @value, @value=0x123})`, + }, + { + input: `test$conditional_struct_minimize(&(0x7f0000000040)={0x1, @value=0xaa, 0x1, @value=0xbb})`, + pred: func(p *Prog, _ int) bool { + return bytes.Contains(p.Serialize(), []byte("0xaa")) + }, + output: `test$conditional_struct_minimize(&(0x7f0000000040)={0x1, @value=0xaa})`, + }, + { + input: `test$conditional_struct_minimize(&(0x7f0000000040)={0x1, @value=0xaa, 0x1, @value=0xbb})`, + pred: func(p *Prog, _ int) bool { + return bytes.Contains(p.Serialize(), []byte("0xbb")) + }, + output: `test$conditional_struct_minimize(&(0x7f0000000040)={0x0, @void, 0x1, @value=0xbb})`, + }, + { + input: `test$conditional_struct_minimize(&(0x7f0000000040)={0x1, @value=0xaa, 0x1, @value=0xbb})`, + pred: func(p *Prog, _ int) bool { + serialized := p.Serialize() + return bytes.Contains(serialized, []byte("0xaa")) && + bytes.Contains(serialized, []byte("0xbb")) + }, + output: `test$conditional_struct_minimize(&(0x7f0000000040)={0x1, @value=0xaa, 0x1, @value=0xbb})`, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(tt *testing.T) { + target, err := GetTarget("test", "64") + assert.NoError(tt, err) + p, err := target.Deserialize([]byte(test.input), Strict) + assert.NoError(tt, err) + p1, _ := Minimize(p, 0, false, test.pred) + res := p1.Serialize() + assert.Equal(tt, test.output, strings.TrimSpace(string(res))) + }) + } +} + +func genConditionalFieldProg(target *Target, ct *ChoiceTable, r *randGen) *Prog { + s := newState(target, ct, nil) + calls := r.generateParticularCall(s, target.SyscallMap["test$conditional_struct"]) + return &Prog{ + Target: target, + Calls: calls, + } +} + +const FLAG1 = 2 +const FLAG2 = 4 + +func validateConditionalProg(t *testing.T, p *Prog) { + for _, call := range p.Calls { + if call.Meta.Name == "test$conditional_struct" { + parseConditionalStructCall(t, call) + } + } +} + +// Validates a test$conditional_struct call. +func parseConditionalStructCall(t *testing.T, c *Call) (bool, bool) { + if c.Meta.Name != "test$conditional_struct" { + t.Fatalf("generated wrong call %v", c.Meta.Name) + } + if len(c.Args) != 1 { + t.Fatalf("generated wrong number of args %v", len(c.Args)) + } + va, ok := c.Args[0].(*PointerArg) + if !ok { + t.Fatalf("expected PointerArg: %v", c.Args[0]) + } + if va.Res == nil { + // Cannot validate. + return false, false + } + ga, ok := va.Res.(*GroupArg) + if !ok { + t.Fatalf("expected GroupArg: %v", va.Res) + } + if len(ga.Inner) != 3 { + t.Fatalf("wrong number of struct args %v", len(ga.Inner)) + } + mask := ga.Inner[0].(*ConstArg).Val + f1 := ga.Inner[1].(*UnionArg).Index == 0 + f2 := ga.Inner[2].(*UnionArg).Index == 0 + assert.Equal(t, mask&FLAG1 != 0, f1, "flag1 must only be set if mask&FLAG1") + assert.Equal(t, mask&FLAG2 != 0, f2, "flag2 must only be set if mask&FLAG2") + return f1, f2 +} diff --git a/prog/hints.go b/prog/hints.go index 9fa80547c..98aeccce2 100644 --- a/prog/hints.go +++ b/prog/hints.go @@ -75,6 +75,13 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog)) if p.Target.sanitize(c, false) != nil { return } + if p.checkConditions() != nil { + // Patching unions that no longer satisfy conditions would + // require much deeped changes to prog arguments than + // generateHints() expects. + // Let's just ignore such mutations. + return + } p.debugValidate() exec(p) } diff --git a/prog/minimization.go b/prog/minimization.go index 0beaef17b..a9864aca2 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -278,13 +278,21 @@ func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool { } v0 := a.Val a.Val = def.Val + + // By mutating an integer, we risk violating conditional fields. + // If the fields are patched, the minimization process must be restarted. + patched := ctx.call.setDefaultConditions(ctx.p.Target) if ctx.pred(ctx.p, ctx.callIndex0) { *ctx.p0 = ctx.p ctx.triedPaths[path] = true return true } a.Val = v0 - return false + if patched { + // No sense to return here. + ctx.triedPaths[path] = true + } + return patched } func (typ *ResourceType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { diff --git a/prog/mutation.go b/prog/mutation.go index 47b682342..cf93ce4eb 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -198,6 +198,8 @@ func (ctx *mutator) mutateArg() bool { ok = false continue } + moreCalls, fieldsPatched := r.patchConditionalFields(c, s) + calls = append(calls, moreCalls...) p.insertBefore(c, calls) idx += len(calls) for len(p.Calls) > ctx.ncalls { @@ -208,7 +210,7 @@ func (ctx *mutator) mutateArg() bool { panic(fmt.Sprintf("wrong call index: idx=%v calls=%v p.Calls=%v ncalls=%v", idx, len(calls), len(p.Calls), ctx.ncalls)) } - if updateSizes { + if updateSizes || fieldsPatched { p.Target.assignSizesCall(c) } } diff --git a/prog/rand.go b/prog/rand.go index c7f9053aa..742dbaa7c 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -27,6 +27,7 @@ type randGen struct { *rand.Rand target *Target inGenerateResource bool + inPatchConditional bool recDepth map[string]int } @@ -598,8 +599,9 @@ func (r *randGen) generateParticularCall(s *state, meta *Syscall) (calls []*Call } c := MakeCall(meta, nil) c.Args, calls = r.generateArgs(s, meta.Args, DirIn) + moreCalls, _ := r.patchConditionalFields(c, s) r.target.assignSizesCall(c) - return append(calls, c) + return append(append(calls, moreCalls...), c) } // GenerateAllSyzProg generates a program that contains all pseudo syz_ calls for testing. @@ -881,6 +883,11 @@ func (a *StructType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []* } func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + if a.isConditional() { + // Conditions may reference other fields that may not have already + // been generated. We'll fill them in later. + return a.DefaultArg(dir), nil + } index := r.Intn(len(a.Fields)) optType, optDir := a.Fields[index].Type, a.Fields[index].Dir(dir) opt, calls := r.generateArg(s, optType, optDir) diff --git a/prog/size.go b/prog/size.go index 70ec678b6..6463821e0 100644 --- a/prog/size.go +++ b/prog/size.go @@ -45,20 +45,34 @@ func (target *Target) assignArgSize(arg Arg, args []Arg, fields []Field, parents } } -func (target *Target) assignSizeStruct(dst *ConstArg, buf Arg, path []string, parentsMap map[Arg]Arg) { +func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []Arg, + fields []Field, parentsMap map[Arg]Arg, overlayField int) { + found := target.findArg(pos, path, args, fields, parentsMap, overlayField) + if found != nil && !found.isAnyPtr { + dst.Val = target.computeSize(found.arg, found.offset, dst.Type().(*LenType)) + } +} + +type foundArg struct { + arg Arg + offset uint64 + isAnyPtr bool +} + +func (target *Target) findFieldStruct(buf Arg, path []string, parentsMap map[Arg]Arg) *foundArg { switch arg := buf.(type) { case *GroupArg: typ := arg.Type().(*StructType) - target.assignSize(dst, buf, path, arg.Inner, typ.Fields, parentsMap, typ.OverlayField) + return target.findArg(buf, path, arg.Inner, typ.Fields, parentsMap, typ.OverlayField) case *UnionArg: - target.assignSize(dst, buf, path, nil, nil, parentsMap, 0) + return target.findArg(buf, path, nil, nil, parentsMap, 0) default: - panic(fmt.Sprintf("unexpected arg type %v", arg)) + panic(fmt.Sprintf("unexpected arg type %#v", arg)) } } -func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []Arg, - fields []Field, parentsMap map[Arg]Arg, overlayField int) { +func (target *Target) findArg(pos Arg, path []string, args []Arg, fields []Field, + parentsMap map[Arg]Arg, overlayField int) *foundArg { elem := path[0] path = path[1:] var offset uint64 @@ -66,6 +80,9 @@ func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []A if i == overlayField { offset = 0 } + if buf == nil { + continue + } if elem != fields[i].Name { offset += buf.Size() continue @@ -74,46 +91,43 @@ func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []A // If path points into squashed argument, we don't have the target argument. // In such case we simply leave size argument as is. It can't happen during generation, // only during mutation and mutation can set size to random values, so it should be fine. - return + return &foundArg{buf, offset, true} } buf = InnerArg(buf) if buf == nil { - dst.Val = 0 // target is an optional pointer - return + return &foundArg{nil, offset, false} } if len(path) != 0 { - target.assignSizeStruct(dst, buf, path, parentsMap) - return + return target.findFieldStruct(buf, path, parentsMap) } - dst.Val = target.computeSize(buf, offset, dst.Type().(*LenType)) - return + return &foundArg{buf, offset, false} } if elem == ParentRef { buf := parentsMap[pos] if len(path) != 0 { - target.assignSizeStruct(dst, buf, path, parentsMap) - return + return target.findFieldStruct(buf, path, parentsMap) } - dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) - return + return &foundArg{buf, noOffset, false} } for buf := parentsMap[pos]; buf != nil; buf = parentsMap[buf] { if elem != buf.Type().TemplateName() { continue } if len(path) != 0 { - target.assignSizeStruct(dst, buf, path, parentsMap) - return + return target.findFieldStruct(buf, path, parentsMap) } - dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) - return + return &foundArg{buf, noOffset, false} } var fieldNames []string for _, field := range fields { fieldNames = append(fieldNames, field.Name) } - panic(fmt.Sprintf("len field %q references non existent field %q, pos=%q, argsMap: %v, path: %v", - dst.Type().Name(), elem, pos.Type().Name(), fieldNames, path)) + posName := "nil" + if pos != nil { + posName = pos.Type().Name() + } + panic(fmt.Sprintf("path references non existent field %q, pos=%q, argsMap: %v, path: %v", + elem, posName, fieldNames, path)) } const noOffset = ^uint64(0) @@ -125,6 +139,10 @@ func (target *Target) computeSize(arg Arg, offset uint64, lenType *LenType) uint } return offset * 8 / lenType.BitSize } + if arg == nil { + // For e.g. optional pointers. + return 0 + } bitSize := lenType.BitSize if bitSize == 0 { bitSize = 8 diff --git a/prog/types.go b/prog/types.go index 73fa3a972..56e999f9e 100644 --- a/prog/types.go +++ b/prog/types.go @@ -82,10 +82,16 @@ func (f *Field) Dir(def Dir) Dir { return def } +type ArgFinder func(path []string) Arg + +// Special case reply of ArgFinder. +var SquashedArgFound = &DataArg{} + type Expression interface { fmt.GoStringer ForEachValue(func(*Value)) Clone() Expression + Evaluate(ArgFinder) (uint64, bool) } type BinaryOperator int @@ -727,13 +733,29 @@ func (t *UnionType) String() string { } func (t *UnionType) DefaultArg(dir Dir) Arg { - f := t.Fields[0] - return MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), 0) + idx := t.defaultField() + f := t.Fields[idx] + return MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), idx) +} + +func (t *UnionType) defaultField() int { + // If it's a conditional union, the last field will be the default value. + if t.isConditional() { + return len(t.Fields) - 1 + } + // Otherwise, just take the first. + return 0 +} + +func (t *UnionType) isConditional() bool { + // In pkg/compiler, we ensure that either none of the fields have conditions, + // or all except the last one. + return t.Fields[0].Condition != nil } func (t *UnionType) isDefaultArg(arg Arg) bool { a := arg.(*UnionArg) - return a.Index == 0 && isDefault(a.Option) + return a.Index == t.defaultField() && isDefault(a.Option) } type ConstValue struct { diff --git a/prog/validation.go b/prog/validation.go index 696ea265d..13b7c33db 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -65,6 +65,9 @@ func (ctx *validCtx) validateCall(c *Call) error { return err } } + if err := c.checkConditions(ctx.target); err != nil { + return err + } return ctx.validateRet(c) } diff --git a/sys/test/exec.txt b/sys/test/exec.txt index 913251a7d..67c0c70c0 100644 --- a/sys/test/exec.txt +++ b/sys/test/exec.txt @@ -51,6 +51,9 @@ compare_data [ overlay0 overlay0 overlay1 overlay1 overlay2 overlay2 + conditional conditional_struct + conditional2 condition_and_align + conditional3 condition_parent_align ] [varlen] flags_with_one_value = 0 diff --git a/sys/test/expressions.txt b/sys/test/expressions.txt new file mode 100644 index 000000000..c99c1364b --- /dev/null +++ b/sys/test/expressions.txt @@ -0,0 +1,58 @@ +# Copyright 2023 syzkaller project authors. All rights reserved. +# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +# Syscalls used for testing conditional expressions. + +define FIELD_FLAG1 2 +define FIELD_FLAG2 4 + +field1 { + f1 const[0xffffffff, int32] +} + +conditional_struct { + mask int32 + f1 field1 (if[value[mask] & FIELD_FLAG1]) + f2 int64 (if[value[mask] & FIELD_FLAG2]) +} [packed] + +test$conditional_struct(a ptr[in, conditional_struct]) + +parent_conditions { + mask int32 + u parent_conditions_nested_union + s parent_conditions_nested_struct +} [packed] + +parent_conditions_nested_union [ + with_flag1 int32 (if[value[parent:parent:mask] & FIELD_FLAG1]) + without_flag1 int64 +] + +parent_conditions_nested_struct { + f0 int64 + f1 int32 (if[value[parent_conditions:mask] & FIELD_FLAG2]) +} + +test$parent_conditions(a ptr[in, parent_conditions]) + +condition_and_align { + f0 int8 + f1 int32 (if[value[f0] == 1]) + f2 int8 +} [packed, align[4]] + +condition_parent_align { + f0 int8 + f1 condition_and_align + f2 int8 +} [packed, align[4]] + +conditional_struct_minimize { + havef0 int8 + f0 int8 (if[value[havef0] == 1]) + havef1 int8 + f1 int8 (if[value[havef1] == 1]) +} [packed] + +test$conditional_struct_minimize(a ptr[in, conditional_struct_minimize]) diff --git a/sys/test/expressions.txt.const b/sys/test/expressions.txt.const new file mode 100644 index 000000000..8b3a2dae5 --- /dev/null +++ b/sys/test/expressions.txt.const @@ -0,0 +1,3 @@ +arches = 32_fork_shmem, 32_shmem, 64, 64_fork +FIELD_FLAG1 = 2 +FIELD_FLAG2 = 4
\ No newline at end of file diff --git a/sys/test/test/expressions b/sys/test/test/expressions new file mode 100644 index 000000000..9ecaaf30c --- /dev/null +++ b/sys/test/test/expressions @@ -0,0 +1,10 @@ +# requires: littleendian + +syz_compare(&AUTO="00000000", 0x4, &AUTO=@conditional={0x0, @void, @void}, AUTO) +syz_compare(&AUTO="02000000ffffffff", 0x8, &AUTO=@conditional={0x2, @value={AUTO}, @void}, AUTO) +syz_compare(&AUTO="04000000aaaa000000000000", 0xc, &AUTO=@conditional={0x4, @void, @value=0xaaaa}, AUTO) +syz_compare(&AUTO="06000000ffffffffaaaa000000000000", 0x10, &AUTO=@conditional={0x6, @value={AUTO}, @value=0xaaaa}, AUTO) +syz_compare(&AUTO="00ff0000", 0x4, &AUTO=@conditional2={0x0, @void, 0xff}, AUTO) +syz_compare(&AUTO="0134120000ff0000", 0x8, &AUTO=@conditional2={0x1, @value=0x1234, 0xff}, AUTO) +syz_compare(&AUTO="1100220000330000", 0x8, &AUTO=@conditional3={0x11, {0x0, @void, 0x22}, 0x33}, AUTO) +syz_compare(&AUTO="1101ddccbbaa220000330000", 0xc, &AUTO=@conditional3={0x11, {0x1, @value=0xaabbccdd, 0x22}, 0x33}, AUTO)
\ No newline at end of file diff --git a/sys/test/test/expressions_be b/sys/test/test/expressions_be new file mode 100644 index 000000000..fde4342ab --- /dev/null +++ b/sys/test/test/expressions_be @@ -0,0 +1,6 @@ +# requires: bigendian + +syz_compare(&AUTO="00ff0000", 0x4, &AUTO=@conditional2={0x0, @void, 0xff}, AUTO) +syz_compare(&AUTO="01001234ff0000", 0x8, &AUTO=@conditional2={0x1, @value=0x1234, 0xff}, AUTO) +syz_compare(&AUTO="1100220000330000", 0x8, &AUTO=@conditional3={0x11, {0x0, @void, 0x22}, 0x33}, AUTO) +syz_compare(&AUTO="1101aabbccdd220000330000", 0xc, &AUTO=@conditional3={0x11, {0x1, @value=0xaabbccdd, 0x22}, 0x33}, AUTO)
\ No newline at end of file |
