aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/tool
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-12-16 08:57:04 +0100
committerDmitry Vyukov <dvyukov@google.com>2020-12-25 10:12:41 +0100
commit3bcdec13657598f6a6163c7ddecff58c2d3a2a71 (patch)
tree3ff27333aeecc6eb7be333bc4407647792968ff1 /pkg/tool
parent80795712865ca86bb21ebb9841598ccbcbd375c9 (diff)
pkg/tool: add package
Package tool contains various helper utilitites useful for implementation of command line tools. Currently it contains Fail/Failf functions that we commonly use and new support for optional command line flags.
Diffstat (limited to 'pkg/tool')
-rw-r--r--pkg/tool/flags.go132
-rw-r--r--pkg/tool/flags_fuzz.go28
-rw-r--r--pkg/tool/flags_test.go62
-rw-r--r--pkg/tool/tool.go19
4 files changed, 241 insertions, 0 deletions
diff --git a/pkg/tool/flags.go b/pkg/tool/flags.go
new file mode 100644
index 000000000..197494823
--- /dev/null
+++ b/pkg/tool/flags.go
@@ -0,0 +1,132 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package tool
+
+import (
+ "bytes"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/google/syzkaller/pkg/log"
+)
+
+// ParseFlags parses command line flags with flag.Parse and then applies optional flags created with OptionalFlags.
+// This is intended for programmatic use only when we invoke older versions of binaries with new unsupported flags.
+func ParseFlags() {
+ if err := parseFlags(flag.CommandLine, os.Args[1:]); err != nil {
+ Fail(err)
+ }
+}
+
+type Flag struct {
+ Name string
+ Value string
+}
+
+// OptionalFlags produces command line flag value that encapsulates the given flags as optional.
+// Use ParseFlags to support optional flags in the binary.
+// The format keeps flags reasonably readable ("-optional=foo=bar:baz=123"), not subject to accidental splitting
+// into multiple arguments due to spaces and supports bool/non-bool flags.
+func OptionalFlags(flags ...Flag) string {
+ return fmt.Sprintf("-%v=%v", optionalFlag, serializeFlags(flags))
+}
+
+func parseFlags(set *flag.FlagSet, args []string) error {
+ optional := set.String(optionalFlag, "", "optional flags for programmatic use only")
+ if err := set.Parse(args); err != nil {
+ return err
+ }
+ flags, err := deserializeFlags(*optional)
+ if err != nil {
+ return err
+ }
+ for _, f := range flags {
+ ff := set.Lookup(f.Name)
+ if ff == nil {
+ log.Logf(0, "ignoring optional flag %q=%q", f.Name, f.Value)
+ continue
+ }
+ if err := ff.Value.Set(f.Value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+const optionalFlag = "optional"
+
+func serializeFlags(flags []Flag) string {
+ if len(flags) == 0 {
+ return ""
+ }
+ buf := new(bytes.Buffer)
+ for _, f := range flags {
+ fmt.Fprintf(buf, ":%v=%v", flagEscape(f.Name), flagEscape(f.Value))
+ }
+ return buf.String()[1:]
+}
+
+func deserializeFlags(value string) ([]Flag, error) {
+ if value == "" {
+ return nil, nil
+ }
+ var flags []Flag
+ for _, arg := range strings.Split(value, ":") {
+ eq := strings.IndexByte(arg, '=')
+ if eq == -1 {
+ return nil, fmt.Errorf("failed to parse flags %q: no eq", value)
+ }
+ name, err := flagUnescape(arg[:eq])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse flags %q: %v", value, err)
+ }
+ value, err := flagUnescape(arg[eq+1:])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse flags %q: %v", value, err)
+ }
+ flags = append(flags, Flag{name, value})
+ }
+ return flags, nil
+}
+
+func flagEscape(s string) string {
+ buf := new(bytes.Buffer)
+ for i := 0; i < len(s); i++ {
+ ch := s[i]
+ if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' || ch == '\\' {
+ buf.Write([]byte{'\\', 'x'})
+ buf.WriteString(hex.EncodeToString([]byte{ch}))
+ continue
+ }
+ buf.WriteByte(ch)
+ }
+ return buf.String()
+}
+
+func flagUnescape(s string) (string, error) {
+ buf := new(bytes.Buffer)
+ for i := 0; i < len(s); i++ {
+ ch := s[i]
+ if ch <= 0x20 || ch >= 0x7f || ch == ':' || ch == '=' {
+ return "", fmt.Errorf("unescaped char %v", ch)
+ }
+ if ch == '\\' {
+ if i+4 > len(s) || s[i+1] != 'x' {
+ return "", fmt.Errorf("truncated escape sequence")
+ }
+ res, err := hex.DecodeString(s[i+2 : i+4])
+ if err != nil {
+ return "", err
+ }
+ buf.WriteByte(res[0])
+ i += 3
+ continue
+ }
+ buf.WriteByte(ch)
+ }
+ return buf.String(), nil
+}
diff --git a/pkg/tool/flags_fuzz.go b/pkg/tool/flags_fuzz.go
new file mode 100644
index 000000000..3f0bb010a
--- /dev/null
+++ b/pkg/tool/flags_fuzz.go
@@ -0,0 +1,28 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package tool
+
+import (
+ "reflect"
+ "strings"
+)
+
+func FuzzParseFlags(data []byte) int {
+ flags, err := deserializeFlags(string(data))
+ if err != nil {
+ return 0
+ }
+ value := serializeFlags(flags)
+ if strings.IndexByte(value, ' ') != -1 {
+ panic("flags contain space")
+ }
+ flags1, err := deserializeFlags(value)
+ if err != nil {
+ panic(err)
+ }
+ if !reflect.DeepEqual(flags, flags1) {
+ panic("changed")
+ }
+ return 1
+}
diff --git a/pkg/tool/flags_test.go b/pkg/tool/flags_test.go
new file mode 100644
index 000000000..6d52a4780
--- /dev/null
+++ b/pkg/tool/flags_test.go
@@ -0,0 +1,62 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package tool
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestParseFlags(t *testing.T) {
+ type Values struct {
+ Foo bool
+ Bar int
+ Baz string
+ }
+ type Test struct {
+ args string
+ vals *Values
+ }
+ tests := []Test{
+ {"", &Values{false, 1, "baz"}},
+ {"-foo -bar=2", &Values{true, 2, "baz"}},
+ {"-foo -bar=2 -qux", nil},
+ {"-foo -bar=2 " + OptionalFlags(Flag{"qux", ""}), &Values{true, 2, "baz"}},
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ vals := new(Values)
+ flags := flag.NewFlagSet("", flag.ContinueOnError)
+ flags.SetOutput(ioutil.Discard)
+ flags.BoolVar(&vals.Foo, "foo", false, "")
+ flags.IntVar(&vals.Bar, "bar", 1, "")
+ flags.StringVar(&vals.Baz, "baz", "baz", "")
+ args := append(strings.Split(test.args, " "), "arg0", "arg1")
+ if args[0] == "" {
+ args = args[1:]
+ }
+ err := parseFlags(flags, args)
+ if test.vals == nil {
+ if err == nil {
+ t.Fatalf("parsing did not fail")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("parsing failed: %v", err)
+ }
+ if diff := cmp.Diff(test.vals, vals); diff != "" {
+ t.Fatal(diff)
+ }
+ if flags.NArg() != 2 || flags.Arg(0) != "arg0" || flags.Arg(1) != "arg1" {
+ t.Fatalf("bad args: %q", flags.Args())
+ }
+ })
+ }
+}
diff --git a/pkg/tool/tool.go b/pkg/tool/tool.go
new file mode 100644
index 000000000..bb1f436cc
--- /dev/null
+++ b/pkg/tool/tool.go
@@ -0,0 +1,19 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+// Package tool contains various helper utilitites useful for implementation of command line tools.
+package tool
+
+import (
+ "fmt"
+ "os"
+)
+
+func Failf(msg string, args ...interface{}) {
+ fmt.Fprintf(os.Stderr, msg+"\n", args...)
+ os.Exit(1)
+}
+
+func Fail(err error) {
+ Failf("%v", err)
+}