aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/spf13/cobra/flag_groups.go
diff options
context:
space:
mode:
authorTaras Madan <tarasmadan@google.com>2024-09-10 12:16:33 +0200
committerTaras Madan <tarasmadan@google.com>2024-09-10 14:05:26 +0000
commitc97c816133b42257d0bcf1ee4bd178bb2a7a2b9e (patch)
tree0bcbc2e540bbf8f62f6c17887cdd53b8c2cee637 /vendor/github.com/spf13/cobra/flag_groups.go
parent54e657429ab892ad06c90cd7c1a4eb33ba93a3dc (diff)
vendor: update
Diffstat (limited to 'vendor/github.com/spf13/cobra/flag_groups.go')
-rw-r--r--vendor/github.com/spf13/cobra/flag_groups.go86
1 files changed, 76 insertions, 10 deletions
diff --git a/vendor/github.com/spf13/cobra/flag_groups.go b/vendor/github.com/spf13/cobra/flag_groups.go
index b35fde155..560612fd3 100644
--- a/vendor/github.com/spf13/cobra/flag_groups.go
+++ b/vendor/github.com/spf13/cobra/flag_groups.go
@@ -23,8 +23,9 @@ import (
)
const (
- requiredAsGroup = "cobra_annotation_required_if_others_set"
- mutuallyExclusive = "cobra_annotation_mutually_exclusive"
+ requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
+ oneRequiredAnnotation = "cobra_annotation_one_required"
+ mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
)
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
@@ -36,7 +37,23 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
- if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
+ if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
+ // Only errs if the flag isn't found.
+ panic(err)
+ }
+ }
+}
+
+// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
+// if the command is invoked without at least one flag from the given set of flags.
+func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
+ c.mergePersistentFlags()
+ for _, v := range flagNames {
+ f := c.Flags().Lookup(v)
+ if f == nil {
+ panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
+ }
+ if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
@@ -53,13 +70,13 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
- if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
+ if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}
-// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
+// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
if c.DisableFlagParsing {
@@ -71,15 +88,20 @@ func (c *Command) ValidateFlagGroups() error {
// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
+ oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
- processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
- processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
+ processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
+ processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
+ processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
if err := validateRequiredFlagGroups(groupStatus); err != nil {
return err
}
+ if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
+ return err
+ }
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
@@ -108,7 +130,7 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota
continue
}
- groupStatus[group] = map[string]bool{}
+ groupStatus[group] = make(map[string]bool, len(flagnames))
for _, name := range flagnames {
groupStatus[group][name] = false
}
@@ -142,6 +164,27 @@ func validateRequiredFlagGroups(data map[string]map[string]bool) error {
return nil
}
+func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
+ keys := sortedKeys(data)
+ for _, flagList := range keys {
+ flagnameAndStatus := data[flagList]
+ var set []string
+ for flagname, isSet := range flagnameAndStatus {
+ if isSet {
+ set = append(set, flagname)
+ }
+ }
+ if len(set) >= 1 {
+ continue
+ }
+
+ // Sort values, so they can be tested/scripted against consistently.
+ sort.Strings(set)
+ return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
+ }
+ return nil
+}
+
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
@@ -176,6 +219,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
// enforceFlagGroupsForCompletion will do the following:
// - when a flag in a group is present, other flags in the group will be marked required
+// - when none of the flags in a one-required group are present, all flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
@@ -185,10 +229,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
flags := c.Flags()
groupStatus := map[string]map[string]bool{}
+ oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
- processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
- processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
+ processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
+ processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
+ processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
// If a flag that is part of a group is present, we make all the other flags
@@ -204,6 +250,26 @@ func (c *Command) enforceFlagGroupsForCompletion() {
}
}
+ // If none of the flags of a one-required group are present, we make all the flags
+ // of that group required so that the shell completion suggests them automatically
+ for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
+ isSet := false
+
+ for _, isSet = range flagnameAndStatus {
+ if isSet {
+ break
+ }
+ }
+
+ // None of the flags of the group are set, mark all flags in the group
+ // as required
+ if !isSet {
+ for _, fName := range strings.Split(flagList, " ") {
+ _ = c.MarkFlagRequired(fName)
+ }
+ }
+ }
+
// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {