aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/osutil
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-09-14 08:31:24 +0200
committerDmitry Vyukov <dvyukov@google.com>2020-09-14 10:38:36 +0200
commitbf14d79b03223b0a9178c69b033355a73b5ed7b2 (patch)
tree3019582f607d4b141b17ffb9839d160eef752ab1 /pkg/osutil
parentfab7609913c9787bdb79602ff716f5e0d1598c98 (diff)
pkg/osutil: support glob patterns in CopyFiles
Follow up to #2053
Diffstat (limited to 'pkg/osutil')
-rw-r--r--pkg/osutil/osutil.go33
-rw-r--r--pkg/osutil/osutil_test.go97
2 files changed, 120 insertions, 10 deletions
diff --git a/pkg/osutil/osutil.go b/pkg/osutil/osutil.go
index 1eafafcb9..340c51de5 100644
--- a/pkg/osutil/osutil.go
+++ b/pkg/osutil/osutil.go
@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "strings"
"syscall"
"time"
)
@@ -141,7 +142,7 @@ func FilesExist(dir string, files map[string]bool) bool {
}
// CopyFiles copies files from srcDir to dstDir as atomically as possible.
-// Files are assumed to be relative names in slash notation.
+// Files are assumed to be relative glob patterns in slash notation in srcDir.
// All other files in dstDir are removed.
func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
// Linux does not support atomic dir replace, so we copy to tmp dir first.
@@ -153,17 +154,29 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
if err := MkdirAll(tmpDir); err != nil {
return err
}
- for f, required := range files {
- src := filepath.Join(srcDir, filepath.FromSlash(f))
- if !required && !IsExist(src) {
- continue
- }
- dst := filepath.Join(tmpDir, filepath.FromSlash(f))
- if err := MkdirAll(filepath.Dir(dst)); err != nil {
+ srcDir = filepath.Clean(srcDir)
+ for pattern, required := range files {
+ files, err := filepath.Glob(filepath.Join(srcDir, filepath.FromSlash(pattern)))
+ if err != nil {
return err
}
- if err := CopyFile(src, dst); err != nil {
- return err
+ if len(files) == 0 {
+ if !required {
+ continue
+ }
+ return fmt.Errorf("file %v does not exist", pattern)
+ }
+ for _, file := range files {
+ if !strings.HasPrefix(file, srcDir) {
+ return fmt.Errorf("file %q matched from %q in %q doesn't have src prefix", file, pattern, srcDir)
+ }
+ dst := filepath.Join(tmpDir, strings.TrimPrefix(file, srcDir))
+ if err := MkdirAll(filepath.Dir(dst)); err != nil {
+ return err
+ }
+ if err := CopyFile(file, dst); err != nil {
+ return err
+ }
}
}
if err := os.RemoveAll(dstDir); err != nil {
diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go
index ee4cb6705..3d367a0ca 100644
--- a/pkg/osutil/osutil_test.go
+++ b/pkg/osutil/osutil_test.go
@@ -4,7 +4,11 @@
package osutil
import (
+ "fmt"
+ "io/ioutil"
"os"
+ "path/filepath"
+ "strings"
"testing"
)
@@ -16,3 +20,96 @@ func TestIsExist(t *testing.T) {
t.Fatalf("file %v exists", f)
}
}
+
+func TestCopyFiles(t *testing.T) {
+ type Test struct {
+ files []string
+ patterns map[string]bool
+ err string
+ }
+ tests := []Test{
+ {
+ files: []string{
+ "foo",
+ "bar",
+ "baz/foo",
+ "baz/bar",
+ },
+ patterns: map[string]bool{
+ "foo": true,
+ "bar": false,
+ "qux": false,
+ "baz/foo": true,
+ "baz/bar": false,
+ },
+ },
+ {
+ files: []string{
+ "foo",
+ },
+ patterns: map[string]bool{
+ "bar": true,
+ },
+ err: "file bar does not exist",
+ },
+ {
+ files: []string{
+ "baz/foo",
+ "baz/bar",
+ },
+ patterns: map[string]bool{
+ "baz/*": true,
+ },
+ },
+ {
+ files: []string{
+ "qux/foo/foo",
+ "qux/foo/bar",
+ "qux/bar/foo",
+ "qux/bar/bar",
+ },
+ patterns: map[string]bool{
+ "qux/*/*": false,
+ },
+ },
+ }
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ dir, err := ioutil.TempDir("", "syz-osutil-test")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(dir)
+ src := filepath.Join(dir, "src")
+ dst := filepath.Join(dir, "dst")
+ for _, file := range test.files {
+ file = filepath.Join(src, filepath.FromSlash(file))
+ if err := MkdirAll(filepath.Dir(file)); err != nil {
+ t.Fatal(err)
+ }
+ if err := WriteFile(file, []byte{'a'}); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := CopyFiles(src, dst, test.patterns); err != nil {
+ if test.err != "" {
+ if strings.Contains(err.Error(), test.err) {
+ return
+ }
+ t.Fatalf("got err %q, want %q", err, test.err)
+ }
+ t.Fatal(err)
+ } else if test.err != "" {
+ t.Fatalf("got no err, want %q", test.err)
+ }
+ if err := os.RemoveAll(src); err != nil {
+ t.Fatal(err)
+ }
+ for _, file := range test.files {
+ if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) {
+ t.Fatalf("%v does not exist in dst", file)
+ }
+ }
+ })
+ }
+}