// Copyright 2024 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. package clangtool import ( "bytes" "crypto/sha256" "encoding/json" "errors" "fmt" "io" "math/rand" "os" "os/exec" "path/filepath" "runtime" "slices" "strings" "time" "github.com/google/syzkaller/pkg/osutil" ) type Config struct { ToolBin string KernelSrc string KernelObj string CacheFile string DebugTrace io.Writer } type OutputDataPtr[T any] interface { *T Merge(*T) SetSourceFile(string, func(filename string) string) Finalize(*Verifier) } // Run runs the clang tool on all files in the compilation database // in the kernel build dir and returns combined output for all files. // It always caches results, and optionally reuses previously cached results. func Run[Output any, OutputPtr OutputDataPtr[Output]](cfg *Config) (OutputPtr, error) { if cfg.CacheFile != "" { out, err := osutil.ReadJSON[OutputPtr](cfg.CacheFile) if err == nil { return out, nil } } dbFile := filepath.Join(cfg.KernelObj, "compile_commands.json") cmds, err := loadCompileCommands(dbFile) if err != nil { return nil, fmt.Errorf("failed to load compile commands: %w", err) } type result struct { out OutputPtr err error } results := make(chan *result, 10) files := make(chan string, len(cmds)) for w := 0; w < runtime.NumCPU(); w++ { go func() { for file := range files { out, err := runTool[Output, OutputPtr](cfg, dbFile, file) results <- &result{out, err} } }() } for _, cmd := range cmds { files <- cmd.File } close(files) out := OutputPtr(new(Output)) for range cmds { res := <-results if res.err != nil { return nil, res.err } out.Merge(res.out) } // Finalize the output (sort, dedup, etc), and let the output verify // that all source file names, line numbers, etc are valid/present. // If there are any bogus entries, it's better to detect them early, // than to crash/error much later when the info is used. // Some of the source files (generated) may be in the obj dir. srcDirs := []string{cfg.KernelSrc, cfg.KernelObj} if err := Finalize(out, srcDirs); err != nil { return nil, err } if cfg.CacheFile != "" { osutil.MkdirAll(filepath.Dir(cfg.CacheFile)) data, err := json.MarshalIndent(out, "", "\t") if err != nil { return nil, fmt.Errorf("failed to marshal output data: %w", err) } if err := osutil.WriteFile(cfg.CacheFile, data); err != nil { return nil, err } } return out, nil } func Finalize[Output any, OutputPtr OutputDataPtr[Output]](out OutputPtr, srcDirs []string) error { v := &Verifier{ srcDirs: srcDirs, fileCache: make(map[string]int), } out.Finalize(v) if v.err.Len() == 0 { return nil } return errors.New(v.err.String()) } type Verifier struct { srcDirs []string fileCache map[string]int // file->line count (-1 is cached for missing files) err strings.Builder } func (v *Verifier) Filename(file string) { if _, ok := v.fileCache[file]; ok { return } for _, srcDir := range v.srcDirs { data, err := os.ReadFile(filepath.Join(srcDir, file)) if err != nil { continue } v.fileCache[file] = len(bytes.Split(data, []byte{'\n'})) return } v.fileCache[file] = -1 fmt.Fprintf(&v.err, "missing file: %v\n", file) } func (v *Verifier) LineRange(file string, start, end int) { v.Filename(file) lines, ok := v.fileCache[file] if !ok || lines < 0 { return } // Line numbers produced by clang are 1-based. if start <= 0 || end < start || end > lines { fmt.Fprintf(&v.err, "bad line range [%v-%v] for file %v with %v lines\n", start, end, file, lines) } } func runTool[Output any, OutputPtr OutputDataPtr[Output]](cfg *Config, dbFile, file string) (OutputPtr, error) { relFile := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(filepath.Clean(file), cfg.KernelSrc), cfg.KernelObj), "/") // Suppress warning since we may build the tool on a different clang // version that produces more warnings. // Comments are needed for codesearch tool, but may be useful for declextract // in the future if we try to parse them with LLMs. data, err := exec.Command(cfg.ToolBin, "-p", dbFile, "--extra-arg=-w", "--extra-arg=-fparse-all-comments", file).Output() if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { err = fmt.Errorf("%v: %w\n%s", relFile, err, exitErr.Stderr) } return nil, err } out, err := osutil.ParseJSON[OutputPtr](data) if err != nil { return nil, err } // All includes in the tool output are relative to the build dir. // Make them relative to the source dir. out.SetSourceFile(relFile, func(filename string) string { rel, err := filepath.Rel(cfg.KernelSrc, filepath.Join(cfg.KernelObj, filename)) if err == nil && filename != "" { return rel } return filename }) return out, nil } type compileCommand struct { Command string Directory string File string } func loadCompileCommands(dbFile string) ([]compileCommand, error) { data, err := os.ReadFile(dbFile) if err != nil { return nil, err } var cmds []compileCommand if err := json.Unmarshal(data, &cmds); err != nil { return nil, err } // Remove commands that don't relate to the kernel build // (probably some host tools, etc). cmds = slices.DeleteFunc(cmds, func(cmd compileCommand) bool { return !strings.HasSuffix(cmd.File, ".c") || // Files compiled with gcc are not a part of the kernel // (assuming compile commands were generated with make CC=clang). // They are probably a part of some host tool. strings.HasPrefix(cmd.Command, "gcc") || // KBUILD should add this define all kernel files. !strings.Contains(cmd.Command, "-DKBUILD_BASENAME") }) // Shuffle the order to detect any non-determinism caused by the order early. // The result should be the same regardless. rand.New(rand.NewSource(time.Now().UnixNano())).Shuffle(len(cmds), func(i, j int) { cmds[i], cmds[j] = cmds[j], cmds[i] }) if len(cmds) == 0 { return nil, fmt.Errorf("no kernel compile commands in compile_commands.json" + " (was the kernel compiled with gcc?)") } return cmds, nil } func SortAndDedupSlice[Slice ~[]E, E comparable](s Slice) Slice { dedup := make(map[[sha256.Size]byte]E) text := make(map[E][]byte) for _, e := range s { t, _ := json.Marshal(e) dedup[sha256.Sum256(t)] = e text[e] = t } s = make([]E, 0, len(dedup)) for _, e := range dedup { s = append(s, e) } slices.SortFunc(s, func(a, b E) int { return bytes.Compare(text[a], text[b]) }) return s }