aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2024-01-17 13:16:29 +0100
committerAleksandr Nogikh <nogikh@google.com>2024-02-19 11:54:01 +0000
commit8cb16e665dbb5f87aa58856049c1ad6067dc6293 (patch)
treecb8eacf9619ce1ea3df3aa5891cf6ab44ebf83b8
parent0936819b9f980bde731cb6191677f9aa2cbfd9aa (diff)
prog: handle multiple matching union fields
If conditions of several union fields are satisfied, select one randomly. This would be a more logical semantics. When conditional struct fields are translated to unions, negate the condition for the union alternative.
-rw-r--r--pkg/compiler/gen.go5
-rw-r--r--prog/expr.go80
-rw-r--r--prog/expr_test.go40
-rw-r--r--prog/prog.go4
-rw-r--r--prog/rand.go2
-rw-r--r--prog/types.go6
-rw-r--r--sys/test/expressions.txt12
7 files changed, 114 insertions, 35 deletions
diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go
index d274877a2..f91e2502c 100644
--- a/pkg/compiler/gen.go
+++ b/pkg/compiler/gen.go
@@ -556,6 +556,11 @@ func (comp *compiler) wrapConditionalField(name string, field prog.Field) prog.F
{
Name: "void",
Type: voidType,
+ Condition: &prog.BinaryExpression{
+ Operator: prog.OperatorCompareEq,
+ Left: newCondition,
+ Right: &prog.Value{Value: 0x0, Path: nil},
+ },
},
},
},
diff --git a/prog/expr.go b/prog/expr.go
index dc2279cac..5ecbef8eb 100644
--- a/prog/expr.go
+++ b/prog/expr.go
@@ -90,11 +90,12 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan
replace := map[Arg]Arg{}
makeArgFinder := argFinderConstructor(r.target, c)
forEachStaleUnion(r.target, c, makeArgFinder,
- func(unionArg *UnionArg, unionType *UnionType, needIdx int) {
- newType, newDir := unionType.Fields[needIdx].Type,
- unionType.Fields[needIdx].Dir(unionArg.Dir())
+ func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
+ idx := okIndices[r.Intn(len(okIndices))]
+ newType, newDir := unionType.Fields[idx].Type,
+ unionType.Fields[idx].Dir(unionArg.Dir())
newTypeArg, newCalls := r.generateArg(s, newType, newDir)
- replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, needIdx)
+ replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx)
extraCalls = append(extraCalls, newCalls...)
anyPatched = true
})
@@ -112,7 +113,7 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan
}
func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) ArgFinder,
- cb func(*UnionArg, *UnionType, int)) {
+ cb func(*UnionArg, *UnionType, []int)) {
ForeachArg(c, func(arg Arg, argCtx *ArgCtx) {
if target.isAnyPtr(arg.Type()) {
argCtx.Stop = true
@@ -127,37 +128,48 @@ func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) Ar
return
}
argFinder := makeArgFinder(unionArg)
- needIdx, ok := calculateUnionArg(unionArg, unionType, argFinder)
- if !ok {
+ ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder)
+ if !calculated {
// Let it stay as is.
return
}
- if unionArg.Index == needIdx {
- // No changes are needed.
+ if !unionArg.transient && ok {
return
}
- cb(unionArg, unionType, needIdx)
- argCtx.Stop = true
+ matchingIndices := matchingUnionArgs(unionType, argFinder)
+ if len(matchingIndices) == 0 {
+ // Conditional fields are transformed in such a way
+ // that one field always matches.
+ // For unions we demand that there's a field w/o conditions.
+ panic(fmt.Sprintf("no matching union fields: %#v", unionType))
+ }
+ cb(unionArg, unionType, matchingIndices)
})
}
-func calculateUnionArg(arg *UnionArg, typ *UnionType, finder ArgFinder) (int, bool) {
- defaultIdx := typ.defaultField()
- for i, field := range typ.Fields {
- if field.Condition == nil {
- continue
- }
- val, ok := field.Condition.Evaluate(finder)
- if !ok {
- // We could not calculate the expression.
- // Let the union stay as it was.
- return defaultIdx, false
- }
- if val != 0 {
- return i, true
+func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) {
+ field := typ.Fields[idx]
+ if field.Condition == nil {
+ return true, true
+ }
+ val, ok := field.Condition.Evaluate(finder)
+ if !ok {
+ // We could not calculate the expression.
+ // Let the union stay as it was.
+ return true, false
+ }
+ return val != 0, true
+}
+
+func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int {
+ var ret []int
+ for i := range typ.Fields {
+ ok, _ := checkUnionArg(i, typ, finder)
+ if ok {
+ ret = append(ret, i)
}
}
- return defaultIdx, true
+ return ret
}
func (p *Prog) checkConditions() error {
@@ -174,11 +186,13 @@ var ErrViolatedConditions = errors.New("conditional fields rules violation")
func (c *Call) checkConditions(target *Target) error {
var ret error
+
makeArgFinder := argFinderConstructor(target, c)
forEachStaleUnion(target, c, makeArgFinder,
- func(a *UnionArg, t *UnionType, need int) {
- ret = fmt.Errorf("%w union %s field is %s, but %s satisfies conditions",
- ErrViolatedConditions, t.Name(), t.Fields[a.Index].Name, t.Fields[need].Name)
+ func(a *UnionArg, t *UnionType, okIndices []int) {
+ ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions",
+ ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name,
+ okIndices)
})
return ret
}
@@ -190,12 +204,14 @@ func (c *Call) setDefaultConditions(target *Target) bool {
replace := map[Arg]Arg{}
makeArgFinder := argFinderConstructor(target, c)
forEachStaleUnion(target, c, makeArgFinder,
- func(unionArg *UnionArg, unionType *UnionType, needIdx int) {
- field := unionType.Fields[needIdx]
+ func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
+ // If several union options match, take the first one.
+ idx := okIndices[0]
+ field := unionType.Fields[idx]
replace[unionArg] = MakeUnionArg(unionType,
unionArg.Dir(),
field.DefaultArg(field.Dir(unionArg.Dir())),
- needIdx)
+ idx)
})
for old, new := range replace {
anyReplaced = true
diff --git a/prog/expr_test.go b/prog/expr_test.go
index 69bd790db..f6c8c8603 100644
--- a/prog/expr_test.go
+++ b/prog/expr_test.go
@@ -90,10 +90,11 @@ func TestEvaluateConditionalFields(t *testing.T) {
`test$parent_conditions(&AUTO={0x2, @with_flag1=0x123, {0x0, @void}})`,
`test$parent_conditions(&AUTO={0x4, @without_flag1=0x123, {0x0, @value=0x0}})`,
`test$parent_conditions(&AUTO={0x6, @with_flag1=0x123, {0x0, @value=0x0}})`,
+ // The @without_flag1 option is still possible.
+ `test$parent_conditions(&AUTO={0x2, @without_flag1=0x123, {0x0, @void}})`,
},
bad: []string{
`test$parent_conditions(&AUTO={0x0, @with_flag1=0x123, {0x0, @void}})`,
- `test$parent_conditions(&AUTO={0x2, @without_flag1=0x123, {0x0, @void}})`,
`test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @void}})`,
`test$parent_conditions(&AUTO={0x4, @with_flag1=0x123, {0x0, @value=0x0}})`,
},
@@ -108,7 +109,8 @@ func TestEvaluateConditionalFields(t *testing.T) {
}
for _, bad := range test.bad {
_, err := target.Deserialize([]byte(bad), Strict)
- assert.ErrorIs(tt, err, ErrViolatedConditions)
+ assert.ErrorIs(tt, err, ErrViolatedConditions,
+ "prog: %s", bad)
}
})
}
@@ -224,3 +226,37 @@ func parseConditionalStructCall(t *testing.T, c *Call) (bool, bool) {
assert.Equal(t, mask&FLAG2 != 0, f2, "flag2 must only be set if mask&FLAG2")
return f1, f2
}
+
+func TestConditionalUnionFields(t *testing.T) {
+ // Ensure that we reach different combinations of conditional fields.
+ target, rs, _ := initRandomTargetTest(t, "test", "64")
+ ct := target.DefaultChoiceTable()
+ r := newRand(target, rs)
+
+ var zeroU1, zeroU2 int
+ var nonzeroU2 int
+ for i := 0; i < 100; i++ {
+ s := newState(target, ct, nil)
+ p := &Prog{
+ Target: target,
+ Calls: r.generateParticularCall(s, target.SyscallMap["test$conditional_union"]),
+ }
+ if len(p.Calls) > 1 {
+ continue
+ }
+ text := string(p.SerializeVerbose())
+ if strings.Contains(text, "{0x0,") {
+ if strings.Contains(text, "@u1") {
+ zeroU1++
+ } else if strings.Contains(text, "@u2") {
+ zeroU2++
+ }
+ } else {
+ assert.NotContains(t, text, "@u1")
+ nonzeroU2++
+ }
+ }
+ assert.Greater(t, zeroU1, 0)
+ assert.Greater(t, zeroU2, 0)
+ assert.Greater(t, nonzeroU2, 0)
+}
diff --git a/prog/prog.go b/prog/prog.go
index 1453e6aea..34fcfa6e8 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -267,6 +267,10 @@ type UnionArg struct {
ArgCommon
Option Arg
Index int // Index of the selected option in the union type.
+ // Used for unions with conditional fields.
+ // We first create a dummy arg with transient=True and then
+ // patch them.
+ transient bool
}
func MakeUnionArg(t Type, dir Dir, opt Arg, index int) *UnionArg {
diff --git a/prog/rand.go b/prog/rand.go
index 742dbaa7c..aca3163bd 100644
--- a/prog/rand.go
+++ b/prog/rand.go
@@ -886,7 +886,7 @@ func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*C
if a.isConditional() {
// Conditions may reference other fields that may not have already
// been generated. We'll fill them in later.
- return a.DefaultArg(dir), nil
+ return a.DefaultTransientArg(dir), nil
}
index := r.Intn(len(a.Fields))
optType, optDir := a.Fields[index].Type, a.Fields[index].Dir(dir)
diff --git a/prog/types.go b/prog/types.go
index 56e999f9e..00440163b 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -738,6 +738,12 @@ func (t *UnionType) DefaultArg(dir Dir) Arg {
return MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), idx)
}
+func (t *UnionType) DefaultTransientArg(dir Dir) Arg {
+ unionArg := t.DefaultArg(dir).(*UnionArg)
+ unionArg.transient = true
+ return unionArg
+}
+
func (t *UnionType) defaultField() int {
// If it's a conditional union, the last field will be the default value.
if t.isConditional() {
diff --git a/sys/test/expressions.txt b/sys/test/expressions.txt
index c99c1364b..16d5b96a7 100644
--- a/sys/test/expressions.txt
+++ b/sys/test/expressions.txt
@@ -56,3 +56,15 @@ conditional_struct_minimize {
} [packed]
test$conditional_struct_minimize(a ptr[in, conditional_struct_minimize])
+
+conditional_union [
+ u1 int8 (if[value[conditional_union_parent:f1] == 0])
+ u2 int8
+]
+
+conditional_union_parent {
+ f1 int8:1
+ f2 conditional_union
+}
+
+test$conditional_union(a ptr[in, conditional_union_parent])