// Copyright 2025 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. package tooltest import ( "encoding/json" "flag" "fmt" "os" "path/filepath" "testing" "github.com/google/go-cmp/cmp" "github.com/google/syzkaller/pkg/clangtool" "github.com/google/syzkaller/pkg/osutil" "github.com/google/syzkaller/pkg/testutil" ) var ( FlagBin = flag.String("bin", "", "path to the clang tool binary to use") FlagUpdate = flag.Bool("update", false, "update golden files") ) func TestClangTool[Output any, OutputPtr clangtool.OutputDataPtr[Output]](t *testing.T) { if *FlagBin == "" { t.Skipf("clang tool path is not specified, run with -bin=clangtool flag") } ForEachTestFile(t, func(t *testing.T, cfg *clangtool.Config, file string) { out, err := clangtool.Run[Output, OutputPtr](cfg) if err != nil { t.Fatal(err) } got, err := json.MarshalIndent(out, "", "\t") if err != nil { t.Fatal(err) } CompareGoldenData(t, file+".json", got) }) } func LoadOutput[Output any, OutputPtr clangtool.OutputDataPtr[Output]](t *testing.T) OutputPtr { out := OutputPtr(new(Output)) forEachTestFile(t, func(t *testing.T, file string) { tmp, err := osutil.ReadJSON[OutputPtr](file + ".json") if err != nil { t.Fatal(err) } out.Merge(tmp) }) if err := clangtool.Finalize(out, []string{"testdata"}); err != nil { t.Fatal(err) } return out } func ForEachTestFile(t *testing.T, fn func(t *testing.T, cfg *clangtool.Config, file string)) { forEachTestFile(t, func(t *testing.T, file string) { t.Run(filepath.Base(file), func(t *testing.T) { t.Parallel() buildDir := t.TempDir() commands := fmt.Sprintf(`[{ "file": "%s", "directory": "%s", "command": "clang -c %s -DKBUILD_BASENAME=foo" }]`, file, buildDir, file) dbFile := filepath.Join(buildDir, "compile_commands.json") if err := os.WriteFile(dbFile, []byte(commands), 0600); err != nil { t.Fatal(err) } cfg := &clangtool.Config{ ToolBin: *FlagBin, KernelSrc: osutil.Abs("testdata"), KernelObj: buildDir, CacheFile: filepath.Join(buildDir, filepath.Base(file)+".json"), DebugTrace: &testutil.Writer{TB: t}, } fn(t, cfg, file) }) }) } func forEachTestFile(t *testing.T, fn func(t *testing.T, file string)) { files, err := filepath.Glob(filepath.Join(osutil.Abs("testdata"), "*.c")) if err != nil { t.Fatal(err) } if len(files) == 0 { t.Fatal("found no source files") } for _, file := range files { fn(t, file) } } func CompareGoldenFile(t *testing.T, goldenFile, gotFile string) { got, err := os.ReadFile(gotFile) if err != nil { t.Fatal(err) } CompareGoldenData(t, goldenFile, got) } func CompareGoldenData(t *testing.T, goldenFile string, got []byte) { if *FlagUpdate { if err := os.WriteFile(goldenFile, got, 0644); err != nil { t.Fatal(err) } } want, err := os.ReadFile(goldenFile) if err != nil { t.Fatal(err) } if diff := cmp.Diff(want, got); diff != "" { t.Fatal(diff) } }