diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2020-09-14 08:31:24 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2020-09-14 10:38:36 +0200 |
| commit | bf14d79b03223b0a9178c69b033355a73b5ed7b2 (patch) | |
| tree | 3019582f607d4b141b17ffb9839d160eef752ab1 /pkg/osutil | |
| parent | fab7609913c9787bdb79602ff716f5e0d1598c98 (diff) | |
pkg/osutil: support glob patterns in CopyFiles
Follow up to #2053
Diffstat (limited to 'pkg/osutil')
| -rw-r--r-- | pkg/osutil/osutil.go | 33 | ||||
| -rw-r--r-- | pkg/osutil/osutil_test.go | 97 |
2 files changed, 120 insertions, 10 deletions
diff --git a/pkg/osutil/osutil.go b/pkg/osutil/osutil.go index 1eafafcb9..340c51de5 100644 --- a/pkg/osutil/osutil.go +++ b/pkg/osutil/osutil.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "syscall" "time" ) @@ -141,7 +142,7 @@ func FilesExist(dir string, files map[string]bool) bool { } // CopyFiles copies files from srcDir to dstDir as atomically as possible. -// Files are assumed to be relative names in slash notation. +// Files are assumed to be relative glob patterns in slash notation in srcDir. // All other files in dstDir are removed. func CopyFiles(srcDir, dstDir string, files map[string]bool) error { // Linux does not support atomic dir replace, so we copy to tmp dir first. @@ -153,17 +154,29 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error { if err := MkdirAll(tmpDir); err != nil { return err } - for f, required := range files { - src := filepath.Join(srcDir, filepath.FromSlash(f)) - if !required && !IsExist(src) { - continue - } - dst := filepath.Join(tmpDir, filepath.FromSlash(f)) - if err := MkdirAll(filepath.Dir(dst)); err != nil { + srcDir = filepath.Clean(srcDir) + for pattern, required := range files { + files, err := filepath.Glob(filepath.Join(srcDir, filepath.FromSlash(pattern))) + if err != nil { return err } - if err := CopyFile(src, dst); err != nil { - return err + if len(files) == 0 { + if !required { + continue + } + return fmt.Errorf("file %v does not exist", pattern) + } + for _, file := range files { + if !strings.HasPrefix(file, srcDir) { + return fmt.Errorf("file %q matched from %q in %q doesn't have src prefix", file, pattern, srcDir) + } + dst := filepath.Join(tmpDir, strings.TrimPrefix(file, srcDir)) + if err := MkdirAll(filepath.Dir(dst)); err != nil { + return err + } + if err := CopyFile(file, dst); err != nil { + return err + } } } if err := os.RemoveAll(dstDir); err != nil { diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go index ee4cb6705..3d367a0ca 100644 --- a/pkg/osutil/osutil_test.go +++ b/pkg/osutil/osutil_test.go @@ -4,7 +4,11 @@ package osutil import ( + "fmt" + "io/ioutil" "os" + "path/filepath" + "strings" "testing" ) @@ -16,3 +20,96 @@ func TestIsExist(t *testing.T) { t.Fatalf("file %v exists", f) } } + +func TestCopyFiles(t *testing.T) { + type Test struct { + files []string + patterns map[string]bool + err string + } + tests := []Test{ + { + files: []string{ + "foo", + "bar", + "baz/foo", + "baz/bar", + }, + patterns: map[string]bool{ + "foo": true, + "bar": false, + "qux": false, + "baz/foo": true, + "baz/bar": false, + }, + }, + { + files: []string{ + "foo", + }, + patterns: map[string]bool{ + "bar": true, + }, + err: "file bar does not exist", + }, + { + files: []string{ + "baz/foo", + "baz/bar", + }, + patterns: map[string]bool{ + "baz/*": true, + }, + }, + { + files: []string{ + "qux/foo/foo", + "qux/foo/bar", + "qux/bar/foo", + "qux/bar/bar", + }, + patterns: map[string]bool{ + "qux/*/*": false, + }, + }, + } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + dir, err := ioutil.TempDir("", "syz-osutil-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + src := filepath.Join(dir, "src") + dst := filepath.Join(dir, "dst") + for _, file := range test.files { + file = filepath.Join(src, filepath.FromSlash(file)) + if err := MkdirAll(filepath.Dir(file)); err != nil { + t.Fatal(err) + } + if err := WriteFile(file, []byte{'a'}); err != nil { + t.Fatal(err) + } + } + if err := CopyFiles(src, dst, test.patterns); err != nil { + if test.err != "" { + if strings.Contains(err.Error(), test.err) { + return + } + t.Fatalf("got err %q, want %q", err, test.err) + } + t.Fatal(err) + } else if test.err != "" { + t.Fatalf("got no err, want %q", test.err) + } + if err := os.RemoveAll(src); err != nil { + t.Fatal(err) + } + for _, file := range test.files { + if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) { + t.Fatalf("%v does not exist in dst", file) + } + } + }) + } +} |
