diff options
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/expr.go | 5 | ||||
| -rw-r--r-- | prog/expr_test.go | 14 | ||||
| -rw-r--r-- | prog/types.go | 28 |
3 files changed, 38 insertions, 9 deletions
diff --git a/prog/expr.go b/prog/expr.go index fbffd3578..699adff3a 100644 --- a/prog/expr.go +++ b/prog/expr.go @@ -205,8 +205,11 @@ func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool { if transientOnly && !unionArg.transient { return } - // If several union options match, take the first one. idx := okIndices[0] + if defIdx, ok := unionType.defaultField(); ok { + // If there's a default value available, use it. + idx = defIdx + } field := unionType.Fields[idx] replace[unionArg] = MakeUnionArg(unionType, unionArg.Dir(), diff --git a/prog/expr_test.go b/prog/expr_test.go index 1558248f5..b536d05e2 100644 --- a/prog/expr_test.go +++ b/prog/expr_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGenerateConditionalFields(t *testing.T) { @@ -309,3 +310,16 @@ func TestNestedConditionalCall(t *testing.T) { } } } + +func TestDefaultConditionalSerialize(t *testing.T) { + // Serialize() omits default-valued fields for a more compact representation, + // but that shouldn't mess with the selected option (see #6105). + const rawProg = "test$parent_conditions(&(0x7f0000000000)={0xa3})\n" + target := initTargetTest(t, "test", "64") + prog, err := target.Deserialize([]byte(rawProg), NonStrict) + require.NoError(t, err) + serialized := prog.Serialize() + prog2, err := target.Deserialize(serialized, NonStrict) + require.NoError(t, err) + assert.Equal(t, rawProg, string(prog2.Serialize())) +} diff --git a/prog/types.go b/prog/types.go index 3a8d478ce..5526eb2a6 100644 --- a/prog/types.go +++ b/prog/types.go @@ -815,31 +815,43 @@ func (t *UnionType) String() string { } func (t *UnionType) DefaultArg(dir Dir) Arg { - idx := t.defaultField() + idx, _ := t.defaultField() f := t.Fields[idx] arg := MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), idx) arg.transient = t.isConditional() return arg } -func (t *UnionType) defaultField() int { - // If it's a conditional union, the last field will be the default value. +func (t *UnionType) defaultField() (int, bool) { + // If it's a conditional union, the last field is usually a safe choice for the default value as + // it must have no condition. + // Auto-generated wrappers for conditional fields are an exception since both fields will have + // conditions, and, moreover, these conditions will be mutually exclusive. if t.isConditional() { - return len(t.Fields) - 1 + if t.Fields[len(t.Fields)-1].Condition != nil { + // There's no correct default index. + return 0, false + } + return len(t.Fields) - 1, true } // Otherwise, just take the first. - return 0 + return 0, true } func (t *UnionType) isConditional() bool { - // In pkg/compiler, we ensure that either none of the fields have conditions, - // or all except the last one. + // Either all fields will have a conditions, or all except the last one, or none. + // So checking for the first one is always enough. return t.Fields[0].Condition != nil } func (t *UnionType) isDefaultArg(arg Arg) bool { a := arg.(*UnionArg) - return a.Index == t.defaultField() && isDefault(a.Option) + defIdx, ok := t.defaultField() + if !ok { + // Any value is the only possible option. + return isDefault(a.Option) + } + return a.Index == defIdx && isDefault(a.Option) } type ConstValue struct { |
