aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-12-09 16:01:06 +0100
committerDmitry Vyukov <dvyukov@google.com>2018-12-10 16:37:01 +0100
commita5efea3ec3e302da3fa01ca44604fe62aec49a79 (patch)
tree230039ec1f8183b2d888f494a3b2abdb13a18d04
parentceeb374637117c0ff8c2369521df9d98d09a5930 (diff)
prog: refactor deserialization code
Move target and vars into parser and make all parsing functions methods of the parser. This reduces number of args that we need to pass around and eases adding more state that needs to be passed around.
-rw-r--r--prog/encoding.go99
-rw-r--r--prog/encoding_test.go4
2 files changed, 55 insertions, 48 deletions
diff --git a/prog/encoding.go b/prog/encoding.go
index 4478bc834..6e3203984 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -183,8 +183,7 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) {
prog = &Prog{
Target: target,
}
- p := newParser(data)
- vars := make(map[string]*ResultArg)
+ p := newParser(target, data)
comment := ""
for p.Scan() {
if p.EOF() {
@@ -222,14 +221,14 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) {
p.Parse('(')
for i := 0; p.Char() != ')'; i++ {
if i >= len(meta.Args) {
- eatExcessive(p, false)
+ p.eatExcessive(false)
break
}
typ := meta.Args[i]
if IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
- arg, err := target.parseArg(typ, p, vars)
+ arg, err := p.parseArg(typ)
if err != nil {
return nil, err
}
@@ -256,7 +255,7 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
}
if r != "" && c.Ret != nil {
- vars[r] = c.Ret
+ p.vars[r] = c.Ret
}
comment = ""
}
@@ -278,7 +277,7 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) {
return
}
-func (target *Target) parseArg(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArg(typ Type) (Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
@@ -286,7 +285,7 @@ func (target *Target) parseArg(typ Type, p *parser, vars map[string]*ResultArg)
p.Parse('=')
p.Parse('>')
}
- arg, err := target.parseArgImpl(typ, p, vars)
+ arg, err := p.parseArgImpl(typ)
if err != nil {
return nil, err
}
@@ -299,28 +298,28 @@ func (target *Target) parseArg(typ Type, p *parser, vars map[string]*ResultArg)
}
if r != "" {
if res, ok := arg.(*ResultArg); ok {
- vars[r] = res
+ p.vars[r] = res
}
}
return arg, nil
}
-func (target *Target) parseArgImpl(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgImpl(typ Type) (Arg, error) {
switch p.Char() {
case '0':
- return target.parseArgInt(typ, p)
+ return p.parseArgInt(typ)
case 'r':
- return target.parseArgRes(typ, p, vars)
+ return p.parseArgRes(typ)
case '&':
- return target.parseArgAddr(typ, p, vars)
+ return p.parseArgAddr(typ)
case '"', '\'':
- return target.parseArgString(typ, p)
+ return p.parseArgString(typ)
case '{':
- return target.parseArgStruct(typ, p, vars)
+ return p.parseArgStruct(typ)
case '[':
- return target.parseArgArray(typ, p, vars)
+ return p.parseArgArray(typ)
case '@':
- return target.parseArgUnion(typ, p, vars)
+ return p.parseArgUnion(typ)
case 'n':
p.Parse('n')
p.Parse('i')
@@ -333,7 +332,7 @@ func (target *Target) parseArgImpl(typ Type, p *parser, vars map[string]*ResultA
}
}
-func (target *Target) parseArgInt(typ Type, p *parser) (Arg, error) {
+func (p *parser) parseArgInt(typ Type) (Arg, error) {
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
@@ -345,15 +344,15 @@ func (target *Target) parseArgInt(typ Type, p *parser) (Arg, error) {
case *ResourceType:
return MakeResultArg(typ, nil, v), nil
case *PtrType, *VmaType:
- index := -v % uint64(len(target.SpecialPointers))
+ index := -v % uint64(len(p.target.SpecialPointers))
return MakeSpecialPointerArg(typ, index), nil
default:
- eatExcessive(p, true)
+ p.eatExcessive(true)
return typ.DefaultArg(), nil
}
}
-func (target *Target) parseArgRes(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgRes(typ Type) (Arg, error) {
id := p.Ident()
var div, add uint64
if p.Char() == '/' {
@@ -374,7 +373,7 @@ func (target *Target) parseArgRes(typ Type, p *parser, vars map[string]*ResultAr
}
add = v
}
- v := vars[id]
+ v := p.vars[id]
if v == nil {
return typ.DefaultArg(), nil
}
@@ -384,18 +383,18 @@ func (target *Target) parseArgRes(typ Type, p *parser, vars map[string]*ResultAr
return arg, nil
}
-func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgAddr(typ Type) (Arg, error) {
var typ1 Type
switch t1 := typ.(type) {
case *PtrType:
typ1 = t1.Type
case *VmaType:
default:
- eatExcessive(p, true)
+ p.eatExcessive(true)
return typ.DefaultArg(), nil
}
p.Parse('&')
- addr, vmaSize, err := target.parseAddr(p)
+ addr, vmaSize, err := p.parseAddr()
if err != nil {
return nil, err
}
@@ -407,10 +406,10 @@ func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultA
p.Parse('N')
p.Parse('Y')
p.Parse('=')
- typ = target.makeAnyPtrType(typ.Size(), typ.FieldName())
- typ1 = target.any.array
+ typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName())
+ typ1 = p.target.any.array
}
- inner, err = target.parseArg(typ1, p, vars)
+ inner, err = p.parseArg(typ1)
if err != nil {
return nil, err
}
@@ -424,12 +423,12 @@ func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]*ResultA
return MakePointerArg(typ, addr, inner), nil
}
-func (target *Target) parseArgString(typ Type, p *parser) (Arg, error) {
+func (p *parser) parseArgString(typ Type) (Arg, error) {
if _, ok := typ.(*BufferType); !ok {
- eatExcessive(p, true)
+ p.eatExcessive(true)
return typ.DefaultArg(), nil
}
- data, err := deserializeData(p)
+ data, err := p.deserializeData()
if err != nil {
return nil, err
}
@@ -457,25 +456,25 @@ func (target *Target) parseArgString(typ Type, p *parser) (Arg, error) {
return MakeDataArg(typ, data), nil
}
-func (target *Target) parseArgStruct(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgStruct(typ Type) (Arg, error) {
p.Parse('{')
t1, ok := typ.(*StructType)
if !ok {
- eatExcessive(p, false)
+ p.eatExcessive(false)
p.Parse('}')
return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != '}'; i++ {
if i >= len(t1.Fields) {
- eatExcessive(p, false)
+ p.eatExcessive(false)
break
}
fld := t1.Fields[i]
if IsPad(fld) {
inner = append(inner, MakeConstArg(fld, 0))
} else {
- arg, err := target.parseArg(fld, p, vars)
+ arg, err := p.parseArg(fld)
if err != nil {
return nil, err
}
@@ -492,17 +491,17 @@ func (target *Target) parseArgStruct(typ Type, p *parser, vars map[string]*Resul
return MakeGroupArg(typ, inner), nil
}
-func (target *Target) parseArgArray(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgArray(typ Type) (Arg, error) {
p.Parse('[')
t1, ok := typ.(*ArrayType)
if !ok {
- eatExcessive(p, false)
+ p.eatExcessive(false)
p.Parse(']')
return typ.DefaultArg(), nil
}
var inner []Arg
for i := 0; p.Char() != ']'; i++ {
- arg, err := target.parseArg(t1.Type, p, vars)
+ arg, err := p.parseArg(t1.Type)
if err != nil {
return nil, err
}
@@ -521,10 +520,10 @@ func (target *Target) parseArgArray(typ Type, p *parser, vars map[string]*Result
return MakeGroupArg(typ, inner), nil
}
-func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*ResultArg) (Arg, error) {
+func (p *parser) parseArgUnion(typ Type) (Arg, error) {
t1, ok := typ.(*UnionType)
if !ok {
- eatExcessive(p, true)
+ p.eatExcessive(true)
return typ.DefaultArg(), nil
}
p.Parse('@')
@@ -537,14 +536,14 @@ func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*Result
}
}
if optType == nil {
- eatExcessive(p, true)
+ p.eatExcessive(true)
return typ.DefaultArg(), nil
}
var opt Arg
if p.Char() == '=' {
p.Parse('=')
var err error
- opt, err = target.parseArg(optType, p, vars)
+ opt, err = p.parseArg(optType)
if err != nil {
return nil, err
}
@@ -555,7 +554,7 @@ func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]*Result
}
// Eats excessive call arguments and struct fields to recover after description changes.
-func eatExcessive(p *parser, stopAtComma bool) {
+func (p *parser) eatExcessive(stopAtComma bool) {
paren, brack, brace := 0, 0, 0
for !p.EOF() && p.e == nil {
ch := p.Char()
@@ -611,7 +610,7 @@ func (target *Target) serializeAddr(arg *PointerArg) string {
return fmt.Sprintf("(0x%x%v)", encodingAddrBase+arg.Address, ssize)
}
-func (target *Target) parseAddr(p *parser) (uint64, uint64, error) {
+func (p *parser) parseAddr() (uint64, uint64, error) {
p.Parse('(')
pstr := p.Ident()
addr, err := strconv.ParseUint(pstr, 0, 64)
@@ -641,6 +640,7 @@ func (target *Target) parseAddr(p *parser) (uint64, uint64, error) {
}
addr += off
}
+ target := p.target
maxMem := target.NumPages * target.PageSize
var vmaSize uint64
if p.Char() == '/' {
@@ -713,7 +713,7 @@ func serializeData(buf *bytes.Buffer, data []byte) {
buf.WriteByte('\'')
}
-func deserializeData(p *parser) ([]byte, error) {
+func (p *parser) deserializeData() ([]byte, error) {
var data []byte
if p.Char() == '"' {
p.Parse('"')
@@ -775,6 +775,9 @@ func deserializeData(p *parser) ([]byte, error) {
}
type parser struct {
+ target *Target
+ vars map[string]*ResultArg
+
r *bufio.Scanner
s string
i int
@@ -782,8 +785,12 @@ type parser struct {
e error
}
-func newParser(data []byte) *parser {
- p := &parser{r: bufio.NewScanner(bytes.NewReader(data))}
+func newParser(target *Target, data []byte) *parser {
+ p := &parser{
+ target: target,
+ vars: make(map[string]*ResultArg),
+ r: bufio.NewScanner(bytes.NewReader(data)),
+ }
p.r.Buffer(nil, maxLineLen)
return p
}
diff --git a/prog/encoding_test.go b/prog/encoding_test.go
index 1bb22228f..e2951666d 100644
--- a/prog/encoding_test.go
+++ b/prog/encoding_test.go
@@ -35,11 +35,11 @@ func TestSerializeData(t *testing.T) {
}
buf := new(bytes.Buffer)
serializeData(buf, data)
- p := newParser(buf.Bytes())
+ p := newParser(nil, buf.Bytes())
if !p.Scan() {
t.Fatalf("parser does not scan")
}
- data1, err := deserializeData(p)
+ data1, err := p.deserializeData()
if err != nil {
t.Fatalf("failed to deserialize %q -> %s: %v", data, buf.Bytes(), err)
}