From 97bce4e2ceff482b90b658dc9031cc0c5705cad1 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Tue, 31 Jul 2018 18:37:43 +0200 Subject: prog: refactor program serialization Make argument serialization Arg method. This eliminates a very long function that serializes all arguments. Update #538 --- prog/encoding.go | 242 +++++++++++++++++++++++++++++++------------------------ prog/prog.go | 2 + 2 files changed, 139 insertions(+), 105 deletions(-) (limited to 'prog') diff --git a/prog/encoding.go b/prog/encoding.go index 8c853fa4a..f8866b573 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -30,124 +30,156 @@ func (p *Prog) Serialize() []byte { panic("serializing invalid program") } } - buf := new(bytes.Buffer) - vars := make(map[*ResultArg]int) - varSeq := 0 + ctx := &serializer{ + target: p.Target, + buf: new(bytes.Buffer), + vars: make(map[*ResultArg]int), + } for _, c := range p.Calls { - if c.Ret != nil && len(c.Ret.uses) != 0 { - fmt.Fprintf(buf, "r%v = ", varSeq) - vars[c.Ret] = varSeq - varSeq++ - } - fmt.Fprintf(buf, "%v(", c.Meta.Name) - for i, a := range c.Args { - if IsPad(a.Type()) { - continue - } - if i != 0 { - fmt.Fprintf(buf, ", ") - } - p.Target.serialize(a, buf, vars, &varSeq) + ctx.call(c) + } + return ctx.buf.Bytes() +} + +type serializer struct { + target *Target + buf *bytes.Buffer + vars map[*ResultArg]int + varSeq int +} + +func (ctx *serializer) printf(text string, args ...interface{}) { + fmt.Fprintf(ctx.buf, text, args...) +} + +func (ctx *serializer) allocVarID(arg *ResultArg) int { + id := ctx.varSeq + ctx.varSeq++ + ctx.vars[arg] = id + return id +} + +func (ctx *serializer) call(c *Call) { + if c.Ret != nil && len(c.Ret.uses) != 0 { + ctx.printf("r%v = ", ctx.allocVarID(c.Ret)) + } + ctx.printf("%v(", c.Meta.Name) + for i, a := range c.Args { + if IsPad(a.Type()) { + continue + } + if i != 0 { + ctx.printf(", ") } - fmt.Fprintf(buf, ")\n") + ctx.arg(a) } - return buf.Bytes() + ctx.printf(")\n") } -func (target *Target) serialize(arg Arg, buf *bytes.Buffer, vars map[*ResultArg]int, varSeq *int) { +func (ctx *serializer) arg(arg Arg) { if arg == nil { - fmt.Fprintf(buf, "nil") + ctx.printf("nil") return } - switch a := arg.(type) { - case *ConstArg: - fmt.Fprintf(buf, "0x%x", a.Val) - case *PointerArg: - if a.IsNull() { - fmt.Fprintf(buf, "0x0") - break - } - fmt.Fprintf(buf, "&%v", target.serializeAddr(a)) - if a.Res == nil || !target.isDefaultArg(a.Res) || target.isAnyPtr(a.Type()) { - fmt.Fprintf(buf, "=") - if target.isAnyPtr(a.Type()) { - fmt.Fprintf(buf, "ANY=") - } - target.serialize(a.Res, buf, vars, varSeq) - } - case *DataArg: - if a.Type().Dir() == DirOut { - fmt.Fprintf(buf, "\"\"/%v", a.Size()) - } else { - data := a.Data() - if !arg.Type().Varlen() { - // Statically typed data will be padded with 0s during - // deserialization, so we can strip them here for readability. - for len(data) >= 2 && data[len(data)-1] == 0 && data[len(data)-2] == 0 { - data = data[:len(data)-1] - } - } - serializeData(buf, data) - } - case *GroupArg: - var delims []byte - switch arg.Type().(type) { - case *StructType: - delims = []byte{'{', '}'} - case *ArrayType: - delims = []byte{'[', ']'} - default: - panic("unknown group type") - } - buf.Write([]byte{delims[0]}) - lastNonDefault := len(a.Inner) - 1 - if a.fixedInnerSize() { - for ; lastNonDefault >= 0; lastNonDefault-- { - if !target.isDefaultArg(a.Inner[lastNonDefault]) { - break - } - } + arg.serialize(ctx) +} + +func (a *ConstArg) serialize(ctx *serializer) { + ctx.printf("0x%x", a.Val) +} + +func (a *PointerArg) serialize(ctx *serializer) { + if a.IsNull() { + ctx.printf("0x0") + return + } + target := ctx.target + ctx.printf("&%v", target.serializeAddr(a)) + if a.Res != nil && target.isDefaultArg(a.Res) && !target.isAnyPtr(a.Type()) { + return + } + ctx.printf("=") + if target.isAnyPtr(a.Type()) { + ctx.printf("ANY=") + } + ctx.arg(a.Res) +} + +func (a *DataArg) serialize(ctx *serializer) { + if a.Type().Dir() == DirOut { + ctx.printf("\"\"/%v", a.Size()) + return + } + data := a.Data() + if !a.Type().Varlen() { + // Statically typed data will be padded with 0s during + // deserialization, so we can strip them here for readability. + for len(data) >= 2 && data[len(data)-1] == 0 && data[len(data)-2] == 0 { + data = data[:len(data)-1] } - for i := 0; i <= lastNonDefault; i++ { - arg1 := a.Inner[i] - if arg1 != nil && IsPad(arg1.Type()) { - continue - } - if i != 0 { - fmt.Fprintf(buf, ", ") + } + serializeData(ctx.buf, data) +} + +func (a *GroupArg) serialize(ctx *serializer) { + var delims []byte + switch a.Type().(type) { + case *StructType: + delims = []byte{'{', '}'} + case *ArrayType: + delims = []byte{'[', ']'} + default: + panic("unknown group type") + } + ctx.buf.WriteByte(delims[0]) + lastNonDefault := len(a.Inner) - 1 + if a.fixedInnerSize() { + for ; lastNonDefault >= 0; lastNonDefault-- { + if !ctx.target.isDefaultArg(a.Inner[lastNonDefault]) { + break } - target.serialize(arg1, buf, vars, varSeq) - } - buf.Write([]byte{delims[1]}) - case *UnionArg: - fmt.Fprintf(buf, "@%v", a.Option.Type().FieldName()) - if !target.isDefaultArg(a.Option) { - fmt.Fprintf(buf, "=") - target.serialize(a.Option, buf, vars, varSeq) - } - case *ResultArg: - if len(a.uses) != 0 { - fmt.Fprintf(buf, "", *varSeq) - vars[a] = *varSeq - *varSeq++ - } - if a.Res == nil { - fmt.Fprintf(buf, "0x%x", a.Val) - break - } - id, ok := vars[a.Res] - if !ok { - panic("no result") } - fmt.Fprintf(buf, "r%v", id) - if a.OpDiv != 0 { - fmt.Fprintf(buf, "/%v", a.OpDiv) + } + for i := 0; i <= lastNonDefault; i++ { + arg1 := a.Inner[i] + if arg1 != nil && IsPad(arg1.Type()) { + continue } - if a.OpAdd != 0 { - fmt.Fprintf(buf, "+%v", a.OpAdd) + if i != 0 { + ctx.printf(", ") } - default: - panic("unknown arg kind") + ctx.arg(arg1) + } + ctx.buf.WriteByte(delims[1]) +} + +func (a *UnionArg) serialize(ctx *serializer) { + ctx.printf("@%v", a.Option.Type().FieldName()) + if ctx.target.isDefaultArg(a.Option) { + return + } + ctx.printf("=") + ctx.arg(a.Option) +} + +func (a *ResultArg) serialize(ctx *serializer) { + if len(a.uses) != 0 { + ctx.printf("", ctx.allocVarID(a)) + } + if a.Res == nil { + ctx.printf("0x%x", a.Val) + return + } + id, ok := ctx.vars[a.Res] + if !ok { + panic("no result") + } + ctx.printf("r%v", id) + if a.OpDiv != 0 { + ctx.printf("/%v", a.OpDiv) + } + if a.OpAdd != 0 { + ctx.printf("+%v", a.OpAdd) } } diff --git a/prog/prog.go b/prog/prog.go index 95f95ad27..53e0fbe33 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -22,7 +22,9 @@ type Call struct { type Arg interface { Type() Type Size() uint64 + validate(ctx *validCtx) error + serialize(ctx *serializer) } type ArgCommon struct { -- cgit mrf-deployment