proc: support calls through function pointers

This commit is contained in:
aarzilli 2018-07-31 18:32:30 +02:00 committed by Derek Parker
parent 7c42fc51d7
commit 19ba86c0c9
11 changed files with 132 additions and 33 deletions

@ -57,6 +57,14 @@ type VRcvrable interface {
var zero = 0
func makeclos(pa *astruct) func(int) string {
i := 0
return func(x int) string {
i++
return fmt.Sprintf("%d + %d + %d = %d", i, pa.X, x, i+pa.X+x)
}
}
func main() {
one, two := 1, 2
intslice := []int{1, 2, 3}
@ -69,7 +77,14 @@ func main() {
var vable_pa VRcvrable = pa
var pable_pa PRcvrable = pa
fn2clos := makeclos(pa)
fn2glob := call1
fn2valmeth := pa.VRcvr
fn2ptrmeth := pa.PRcvr
var fn2nil func()
runtime.Breakpoint()
call1(one, two)
fmt.Println(one, two, zero, callpanic, callstacktrace, stringsJoin, intslice, stringslice, comma, a.VRcvr, a.PRcvr, pa, vable_a, vable_pa, pable_pa)
fn2clos(2)
fmt.Println(one, two, zero, callpanic, callstacktrace, stringsJoin, intslice, stringslice, comma, a.VRcvr, a.PRcvr, pa, vable_a, vable_pa, pable_pa, fn2clos, fn2glob, fn2valmeth, fn2ptrmeth, fn2nil)
}

@ -275,6 +275,10 @@ func (t *Thread) SetSP(uint64) error {
return errors.New("not supported")
}
func (t *Thread) SetDX(uint64) error {
return errors.New("not supported")
}
func (p *Process) Breakpoints() *proc.BreakpointMap {
return &p.breakpoints
}

@ -65,6 +65,8 @@ type functionCallState struct {
err error
// fn is the function that is being called
fn *Function
// closureAddr is the address of the closure being called
closureAddr uint64
// argmem contains the argument frame of this function call
argmem []byte
// retvars contains the return variables after the function call terminates without panic'ing
@ -118,7 +120,7 @@ func CallFunction(p Process, expr string, retLoadCfg *LoadConfig) error {
return ErrFuncCallUnsupportedBackend
}
fn, argvars, err := funcCallEvalExpr(p, expr)
fn, closureAddr, argvars, err := funcCallEvalExpr(p, expr)
if err != nil {
return err
}
@ -140,6 +142,7 @@ func CallFunction(p Process, expr string, retLoadCfg *LoadConfig) error {
fncall.savedRegs = regs.Save()
fncall.expr = expr
fncall.fn = fn
fncall.closureAddr = closureAddr
fncall.argmem = argmem
fncall.retLoadCfg = retLoadCfg
@ -191,35 +194,42 @@ func callOP(bi *BinaryInfo, thread Thread, regs Registers, callAddr uint64) erro
// funcCallEvalExpr evaluates expr, which must be a function call, returns
// the function being called and its arguments.
func funcCallEvalExpr(p Process, expr string) (fn *Function, argvars []*Variable, err error) {
func funcCallEvalExpr(p Process, expr string) (fn *Function, closureAddr uint64, argvars []*Variable, err error) {
bi := p.BinInfo()
scope, err := GoroutineScope(p.CurrentThread())
if err != nil {
return nil, nil, err
return nil, 0, nil, err
}
t, err := parser.ParseExpr(expr)
if err != nil {
return nil, nil, err
return nil, 0, nil, err
}
callexpr, iscall := t.(*ast.CallExpr)
if !iscall {
return nil, nil, ErrNotACallExpr
return nil, 0, nil, ErrNotACallExpr
}
fnvar, err := scope.evalAST(callexpr.Fun)
if err != nil {
return nil, nil, err
return nil, 0, nil, err
}
if fnvar.Kind != reflect.Func {
return nil, nil, fmt.Errorf("expression %q is not a function", exprToString(callexpr.Fun))
return nil, 0, nil, fmt.Errorf("expression %q is not a function", exprToString(callexpr.Fun))
}
fnvar.loadValue(LoadConfig{false, 0, 0, 0, 0})
if fnvar.Unreadable != nil {
return nil, 0, nil, fnvar.Unreadable
}
if fnvar.Base == 0 {
return nil, 0, nil, errors.New("nil pointer dereference")
}
fn = bi.PCToFunc(uint64(fnvar.Base))
if fn == nil {
return nil, nil, fmt.Errorf("could not find DIE for function %q", exprToString(callexpr.Fun))
return nil, 0, nil, fmt.Errorf("could not find DIE for function %q", exprToString(callexpr.Fun))
}
if !fn.cu.isgo {
return nil, nil, ErrNotAGoFunction
return nil, 0, nil, ErrNotAGoFunction
}
argvars = make([]*Variable, 0, len(callexpr.Args)+1)
@ -230,13 +240,13 @@ func funcCallEvalExpr(p Process, expr string) (fn *Function, argvars []*Variable
for i := range callexpr.Args {
argvar, err := scope.evalAST(callexpr.Args[i])
if err != nil {
return nil, nil, err
return nil, 0, nil, err
}
argvar.Name = exprToString(callexpr.Args[i])
argvars = append(argvars, argvar)
}
return fn, argvars, nil
return fn, fnvar.funcvalAddr(), argvars, nil
}
type funcCallArg struct {
@ -357,7 +367,9 @@ func escapeCheck(v *Variable, name string, g *G) error {
}
}
case reflect.Func:
//TODO(aarzilli): check closure argument?
if err := escapeCheckPointer(uintptr(v.funcvalAddr()), name, g); err != nil {
return err
}
}
return nil
@ -365,7 +377,7 @@ func escapeCheck(v *Variable, name string, g *G) error {
func escapeCheckPointer(addr uintptr, name string, g *G) error {
if uint64(addr) >= g.stacklo && uint64(addr) < g.stackhi {
return fmt.Errorf("stack object passed to escaping pointer %s", name)
return fmt.Errorf("stack object passed to escaping pointer: %s", name)
}
return nil
}
@ -425,7 +437,11 @@ func (fncall *functionCallState) step(p Process) {
if n != len(fncall.argmem) {
fncall.err = fmt.Errorf("short argument write: %d %d", n, len(fncall.argmem))
}
//TODO(aarzilli): if fncall.fn is a function closure CX needs to be set here
if fncall.closureAddr != 0 {
// When calling a function pointer we must set the DX register to the
// address of the function pointer itself.
thread.SetDX(fncall.closureAddr)
}
callOP(bi, thread, regs, fncall.fn.Entry)
case debugCallAXRestoreRegisters:

@ -1533,6 +1533,10 @@ func (regs *gdbRegisters) setSP(value uint64) {
binary.LittleEndian.PutUint64(regs.regs[regnameSP].value, value)
}
func (regs *gdbRegisters) setDX(value uint64) {
binary.LittleEndian.PutUint64(regs.regs[regnameDX].value, value)
}
func (regs *gdbRegisters) BP() uint64 {
return binary.LittleEndian.Uint64(regs.regs[regnameBP].value)
}
@ -1736,6 +1740,15 @@ func (t *Thread) SetSP(sp uint64) error {
return t.p.conn.writeRegister(t.strID, reg.regnum, reg.value)
}
func (t *Thread) SetDX(dx uint64) error {
t.regs.setDX(dx)
if t.p.gcmdok {
return t.p.conn.writeRegisters(t.strID, t.regs.buf)
}
reg := t.regs.regs[regnameDX]
return t.p.conn.writeRegister(t.strID, reg.regnum, reg.value)
}
func (regs *gdbRegisters) Slice() []proc.Register {
r := make([]proc.Register, 0, len(regs.regsInfo))
for _, reginfo := range regs.regsInfo {

@ -52,6 +52,7 @@ const (
regnamePC = "rip"
regnameCX = "rcx"
regnameSP = "rsp"
regnameDX = "rdx"
regnameBP = "rbp"
regnameFsBase = "fs_base"
regnameGsBase = "gs_base"

@ -126,6 +126,10 @@ func (thread *Thread) SetSP(sp uint64) error {
return errors.New("not implemented")
}
func (thread *Thread) SetDX(dx uint64) error {
return errors.New("not implemented")
}
func (r *Regs) Get(n int) (uint64, error) {
reg := x86asm.Reg(n)
const (

@ -115,6 +115,18 @@ func (thread *Thread) SetSP(sp uint64) (err error) {
return
}
func (thread *Thread) SetDX(dx uint64) (err error) {
var ir proc.Registers
ir, err = registers(thread, false)
if err != nil {
return err
}
r := ir.(*Regs)
r.regs.Rdx = dx
thread.dbp.execPtraceFunc(func() { err = sys.PtraceSetRegs(thread.ID, r.regs) })
return
}
func (r *Regs) Get(n int) (uint64, error) {
reg := x86asm.Reg(n)
const (

@ -160,6 +160,20 @@ func (thread *Thread) SetSP(sp uint64) error {
return _SetThreadContext(thread.os.hThread, context)
}
func (thread *Thread) SetDX(dx uint64) error {
context := newCONTEXT()
context.ContextFlags = _CONTEXT_ALL
err := _GetThreadContext(thread.os.hThread, context)
if err != nil {
return err
}
context.Rdx = dx
return _SetThreadContext(thread.os.hThread, context)
}
func (r *Regs) Get(n int) (uint64, error) {
reg := x86asm.Reg(n)
const (

@ -37,6 +37,7 @@ type Thread interface {
SetPC(uint64) error
SetSP(uint64) error
SetDX(uint64) error
}
// Location represents the location of a thread.

@ -1490,22 +1490,19 @@ func (dstv *Variable) writeCopy(srcv *Variable) error {
}
func (v *Variable) readFunctionPtr() {
val := make([]byte, v.bi.Arch.PtrSize())
_, err := v.mem.ReadMemory(val, v.Addr)
if err != nil {
v.Unreadable = err
// dereference pointer to find function pc
fnaddr := v.funcvalAddr()
if v.Unreadable != nil {
return
}
// dereference pointer to find function pc
fnaddr := uintptr(binary.LittleEndian.Uint64(val))
if fnaddr == 0 {
v.Base = 0
v.Value = constant.MakeString("")
return
}
_, err = v.mem.ReadMemory(val, fnaddr)
val := make([]byte, v.bi.Arch.PtrSize())
_, err := v.mem.ReadMemory(val, uintptr(fnaddr))
if err != nil {
v.Unreadable = err
return
@ -1521,6 +1518,17 @@ func (v *Variable) readFunctionPtr() {
v.Value = constant.MakeString(fn.Name)
}
// funcvalAddr reads the address of the funcval contained in a function variable.
func (v *Variable) funcvalAddr() uint64 {
val := make([]byte, v.bi.Arch.PtrSize())
_, err := v.mem.ReadMemory(val, v.Addr)
if err != nil {
v.Unreadable = err
return 0
}
return binary.LittleEndian.Uint64(val)
}
func (v *Variable) loadMap(recurseLevel int, cfg LoadConfig) {
it := v.mapIterator()
if it == nil {

@ -1105,10 +1105,13 @@ func TestCallFunction(t *testing.T) {
{`vable_a.nonexistent()`, nil, errors.New("vable_a has no member nonexistent")},
{`pable_pa.nonexistent()`, nil, errors.New("pable_pa has no member nonexistent")},
//TODO(aarzilli): indirect call of func value / set to top-level func
//TODO(aarzilli): indirect call of func value / set to func literal
//TODO(aarzilli): indirect call of func value / set to value method
//TODO(aarzilli): indirect call of func value / set to pointer method
{`fn2glob(10, 20)`, []string{":int:30"}, nil}, // indirect call of func value / set to top-level func
{`fn2clos(11)`, []string{`:string:"1 + 6 + 11 = 18"`}, nil}, // indirect call of func value / set to func literal
{`fn2clos(12)`, []string{`:string:"2 + 6 + 12 = 20"`}, nil},
{`fn2valmeth(13)`, []string{`:string:"13 + 6 = 19"`}, nil}, // indirect call of func value / set to value method
{`fn2ptrmeth(14)`, []string{`:string:"14 - 6 = 8"`}, nil}, // indirect call of func value / set to pointer method
{"fn2nil()", nil, errors.New("nil pointer dereference")},
}
withTestProcess("fncall", t, func(p proc.Process, fixture protest.Fixture) {
@ -1133,7 +1136,17 @@ func TestCallFunction(t *testing.T) {
t.Fatalf("call %q: error %q", tc.expr, err.Error())
}
retvals := p.CurrentThread().Common().ReturnValues(pnormalLoadConfig)
retvalsVar := p.CurrentThread().Common().ReturnValues(pnormalLoadConfig)
retvals := make([]*api.Variable, len(retvalsVar))
for i := range retvals {
retvals[i] = api.ConvertVar(retvalsVar[i])
}
t.Logf("call %q", tc.expr)
for i := range retvals {
t.Logf("\t%s = %s", retvals[i].Name, retvals[i].SinglelineString())
}
if len(retvals) != len(tc.outs) {
t.Fatalf("call %q: wrong number of return parameters", tc.expr)
@ -1147,12 +1160,10 @@ func TestCallFunction(t *testing.T) {
t.Fatalf("call %q output parameter %d: expected name %q, got %q", tc.expr, i, tgtName, retvals[i].Name)
}
cv := api.ConvertVar(retvals[i])
if cv.Type != tgtType {
t.Fatalf("call %q, output parameter %d: expected type %q, got %q", tc.expr, i, tgtType, cv.Type)
if retvals[i].Type != tgtType {
t.Fatalf("call %q, output parameter %d: expected type %q, got %q", tc.expr, i, tgtType, retvals[i].Type)
}
if cvs := cv.SinglelineString(); cvs != tgtValue {
if cvs := retvals[i].SinglelineString(); cvs != tgtValue {
t.Fatalf("call %q, output parameter %d: expected value %q, got %q", tc.expr, i, tgtValue, cvs)
}
}