aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
Diffstat (limited to 'prog')
-rw-r--r--prog/expr.go5
-rw-r--r--prog/expr_test.go14
-rw-r--r--prog/types.go28
3 files changed, 38 insertions, 9 deletions
diff --git a/prog/expr.go b/prog/expr.go
index fbffd3578..699adff3a 100644
--- a/prog/expr.go
+++ b/prog/expr.go
@@ -205,8 +205,11 @@ func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool {
if transientOnly && !unionArg.transient {
return
}
- // If several union options match, take the first one.
idx := okIndices[0]
+ if defIdx, ok := unionType.defaultField(); ok {
+ // If there's a default value available, use it.
+ idx = defIdx
+ }
field := unionType.Fields[idx]
replace[unionArg] = MakeUnionArg(unionType,
unionArg.Dir(),
diff --git a/prog/expr_test.go b/prog/expr_test.go
index 1558248f5..b536d05e2 100644
--- a/prog/expr_test.go
+++ b/prog/expr_test.go
@@ -10,6 +10,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestGenerateConditionalFields(t *testing.T) {
@@ -309,3 +310,16 @@ func TestNestedConditionalCall(t *testing.T) {
}
}
}
+
+func TestDefaultConditionalSerialize(t *testing.T) {
+ // Serialize() omits default-valued fields for a more compact representation,
+ // but that shouldn't mess with the selected option (see #6105).
+ const rawProg = "test$parent_conditions(&(0x7f0000000000)={0xa3})\n"
+ target := initTargetTest(t, "test", "64")
+ prog, err := target.Deserialize([]byte(rawProg), NonStrict)
+ require.NoError(t, err)
+ serialized := prog.Serialize()
+ prog2, err := target.Deserialize(serialized, NonStrict)
+ require.NoError(t, err)
+ assert.Equal(t, rawProg, string(prog2.Serialize()))
+}
diff --git a/prog/types.go b/prog/types.go
index 3a8d478ce..5526eb2a6 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -815,31 +815,43 @@ func (t *UnionType) String() string {
}
func (t *UnionType) DefaultArg(dir Dir) Arg {
- idx := t.defaultField()
+ idx, _ := t.defaultField()
f := t.Fields[idx]
arg := MakeUnionArg(t, dir, f.DefaultArg(f.Dir(dir)), idx)
arg.transient = t.isConditional()
return arg
}
-func (t *UnionType) defaultField() int {
- // If it's a conditional union, the last field will be the default value.
+func (t *UnionType) defaultField() (int, bool) {
+ // If it's a conditional union, the last field is usually a safe choice for the default value as
+ // it must have no condition.
+ // Auto-generated wrappers for conditional fields are an exception since both fields will have
+ // conditions, and, moreover, these conditions will be mutually exclusive.
if t.isConditional() {
- return len(t.Fields) - 1
+ if t.Fields[len(t.Fields)-1].Condition != nil {
+ // There's no correct default index.
+ return 0, false
+ }
+ return len(t.Fields) - 1, true
}
// Otherwise, just take the first.
- return 0
+ return 0, true
}
func (t *UnionType) isConditional() bool {
- // In pkg/compiler, we ensure that either none of the fields have conditions,
- // or all except the last one.
+ // Either all fields will have a conditions, or all except the last one, or none.
+ // So checking for the first one is always enough.
return t.Fields[0].Condition != nil
}
func (t *UnionType) isDefaultArg(arg Arg) bool {
a := arg.(*UnionArg)
- return a.Index == t.defaultField() && isDefault(a.Option)
+ defIdx, ok := t.defaultField()
+ if !ok {
+ // Any value is the only possible option.
+ return isDefault(a.Option)
+ }
+ return a.Index == defIdx && isDefault(a.Option)
}
type ConstValue struct {