diff options
Diffstat (limited to 'prog/encoding.go')
| -rw-r--r-- | prog/encoding.go | 118 |
1 files changed, 60 insertions, 58 deletions
diff --git a/prog/encoding.go b/prog/encoding.go index f99ff9d84..c7f8ba56a 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -113,7 +113,7 @@ func (a *PointerArg) serialize(ctx *serializer) { func (a *DataArg) serialize(ctx *serializer) { typ := a.Type().(*BufferType) - if typ.Dir() == DirOut { + if a.Dir() == DirOut { ctx.printf("\"\"/%v", a.Size()) return } @@ -280,7 +280,7 @@ func (p *parser) parseProg() (*Prog, error) { if IsPad(typ) { return nil, fmt.Errorf("padding in syscall %v arguments", name) } - arg, err := p.parseArg(typ) + arg, err := p.parseArg(typ, DirIn) if err != nil { return nil, err } @@ -302,7 +302,7 @@ func (p *parser) parseProg() (*Prog, error) { } for i := len(c.Args); i < len(meta.Args); i++ { p.strictFailf("missing syscall args") - c.Args = append(c.Args, meta.Args[i].DefaultArg()) + c.Args = append(c.Args, meta.Args[i].DefaultArg(DirIn)) } if len(c.Args) != len(meta.Args) { return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args)) @@ -318,7 +318,7 @@ func (p *parser) parseProg() (*Prog, error) { return prog, nil } -func (p *parser) parseArg(typ Type) (Arg, error) { +func (p *parser) parseArg(typ Type, dir Dir) (Arg, error) { r := "" if p.Char() == '<' { p.Parse('<') @@ -326,13 +326,13 @@ func (p *parser) parseArg(typ Type) (Arg, error) { p.Parse('=') p.Parse('>') } - arg, err := p.parseArgImpl(typ) + arg, err := p.parseArgImpl(typ, dir) if err != nil { return nil, err } if arg == nil { if typ != nil { - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } else if r != "" { return nil, fmt.Errorf("named nil argument") } @@ -345,26 +345,26 @@ func (p *parser) parseArg(typ Type) (Arg, error) { return arg, nil } -func (p *parser) parseArgImpl(typ Type) (Arg, error) { +func (p *parser) parseArgImpl(typ Type, dir Dir) (Arg, error) { if typ == nil && p.Char() != 'n' { p.eatExcessive(true, "non-nil argument for nil type") return nil, nil } switch p.Char() { case '0': - return p.parseArgInt(typ) + return p.parseArgInt(typ, dir) case 'r': - return p.parseArgRes(typ) + return p.parseArgRes(typ, dir) case '&': - return p.parseArgAddr(typ) + return p.parseArgAddr(typ, dir) case '"', '\'': - return p.parseArgString(typ) + return p.parseArgString(typ, dir) case '{': - return p.parseArgStruct(typ) + return p.parseArgStruct(typ, dir) case '[': - return p.parseArgArray(typ) + return p.parseArgArray(typ, dir) case '@': - return p.parseArgUnion(typ) + return p.parseArgUnion(typ, dir) case 'n': p.Parse('n') p.Parse('i') @@ -375,14 +375,14 @@ func (p *parser) parseArgImpl(typ Type) (Arg, error) { p.Parse('U') p.Parse('T') p.Parse('O') - return p.parseAuto(typ) + return p.parseAuto(typ, dir) default: return nil, fmt.Errorf("failed to parse argument at '%c' (line #%v/%v: %v)", p.Char(), p.l, p.i, p.s) } } -func (p *parser) parseArgInt(typ Type) (Arg, error) { +func (p *parser) parseArgInt(typ Type, dir Dir) (Arg, error) { val := p.Ident() v, err := strconv.ParseUint(val, 0, 64) if err != nil { @@ -390,35 +390,35 @@ func (p *parser) parseArgInt(typ Type) (Arg, error) { } switch typ.(type) { case *ConstType, *IntType, *FlagsType, *ProcType, *CsumType: - arg := Arg(MakeConstArg(typ, v)) - if typ.Dir() == DirOut && !typ.isDefaultArg(arg) { + arg := Arg(MakeConstArg(typ, dir, v)) + if dir == DirOut && !typ.isDefaultArg(arg) { p.strictFailf("out arg %v has non-default value: %v", typ, v) - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } return arg, nil case *LenType: - return MakeConstArg(typ, v), nil + return MakeConstArg(typ, dir, v), nil case *ResourceType: - return MakeResultArg(typ, nil, v), nil + return MakeResultArg(typ, dir, nil, v), nil case *PtrType, *VmaType: index := -v % uint64(len(p.target.SpecialPointers)) - return MakeSpecialPointerArg(typ, index), nil + return MakeSpecialPointerArg(typ, dir, index), nil default: p.eatExcessive(true, "wrong int arg %T", typ) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } } -func (p *parser) parseAuto(typ Type) (Arg, error) { +func (p *parser) parseAuto(typ Type, dir Dir) (Arg, error) { switch typ.(type) { case *ConstType, *LenType, *CsumType: - return p.auto(MakeConstArg(typ, 0)), nil + return p.auto(MakeConstArg(typ, dir, 0)), nil default: return nil, fmt.Errorf("wrong type %T for AUTO", typ) } } -func (p *parser) parseArgRes(typ Type) (Arg, error) { +func (p *parser) parseArgRes(typ Type, dir Dir) (Arg, error) { id := p.Ident() var div, add uint64 if p.Char() == '/' { @@ -442,23 +442,25 @@ func (p *parser) parseArgRes(typ Type) (Arg, error) { v := p.vars[id] if v == nil { p.strictFailf("undeclared variable %v", id) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } - arg := MakeResultArg(typ, v, 0) + arg := MakeResultArg(typ, dir, v, 0) arg.OpDiv = div arg.OpAdd = add return arg, nil } -func (p *parser) parseArgAddr(typ Type) (Arg, error) { +func (p *parser) parseArgAddr(typ Type, dir Dir) (Arg, error) { var typ1 Type + elemDir := DirInOut switch t1 := typ.(type) { case *PtrType: typ1 = t1.Type + elemDir = t1.ElemDir case *VmaType: default: p.eatExcessive(true, "wrong addr arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('&') auto := false @@ -487,11 +489,11 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.Parse('N') p.Parse('Y') p.Parse('=') - typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) - typ1 = p.target.any.array + anyPtr := p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) + typ, typ1, elemDir = anyPtr, anyPtr.Type, anyPtr.ElemDir } var err error - inner, err = p.parseArg(typ1) + inner, err = p.parseArg(typ1, elemDir) if err != nil { return nil, err } @@ -501,23 +503,23 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.strictFailf("unaligned vma address 0x%x", addr) addr &= ^(p.target.PageSize - 1) } - return MakeVmaPointerArg(typ, addr, vmaSize), nil + return MakeVmaPointerArg(typ, dir, addr, vmaSize), nil } if inner == nil { - inner = typ1.DefaultArg() + inner = typ1.DefaultArg(elemDir) } - arg := MakePointerArg(typ, addr, inner) + arg := MakePointerArg(typ, dir, addr, inner) if auto { p.auto(arg) } return arg, nil } -func (p *parser) parseArgString(t Type) (Arg, error) { +func (p *parser) parseArgString(t Type, dir Dir) (Arg, error) { typ, ok := t.(*BufferType) if !ok { p.eatExcessive(true, "wrong string arg") - return t.DefaultArg(), nil + return t.DefaultArg(dir), nil } data, err := p.deserializeData() if err != nil { @@ -542,8 +544,8 @@ func (p *parser) parseArgString(t Type) (Arg, error) { } else if size == ^uint64(0) { size = uint64(len(data)) } - if typ.Dir() == DirOut { - return MakeOutDataArg(typ, size), nil + if dir == DirOut { + return MakeOutDataArg(typ, dir, size), nil } if diff := int(size) - len(data); diff > 0 { data = append(data, make([]byte, diff)...) @@ -564,16 +566,16 @@ func (p *parser) parseArgString(t Type) (Arg, error) { data = []byte(typ.Values[0]) } } - return MakeDataArg(typ, data), nil + return MakeDataArg(typ, dir, data), nil } -func (p *parser) parseArgStruct(typ Type) (Arg, error) { +func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) { p.Parse('{') t1, ok := typ.(*StructType) if !ok { p.eatExcessive(false, "wrong struct arg") p.Parse('}') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != '}'; i++ { @@ -583,9 +585,9 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { } fld := t1.Fields[i] if IsPad(fld) { - inner = append(inner, MakeConstArg(fld, 0)) + inner = append(inner, MakeConstArg(fld, dir, 0)) } else { - arg, err := p.parseArg(fld) + arg, err := p.parseArg(fld, dir) if err != nil { return nil, err } @@ -601,22 +603,22 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { if !IsPad(fld) { p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields)) } - inner = append(inner, fld.DefaultArg()) + inner = append(inner, fld.DefaultArg(dir)) } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgArray(typ Type) (Arg, error) { +func (p *parser) parseArgArray(typ Type, dir Dir) (Arg, error) { p.Parse('[') t1, ok := typ.(*ArrayType) if !ok { p.eatExcessive(false, "wrong array arg %T", typ) p.Parse(']') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != ']'; i++ { - arg, err := p.parseArg(t1.Type) + arg, err := p.parseArg(t1.Type, dir) if err != nil { return nil, err } @@ -629,18 +631,18 @@ func (p *parser) parseArgArray(typ Type) (Arg, error) { if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd { for uint64(len(inner)) < t1.RangeBegin { p.strictFailf("missing array elements") - inner = append(inner, t1.Type.DefaultArg()) + inner = append(inner, t1.Type.DefaultArg(dir)) } inner = inner[:t1.RangeBegin] } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgUnion(typ Type) (Arg, error) { +func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) { t1, ok := typ.(*UnionType) if !ok { p.eatExcessive(true, "wrong union arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('@') name := p.Ident() @@ -653,20 +655,20 @@ func (p *parser) parseArgUnion(typ Type) (Arg, error) { } if optType == nil { p.eatExcessive(true, "wrong union option") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var opt Arg if p.Char() == '=' { p.Parse('=') var err error - opt, err = p.parseArg(optType) + opt, err = p.parseArg(optType, dir) if err != nil { return nil, err } } else { - opt = optType.DefaultArg() + opt = optType.DefaultArg(dir) } - return MakeUnionArg(typ, opt), nil + return MakeUnionArg(typ, dir, opt), nil } // Eats excessive call arguments and struct fields to recover after description changes. |
