diff options
| -rw-r--r-- | pkg/compiler/gen.go | 5 | ||||
| -rw-r--r-- | prog/expr.go | 80 | ||||
| -rw-r--r-- | prog/expr_test.go | 40 | ||||
| -rw-r--r-- | prog/prog.go | 4 | ||||
| -rw-r--r-- | prog/rand.go | 2 | ||||
| -rw-r--r-- | prog/types.go | 6 | ||||
| -rw-r--r-- | sys/test/expressions.txt | 12 |
7 files changed, 114 insertions, 35 deletions
diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go index d274877a2..f91e2502c 100644 --- a/pkg/compiler/gen.go +++ b/pkg/compiler/gen.go @@ -556,6 +556,11 @@ func (comp *compiler) wrapConditionalField(name string, field prog.Field) prog.F { Name: "void", Type: voidType, + Condition: &prog.BinaryExpression{ + Operator: prog.OperatorCompareEq, + Left: newCondition, + Right: &prog.Value{Value: 0x0, Path: nil}, + }, }, }, }, diff --git a/prog/expr.go b/prog/expr.go index dc2279cac..5ecbef8eb 100644 --- a/prog/expr.go +++ b/prog/expr.go @@ -90,11 +90,12 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan replace := map[Arg]Arg{} makeArgFinder := argFinderConstructor(r.target, c) forEachStaleUnion(r.target, c, makeArgFinder, - func(unionArg *UnionArg, unionType *UnionType, needIdx int) { - newType, newDir := unionType.Fields[needIdx].Type, - unionType.Fields[needIdx].Dir(unionArg.Dir()) + func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { + idx := okIndices[r.Intn(len(okIndices))] + newType, newDir := unionType.Fields[idx].Type, + unionType.Fields[idx].Dir(unionArg.Dir()) newTypeArg, newCalls := r.generateArg(s, newType, newDir) - replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, needIdx) + replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx) extraCalls = append(extraCalls, newCalls...) anyPatched = true }) @@ -112,7 +113,7 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan } func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) ArgFinder, - cb func(*UnionArg, *UnionType, int)) { + cb func(*UnionArg, *UnionType, []int)) { ForeachArg(c, func(arg Arg, argCtx *ArgCtx) { if target.isAnyPtr(arg.Type()) { argCtx.Stop = true @@ -127,37 +128,48 @@ func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) Ar return } argFinder := makeArgFinder(unionArg) - needIdx, ok := calculateUnionArg(unionArg, unionType, argFinder) - if !ok { + ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder) + if !calculated { // Let it stay as is. return } - if unionArg.Index == needIdx { - // No changes are needed. + if !unionArg.transient && ok { return } - cb(unionArg, unionType, needIdx) - argCtx.Stop = true + matchingIndices := matchingUnionArgs(unionType, argFinder) + if len(matchingIndices) == 0 { + // Conditional fields are transformed in such a way + // that one field always matches. + // For unions we demand that there's a field w/o conditions. + panic(fmt.Sprintf("no matching union fields: %#v", unionType)) + } + cb(unionArg, unionType, matchingIndices) }) } -func calculateUnionArg(arg *UnionArg, typ *UnionType, finder ArgFinder) (int, bool) { - defaultIdx := typ.defaultField() - for i, field := range typ.Fields { - if field.Condition == nil { - continue - } - val, ok := field.Condition.Evaluate(finder) - if !ok { - // We could not calculate the expression. - // Let the union stay as it was. - return defaultIdx, false - } - if val != 0 { - return i, true +func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) { + field := typ.Fields[idx] + if field.Condition == nil { + return true, true + } + val, ok := field.Condition.Evaluate(finder) + if !ok { + // We could not calculate the expression. + // Let the union stay as it was. + return true, false + } + return val != 0, true +} + +func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int { + var ret []int + for i := range typ.Fields { + ok, _ := checkUnionArg(i, typ, finder) + if ok { + ret = append(ret, i) } } - return defaultIdx, true + return ret } func (p *Prog) checkConditions() error { @@ -174,11 +186,13 @@ var ErrViolatedConditions = errors.New("conditional fields rules violation") func (c *Call) checkConditions(target *Target) error { var ret error + makeArgFinder := argFinderConstructor(target, c) forEachStaleUnion(target, c, makeArgFinder, - func(a *UnionArg, t *UnionType, need int) { - ret = fmt.Errorf("%w union %s field is %s, but %s satisfies conditions", - ErrViolatedConditions, t.Name(), t.Fields[a.Index].Name, t.Fields[need].Name) + func(a *UnionArg, t *UnionType, okIndices []int) { + ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions", + ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name, + okIndices) }) return ret } @@ -190,12 +204,14 @@ func (c *Call) setDefaultConditions(target *Target) bool { replace := map[Arg]Arg{} makeArgFinder := argFinderConstructor(target, c) forEachStaleUnion(target, c, makeArgFinder, - func(unionArg *UnionArg, unionType *UnionType, needIdx int) { - field := unionType.Fields[needIdx] + func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { + // If several union options match, take the first one. + idx := okIndices[0] + field := unionType.Fields[idx] replace[unionArg] = MakeUnionArg(unionType, unionArg.Dir(), field.DefaultArg(field.Dir(unionArg.Dir())), - needIdx) + idx) }) for old, new := range replace { anyReplaced = true diff --git a/prog/expr_test.go b/prog/expr_test.go index 69bd790db..f6c8c8603 100644 --- a/prog/expr_test.go +++ b/prog/expr_test.go @@ -90,10 +90,11 @@ func TestEvaluateConditionalFields(t *testing.T) { `test$parent_conditions(&AUTO={0x2, @with_flag1=0x123, {0x0, @void}})`, `test$parent_conditions(&AUTO={0x4, @without_flag1=0x123, {0x0, @value=0x0}})`, `test$parent_conditions(&AUTO={0x6, @with_flag1=0x123, {0x0, @value=0x0}})`, + // The @without_flag1 option is still possible. + `test$parent_conditions(&AUTO={0x2, @without_flag1=0x123, {0x0, @void}})`, }, bad: []string{ `test$parent_conditions(&AUTO={0x0, @with_flag1=0x123, {0x0, @void}})`, - `test$parent_conditions(&AUTO={0x2, @without_flag1=0x123, {0x0, @void}})`, `test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @void}})`, `test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @value=0x0}})`, }, @@ -108,7 +109,8 @@ func TestEvaluateConditionalFields(t *testing.T) { } for _, bad := range test.bad { _, err := target.Deserialize([]byte(bad), Strict) - assert.ErrorIs(tt, err, ErrViolatedConditions) + assert.ErrorIs(tt, err, ErrViolatedConditions, + "prog: %s", bad) } }) } @@ -224,3 +226,37 @@ func parseConditionalStructCall(t *testing.T, c *Call) (bool, bool) { assert.Equal(t, mask&FLAG2 != 0, f2, "flag2 must only be set if mask&FLAG2") return f1, f2 } + +func TestConditionalUnionFields(t *testing.T) { + // Ensure that we reach different combinations of conditional fields. + target, rs, _ := initRandomTargetTest(t, "test", "64") + ct := target.DefaultChoiceTable() + r := newRand(target, rs) + + var zeroU1, zeroU2 int + var nonzeroU2 int + for i := 0; i < 100; i++ { + s := newState(target, ct, nil) + p := &Prog{ + Target: target, + Calls: r.generateParticularCall(s, target.SyscallMap["test$conditional_union"]), + } + if len(p.Calls) > 1 { + continue + } + text := string(p.SerializeVerbose()) + if strings.Contains(text, "{0x0,") { + if strings.Contains(text, "@u1") { + zeroU1++ + } else if strings.Contains(text, "@u2") { + zeroU2++ + } + } else { + assert.NotContains(t, text, "@u1") + nonzeroU2++ + } + } + assert.Greater(t, zeroU1, 0) + assert.Greater(t, zeroU2, 0) + assert.Greater(t, nonzeroU2, 0) +} diff --git a/prog/prog.go b/prog/prog.go index 1453e6aea..34fcfa6e8 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -267,6 +267,10 @@ type UnionArg struct { ArgCommon Option Arg Index int // Index of the selected option in the union type. + // Used for unions with conditional fields. + // We first create a dummy arg with transient=True and then + // patch them. + transient bool } func MakeUnionArg(t Type, dir Dir, opt Arg, index int) *UnionArg { diff --git a/prog/rand.go b/prog/rand.go index 742dbaa7c..aca3163bd 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -886,7 +886,7 @@ func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*C if a.isConditional() { // Conditions may reference other fields that may not have already // been generated. We'll fill them in later. - return a.DefaultArg(dir), nil + return a.DefaultTransientArg(dir), nil } index := r.Intn(len(a.Fields)) optType, optDir := a.Fields[index].Type, a.Fields[index].Dir(dir) diff --git a/prog/types.go b/prog/types.go index 56e999f9e..00440163b 100644 --- a/prog/types.go +++ b/prog/types.go @@ -738,6 +738,12 @@ func (t *UnionType) DefaultArg(dir Dir) Arg { return MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), idx) } +func (t *UnionType) DefaultTransientArg(dir Dir) Arg { + unionArg := t.DefaultArg(dir).(*UnionArg) + unionArg.transient = true + return unionArg +} + func (t *UnionType) defaultField() int { // If it's a conditional union, the last field will be the default value. if t.isConditional() { diff --git a/sys/test/expressions.txt b/sys/test/expressions.txt index c99c1364b..16d5b96a7 100644 --- a/sys/test/expressions.txt +++ b/sys/test/expressions.txt @@ -56,3 +56,15 @@ conditional_struct_minimize { } [packed] test$conditional_struct_minimize(a ptr[in, conditional_struct_minimize]) + +conditional_union [ + u1 int8 (if[value[conditional_union_parent:f1] == 0]) + u2 int8 +] + +conditional_union_parent { + f1 int8:1 + f2 conditional_union +} + +test$conditional_union(a ptr[in, conditional_union_parent]) |
