From 8cb16e665dbb5f87aa58856049c1ad6067dc6293 Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Wed, 17 Jan 2024 13:16:29 +0100 Subject: prog: handle multiple matching union fields If conditions of several union fields are satisfied, select one randomly. This would be a more logical semantics. When conditional struct fields are translated to unions, negate the condition for the union alternative. --- prog/expr.go | 80 ++++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 32 deletions(-) (limited to 'prog/expr.go') diff --git a/prog/expr.go b/prog/expr.go index dc2279cac..5ecbef8eb 100644 --- a/prog/expr.go +++ b/prog/expr.go @@ -90,11 +90,12 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan replace := map[Arg]Arg{} makeArgFinder := argFinderConstructor(r.target, c) forEachStaleUnion(r.target, c, makeArgFinder, - func(unionArg *UnionArg, unionType *UnionType, needIdx int) { - newType, newDir := unionType.Fields[needIdx].Type, - unionType.Fields[needIdx].Dir(unionArg.Dir()) + func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { + idx := okIndices[r.Intn(len(okIndices))] + newType, newDir := unionType.Fields[idx].Type, + unionType.Fields[idx].Dir(unionArg.Dir()) newTypeArg, newCalls := r.generateArg(s, newType, newDir) - replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, needIdx) + replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx) extraCalls = append(extraCalls, newCalls...) anyPatched = true }) @@ -112,7 +113,7 @@ func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, chan } func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) ArgFinder, - cb func(*UnionArg, *UnionType, int)) { + cb func(*UnionArg, *UnionType, []int)) { ForeachArg(c, func(arg Arg, argCtx *ArgCtx) { if target.isAnyPtr(arg.Type()) { argCtx.Stop = true @@ -127,37 +128,48 @@ func forEachStaleUnion(target *Target, c *Call, makeArgFinder func(*UnionArg) Ar return } argFinder := makeArgFinder(unionArg) - needIdx, ok := calculateUnionArg(unionArg, unionType, argFinder) - if !ok { + ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder) + if !calculated { // Let it stay as is. return } - if unionArg.Index == needIdx { - // No changes are needed. + if !unionArg.transient && ok { return } - cb(unionArg, unionType, needIdx) - argCtx.Stop = true + matchingIndices := matchingUnionArgs(unionType, argFinder) + if len(matchingIndices) == 0 { + // Conditional fields are transformed in such a way + // that one field always matches. + // For unions we demand that there's a field w/o conditions. + panic(fmt.Sprintf("no matching union fields: %#v", unionType)) + } + cb(unionArg, unionType, matchingIndices) }) } -func calculateUnionArg(arg *UnionArg, typ *UnionType, finder ArgFinder) (int, bool) { - defaultIdx := typ.defaultField() - for i, field := range typ.Fields { - if field.Condition == nil { - continue - } - val, ok := field.Condition.Evaluate(finder) - if !ok { - // We could not calculate the expression. - // Let the union stay as it was. - return defaultIdx, false - } - if val != 0 { - return i, true +func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) { + field := typ.Fields[idx] + if field.Condition == nil { + return true, true + } + val, ok := field.Condition.Evaluate(finder) + if !ok { + // We could not calculate the expression. + // Let the union stay as it was. + return true, false + } + return val != 0, true +} + +func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int { + var ret []int + for i := range typ.Fields { + ok, _ := checkUnionArg(i, typ, finder) + if ok { + ret = append(ret, i) } } - return defaultIdx, true + return ret } func (p *Prog) checkConditions() error { @@ -174,11 +186,13 @@ var ErrViolatedConditions = errors.New("conditional fields rules violation") func (c *Call) checkConditions(target *Target) error { var ret error + makeArgFinder := argFinderConstructor(target, c) forEachStaleUnion(target, c, makeArgFinder, - func(a *UnionArg, t *UnionType, need int) { - ret = fmt.Errorf("%w union %s field is %s, but %s satisfies conditions", - ErrViolatedConditions, t.Name(), t.Fields[a.Index].Name, t.Fields[need].Name) + func(a *UnionArg, t *UnionType, okIndices []int) { + ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions", + ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name, + okIndices) }) return ret } @@ -190,12 +204,14 @@ func (c *Call) setDefaultConditions(target *Target) bool { replace := map[Arg]Arg{} makeArgFinder := argFinderConstructor(target, c) forEachStaleUnion(target, c, makeArgFinder, - func(unionArg *UnionArg, unionType *UnionType, needIdx int) { - field := unionType.Fields[needIdx] + func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { + // If several union options match, take the first one. + idx := okIndices[0] + field := unionType.Fields[idx] replace[unionArg] = MakeUnionArg(unionType, unionArg.Dir(), field.DefaultArg(field.Dir(unionArg.Dir())), - needIdx) + idx) }) for old, new := range replace { anyReplaced = true -- cgit mrf-deployment