aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/tool/flags.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-12-16 08:57:04 +0100
committerDmitry Vyukov <dvyukov@google.com>2020-12-25 10:12:41 +0100
commit3bcdec13657598f6a6163c7ddecff58c2d3a2a71 (patch)
tree3ff27333aeecc6eb7be333bc4407647792968ff1 /pkg/tool/flags.go
parent80795712865ca86bb21ebb9841598ccbcbd375c9 (diff)
pkg/tool: add package
Package tool contains various helper utilitites useful for implementation of command line tools. Currently it contains Fail/Failf functions that we commonly use and new support for optional command line flags.
Diffstat (limited to 'pkg/tool/flags.go')
-rw-r--r--pkg/tool/flags.go132
1 files changed, 132 insertions, 0 deletions
diff --git a/pkg/tool/flags.go b/pkg/tool/flags.go
new file mode 100644
index 000000000..197494823
--- /dev/null
+++ b/pkg/tool/flags.go
@@ -0,0 +1,132 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package tool
+
+import (
+ "bytes"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/google/syzkaller/pkg/log"
+)
+
+// ParseFlags parses command line flags with flag.Parse and then applies optional flags created with OptionalFlags.
+// This is intended for programmatic use only when we invoke older versions of binaries with new unsupported flags.
+func ParseFlags() {
+ if err := parseFlags(flag.CommandLine, os.Args[1:]); err != nil {
+ Fail(err)
+ }
+}
+
+type Flag struct {
+ Name string
+ Value string
+}
+
+// OptionalFlags produces command line flag value that encapsulates the given flags as optional.
+// Use ParseFlags to support optional flags in the binary.
+// The format keeps flags reasonably readable ("-optional=foo=bar:baz=123"), not subject to accidental splitting
+// into multiple arguments due to spaces and supports bool/non-bool flags.
+func OptionalFlags(flags ...Flag) string {
+ return fmt.Sprintf("-%v=%v", optionalFlag, serializeFlags(flags))
+}
+
+func parseFlags(set *flag.FlagSet, args []string) error {
+ optional := set.String(optionalFlag, "", "optional flags for programmatic use only")
+ if err := set.Parse(args); err != nil {
+ return err
+ }
+ flags, err := deserializeFlags(*optional)
+ if err != nil {
+ return err
+ }
+ for _, f := range flags {
+ ff := set.Lookup(f.Name)
+ if ff == nil {
+ log.Logf(0, "ignoring optional flag %q=%q", f.Name, f.Value)
+ continue
+ }
+ if err := ff.Value.Set(f.Value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+const optionalFlag = "optional"
+
+func serializeFlags(flags []Flag) string {
+ if len(flags) == 0 {
+ return ""
+ }
+ buf := new(bytes.Buffer)
+ for _, f := range flags {
+ fmt.Fprintf(buf, ":%v=%v", flagEscape(f.Name), flagEscape(f.Value))
+ }
+ return buf.String()[1:]
+}
+
+func deserializeFlags(value string) ([]Flag, error) {
+ if value == "" {
+ return nil, nil
+ }
+ var flags []Flag
+ for _, arg := range strings.Split(value, ":") {
+ eq := strings.IndexByte(arg, '=')
+ if eq == -1 {
+ return nil, fmt.Errorf("failed to parse flags %q: no eq", value)
+ }
+ name, err := flagUnescape(arg[:eq])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse flags %q: %v", value, err)
+ }
+ value, err := flagUnescape(arg[eq+1:])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse flags %q: %v", value, err)
+ }
+ flags = append(flags, Flag{name, value})
+ }
+ return flags, nil
+}
+
+func flagEscape(s string) string {
+ buf := new(bytes.Buffer)
+ for i := 0; i < len(s); i++ {
+ ch := s[i]
+ if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' || ch == '\\' {
+ buf.Write([]byte{'\\', 'x'})
+ buf.WriteString(hex.EncodeToString([]byte{ch}))
+ continue
+ }
+ buf.WriteByte(ch)
+ }
+ return buf.String()
+}
+
+func flagUnescape(s string) (string, error) {
+ buf := new(bytes.Buffer)
+ for i := 0; i < len(s); i++ {
+ ch := s[i]
+ if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' {
+ return "", fmt.Errorf("unescaped char %v", ch)
+ }
+ if ch == '\\' {
+ if i+4 > len(s) || s[i+1] != 'x' {
+ return "", fmt.Errorf("truncated escape sequence")
+ }
+ res, err := hex.DecodeString(s[i+2 : i+4])
+ if err != nil {
+ return "", err
+ }
+ buf.WriteByte(res[0])
+ i += 3
+ continue
+ }
+ buf.WriteByte(ch)
+ }
+ return buf.String(), nil
+}