diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/compiler/check.go | 15 | ||||
| -rw-r--r-- | pkg/compiler/compiler.go | 6 | ||||
| -rw-r--r-- | pkg/compiler/compiler_test.go | 89 |
3 files changed, 104 insertions, 6 deletions
diff --git a/pkg/compiler/check.go b/pkg/compiler/check.go index 087075a12..6ec2fa9e2 100644 --- a/pkg/compiler/check.go +++ b/pkg/compiler/check.go @@ -6,6 +6,7 @@ package compiler import ( + "errors" "fmt" "strings" @@ -372,10 +373,18 @@ func (comp *compiler) checkLenTarget(t *ast.Type, name, target string, fields [] comp.error(t.Pos, "%v target %v does not exist", t.Ident, target) } -func CollectUnused(desc *ast.Description, target *targets.Target) []ast.Node { - comp := createCompiler(desc, target, nil) +func CollectUnused(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) ([]ast.Node, error) { + comp := createCompiler(desc, target, eh) comp.typecheck() - return comp.collectUnused() + if comp.errors > 0 { + return nil, errors.New("typecheck failed") + } + + nodes := comp.collectUnused() + if comp.errors > 0 { + return nil, errors.New("collectUnused failed") + } + return nodes, nil } func (comp *compiler) collectUnused() []ast.Node { diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 2c4e81b2b..5abc8e2a8 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -44,6 +44,9 @@ type Prog struct { } func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) *compiler { + if eh == nil { + eh = ast.LoggingHandler + } comp := &compiler{ desc: desc, target: target, @@ -73,9 +76,6 @@ func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorH // Compile compiles sys description. func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Target, eh ast.ErrorHandler) *Prog { - if eh == nil { - eh = ast.LoggingHandler - } comp := createCompiler(desc.Clone(), target, eh) comp.typecheck() // The subsequent, more complex, checks expect basic validity of the tree, diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 65c064e97..7bb787f2c 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -9,6 +9,8 @@ import ( "fmt" "io/ioutil" "path/filepath" + "reflect" + "sort" "testing" "github.com/google/syzkaller/pkg/ast" @@ -205,3 +207,90 @@ s2 { got := p.StructDescs[0].Desc t.Logf("got: %#v", got) } + +func TestCollectUnusedError(t *testing.T) { + t.Parallel() + const input = ` + s0 { + f0 fidl_string + } + ` + nopErrorHandler := func(pos ast.Pos, msg string) {} + desc := ast.Parse([]byte(input), "input", nopErrorHandler) + if desc == nil { + t.Fatal("failed to parse") + } + + _, err := CollectUnused(desc, targets.List["test"]["64"], nopErrorHandler) + if err == nil { + t.Fatal("CollectUnused should have failed but didn't") + } +} + +func TestCollectUnused(t *testing.T) { + t.Parallel() + inputs := []struct { + text string + names []string + }{ + { + text: ` + s0 { + f0 string + } + `, + names: []string{"s0"}, + }, + { + text: ` + foo$0(a ptr[in, s0]) + s0 { + f0 int8 + f1 int16 + } + `, + names: []string{}, + }, + { + text: ` + s0 { + f0 int8 + f1 int16 + } + s1 { + f2 int32 + } + foo$0(a ptr[in, s0]) + `, + names: []string{"s1"}, + }, + } + + for i, input := range inputs { + desc := ast.Parse([]byte(input.text), "input", nil) + if desc == nil { + t.Fatalf("Test %d: failed to parse", i) + } + + nodes, err := CollectUnused(desc, targets.List["test"]["64"], nil) + if err != nil { + t.Fatalf("Test %d: CollectUnused failed: %v", i, err) + } + + if len(input.names) != len(nodes) { + t.Errorf("Test %d: want %d nodes, got %d", i, len(input.names), len(nodes)) + } + + names := make([]string, len(nodes)) + for i := range nodes { + _, _, names[i] = nodes[i].Info() + } + + sort.Strings(names) + sort.Strings(input.names) + + if !reflect.DeepEqual(names, input.names) { + t.Errorf("Test %d: Unused nodes differ. Want %v, Got %v", i, input.names, names) + } + } +} |
