aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pkg/ast/ast.go33
-rw-r--r--pkg/compiler/compiler.go30
2 files changed, 51 insertions, 12 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index e225c69b2..2e8a7a015 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -27,6 +27,15 @@ type Node interface {
walk(cb func(Node))
}
+type Flags[T FlagValue] interface {
+ SetValues(values []T)
+ GetValues() []T
+}
+
+type FlagValue interface {
+ GetName() string
+}
+
// Top-level AST nodes.
type NewLine struct {
@@ -135,6 +144,14 @@ func (n *IntFlags) Info() (Pos, string, string) {
return n.Pos, "flags", n.Name.Name
}
+func (n *IntFlags) SetValues(values []*Int) {
+ n.Values = values
+}
+
+func (n *IntFlags) GetValues() []*Int {
+ return n.Values
+}
+
type StrFlags struct {
Pos Pos
Name *Ident
@@ -145,6 +162,14 @@ func (n *StrFlags) Info() (Pos, string, string) {
return n.Pos, "string flags", n.Name.Name
}
+func (n *StrFlags) SetValues(values []*String) {
+ n.Values = values
+}
+
+func (n *StrFlags) GetValues() []*String {
+ return n.Values
+}
+
type TypeDef struct {
Pos Pos
Name *Ident
@@ -180,6 +205,10 @@ func (n *String) Info() (Pos, string, string) {
return n.Pos, tok2str[tokString], ""
}
+func (n *String) GetName() string {
+ return n.Value
+}
+
type IntFmt int
const (
@@ -210,6 +239,10 @@ func (n *Int) Info() (Pos, string, string) {
return n.Pos, tok2str[tokInt], ""
}
+func (n *Int) GetName() string {
+ return n.Ident
+}
+
type Type struct {
Pos Pos
// Only one of Value, Ident, String is filled.
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go
index 2e13988b7..66da5e2b4 100644
--- a/pkg/compiler/compiler.go
+++ b/pkg/compiler/compiler.go
@@ -360,11 +360,7 @@ func arrayContains(a []string, v string) bool {
}
func (comp *compiler) flattenFlags() {
- for name, flags := range comp.intFlags {
- if err := comp.recurFlattenFlags(name, flags, map[string]bool{}); err != nil {
- comp.error(flags.Pos, "%v", err)
- }
- }
+ comp.flattenIntFlags()
for _, n := range comp.desc.Nodes {
switch flags := n.(type) {
@@ -379,19 +375,29 @@ func (comp *compiler) flattenFlags() {
}
}
-func (comp *compiler) recurFlattenFlags(name string, flags *ast.IntFlags, visitedFlags map[string]bool) error {
+func (comp *compiler) flattenIntFlags() {
+ for name, flags := range comp.intFlags {
+ if err := recurFlattenFlags[*ast.IntFlags, *ast.Int](comp, name, flags, comp.intFlags,
+ map[string]bool{}); err != nil {
+ comp.error(flags.Pos, "%v", err)
+ }
+ }
+}
+
+func recurFlattenFlags[F ast.Flags[V], V ast.FlagValue](comp *compiler, name string, flags F,
+ allFlags map[string]F, visitedFlags map[string]bool) error {
if _, visited := visitedFlags[name]; visited {
return fmt.Errorf("flags %v used twice or circular dependency on %v", name, name)
}
visitedFlags[name] = true
- var values []*ast.Int
- for _, flag := range flags.Values {
- if f, isFlags := comp.intFlags[flag.Ident]; isFlags {
- if err := comp.recurFlattenFlags(flag.Ident, f, visitedFlags); err != nil {
+ var values []V
+ for _, flag := range flags.GetValues() {
+ if f, isFlags := allFlags[flag.GetName()]; isFlags {
+ if err := recurFlattenFlags[F, V](comp, flag.GetName(), f, allFlags, visitedFlags); err != nil {
return err
}
- values = append(values, comp.intFlags[flag.Ident].Values...)
+ values = append(values, allFlags[flag.GetName()].GetValues()...)
} else {
values = append(values, flag)
}
@@ -399,6 +405,6 @@ func (comp *compiler) recurFlattenFlags(name string, flags *ast.IntFlags, visite
if len(values) > 2000 {
return fmt.Errorf("%v has more than 2000 values", name)
}
- flags.Values = values
+ flags.SetValues(values)
return nil
}