diff options
| -rw-r--r-- | prog/resources.go | 2 | ||||
| -rw-r--r-- | prog/resources_test.go | 2 | ||||
| -rw-r--r-- | prog/types.go | 30 |
3 files changed, 19 insertions, 15 deletions
diff --git a/prog/resources.go b/prog/resources.go index 95480a8a0..61b30581f 100644 --- a/prog/resources.go +++ b/prog/resources.go @@ -151,7 +151,7 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { } switch typ1 := typ.(type) { case *ResourceType: - if !typ1.IsOptional && !dedup[typ1.Desc] { + if !ctx.Optional && !dedup[typ1.Desc] { dedup[typ1.Desc] = true resources = append(resources, typ1.Desc) } diff --git a/prog/resources_test.go b/prog/resources_test.go index e48efa65c..f0ba81adf 100644 --- a/prog/resources_test.go +++ b/prog/resources_test.go @@ -153,7 +153,7 @@ func testCreateResource(t *testing.T, target *Target, calls map[*Syscall]bool, r if res, ok := typ.(*ResourceType); ok && ctx.Dir != DirOut { s := newState(target, ct, nil) arg, calls := r.createResource(s, res, DirIn) - if arg == nil && !res.Optional() { + if arg == nil && !ctx.Optional { t.Fatalf("failed to create resource %v", res.Name()) } if arg != nil && len(calls) == 0 { diff --git a/prog/types.go b/prog/types.go index 4fd6a9daa..45136f079 100644 --- a/prog/types.go +++ b/prog/types.go @@ -680,10 +680,11 @@ type ConstValue struct { } type TypeCtx struct { - Meta *Syscall - Dir Dir - Ptr *Type - Stop bool // If set by the callback, subtypes of this type are not visited. + Meta *Syscall + Dir Dir + Ptr *Type + Optional bool + Stop bool // If set by the callback, subtypes of this type are not visited. } func ForeachType(syscalls []*Syscall, f func(t Type, ctx *TypeCtx)) { @@ -707,9 +708,12 @@ func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx *TypeCtx)) // It would prune recursion more (across syscalls), but lots of users need to // visit each struct per-syscall (e.g. prio, used resources). seen := make(map[Type]bool) - var rec func(*Type, Dir) - rec = func(ptr *Type, dir Dir) { - ctx := &TypeCtx{Meta: meta, Dir: dir, Ptr: ptr} + var rec func(*Type, Dir, bool) + rec = func(ptr *Type, dir Dir, optional bool) { + if _, ref := (*ptr).(Ref); !ref { + optional = optional || (*ptr).Optional() + } + ctx := &TypeCtx{Meta: meta, Dir: dir, Ptr: ptr, Optional: optional} if preorder { f(*ptr, ctx) if ctx.Stop { @@ -718,16 +722,16 @@ func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx *TypeCtx)) } switch a := (*ptr).(type) { case *PtrType: - rec(&a.Elem, a.ElemDir) + rec(&a.Elem, a.ElemDir, optional) case *ArrayType: - rec(&a.Elem, dir) + rec(&a.Elem, dir, optional) case *StructType: if seen[a] { break // prune recursion via pointers to structs/unions } seen[a] = true for i, f := range a.Fields { - rec(&a.Fields[i].Type, f.Dir(dir)) + rec(&a.Fields[i].Type, f.Dir(dir), optional) } case *UnionType: if seen[a] { @@ -735,7 +739,7 @@ func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx *TypeCtx)) } seen[a] = true for i, f := range a.Fields { - rec(&a.Fields[i].Type, f.Dir(dir)) + rec(&a.Fields[i].Type, f.Dir(dir), optional) } case *ResourceType, *BufferType, *VmaType, *LenType, *FlagsType, *ConstType, *IntType, *ProcType, *CsumType: @@ -752,10 +756,10 @@ func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx *TypeCtx)) } } for i := range meta.Args { - rec(&meta.Args[i].Type, DirIn) + rec(&meta.Args[i].Type, DirIn, false) } if meta.Ret != nil { - rec(&meta.Ret, DirOut) + rec(&meta.Ret, DirOut, false) } } |
