From 19ba86c0c957ce26f0148ce96da660e3fb965889 Mon Sep 17 00:00:00 2001 From: aarzilli Date: Tue, 31 Jul 2018 18:32:30 +0200 Subject: [PATCH] proc: support calls through function pointers --- _fixtures/fncall.go | 17 ++++++++- pkg/proc/core/core.go | 4 ++ pkg/proc/fncall.go | 44 +++++++++++++++------- pkg/proc/gdbserial/gdbserver.go | 13 +++++++ pkg/proc/gdbserial/gdbserver_conn.go | 1 + pkg/proc/native/registers_darwin_amd64.go | 4 ++ pkg/proc/native/registers_linux_amd64.go | 12 ++++++ pkg/proc/native/registers_windows_amd64.go | 14 +++++++ pkg/proc/threads.go | 1 + pkg/proc/variables.go | 24 ++++++++---- service/test/variables_test.go | 31 ++++++++++----- 11 files changed, 132 insertions(+), 33 deletions(-) diff --git a/_fixtures/fncall.go b/_fixtures/fncall.go index 44298608..d451ec15 100644 --- a/_fixtures/fncall.go +++ b/_fixtures/fncall.go @@ -57,6 +57,14 @@ type VRcvrable interface { var zero = 0 +func makeclos(pa *astruct) func(int) string { + i := 0 + return func(x int) string { + i++ + return fmt.Sprintf("%d + %d + %d = %d", i, pa.X, x, i+pa.X+x) + } +} + func main() { one, two := 1, 2 intslice := []int{1, 2, 3} @@ -69,7 +77,14 @@ func main() { var vable_pa VRcvrable = pa var pable_pa PRcvrable = pa + fn2clos := makeclos(pa) + fn2glob := call1 + fn2valmeth := pa.VRcvr + fn2ptrmeth := pa.PRcvr + var fn2nil func() + runtime.Breakpoint() call1(one, two) - fmt.Println(one, two, zero, callpanic, callstacktrace, stringsJoin, intslice, stringslice, comma, a.VRcvr, a.PRcvr, pa, vable_a, vable_pa, pable_pa) + fn2clos(2) + fmt.Println(one, two, zero, callpanic, callstacktrace, stringsJoin, intslice, stringslice, comma, a.VRcvr, a.PRcvr, pa, vable_a, vable_pa, pable_pa, fn2clos, fn2glob, fn2valmeth, fn2ptrmeth, fn2nil) } diff --git a/pkg/proc/core/core.go b/pkg/proc/core/core.go index 38428357..5a3f0f13 100644 --- a/pkg/proc/core/core.go +++ b/pkg/proc/core/core.go @@ -275,6 +275,10 @@ func (t *Thread) SetSP(uint64) error { return errors.New("not supported") } +func (t *Thread) SetDX(uint64) error { + return errors.New("not supported") +} + func (p *Process) Breakpoints() *proc.BreakpointMap { return &p.breakpoints } diff --git a/pkg/proc/fncall.go b/pkg/proc/fncall.go index 87a78786..ae117dcb 100644 --- a/pkg/proc/fncall.go +++ b/pkg/proc/fncall.go @@ -65,6 +65,8 @@ type functionCallState struct { err error // fn is the function that is being called fn *Function + // closureAddr is the address of the closure being called + closureAddr uint64 // argmem contains the argument frame of this function call argmem []byte // retvars contains the return variables after the function call terminates without panic'ing @@ -118,7 +120,7 @@ func CallFunction(p Process, expr string, retLoadCfg *LoadConfig) error { return ErrFuncCallUnsupportedBackend } - fn, argvars, err := funcCallEvalExpr(p, expr) + fn, closureAddr, argvars, err := funcCallEvalExpr(p, expr) if err != nil { return err } @@ -140,6 +142,7 @@ func CallFunction(p Process, expr string, retLoadCfg *LoadConfig) error { fncall.savedRegs = regs.Save() fncall.expr = expr fncall.fn = fn + fncall.closureAddr = closureAddr fncall.argmem = argmem fncall.retLoadCfg = retLoadCfg @@ -191,35 +194,42 @@ func callOP(bi *BinaryInfo, thread Thread, regs Registers, callAddr uint64) erro // funcCallEvalExpr evaluates expr, which must be a function call, returns // the function being called and its arguments. -func funcCallEvalExpr(p Process, expr string) (fn *Function, argvars []*Variable, err error) { +func funcCallEvalExpr(p Process, expr string) (fn *Function, closureAddr uint64, argvars []*Variable, err error) { bi := p.BinInfo() scope, err := GoroutineScope(p.CurrentThread()) if err != nil { - return nil, nil, err + return nil, 0, nil, err } t, err := parser.ParseExpr(expr) if err != nil { - return nil, nil, err + return nil, 0, nil, err } callexpr, iscall := t.(*ast.CallExpr) if !iscall { - return nil, nil, ErrNotACallExpr + return nil, 0, nil, ErrNotACallExpr } fnvar, err := scope.evalAST(callexpr.Fun) if err != nil { - return nil, nil, err + return nil, 0, nil, err } if fnvar.Kind != reflect.Func { - return nil, nil, fmt.Errorf("expression %q is not a function", exprToString(callexpr.Fun)) + return nil, 0, nil, fmt.Errorf("expression %q is not a function", exprToString(callexpr.Fun)) + } + fnvar.loadValue(LoadConfig{false, 0, 0, 0, 0}) + if fnvar.Unreadable != nil { + return nil, 0, nil, fnvar.Unreadable + } + if fnvar.Base == 0 { + return nil, 0, nil, errors.New("nil pointer dereference") } fn = bi.PCToFunc(uint64(fnvar.Base)) if fn == nil { - return nil, nil, fmt.Errorf("could not find DIE for function %q", exprToString(callexpr.Fun)) + return nil, 0, nil, fmt.Errorf("could not find DIE for function %q", exprToString(callexpr.Fun)) } if !fn.cu.isgo { - return nil, nil, ErrNotAGoFunction + return nil, 0, nil, ErrNotAGoFunction } argvars = make([]*Variable, 0, len(callexpr.Args)+1) @@ -230,13 +240,13 @@ func funcCallEvalExpr(p Process, expr string) (fn *Function, argvars []*Variable for i := range callexpr.Args { argvar, err := scope.evalAST(callexpr.Args[i]) if err != nil { - return nil, nil, err + return nil, 0, nil, err } argvar.Name = exprToString(callexpr.Args[i]) argvars = append(argvars, argvar) } - return fn, argvars, nil + return fn, fnvar.funcvalAddr(), argvars, nil } type funcCallArg struct { @@ -357,7 +367,9 @@ func escapeCheck(v *Variable, name string, g *G) error { } } case reflect.Func: - //TODO(aarzilli): check closure argument? + if err := escapeCheckPointer(uintptr(v.funcvalAddr()), name, g); err != nil { + return err + } } return nil @@ -365,7 +377,7 @@ func escapeCheck(v *Variable, name string, g *G) error { func escapeCheckPointer(addr uintptr, name string, g *G) error { if uint64(addr) >= g.stacklo && uint64(addr) < g.stackhi { - return fmt.Errorf("stack object passed to escaping pointer %s", name) + return fmt.Errorf("stack object passed to escaping pointer: %s", name) } return nil } @@ -425,7 +437,11 @@ func (fncall *functionCallState) step(p Process) { if n != len(fncall.argmem) { fncall.err = fmt.Errorf("short argument write: %d %d", n, len(fncall.argmem)) } - //TODO(aarzilli): if fncall.fn is a function closure CX needs to be set here + if fncall.closureAddr != 0 { + // When calling a function pointer we must set the DX register to the + // address of the function pointer itself. + thread.SetDX(fncall.closureAddr) + } callOP(bi, thread, regs, fncall.fn.Entry) case debugCallAXRestoreRegisters: diff --git a/pkg/proc/gdbserial/gdbserver.go b/pkg/proc/gdbserial/gdbserver.go index 1c75faa2..b147edab 100644 --- a/pkg/proc/gdbserial/gdbserver.go +++ b/pkg/proc/gdbserial/gdbserver.go @@ -1533,6 +1533,10 @@ func (regs *gdbRegisters) setSP(value uint64) { binary.LittleEndian.PutUint64(regs.regs[regnameSP].value, value) } +func (regs *gdbRegisters) setDX(value uint64) { + binary.LittleEndian.PutUint64(regs.regs[regnameDX].value, value) +} + func (regs *gdbRegisters) BP() uint64 { return binary.LittleEndian.Uint64(regs.regs[regnameBP].value) } @@ -1736,6 +1740,15 @@ func (t *Thread) SetSP(sp uint64) error { return t.p.conn.writeRegister(t.strID, reg.regnum, reg.value) } +func (t *Thread) SetDX(dx uint64) error { + t.regs.setDX(dx) + if t.p.gcmdok { + return t.p.conn.writeRegisters(t.strID, t.regs.buf) + } + reg := t.regs.regs[regnameDX] + return t.p.conn.writeRegister(t.strID, reg.regnum, reg.value) +} + func (regs *gdbRegisters) Slice() []proc.Register { r := make([]proc.Register, 0, len(regs.regsInfo)) for _, reginfo := range regs.regsInfo { diff --git a/pkg/proc/gdbserial/gdbserver_conn.go b/pkg/proc/gdbserial/gdbserver_conn.go index 9b434013..41d2f40c 100644 --- a/pkg/proc/gdbserial/gdbserver_conn.go +++ b/pkg/proc/gdbserial/gdbserver_conn.go @@ -52,6 +52,7 @@ const ( regnamePC = "rip" regnameCX = "rcx" regnameSP = "rsp" + regnameDX = "rdx" regnameBP = "rbp" regnameFsBase = "fs_base" regnameGsBase = "gs_base" diff --git a/pkg/proc/native/registers_darwin_amd64.go b/pkg/proc/native/registers_darwin_amd64.go index e6dd80e9..a040b0b5 100644 --- a/pkg/proc/native/registers_darwin_amd64.go +++ b/pkg/proc/native/registers_darwin_amd64.go @@ -126,6 +126,10 @@ func (thread *Thread) SetSP(sp uint64) error { return errors.New("not implemented") } +func (thread *Thread) SetDX(dx uint64) error { + return errors.New("not implemented") +} + func (r *Regs) Get(n int) (uint64, error) { reg := x86asm.Reg(n) const ( diff --git a/pkg/proc/native/registers_linux_amd64.go b/pkg/proc/native/registers_linux_amd64.go index 5ac6b205..4e3036ca 100644 --- a/pkg/proc/native/registers_linux_amd64.go +++ b/pkg/proc/native/registers_linux_amd64.go @@ -115,6 +115,18 @@ func (thread *Thread) SetSP(sp uint64) (err error) { return } +func (thread *Thread) SetDX(dx uint64) (err error) { + var ir proc.Registers + ir, err = registers(thread, false) + if err != nil { + return err + } + r := ir.(*Regs) + r.regs.Rdx = dx + thread.dbp.execPtraceFunc(func() { err = sys.PtraceSetRegs(thread.ID, r.regs) }) + return +} + func (r *Regs) Get(n int) (uint64, error) { reg := x86asm.Reg(n) const ( diff --git a/pkg/proc/native/registers_windows_amd64.go b/pkg/proc/native/registers_windows_amd64.go index e03f4b83..0e29ab34 100644 --- a/pkg/proc/native/registers_windows_amd64.go +++ b/pkg/proc/native/registers_windows_amd64.go @@ -160,6 +160,20 @@ func (thread *Thread) SetSP(sp uint64) error { return _SetThreadContext(thread.os.hThread, context) } +func (thread *Thread) SetDX(dx uint64) error { + context := newCONTEXT() + context.ContextFlags = _CONTEXT_ALL + + err := _GetThreadContext(thread.os.hThread, context) + if err != nil { + return err + } + + context.Rdx = dx + + return _SetThreadContext(thread.os.hThread, context) +} + func (r *Regs) Get(n int) (uint64, error) { reg := x86asm.Reg(n) const ( diff --git a/pkg/proc/threads.go b/pkg/proc/threads.go index bee6d9f0..27b27ac9 100644 --- a/pkg/proc/threads.go +++ b/pkg/proc/threads.go @@ -37,6 +37,7 @@ type Thread interface { SetPC(uint64) error SetSP(uint64) error + SetDX(uint64) error } // Location represents the location of a thread. diff --git a/pkg/proc/variables.go b/pkg/proc/variables.go index fcd2dce1..ec89a044 100644 --- a/pkg/proc/variables.go +++ b/pkg/proc/variables.go @@ -1490,22 +1490,19 @@ func (dstv *Variable) writeCopy(srcv *Variable) error { } func (v *Variable) readFunctionPtr() { - val := make([]byte, v.bi.Arch.PtrSize()) - _, err := v.mem.ReadMemory(val, v.Addr) - if err != nil { - v.Unreadable = err + // dereference pointer to find function pc + fnaddr := v.funcvalAddr() + if v.Unreadable != nil { return } - - // dereference pointer to find function pc - fnaddr := uintptr(binary.LittleEndian.Uint64(val)) if fnaddr == 0 { v.Base = 0 v.Value = constant.MakeString("") return } - _, err = v.mem.ReadMemory(val, fnaddr) + val := make([]byte, v.bi.Arch.PtrSize()) + _, err := v.mem.ReadMemory(val, uintptr(fnaddr)) if err != nil { v.Unreadable = err return @@ -1521,6 +1518,17 @@ func (v *Variable) readFunctionPtr() { v.Value = constant.MakeString(fn.Name) } +// funcvalAddr reads the address of the funcval contained in a function variable. +func (v *Variable) funcvalAddr() uint64 { + val := make([]byte, v.bi.Arch.PtrSize()) + _, err := v.mem.ReadMemory(val, v.Addr) + if err != nil { + v.Unreadable = err + return 0 + } + return binary.LittleEndian.Uint64(val) +} + func (v *Variable) loadMap(recurseLevel int, cfg LoadConfig) { it := v.mapIterator() if it == nil { diff --git a/service/test/variables_test.go b/service/test/variables_test.go index 19e5e1fa..411af69e 100644 --- a/service/test/variables_test.go +++ b/service/test/variables_test.go @@ -1105,10 +1105,13 @@ func TestCallFunction(t *testing.T) { {`vable_a.nonexistent()`, nil, errors.New("vable_a has no member nonexistent")}, {`pable_pa.nonexistent()`, nil, errors.New("pable_pa has no member nonexistent")}, - //TODO(aarzilli): indirect call of func value / set to top-level func - //TODO(aarzilli): indirect call of func value / set to func literal - //TODO(aarzilli): indirect call of func value / set to value method - //TODO(aarzilli): indirect call of func value / set to pointer method + {`fn2glob(10, 20)`, []string{":int:30"}, nil}, // indirect call of func value / set to top-level func + {`fn2clos(11)`, []string{`:string:"1 + 6 + 11 = 18"`}, nil}, // indirect call of func value / set to func literal + {`fn2clos(12)`, []string{`:string:"2 + 6 + 12 = 20"`}, nil}, + {`fn2valmeth(13)`, []string{`:string:"13 + 6 = 19"`}, nil}, // indirect call of func value / set to value method + {`fn2ptrmeth(14)`, []string{`:string:"14 - 6 = 8"`}, nil}, // indirect call of func value / set to pointer method + + {"fn2nil()", nil, errors.New("nil pointer dereference")}, } withTestProcess("fncall", t, func(p proc.Process, fixture protest.Fixture) { @@ -1133,7 +1136,17 @@ func TestCallFunction(t *testing.T) { t.Fatalf("call %q: error %q", tc.expr, err.Error()) } - retvals := p.CurrentThread().Common().ReturnValues(pnormalLoadConfig) + retvalsVar := p.CurrentThread().Common().ReturnValues(pnormalLoadConfig) + retvals := make([]*api.Variable, len(retvalsVar)) + + for i := range retvals { + retvals[i] = api.ConvertVar(retvalsVar[i]) + } + + t.Logf("call %q", tc.expr) + for i := range retvals { + t.Logf("\t%s = %s", retvals[i].Name, retvals[i].SinglelineString()) + } if len(retvals) != len(tc.outs) { t.Fatalf("call %q: wrong number of return parameters", tc.expr) @@ -1147,12 +1160,10 @@ func TestCallFunction(t *testing.T) { t.Fatalf("call %q output parameter %d: expected name %q, got %q", tc.expr, i, tgtName, retvals[i].Name) } - cv := api.ConvertVar(retvals[i]) - - if cv.Type != tgtType { - t.Fatalf("call %q, output parameter %d: expected type %q, got %q", tc.expr, i, tgtType, cv.Type) + if retvals[i].Type != tgtType { + t.Fatalf("call %q, output parameter %d: expected type %q, got %q", tc.expr, i, tgtType, retvals[i].Type) } - if cvs := cv.SinglelineString(); cvs != tgtValue { + if cvs := retvals[i].SinglelineString(); cvs != tgtValue { t.Fatalf("call %q, output parameter %d: expected value %q, got %q", tc.expr, i, tgtValue, cvs) } }