diff --git a/proctl/arch.go b/proctl/arch.go index 62cdcba9..99be3af0 100644 --- a/proctl/arch.go +++ b/proctl/arch.go @@ -1,22 +1,45 @@ package proctl +import "runtime" + type Arch interface { PtrSize() int BreakpointInstruction() []byte BreakpointSize() int + CurgInstructions() []byte } type AMD64 struct { ptrSize int breakInstruction []byte breakInstructionLen int + curgInstructions []byte } func AMD64Arch() *AMD64 { + var ( + curg []byte + breakInstr = []byte{0xCC} + ) + + switch runtime.GOOS { + case "darwin": + curg = []byte{ + 0x65, 0x48, 0x8b, 0x0C, 0x25, 0xA0, 0x08, // mov %gs:0x8a0,%rcx + 0x0, 0x0, + } + case "linux": + curg = []byte{ + 0x64, 0x48, 0x8b, 0x0c, 0x25, 0xf0, 0xff, 0xff, 0xff, // mov %fs:0xfffffffffffffff0,%rcx + } + } + curg = append(curg, breakInstr[0]) + return &AMD64{ ptrSize: 8, - breakInstruction: []byte{0xCC}, + breakInstruction: breakInstr, breakInstructionLen: 1, + curgInstructions: curg, } } @@ -31,3 +54,7 @@ func (a *AMD64) BreakpointInstruction() []byte { func (a *AMD64) BreakpointSize() int { return a.breakInstructionLen } + +func (a *AMD64) CurgInstructions() []byte { + return a.curgInstructions +} diff --git a/proctl/proctl.go b/proctl/proctl.go index ead2e1f5..accbee9c 100644 --- a/proctl/proctl.go +++ b/proctl/proctl.go @@ -248,13 +248,13 @@ func (dbp *DebuggedProcess) next() error { return err } - curg, err := dbp.CurrentThread.curG() + g, err := dbp.CurrentThread.getG() if err != nil { return err } - if curg.DeferPC != 0 { - _, err = dbp.TempBreak(curg.DeferPC) + if g.DeferPC != 0 { + _, err = dbp.TempBreak(g.DeferPC) if err != nil { return err } @@ -273,7 +273,7 @@ func (dbp *DebuggedProcess) next() error { if err = th.Next(); err != nil { if err, ok := err.(GoroutineExitingError); ok { waitCount = waitCount - 1 + chanRecvCount - if err.goid == curg.Id { + if err.goid == g.Id { goroutineExiting = true } if err := th.Continue(); err != nil { @@ -290,12 +290,12 @@ func (dbp *DebuggedProcess) next() error { if err != nil { return err } - tg, err := thread.curG() + tg, err := thread.getG() if err != nil { return err } // Make sure we're on the same goroutine, unless it has exited. - if tg.Id == curg.Id || goroutineExiting { + if tg.Id == g.Id || goroutineExiting { if dbp.CurrentThread != thread { dbp.SwitchThread(thread.Id) } @@ -427,7 +427,7 @@ func (dbp *DebuggedProcess) GoroutinesInfo() ([]*G, error) { allgptr := binary.LittleEndian.Uint64(faddr) for i := uint64(0); i < allglen; i++ { - g, err := parseG(dbp.CurrentThread, allgptr+(i*uint64(dbp.arch.PtrSize()))) + g, err := parseG(dbp.CurrentThread, allgptr+(i*uint64(dbp.arch.PtrSize())), true) if err != nil { return nil, err } @@ -467,10 +467,6 @@ func (dbp *DebuggedProcess) EvalSymbol(name string) (*Variable, error) { return dbp.CurrentThread.EvalSymbol(name) } -func (dbp *DebuggedProcess) CallFn(name string, fn func() error) error { - return dbp.CurrentThread.CallFn(name, fn) -} - // Returns a reader for the dwarf data func (dbp *DebuggedProcess) DwarfReader() *reader.Reader { return reader.New(dbp.dwarf) diff --git a/proctl/proctl_linux.go b/proctl/proctl_linux.go index 9ff87017..93787be9 100644 --- a/proctl/proctl_linux.go +++ b/proctl/proctl_linux.go @@ -234,7 +234,6 @@ func (dbp *DebuggedProcess) trapWait(pid int) (*ThreadContext, error) { if ok { th.Status = status } - if status.Exited() { if wpid == dbp.Pid { dbp.exited = true @@ -249,30 +248,29 @@ func (dbp *DebuggedProcess) trapWait(pid int) (*ThreadContext, error) { if err != nil { return nil, fmt.Errorf("could not get event message: %s", err) } - th, err = dbp.addThread(int(cloned), false) if err != nil { return nil, err } - - err = th.Continue() - if err != nil { + if err = th.Continue(); err != nil { return nil, fmt.Errorf("could not continue new thread %d %s", cloned, err) } - - err = dbp.Threads[int(wpid)].Continue() - if err != nil { - return nil, fmt.Errorf("could not continue new thread %d %s", cloned, err) + if err = dbp.Threads[int(wpid)].Continue(); err != nil { + return nil, fmt.Errorf("could not continue existing thread %d %s", cloned, err) } continue } if status.StopSignal() == sys.SIGTRAP { return dbp.handleBreakpointOnThread(wpid) } + if status.StopSignal() == sys.SIGTRAP && dbp.halt { + return th, nil + } if status.StopSignal() == sys.SIGSTOP && dbp.halt { return nil, ManualStopError{} } if th != nil { + // TODO(dp) alert user about unexpected signals here. if err := th.Continue(); err != nil { return nil, err } diff --git a/proctl/proctl_test.go b/proctl/proctl_test.go index 3b6198fc..f9bc49fd 100644 --- a/proctl/proctl_test.go +++ b/proctl/proctl_test.go @@ -411,59 +411,3 @@ func TestSwitchThread(t *testing.T) { } }) } - -func TestFunctionCall(t *testing.T) { - withTestProcess("testprog", t, func(p *DebuggedProcess, fixture protest.Fixture) { - pc, err := p.FindLocation("main.main") - if err != nil { - t.Fatal(err) - } - _, err = p.Break(pc) - if err != nil { - t.Fatal(err) - } - err = p.Continue() - if err != nil { - t.Fatal(err) - } - pc, err = p.PC() - if err != nil { - t.Fatal(err) - } - fn := p.goSymTable.PCToFunc(pc) - if fn == nil { - t.Fatalf("Could not find func for PC: %#v", pc) - } - if fn.Name != "main.main" { - t.Fatal("Program stopped at incorrect place") - } - if err = p.CallFn("runtime.getg", func() error { - th := p.CurrentThread - pc, err := th.PC() - if err != nil { - t.Fatal(err) - } - f := th.dbp.goSymTable.LookupFunc("runtime.getg") - if f == nil { - t.Fatalf("could not find function %s", "runtime.getg") - } - if pc-1 != f.End-2 && pc != f.End-2 { - t.Fatalf("wrong pc expected %#v got %#v", f.End-2, pc-1) - } - return nil - }); err != nil { - t.Fatal(err) - } - pc, err = p.PC() - if err != nil { - t.Fatal(err) - } - fn = p.goSymTable.PCToFunc(pc) - if fn == nil { - t.Fatalf("Could not find func for PC: %#v", pc) - } - if fn.Name != "main.main" { - t.Fatal("Program stopped at incorrect place") - } - }) -} diff --git a/proctl/registers.go b/proctl/registers.go index ab9f2700..522870ca 100644 --- a/proctl/registers.go +++ b/proctl/registers.go @@ -9,6 +9,7 @@ import "fmt" type Registers interface { PC() uint64 SP() uint64 + CX() uint64 SetPC(*ThreadContext, uint64) error } diff --git a/proctl/registers_darwin_amd64.go b/proctl/registers_darwin_amd64.go index 80feb3da..29630277 100644 --- a/proctl/registers_darwin_amd64.go +++ b/proctl/registers_darwin_amd64.go @@ -5,7 +5,7 @@ import "C" import "fmt" type Regs struct { - pc, sp uint64 + pc, sp, cx uint64 } func (r *Regs) PC() uint64 { @@ -16,6 +16,10 @@ func (r *Regs) SP() uint64 { return r.sp } +func (r *Regs) CX() uint64 { + return r.cx +} + func (r *Regs) SetPC(thread *ThreadContext, pc uint64) error { kret := C.set_pc(thread.os.thread_act, C.uint64_t(pc)) if kret != C.KERN_SUCCESS { @@ -30,6 +34,6 @@ func registers(thread *ThreadContext) (Registers, error) { if kret != C.KERN_SUCCESS { return nil, fmt.Errorf("could not get registers") } - regs := &Regs{pc: uint64(state.__rip), sp: uint64(state.__rsp)} + regs := &Regs{pc: uint64(state.__rip), sp: uint64(state.__rsp), cx: uint64(state.__rcx)} return regs, nil } diff --git a/proctl/registers_linux_amd64.go b/proctl/registers_linux_amd64.go index 2a2793b6..0cb0a487 100644 --- a/proctl/registers_linux_amd64.go +++ b/proctl/registers_linux_amd64.go @@ -14,6 +14,10 @@ func (r *Regs) SP() uint64 { return r.regs.Rsp } +func (r *Regs) CX() uint64 { + return r.regs.Rcx +} + func (r *Regs) SetPC(thread *ThreadContext, pc uint64) error { r.regs.SetPC(pc) return sys.PtraceSetRegs(thread.Id, r.regs) diff --git a/proctl/threads.go b/proctl/threads.go index 3840f136..bc991b13 100644 --- a/proctl/threads.go +++ b/proctl/threads.go @@ -85,47 +85,6 @@ func (thread *ThreadContext) Step() (err error) { return nil } -// Call a function named `name`. This is currently _NOT_ safe. -func (thread *ThreadContext) CallFn(name string, fn func() error) error { - f := thread.dbp.goSymTable.LookupFunc(name) - if f == nil { - return fmt.Errorf("could not find function %s", name) - } - - // Set breakpoint at the end of the function (before it returns). - bp, err := thread.TempBreak(f.End - 2) - if err != nil { - return err - } - defer thread.dbp.Clear(bp.Addr) - - regs, err := thread.saveRegisters() - if err != nil { - return err - } - - previousFrame := make([]byte, f.FrameSize) - frameSize := uintptr(regs.SP() + uint64(f.FrameSize)) - if _, err := readMemory(thread, frameSize, previousFrame); err != nil { - return err - } - defer func() { writeMemory(thread, frameSize, previousFrame) }() - - if err = thread.SetPC(f.Entry); err != nil { - return err - } - defer thread.restoreRegisters() - if err = thread.Continue(); err != nil { - return err - } - th, err := thread.dbp.trapWait(-1) - if err != nil { - return err - } - th.CurrentBreakpoint = nil - return fn() -} - // Set breakpoint using this thread. func (thread *ThreadContext) Break(addr uint64) (*BreakPoint, error) { return thread.dbp.setBreakpoint(thread.Id, addr, false) @@ -228,7 +187,7 @@ func (thread *ThreadContext) next(curpc uint64, fde *frame.FrameDescriptionEntry if !covered { fn := thread.dbp.goSymTable.PCToFunc(ret) if fn != nil && fn.Name == "runtime.goexit" { - g, err := thread.curG() + g, err := thread.getG() if err != nil { return err } @@ -275,15 +234,77 @@ func (thread *ThreadContext) SetPC(pc uint64) error { return regs.SetPC(thread, pc) } -func (thread *ThreadContext) curG() (*G, error) { - var g *G - err := thread.CallFn("runtime.getg", func() error { - regs, err := thread.Registers() - if err != nil { - return err +// Returns information on the G (goroutine) that is executing on this thread. +// +// The G structure for a thread is stored in thread local memory. Execute instructions +// that move the *G structure into a CPU register (we use rcx here), and then grab +// the new registers and parse the G structure. +// +// We cannot simply use the allg linked list in order to find the M that represents +// the given OS thread and follow its G pointer because on Darwin mach ports are not +// universal, so our port for this thread would not map to the `id` attribute of the M +// structure. Also, when linked against libc, Go prefers the libc version of clone as +// opposed to the runtime version. This has the consequence of not setting M.id for +// any thread, regardless of OS. +// +// In order to get around all this craziness, we write the instructions to retrieve the G +// structure running on this thread (which is stored in thread local memory) into the +// current instruction stream. The instructions are obviously arch/os dependant, as they +// vary on how thread local storage is implemented, which MMU register is used and +// what the offset into thread local storage is. +func (thread *ThreadContext) getG() (g *G, err error) { + var pcInt uint64 + pcInt, err = thread.PC() + if err != nil { + return + } + pc := uintptr(pcInt) + // Read original instructions. + originalInstructions := make([]byte, len(thread.dbp.arch.CurgInstructions())) + if _, err = readMemory(thread, pc, originalInstructions); err != nil { + return + } + // Write new instructions. + if _, err = writeMemory(thread, pc, thread.dbp.arch.CurgInstructions()); err != nil { + return + } + // We're going to be intentionally modifying the registers + // once we execute the code we inject into the instruction stream, + // so save them off here so we can restore them later. + if _, err = thread.saveRegisters(); err != nil { + return + } + // Ensure original instructions and PC are both restored. + defer func() { + // Do not shadow previous error, if there was one. + originalErr := err + // Restore the original instructions and register contents. + if _, err = writeMemory(thread, pc, originalInstructions); err != nil { + return } - g, err = parseG(thread, regs.SP()+uint64(thread.dbp.arch.PtrSize())) - return err - }) - return g, err + if err = thread.restoreRegisters(); err != nil { + return + } + err = originalErr + return + }() + // Execute new instructions. + if err = thread.resume(); err != nil { + return + } + // Set the halt flag so that trapWait will ignore the fact that + // we hit a breakpoint that isn't captured in our list of + // known breakpoints. + thread.dbp.halt = true + defer func(dbp *DebuggedProcess) { dbp.halt = false }(thread.dbp) + if _, err = thread.dbp.trapWait(-1); err != nil { + return + } + // Grab *G from RCX. + regs, err := thread.Registers() + if err != nil { + return nil, err + } + g, err = parseG(thread, regs.CX(), false) + return } diff --git a/proctl/threads_darwin.c b/proctl/threads_darwin.c index 4ae4bfe5..3879ac8e 100644 --- a/proctl/threads_darwin.c +++ b/proctl/threads_darwin.c @@ -3,9 +3,6 @@ int write_memory(mach_port_name_t task, mach_vm_address_t addr, void *d, mach_msg_type_number_t len) { kern_return_t kret; - pointer_t data; - memcpy((void *)&data, d, len); - vm_region_submap_short_info_data_64_t info; mach_msg_type_number_t count = VM_REGION_SUBMAP_SHORT_INFO_COUNT_64; mach_vm_size_t l = len; @@ -19,7 +16,7 @@ write_memory(mach_port_name_t task, mach_vm_address_t addr, void *d, mach_msg_ty kret = mach_vm_protect(task, addr, len, FALSE, VM_PROT_WRITE|VM_PROT_COPY|VM_PROT_READ); if (kret != KERN_SUCCESS) return -1; - kret = mach_vm_write((vm_map_t)task, addr, (vm_offset_t)&data, len); + kret = mach_vm_write((vm_map_t)task, addr, (vm_offset_t)d, len); if (kret != KERN_SUCCESS) return -1; // Restore virtual memory permissions diff --git a/proctl/threads_darwin.go b/proctl/threads_darwin.go index f6aee81b..7cd18453 100644 --- a/proctl/threads_darwin.go +++ b/proctl/threads_darwin.go @@ -57,12 +57,14 @@ func (t *ThreadContext) blocked() bool { } func writeMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } var ( vm_data = unsafe.Pointer(&data[0]) vm_addr = C.mach_vm_address_t(addr) length = C.mach_msg_type_number_t(len(data)) ) - if ret := C.write_memory(thread.dbp.os.task, vm_addr, vm_data, length); ret < 0 { return 0, fmt.Errorf("could not write memory") } @@ -70,6 +72,9 @@ func writeMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) } func readMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } var ( vm_data = unsafe.Pointer(&data[0]) vm_addr = C.mach_vm_address_t(addr) diff --git a/proctl/threads_darwin.h b/proctl/threads_darwin.h index eb4d6b44..93310f71 100644 --- a/proctl/threads_darwin.h +++ b/proctl/threads_darwin.h @@ -1,3 +1,4 @@ +#include #include #include #include diff --git a/proctl/threads_linux.go b/proctl/threads_linux.go index f7e5dca7..aa4dfd05 100644 --- a/proctl/threads_linux.go +++ b/proctl/threads_linux.go @@ -62,9 +62,15 @@ func (thread *ThreadContext) restoreRegisters() error { } func writeMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } return sys.PtracePokeData(thread.Id, addr, data) } func readMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { + if len(data) == 0 { + return 0, nil + } return sys.PtracePeekData(thread.Id, addr, data) } diff --git a/proctl/variables.go b/proctl/variables.go index 93e2c166..36852b02 100644 --- a/proctl/variables.go +++ b/proctl/variables.go @@ -79,15 +79,20 @@ func (ng NoGError) Error() string { return fmt.Sprintf("no G executing on thread %d", ng.tid) } -func parseG(thread *ThreadContext, addr uint64) (*G, error) { - gaddrbytes, err := thread.readMemory(uintptr(addr), thread.dbp.arch.PtrSize()) - if err != nil { - return nil, fmt.Errorf("error derefing *G %s", err) - } - initialInstructions := append([]byte{op.DW_OP_addr}, gaddrbytes...) - gaddr := binary.LittleEndian.Uint64(gaddrbytes) - if gaddr == 0 { - return nil, NoGError{tid: thread.Id} +func parseG(thread *ThreadContext, gaddr uint64, deref bool) (*G, error) { + initialInstructions := make([]byte, thread.dbp.arch.PtrSize()+1) + initialInstructions[0] = op.DW_OP_addr + binary.LittleEndian.PutUint64(initialInstructions[1:], gaddr) + if deref { + gaddrbytes, err := thread.readMemory(uintptr(gaddr), thread.dbp.arch.PtrSize()) + if err != nil { + return nil, fmt.Errorf("error derefing *G %s", err) + } + initialInstructions = append([]byte{op.DW_OP_addr}, gaddrbytes...) + gaddr = binary.LittleEndian.Uint64(gaddrbytes) + if gaddr == 0 { + return nil, NoGError{tid: thread.Id} + } } rdr := thread.dbp.DwarfReader() @@ -97,9 +102,6 @@ func parseG(thread *ThreadContext, addr uint64) (*G, error) { return nil, err } - // Let's parse all of the members we care about in order so that - // we don't have to spend any extra time seeking. - // Parse defer deferAddr, err := rdr.AddrForMember("_defer", initialInstructions) if err != nil { @@ -109,7 +111,7 @@ func parseG(thread *ThreadContext, addr uint64) (*G, error) { // Dereference *defer pointer deferAddrBytes, err := thread.readMemory(uintptr(deferAddr), thread.dbp.arch.PtrSize()) if err != nil { - return nil, fmt.Errorf("error derefing *G %s", err) + return nil, fmt.Errorf("error derefing defer %s", err) } if binary.LittleEndian.Uint64(deferAddrBytes) != 0 { initialDeferInstructions := append([]byte{op.DW_OP_addr}, deferAddrBytes...) @@ -126,11 +128,16 @@ func parseG(thread *ThreadContext, addr uint64) (*G, error) { if err != nil { return nil, err } - err = rdr.SeekToEntry(entry) - if err != nil { - return nil, err - } } + + // Let's parse all of the members we care about in order so that + // we don't have to spend any extra time seeking. + + err = rdr.SeekToEntry(entry) + if err != nil { + return nil, err + } + // Parse sched schedAddr, err := rdr.AddrForMember("sched", initialInstructions) if err != nil { @@ -530,14 +537,14 @@ func (thread *ThreadContext) readString(addr uintptr) (string, error) { // read len val, err := thread.readMemory(addr+uintptr(thread.dbp.arch.PtrSize()), thread.dbp.arch.PtrSize()) if err != nil { - return "", err + return "", fmt.Errorf("could not read string len %s", err) } strlen := int(binary.LittleEndian.Uint64(val)) // read addr val, err = thread.readMemory(addr, thread.dbp.arch.PtrSize()) if err != nil { - return "", err + return "", fmt.Errorf("could not read string pointer %s", err) } addr = uintptr(binary.LittleEndian.Uint64(val)) if addr == 0 { @@ -546,7 +553,7 @@ func (thread *ThreadContext) readString(addr uintptr) (string, error) { val, err = thread.readMemory(addr, strlen) if err != nil { - return "", err + return "", fmt.Errorf("could not read string at %#v due to %s", addr, err) } return *(*string)(unsafe.Pointer(&val)), nil