aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/osutil
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-09-15 09:22:26 +0200
committerDmitry Vyukov <dvyukov@google.com>2020-09-15 09:37:22 +0200
commit9e681632f52e946a9adebdf8d12fbb814b6d2653 (patch)
treef686cb7f5c6d2a30bec4d367d66bb1c625137eff /pkg/osutil
parenta2360d0742f01e40bf4fb1714de4503f8a82aa3f (diff)
pkg/osutil: fix LinkFiles/FilesExist for the new pattern format
Diffstat (limited to 'pkg/osutil')
-rw-r--r--pkg/osutil/osutil.go40
-rw-r--r--pkg/osutil/osutil_test.go81
2 files changed, 64 insertions, 57 deletions
diff --git a/pkg/osutil/osutil.go b/pkg/osutil/osutil.go
index cd6a1cce7..b50b6e45b 100644
--- a/pkg/osutil/osutil.go
+++ b/pkg/osutil/osutil.go
@@ -130,11 +130,12 @@ func IsAccessible(name string) error {
// FilesExist returns true if all files exist in dir.
// Files are assumed to be relative names in slash notation.
func FilesExist(dir string, files map[string]bool) bool {
- for f, required := range files {
+ for pattern, required := range files {
if !required {
continue
}
- if !IsExist(filepath.Join(dir, filepath.FromSlash(f))) {
+ files, err := filepath.Glob(filepath.Join(dir, filepath.FromSlash(pattern)))
+ if err != nil || len(files) == 0 {
return false
}
}
@@ -154,7 +155,18 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
if err := MkdirAll(tmpDir); err != nil {
return err
}
+ if err := foreachPatternFile(srcDir, tmpDir, files, CopyFile); err != nil {
+ return err
+ }
+ if err := os.RemoveAll(dstDir); err != nil {
+ return err
+ }
+ return os.Rename(tmpDir, dstDir)
+}
+
+func foreachPatternFile(srcDir, dstDir string, files map[string]bool, fn func(src, dst string) error) error {
srcDir = filepath.Clean(srcDir)
+ dstDir = filepath.Clean(dstDir)
for pattern, required := range files {
files, err := filepath.Glob(filepath.Join(srcDir, filepath.FromSlash(pattern)))
if err != nil {
@@ -170,19 +182,16 @@ func CopyFiles(srcDir, dstDir string, files map[string]bool) error {
if !strings.HasPrefix(file, srcDir) {
return fmt.Errorf("file %q matched from %q in %q doesn't have src prefix", file, pattern, srcDir)
}
- dst := filepath.Join(tmpDir, strings.TrimPrefix(file, srcDir))
+ dst := filepath.Join(dstDir, strings.TrimPrefix(file, srcDir))
if err := MkdirAll(filepath.Dir(dst)); err != nil {
return err
}
- if err := CopyFile(file, dst); err != nil {
+ if err := fn(file, dst); err != nil {
return err
}
}
}
- if err := os.RemoveAll(dstDir); err != nil {
- return err
- }
- return os.Rename(tmpDir, dstDir)
+ return nil
}
func CopyDirRecursively(srcDir, dstDir string) error {
@@ -219,20 +228,7 @@ func LinkFiles(srcDir, dstDir string, files map[string]bool) error {
if err := MkdirAll(dstDir); err != nil {
return err
}
- for f, required := range files {
- src := filepath.Join(srcDir, filepath.FromSlash(f))
- if !required && !IsExist(src) {
- continue
- }
- dst := filepath.Join(dstDir, filepath.FromSlash(f))
- if err := MkdirAll(filepath.Dir(dst)); err != nil {
- return err
- }
- if err := os.Link(src, dst); err != nil {
- return err
- }
- }
- return nil
+ return foreachPatternFile(srcDir, dstDir, files, os.Link)
}
func MkdirAll(dir string) error {
diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go
index 3d367a0ca..168e7ba36 100644
--- a/pkg/osutil/osutil_test.go
+++ b/pkg/osutil/osutil_test.go
@@ -73,42 +73,53 @@ func TestCopyFiles(t *testing.T) {
},
},
}
- for i, test := range tests {
- t.Run(fmt.Sprint(i), func(t *testing.T) {
- dir, err := ioutil.TempDir("", "syz-osutil-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(dir)
- src := filepath.Join(dir, "src")
- dst := filepath.Join(dir, "dst")
- for _, file := range test.files {
- file = filepath.Join(src, filepath.FromSlash(file))
- if err := MkdirAll(filepath.Dir(file)); err != nil {
- t.Fatal(err)
- }
- if err := WriteFile(file, []byte{'a'}); err != nil {
- t.Fatal(err)
- }
- }
- if err := CopyFiles(src, dst, test.patterns); err != nil {
- if test.err != "" {
- if strings.Contains(err.Error(), test.err) {
- return
+ for _, link := range []bool{false, true} {
+ fn, fnName := CopyFiles, "CopyFiles"
+ if link {
+ fn, fnName = LinkFiles, "LinkFiles"
+ }
+ t.Run(fnName, func(t *testing.T) {
+ for i, test := range tests {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ dir, err := ioutil.TempDir("", "syz-osutil-test")
+ if err != nil {
+ t.Fatal(err)
}
- t.Fatalf("got err %q, want %q", err, test.err)
- }
- t.Fatal(err)
- } else if test.err != "" {
- t.Fatalf("got no err, want %q", test.err)
- }
- if err := os.RemoveAll(src); err != nil {
- t.Fatal(err)
- }
- for _, file := range test.files {
- if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) {
- t.Fatalf("%v does not exist in dst", file)
- }
+ defer os.RemoveAll(dir)
+ src := filepath.Join(dir, "src")
+ dst := filepath.Join(dir, "dst")
+ for _, file := range test.files {
+ file = filepath.Join(src, filepath.FromSlash(file))
+ if err := MkdirAll(filepath.Dir(file)); err != nil {
+ t.Fatal(err)
+ }
+ if err := WriteFile(file, []byte{'a'}); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := fn(src, dst, test.patterns); err != nil {
+ if test.err != "" {
+ if strings.Contains(err.Error(), test.err) {
+ return
+ }
+ t.Fatalf("got err %q, want %q", err, test.err)
+ }
+ t.Fatal(err)
+ } else if test.err != "" {
+ t.Fatalf("got no err, want %q", test.err)
+ }
+ if err := os.RemoveAll(src); err != nil {
+ t.Fatal(err)
+ }
+ for _, file := range test.files {
+ if !IsExist(filepath.Join(dst, filepath.FromSlash(file))) {
+ t.Fatalf("%v does not exist in dst", file)
+ }
+ }
+ if !FilesExist(dst, test.patterns) {
+ t.Fatalf("dst files don't exist after copy")
+ }
+ })
}
})
}