aboutsummaryrefslogtreecommitdiffstats
path: root/prog/expr.go
diff options
context:
space:
mode:
Diffstat (limited to 'prog/expr.go')
-rw-r--r--prog/expr.go80
1 files changed, 48 insertions, 32 deletions
diff --git a/prog/expr.go b/prog/expr.go
index dc2279cac..5ecbef8eb 100644
--- a/prog/expr.go
+++ b/prog/expr.go
@@ -90,11 +90,12 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan
replace := map[Arg]Arg{}
makeArgFinder := argFinderConstructor(r.target, c)
forEachStaleUnion(r.target, c, makeArgFinder,
- func(unionArg *UnionArg, unionType *UnionType, needIdx int) {
- newType, newDir := unionType.Fields[needIdx].Type,
- unionType.Fields[needIdx].Dir(unionArg.Dir())
+ func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
+ idx := okIndices[r.Intn(len(okIndices))]
+ newType, newDir := unionType.Fields[idx].Type,
+ unionType.Fields[idx].Dir(unionArg.Dir())
newTypeArg, newCalls := r.generateArg(s, newType, newDir)
- replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, needIdx)
+ replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx)
extraCalls = append(extraCalls, newCalls...)
anyPatched = true
})
@@ -112,7 +113,7 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan
}
func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) ArgFinder,
- cb func(*UnionArg, *UnionType, int)) {
+ cb func(*UnionArg, *UnionType, []int)) {
ForeachArg(c, func(arg Arg, argCtx *ArgCtx) {
if target.isAnyPtr(arg.Type()) {
argCtx.Stop = true
@@ -127,37 +128,48 @@ func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) Ar
return
}
argFinder := makeArgFinder(unionArg)
- needIdx, ok := calculateUnionArg(unionArg, unionType, argFinder)
- if !ok {
+ ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder)
+ if !calculated {
// Let it stay as is.
return
}
- if unionArg.Index == needIdx {
- // No changes are needed.
+ if !unionArg.transient && ok {
return
}
- cb(unionArg, unionType, needIdx)
- argCtx.Stop = true
+ matchingIndices := matchingUnionArgs(unionType, argFinder)
+ if len(matchingIndices) == 0 {
+ // Conditional fields are transformed in such a way
+ // that one field always matches.
+ // For unions we demand that there's a field w/o conditions.
+ panic(fmt.Sprintf("no matching union fields: %#v", unionType))
+ }
+ cb(unionArg, unionType, matchingIndices)
})
}
-func calculateUnionArg(arg *UnionArg, typ *UnionType, finder ArgFinder) (int, bool) {
- defaultIdx := typ.defaultField()
- for i, field := range typ.Fields {
- if field.Condition == nil {
- continue
- }
- val, ok := field.Condition.Evaluate(finder)
- if !ok {
- // We could not calculate the expression.
- // Let the union stay as it was.
- return defaultIdx, false
- }
- if val != 0 {
- return i, true
+func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) {
+ field := typ.Fields[idx]
+ if field.Condition == nil {
+ return true, true
+ }
+ val, ok := field.Condition.Evaluate(finder)
+ if !ok {
+ // We could not calculate the expression.
+ // Let the union stay as it was.
+ return true, false
+ }
+ return val != 0, true
+}
+
+func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int {
+ var ret []int
+ for i := range typ.Fields {
+ ok, _ := checkUnionArg(i, typ, finder)
+ if ok {
+ ret = append(ret, i)
}
}
- return defaultIdx, true
+ return ret
}
func (p *Prog) checkConditions() error {
@@ -174,11 +186,13 @@ var ErrViolatedConditions = errors.New("conditional fields rules violation")
func (c *Call) checkConditions(target *Target) error {
var ret error
+
makeArgFinder := argFinderConstructor(target, c)
forEachStaleUnion(target, c, makeArgFinder,
- func(a *UnionArg, t *UnionType, need int) {
- ret = fmt.Errorf("%w union %s field is %s, but %s satisfies conditions",
- ErrViolatedConditions, t.Name(), t.Fields[a.Index].Name, t.Fields[need].Name)
+ func(a *UnionArg, t *UnionType, okIndices []int) {
+ ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions",
+ ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name,
+ okIndices)
})
return ret
}
@@ -190,12 +204,14 @@ func (c *Call) setDefaultConditions(target *Target) bool {
replace := map[Arg]Arg{}
makeArgFinder := argFinderConstructor(target, c)
forEachStaleUnion(target, c, makeArgFinder,
- func(unionArg *UnionArg, unionType *UnionType, needIdx int) {
- field := unionType.Fields[needIdx]
+ func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
+ // If several union options match, take the first one.
+ idx := okIndices[0]
+ field := unionType.Fields[idx]
replace[unionArg] = MakeUnionArg(unionType,
unionArg.Dir(),
field.DefaultArg(field.Dir(unionArg.Dir())),
- needIdx)
+ idx)
})
for old, new := range replace {
anyReplaced = true