From 3e9d168edd9e40138b295bb7d8adc92fa9430e78 Mon Sep 17 00:00:00 2001 From: Paul Chaignon Date: Fri, 1 Dec 2023 14:46:41 +0100 Subject: compiler: refactor recurFlattenFlags This commit refactors recurFlattenFlags using Go generics and new interfaces so that it also applies to a different set of flags types. In a subsequent commit, we will use that to perform the same recursive flattening for string flags. Signed-off-by: Paul Chaignon --- pkg/compiler/compiler.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) (limited to 'pkg/compiler') diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 2e13988b7..66da5e2b4 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -360,11 +360,7 @@ func arrayContains(a []string, v string) bool { } func (comp *compiler) flattenFlags() { - for name, flags := range comp.intFlags { - if err := comp.recurFlattenFlags(name, flags, map[string]bool{}); err != nil { - comp.error(flags.Pos, "%v", err) - } - } + comp.flattenIntFlags() for _, n := range comp.desc.Nodes { switch flags := n.(type) { @@ -379,19 +375,29 @@ func (comp *compiler) flattenFlags() { } } -func (comp *compiler) recurFlattenFlags(name string, flags *ast.IntFlags, visitedFlags map[string]bool) error { +func (comp *compiler) flattenIntFlags() { + for name, flags := range comp.intFlags { + if err := recurFlattenFlags[*ast.IntFlags, *ast.Int](comp, name, flags, comp.intFlags, + map[string]bool{}); err != nil { + comp.error(flags.Pos, "%v", err) + } + } +} + +func recurFlattenFlags[F ast.Flags[V], V ast.FlagValue](comp *compiler, name string, flags F, + allFlags map[string]F, visitedFlags map[string]bool) error { if _, visited := visitedFlags[name]; visited { return fmt.Errorf("flags %v used twice or circular dependency on %v", name, name) } visitedFlags[name] = true - var values []*ast.Int - for _, flag := range flags.Values { - if f, isFlags := comp.intFlags[flag.Ident]; isFlags { - if err := comp.recurFlattenFlags(flag.Ident, f, visitedFlags); err != nil { + var values []V + for _, flag := range flags.GetValues() { + if f, isFlags := allFlags[flag.GetName()]; isFlags { + if err := recurFlattenFlags[F, V](comp, flag.GetName(), f, allFlags, visitedFlags); err != nil { return err } - values = append(values, comp.intFlags[flag.Ident].Values...) + values = append(values, allFlags[flag.GetName()].GetValues()...) } else { values = append(values, flag) } @@ -399,6 +405,6 @@ func (comp *compiler) recurFlattenFlags(name string, flags *ast.IntFlags, visite if len(values) > 2000 { return fmt.Errorf("%v has more than 2000 values", name) } - flags.Values = values + flags.SetValues(values) return nil } -- cgit mrf-deployment