diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2020-12-16 08:57:04 +0100 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2020-12-25 10:12:41 +0100 |
| commit | 3bcdec13657598f6a6163c7ddecff58c2d3a2a71 (patch) | |
| tree | 3ff27333aeecc6eb7be333bc4407647792968ff1 /pkg/tool/flags.go | |
| parent | 80795712865ca86bb21ebb9841598ccbcbd375c9 (diff) | |
pkg/tool: add package
Package tool contains various helper utilitites useful for implementation of command line tools.
Currently it contains Fail/Failf functions that we commonly use
and new support for optional command line flags.
Diffstat (limited to 'pkg/tool/flags.go')
| -rw-r--r-- | pkg/tool/flags.go | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/pkg/tool/flags.go b/pkg/tool/flags.go new file mode 100644 index 000000000..197494823 --- /dev/null +++ b/pkg/tool/flags.go @@ -0,0 +1,132 @@ +// Copyright 2020 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package tool + +import ( + "bytes" + "encoding/hex" + "flag" + "fmt" + "os" + "strings" + + "github.com/google/syzkaller/pkg/log" +) + +// ParseFlags parses command line flags with flag.Parse and then applies optional flags created with OptionalFlags. +// This is intended for programmatic use only when we invoke older versions of binaries with new unsupported flags. +func ParseFlags() { + if err := parseFlags(flag.CommandLine, os.Args[1:]); err != nil { + Fail(err) + } +} + +type Flag struct { + Name string + Value string +} + +// OptionalFlags produces command line flag value that encapsulates the given flags as optional. +// Use ParseFlags to support optional flags in the binary. +// The format keeps flags reasonably readable ("-optional=foo=bar:baz=123"), not subject to accidental splitting +// into multiple arguments due to spaces and supports bool/non-bool flags. +func OptionalFlags(flags ...Flag) string { + return fmt.Sprintf("-%v=%v", optionalFlag, serializeFlags(flags)) +} + +func parseFlags(set *flag.FlagSet, args []string) error { + optional := set.String(optionalFlag, "", "optional flags for programmatic use only") + if err := set.Parse(args); err != nil { + return err + } + flags, err := deserializeFlags(*optional) + if err != nil { + return err + } + for _, f := range flags { + ff := set.Lookup(f.Name) + if ff == nil { + log.Logf(0, "ignoring optional flag %q=%q", f.Name, f.Value) + continue + } + if err := ff.Value.Set(f.Value); err != nil { + return err + } + } + return nil +} + +const optionalFlag = "optional" + +func serializeFlags(flags []Flag) string { + if len(flags) == 0 { + return "" + } + buf := new(bytes.Buffer) + for _, f := range flags { + fmt.Fprintf(buf, ":%v=%v", flagEscape(f.Name), flagEscape(f.Value)) + } + return buf.String()[1:] +} + +func deserializeFlags(value string) ([]Flag, error) { + if value == "" { + return nil, nil + } + var flags []Flag + for _, arg := range strings.Split(value, ":") { + eq := strings.IndexByte(arg, '=') + if eq == -1 { + return nil, fmt.Errorf("failed to parse flags %q: no eq", value) + } + name, err := flagUnescape(arg[:eq]) + if err != nil { + return nil, fmt.Errorf("failed to parse flags %q: %v", value, err) + } + value, err := flagUnescape(arg[eq+1:]) + if err != nil { + return nil, fmt.Errorf("failed to parse flags %q: %v", value, err) + } + flags = append(flags, Flag{name, value}) + } + return flags, nil +} + +func flagEscape(s string) string { + buf := new(bytes.Buffer) + for i := 0; i < len(s); i++ { + ch := s[i] + if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' || ch == '\\' { + buf.Write([]byte{'\\', 'x'}) + buf.WriteString(hex.EncodeToString([]byte{ch})) + continue + } + buf.WriteByte(ch) + } + return buf.String() +} + +func flagUnescape(s string) (string, error) { + buf := new(bytes.Buffer) + for i := 0; i < len(s); i++ { + ch := s[i] + if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' { + return "", fmt.Errorf("unescaped char %v", ch) + } + if ch == '\\' { + if i+4 > len(s) || s[i+1] != 'x' { + return "", fmt.Errorf("truncated escape sequence") + } + res, err := hex.DecodeString(s[i+2 : i+4]) + if err != nil { + return "", err + } + buf.WriteByte(res[0]) + i += 3 + continue + } + buf.WriteByte(ch) + } + return buf.String(), nil +} |
