GP-4209: GhidraTime-MSTTD integration. Type hints for (most) Python agents.

This commit is contained in:
Dan 2025-03-24 18:28:07 +00:00
parent deb49d5322
commit 21a1602579
93 changed files with 6453 additions and 4118 deletions

View file

@ -17,5 +17,6 @@ src/main/py/MANIFEST.in||GHIDRA||||END|
src/main/py/README.md||GHIDRA||||END| src/main/py/README.md||GHIDRA||||END|
src/main/py/pyproject.toml||GHIDRA||||END| src/main/py/pyproject.toml||GHIDRA||||END|
src/main/py/src/ghidradbg/dbgmodel/DbgModel.idl||GHIDRA||||END| src/main/py/src/ghidradbg/dbgmodel/DbgModel.idl||GHIDRA||||END|
src/main/py/src/ghidradbg/py.typed||GHIDRA||||END|
src/main/py/src/ghidradbg/schema.xml||GHIDRA||||END| src/main/py/src/ghidradbg/schema.xml||GHIDRA||||END|
src/main/py/src/ghidradbg/schema_exdi.xml||GHIDRA||||END| src/main/py/src/ghidradbg/schema_exdi.xml||GHIDRA||||END|

View file

@ -57,6 +57,8 @@ def main():
cmd.ghidra_trace_open(target, start_trace=False) cmd.ghidra_trace_open(target, start_trace=False)
# TODO: HACK # TODO: HACK
# Also, the wait() must precede sync_enable() or else PROC_STATE will
# contain the wrong PID, and later events will get snuffed
try: try:
dbg.wait() dbg.wait()
except KeyboardInterrupt as ki: except KeyboardInterrupt as ki:
@ -65,7 +67,8 @@ def main():
cmd.ghidra_trace_start(target) cmd.ghidra_trace_start(target)
cmd.ghidra_trace_sync_enable() cmd.ghidra_trace_sync_enable()
on_state_changed(DbgEng.DEBUG_CES_EXECUTION_STATUS, DbgEng.DEBUG_STATUS_BREAK) on_state_changed(DbgEng.DEBUG_CES_EXECUTION_STATUS,
DbgEng.DEBUG_STATUS_BREAK)
cmd.repl() cmd.repl()

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "ghidradbg" name = "ghidradbg"
version = "11.3" version = "11.4"
authors = [ authors = [
{ name="Ghidra Development Team" }, { name="Ghidra Development Team" },
] ]
@ -17,7 +17,7 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"ghidratrace==11.3", "ghidratrace==11.4",
"pybag>=2.2.12" "pybag>=2.2.12"
] ]
@ -26,7 +26,7 @@ dependencies = [
"Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues" "Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues"
[tool.setuptools.package-data] [tool.setuptools.package-data]
ghidradbg = ["*.tlb"] ghidradbg = ["*.tlb", "py.typed"]
[tool.setuptools] [tool.setuptools]
include-package-data = true include-package-data = true

View file

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from typing import Dict, List, Optional, Tuple
from ghidratrace.client import Address, RegVal from ghidratrace.client import Address, RegVal
from pybag import pydbg from pybag import pydbg
from . import util from . import util
language_map = { language_map: Dict[str, List[str]] = {
'AARCH64': ['AARCH64:LE:64:AppleSilicon'], 'AARCH64': ['AARCH64:LE:64:AppleSilicon'],
'ARM': ['ARM:LE:32:v8'], 'ARM': ['ARM:LE:32:v8'],
'Itanium': [], 'Itanium': [],
@ -31,25 +33,25 @@ language_map = {
'SH4': ['SuperH4:LE:32:default'], 'SH4': ['SuperH4:LE:32:default'],
} }
data64_compiler_map = { data64_compiler_map: Dict[Optional[str], str] = {
None: 'pointer64', None: 'pointer64',
} }
x86_compiler_map = { x86_compiler_map: Dict[Optional[str], str] = {
'windows': 'windows', 'windows': 'windows',
'Cygwin': 'windows', 'Cygwin': 'windows',
'default': 'windows', 'default': 'windows',
} }
default_compiler_map = { default_compiler_map: Dict[Optional[str], str] = {
'windows': 'default', 'windows': 'default',
} }
windows_compiler_map = { windows_compiler_map: Dict[Optional[str], str] = {
'windows': 'windows', 'windows': 'windows',
} }
compiler_map = { compiler_map : Dict[str, Dict[Optional[str], str]]= {
'DATA:BE:64:default': data64_compiler_map, 'DATA:BE:64:default': data64_compiler_map,
'DATA:LE:64:default': data64_compiler_map, 'DATA:LE:64:default': data64_compiler_map,
'x86:LE:32:default': x86_compiler_map, 'x86:LE:32:default': x86_compiler_map,
@ -62,11 +64,11 @@ compiler_map = {
} }
def get_arch(): def get_arch() -> str:
try: try:
type = util.dbg.get_actual_processor_type() type = util.dbg.get_actual_processor_type()
except Exception: except Exception as e:
print("Error getting actual processor type.") print(f"Error getting actual processor type: {e}")
return "Unknown" return "Unknown"
if type is None: if type is None:
return "x86_64" return "x86_64"
@ -129,14 +131,14 @@ def get_arch():
return "Unknown" return "Unknown"
def get_endian(): def get_endian() -> str:
parm = util.get_convenience_variable('endian') parm = util.get_convenience_variable('endian')
if parm != 'auto': if parm != 'auto':
return parm return parm
return 'little' return 'little'
def get_osabi(): def get_osabi() -> str:
parm = util.get_convenience_variable('osabi') parm = util.get_convenience_variable('osabi')
if not parm in ['auto', 'default']: if not parm in ['auto', 'default']:
return parm return parm
@ -150,7 +152,7 @@ def get_osabi():
return "windows" return "windows"
def compute_ghidra_language(): def compute_ghidra_language() -> str:
# First, check if the parameter is set # First, check if the parameter is set
lang = util.get_convenience_variable('ghidra-language') lang = util.get_convenience_variable('ghidra-language')
if lang != 'auto': if lang != 'auto':
@ -175,7 +177,7 @@ def compute_ghidra_language():
return 'DATA' + lebe + '64:default' return 'DATA' + lebe + '64:default'
def compute_ghidra_compiler(lang): def compute_ghidra_compiler(lang: str) -> str:
# First, check if the parameter is set # First, check if the parameter is set
comp = util.get_convenience_variable('ghidra-compiler') comp = util.get_convenience_variable('ghidra-compiler')
if comp != 'auto': if comp != 'auto':
@ -197,7 +199,7 @@ def compute_ghidra_compiler(lang):
return 'default' return 'default'
def compute_ghidra_lcsp(): def compute_ghidra_lcsp() -> Tuple[str, str]:
lang = compute_ghidra_language() lang = compute_ghidra_language()
comp = compute_ghidra_compiler(lang) comp = compute_ghidra_compiler(lang)
return lang, comp return lang, comp
@ -205,10 +207,10 @@ def compute_ghidra_lcsp():
class DefaultMemoryMapper(object): class DefaultMemoryMapper(object):
def __init__(self, defaultSpace): def __init__(self, defaultSpace: str) -> None:
self.defaultSpace = defaultSpace self.defaultSpace = defaultSpace
def map(self, proc: int, offset: int): def map(self, proc: int, offset: int) -> Tuple[str, Address]:
space = self.defaultSpace space = self.defaultSpace
return self.defaultSpace, Address(space, offset) return self.defaultSpace, Address(space, offset)
@ -220,10 +222,10 @@ class DefaultMemoryMapper(object):
DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram') DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram')
memory_mappers = {} memory_mappers: Dict[str, DefaultMemoryMapper] = {}
def compute_memory_mapper(lang): def compute_memory_mapper(lang: str) -> DefaultMemoryMapper:
if not lang in memory_mappers: if not lang in memory_mappers:
return DEFAULT_MEMORY_MAPPER return DEFAULT_MEMORY_MAPPER
return memory_mappers[lang] return memory_mappers[lang]
@ -231,16 +233,15 @@ def compute_memory_mapper(lang):
class DefaultRegisterMapper(object): class DefaultRegisterMapper(object):
def __init__(self, byte_order): def __init__(self, byte_order: str) -> None:
if not byte_order in ['big', 'little']: if not byte_order in ['big', 'little']:
raise ValueError("Invalid byte_order: {}".format(byte_order)) raise ValueError("Invalid byte_order: {}".format(byte_order))
self.byte_order = byte_order self.byte_order = byte_order
self.union_winners = {}
def map_name(self, proc, name): def map_name(self, proc: int, name: str):
return name return name
def map_value(self, proc, name, value): def map_value(self, proc: int, name: str, value: int):
try: try:
# TODO: this seems half-baked # TODO: this seems half-baked
av = value.to_bytes(8, "big") av = value.to_bytes(8, "big")
@ -249,10 +250,10 @@ class DefaultRegisterMapper(object):
.format(name, value, type(value))) .format(name, value, type(value)))
return RegVal(self.map_name(proc, name), av) return RegVal(self.map_name(proc, name), av)
def map_name_back(self, proc, name): def map_name_back(self, proc: int, name: str) -> str:
return name return name
def map_value_back(self, proc, name, value): def map_value_back(self, proc: int, name: str, value: bytes):
return RegVal(self.map_name_back(proc, name), value) return RegVal(self.map_name_back(proc, name), value)

View file

@ -42,4 +42,3 @@ class ModelMethod(object):
return None return None
return mo.ModelObject(result) return mo.ModelObject(result)

View file

@ -342,7 +342,6 @@ class ModelObject(object):
# print(f"{element} not found") # print(f"{element} not found")
return next return next
def GetValue(self): def GetValue(self):
value = self.GetIntrinsicValue() value = self.GetIntrinsicValue()
if value is None: if value is None:
@ -350,4 +349,3 @@ class ModelObject(object):
if value.vt == 0xd: if value.vt == 0xd:
return None return None
return value.value return value.value

View file

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from ghidradbg import arch, commands, util
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import Client, Address, AddressRange, TraceObject from ghidratrace.client import Client, Address, AddressRange, Trace, TraceObject
PAGE_SIZE = 4096 PAGE_SIZE = 4096
from ghidradbg import arch, commands, util
SESSION_PATH = 'Sessions[0]' SESSION_PATH = 'Sessions[0]'
PROCESSES_PATH = SESSION_PATH + '.ExdiProcesses' PROCESSES_PATH = SESSION_PATH + '.ExdiProcesses'
@ -42,69 +42,60 @@ SECTIONS_ADD_PATTERN = '.Sections'
SECTION_KEY_PATTERN = '[{secname}]' SECTION_KEY_PATTERN = '[{secname}]'
SECTION_ADD_PATTERN = SECTIONS_ADD_PATTERN + SECTION_KEY_PATTERN SECTION_ADD_PATTERN = SECTIONS_ADD_PATTERN + SECTION_KEY_PATTERN
@util.dbg.eng_thread @util.dbg.eng_thread
def ghidra_trace_put_processes_exdi(): def ghidra_trace_put_processes_exdi() -> None:
""" """Put the list of processes into the trace's processes list."""
Put the list of processes into the trace's processes list.
"""
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
commands.STATE.require_tx() trace, tx = commands.STATE.require_tx()
with commands.STATE.client.batch() as b: with trace.client.batch() as b:
put_processes_exdi(commands.STATE, radix) put_processes_exdi(trace, radix)
@util.dbg.eng_thread @util.dbg.eng_thread
def ghidra_trace_put_regions_exdi(): def ghidra_trace_put_regions_exdi() -> None:
""" """Read the memory map, if applicable, and write to the trace's Regions."""
Read the memory map, if applicable, and write to the trace's Regions
"""
commands.STATE.require_tx() trace, tx = commands.STATE.require_tx()
with commands.STATE.client.batch() as b: with trace.client.batch() as b:
put_regions_exdi(commands.STATE) put_regions_exdi(trace)
@util.dbg.eng_thread @util.dbg.eng_thread
def ghidra_trace_put_kmodules_exdi(): def ghidra_trace_put_kmodules_exdi() -> None:
""" """Gather object files, if applicable, and write to the trace's Modules."""
Gather object files, if applicable, and write to the trace's Modules
"""
commands.STATE.require_tx() trace, tx = commands.STATE.require_tx()
with commands.STATE.client.batch() as b: with trace.client.batch() as b:
put_kmodules_exdi(commands.STATE) put_kmodules_exdi(trace)
@util.dbg.eng_thread @util.dbg.eng_thread
def ghidra_trace_put_threads_exdi(pid): def ghidra_trace_put_threads_exdi(pid: int) -> None:
""" """Put the current process's threads into the Ghidra trace."""
Put the current process's threads into the Ghidra trace
"""
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
commands.STATE.require_tx() trace, tx = commands.STATE.require_tx()
with commands.STATE.client.batch() as b: with trace.client.batch() as b:
put_threads_exdi(commands.STATE, pid, radix) put_threads_exdi(trace, pid, radix)
@util.dbg.eng_thread @util.dbg.eng_thread
def ghidra_trace_put_all_exdi(): def ghidra_trace_put_all_exdi() -> None:
""" """Put everything currently selected into the Ghidra trace."""
Put everything currently selected into the Ghidra trace
"""
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
commands.STATE.require_tx() trace, tx = commands.STATE.require_tx()
with commands.STATE.client.batch() as b: with trace.client.batch() as b:
if util.dbg.use_generics == False: if util.dbg.use_generics == False:
put_processes_exdi(commands.STATE, radix) put_processes_exdi(trace, radix)
put_regions_exdi(commands.STATE) put_regions_exdi(trace)
put_kmodules_exdi(commands.STATE) put_kmodules_exdi(trace)
@util.dbg.eng_thread @util.dbg.eng_thread
def put_processes_exdi(state, radix): def put_processes_exdi(trace: Trace, radix: int) -> None:
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
keys = [] keys = []
result = util.dbg._base.cmd("!process 0 0") result = util.dbg._base.cmd("!process 0 0")
@ -118,7 +109,7 @@ def put_processes_exdi(state, radix):
id = int(l2[3], 16) id = int(l2[3], 16)
name = l4[1] name = l4[1]
ppath = PROCESS_PATTERN.format(pid=id) ppath = PROCESS_PATTERN.format(pid=id)
procobj = state.trace.create_object(ppath) procobj = trace.create_object(ppath)
keys.append(PROCESS_KEY_PATTERN.format(pid=id)) keys.append(PROCESS_KEY_PATTERN.format(pid=id))
pidstr = ('0x{:x}' if radix == pidstr = ('0x{:x}' if radix ==
16 else '0{:o}' if radix == 8 else '{}').format(id) 16 else '0{:o}' if radix == 8 else '{}').format(id)
@ -126,22 +117,22 @@ def put_processes_exdi(state, radix):
procobj.set_value('Name', name) procobj.set_value('Name', name)
procobj.set_value('_display', '[{}] {}'.format(pidstr, name)) procobj.set_value('_display', '[{}] {}'.format(pidstr, name))
(base, addr) = commands.map_address(int(l1[1], 16)) (base, addr) = commands.map_address(int(l1[1], 16))
procobj.set_value('EPROCESS', addr, schema="ADDRESS") procobj.set_value('EPROCESS', addr, schema=sch.ADDRESS)
(base, addr) = commands.map_address(int(l2[5], 16)) (base, addr) = commands.map_address(int(l2[5], 16))
procobj.set_value('PEB', addr, schema="ADDRESS") procobj.set_value('PEB', addr, schema=sch.ADDRESS)
(base, addr) = commands.map_address(int(l3[1], 16)) (base, addr) = commands.map_address(int(l3[1], 16))
procobj.set_value('DirBase', addr, schema="ADDRESS") procobj.set_value('DirBase', addr, schema=sch.ADDRESS)
(base, addr) = commands.map_address(int(l3[3], 16)) (base, addr) = commands.map_address(int(l3[3], 16))
procobj.set_value('ObjectTable', addr, schema="ADDRESS") procobj.set_value('ObjectTable', addr, schema=sch.ADDRESS)
# procobj.set_value('ObjectTable', l3[3]) # procobj.set_value('ObjectTable', l3[3])
tcobj = state.trace.create_object(ppath+".Threads") tcobj = trace.create_object(ppath+".Threads")
procobj.insert() procobj.insert()
tcobj.insert() tcobj.insert()
state.trace.proxy_object_path(PROCESSES_PATH).retain_values(keys) trace.proxy_object_path(PROCESSES_PATH).retain_values(keys)
@util.dbg.eng_thread @util.dbg.eng_thread
def put_regions_exdi(state): def put_regions_exdi(trace: Trace) -> None:
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
keys = [] keys = []
result = util.dbg._base.cmd("!address") result = util.dbg._base.cmd("!address")
@ -165,8 +156,8 @@ def put_regions_exdi(state):
rng = saddr.extend(int(length, 16)) rng = saddr.extend(int(length, 16))
rpath = REGION_PATTERN.format(start=start) rpath = REGION_PATTERN.format(start=start)
keys.append(REGION_KEY_PATTERN.format(start=start)) keys.append(REGION_KEY_PATTERN.format(start=start))
regobj = state.trace.create_object(rpath) regobj = trace.create_object(rpath)
regobj.set_value('Range', rng, schema="RANGE") regobj.set_value('Range', rng, schema=sch.RANGE)
regobj.set_value('Size', length) regobj.set_value('Size', length)
regobj.set_value('Type', type) regobj.set_value('Type', type)
regobj.set_value('_readable', True) regobj.set_value('_readable', True)
@ -175,11 +166,11 @@ def put_regions_exdi(state):
regobj.set_value('_display', '[{}] {}'.format( regobj.set_value('_display', '[{}] {}'.format(
start, type)) start, type))
regobj.insert() regobj.insert()
state.trace.proxy_object_path(MEMORY_PATH).retain_values(keys) trace.proxy_object_path(MEMORY_PATH).retain_values(keys)
@util.dbg.eng_thread @util.dbg.eng_thread
def put_kmodules_exdi(state): def put_kmodules_exdi(trace: Trace) -> None:
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
keys = [] keys = []
result = util.dbg._base.cmd("lm") result = util.dbg._base.cmd("lm")
@ -203,19 +194,20 @@ def put_kmodules_exdi(state):
rng = saddr.extend(sz) rng = saddr.extend(sz)
mpath = KMODULE_PATTERN.format(modpath=sname) mpath = KMODULE_PATTERN.format(modpath=sname)
keys.append(KMODULE_KEY_PATTERN.format(modpath=sname)) keys.append(KMODULE_KEY_PATTERN.format(modpath=sname))
modobj = commands.STATE.trace.create_object(mpath) modobj = trace.create_object(mpath)
modobj.set_value('Name', name) modobj.set_value('Name', name)
modobj.set_value('Base', saddr, schema="ADDRESS") modobj.set_value('Base', saddr, schema=sch.ADDRESS)
modobj.set_value('Range', rng, schema="RANGE") modobj.set_value('Range', rng, schema=sch.RANGE)
modobj.set_value('Size', hex(sz)) modobj.set_value('Size', hex(sz))
modobj.insert() modobj.insert()
state.trace.proxy_object_path(KMODULES_PATH).retain_values(keys) trace.proxy_object_path(KMODULES_PATH).retain_values(keys)
@util.dbg.eng_thread @util.dbg.eng_thread
def put_threads_exdi(state, pid, radix): def put_threads_exdi(trace: Trace, pid: int, radix: int) -> None:
radix = util.get_convenience_variable('output-radix') radix = util.get_convenience_variable('output-radix')
pidstr = ('0x{:x}' if radix == 16 else '0{:o}' if radix == 8 else '{}').format(pid) pidstr = ('0x{:x}' if radix == 16 else '0{:o}' if radix ==
8 else '{}').format(pid)
keys = [] keys = []
result = util.dbg._base.cmd("!process "+hex(pid)+" 4") result = util.dbg._base.cmd("!process "+hex(pid)+" 4")
lines = result.split("\n") lines = result.split("\n")
@ -229,9 +221,9 @@ def put_threads_exdi(state, pid, radix):
tidstr = ('0x{:x}' if radix == tidstr = ('0x{:x}' if radix ==
16 else '0{:o}' if radix == 8 else '{}').format(tid) 16 else '0{:o}' if radix == 8 else '{}').format(tid)
tpath = THREAD_PATTERN.format(pid=pid, tnum=tid) tpath = THREAD_PATTERN.format(pid=pid, tnum=tid)
tobj = commands.STATE.trace.create_object(tpath) tobj = trace.create_object(tpath)
keys.append(THREAD_KEY_PATTERN.format(tnum=tidstr)) keys.append(THREAD_KEY_PATTERN.format(tnum=tidstr))
tobj = state.trace.create_object(tpath) tobj = trace.create_object(tpath)
tobj.set_value('PID', pidstr) tobj.set_value('PID', pidstr)
tobj.set_value('TID', tidstr) tobj.set_value('TID', tidstr)
tobj.set_value('_display', '[{}]'.format(tidstr)) tobj.set_value('_display', '[{}]'.format(tidstr))
@ -240,5 +232,5 @@ def put_threads_exdi(state, pid, radix):
tobj.set_value('Win32Thread', fields[7]) tobj.set_value('Win32Thread', fields[7])
tobj.set_value('State', fields[8]) tobj.set_value('State', fields[8])
tobj.insert() tobj.insert()
commands.STATE.trace.proxy_object_path( trace.proxy_object_path(THREADS_PATTERN.format(
THREADS_PATTERN.format(pid=pidstr)).retain_values(keys) pid=pidstr)).retain_values(keys)

View file

@ -16,15 +16,17 @@
import re import re
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import MethodRegistry, ParamDesc, Address, AddressRange from ghidratrace.client import (MethodRegistry, ParamDesc, Address,
AddressRange, TraceObject)
from ghidradbg import util, commands, methods from ghidradbg import util, commands, methods
from ghidradbg.methods import REGISTRY, SESSIONS_PATTERN, SESSION_PATTERN, extre from ghidradbg.methods import REGISTRY, SESSIONS_PATTERN, SESSION_PATTERN, extre
from . import exdi_commands from . import exdi_commands
XPROCESSES_PATTERN = extre(SESSION_PATTERN, '\.ExdiProcesses') XPROCESSES_PATTERN = extre(SESSION_PATTERN, '\\.ExdiProcesses')
XPROCESS_PATTERN = extre(XPROCESSES_PATTERN, '\[(?P<procnum>\\d*)\]') XPROCESS_PATTERN = extre(XPROCESSES_PATTERN, '\\[(?P<procnum>\\d*)\\]')
XTHREADS_PATTERN = extre(XPROCESS_PATTERN, '\.Threads') XTHREADS_PATTERN = extre(XPROCESS_PATTERN, '\\.Threads')
def find_pid_by_pattern(pattern, object, err_msg): def find_pid_by_pattern(pattern, object, err_msg):
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
@ -38,16 +40,23 @@ def find_pid_by_obj(object):
return find_pid_by_pattern(XTHREADS_PATTERN, object, "an ExdiThreadsContainer") return find_pid_by_pattern(XTHREADS_PATTERN, object, "an ExdiThreadsContainer")
class ExdiProcessContainer(TraceObject):
pass
class ExdiThreadContainer(TraceObject):
pass
@REGISTRY.method(action='refresh', display="Refresh Target Processes") @REGISTRY.method(action='refresh', display="Refresh Target Processes")
def refresh_exdi_processes(node: sch.Schema('ExdiProcessContainer')): def refresh_exdi_processes(node: ExdiProcessContainer) -> None:
"""Refresh the list of processes in the target kernel.""" """Refresh the list of processes in the target kernel."""
with commands.open_tracked_tx('Refresh Processes'): with commands.open_tracked_tx('Refresh Processes'):
exdi_commands.ghidra_trace_put_processes_exdi() exdi_commands.ghidra_trace_put_processes_exdi()
@REGISTRY.method(action='refresh', display="Refresh Process Threads") @REGISTRY.method(action='refresh', display="Refresh Process Threads")
def refresh_exdi_threads(node: sch.Schema('ExdiThreadContainer')): def refresh_exdi_threads(node: ExdiThreadContainer) -> None:
"""Refresh the list of threads in the process.""" """Refresh the list of threads in the process."""
pid = find_pid_by_obj(node) pid = find_pid_by_obj(node)
with commands.open_tracked_tx('Refresh Threads'): with commands.open_tracked_tx('Refresh Threads'):

View file

@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from bisect import bisect_left, bisect_right
from dataclasses import dataclass, field
import functools import functools
import sys import sys
import threading import threading
import time import time
import traceback import traceback
from typing import Any, Callable, Collection, Dict, Optional, TypeVar, cast
from comtypes.hresult import S_OK from comtypes.hresult import S_OK
from pybag import pydbg from pybag import pydbg
@ -26,6 +29,8 @@ from pybag.dbgeng import exception
from pybag.dbgeng.callbacks import EventHandler from pybag.dbgeng.callbacks import EventHandler
from pybag.dbgeng.idebugbreakpoint import DebugBreakpoint from pybag.dbgeng.idebugbreakpoint import DebugBreakpoint
from ghidratrace.client import Schedule
from . import commands, util from . import commands, util
from .exdi import exdi_commands from .exdi import exdi_commands
@ -33,36 +38,33 @@ from .exdi import exdi_commands
ALL_EVENTS = 0xFFFF ALL_EVENTS = 0xFFFF
class HookState(object): @dataclass(frozen=False)
__slots__ = ('installed', 'mem_catchpoint') class HookState:
installed = False
def __init__(self): mem_catchpoint = None
self.installed = False
self.mem_catchpoint = None
class ProcessState(object): @dataclass(frozen=False)
__slots__ = ('first', 'regions', 'modules', 'threads', class ProcessState:
'breaks', 'watches', 'visited', 'waiting') first = True
def __init__(self):
self.first = True
# For things we can detect changes to between stops # For things we can detect changes to between stops
self.regions = False regions = False
self.modules = False modules = False
self.threads = False threads = False
self.breaks = False breaks = False
self.watches = False watches = False
# For frames and threads that have already been synced since last stop # For frames and threads that have already been synced since last stop
self.visited = set() visited: set[Any] = field(default_factory=set)
self.waiting = False waiting = False
def record(self, description=None, snap=None): def record(self, description: Optional[str] = None,
time: Optional[Schedule] = None) -> None:
# print("RECORDING") # print("RECORDING")
first = self.first first = self.first
self.first = False self.first = False
trace = commands.STATE.require_trace()
if description is not None: if description is not None:
commands.STATE.trace.snapshot(description, snap=snap) trace.snapshot(description, time=time)
if first: if first:
if util.is_kernel(): if util.is_kernel():
commands.create_generic("Sessions") commands.create_generic("Sessions")
@ -93,46 +95,46 @@ class ProcessState(object):
self.visited.add(hashable_frame) self.visited.add(hashable_frame)
if first or self.regions: if first or self.regions:
if util.is_exdi(): if util.is_exdi():
exdi_commands.put_regions_exdi(commands.STATE) exdi_commands.put_regions_exdi(trace)
commands.put_regions() commands.put_regions()
self.regions = False self.regions = False
if first or self.modules: if first or self.modules:
if util.is_exdi(): if util.is_exdi():
exdi_commands.put_kmodules_exdi(commands.STATE) exdi_commands.put_kmodules_exdi(trace)
commands.put_modules() commands.put_modules()
self.modules = False self.modules = False
if first or self.breaks: if first or self.breaks:
commands.put_breakpoints() commands.put_breakpoints()
self.breaks = False self.breaks = False
def record_continued(self): def record_continued(self) -> None:
commands.put_processes(running=True) commands.put_processes(running=True)
commands.put_threads(running=True) commands.put_threads(running=True)
def record_exited(self, exit_code, description=None, snap=None): def record_exited(self, exit_code: int, description: Optional[str] = None,
time: Optional[Schedule] = None) -> None:
# print("RECORD_EXITED") # print("RECORD_EXITED")
trace = commands.STATE.require_trace()
if description is not None: if description is not None:
commands.STATE.trace.snapshot(description, snap=snap) trace.snapshot(description, time=time)
proc = util.selected_process() proc = util.selected_process()
ipath = commands.PROCESS_PATTERN.format(procnum=proc) ipath = commands.PROCESS_PATTERN.format(procnum=proc)
procobj = commands.STATE.trace.proxy_object_path(ipath) procobj = trace.proxy_object_path(ipath)
procobj.set_value('Exit Code', exit_code) procobj.set_value('Exit Code', exit_code)
procobj.set_value('State', 'TERMINATED') procobj.set_value('State', 'TERMINATED')
class BrkState(object): @dataclass(frozen=False)
__slots__ = ('break_loc_counts',) class BrkState:
break_loc_counts: Dict[int, int] = field(default_factory=dict)
def __init__(self): def update_brkloc_count(self, b: DebugBreakpoint, count: int) -> None:
self.break_loc_counts = {}
def update_brkloc_count(self, b, count):
self.break_loc_counts[b.GetID()] = count self.break_loc_counts[b.GetID()] = count
def get_brkloc_count(self, b): def get_brkloc_count(self, b: DebugBreakpoint) -> int:
return self.break_loc_counts.get(b.GetID(), 0) return self.break_loc_counts.get(b.GetID(), 0)
def del_brkloc_count(self, b): def del_brkloc_count(self, b: DebugBreakpoint) -> int:
if b not in self.break_loc_counts: if b not in self.break_loc_counts:
return 0 # TODO: Print a warning? return 0 # TODO: Print a warning?
count = self.break_loc_counts[b.GetID()] count = self.break_loc_counts[b.GetID()]
@ -142,35 +144,37 @@ class BrkState(object):
HOOK_STATE = HookState() HOOK_STATE = HookState()
BRK_STATE = BrkState() BRK_STATE = BrkState()
PROC_STATE = {} PROC_STATE: Dict[int, ProcessState] = {}
def log_errors(func): C = TypeVar('C', bound=Callable)
'''
Wrap a function in a try-except that prints and reraises the
exception. def log_errors(func: C) -> C:
"""Wrap a function in a try-except that prints and reraises the exception.
This is needed because pybag and/or the COM wrappers do not print This is needed because pybag and/or the COM wrappers do not print
exceptions that occur during event callbacks. exceptions that occur during event callbacks.
''' """
@functools.wraps(func) @functools.wraps(func)
def _func(*args, **kwargs): def _func(*args, **kwargs) -> Any:
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except: except:
traceback.print_exc() traceback.print_exc()
raise raise
return _func return cast(C, _func)
@log_errors @log_errors
def on_state_changed(*args): def on_state_changed(*args) -> int:
# print("ON_STATE_CHANGED") # print(f"---ON_STATE_CHANGED:{args}---")
# print(args)
if args[0] == DbgEng.DEBUG_CES_CURRENT_THREAD: if args[0] == DbgEng.DEBUG_CES_CURRENT_THREAD:
return on_thread_selected(args) on_thread_selected(args)
return S_OK
elif args[0] == DbgEng.DEBUG_CES_BREAKPOINTS: elif args[0] == DbgEng.DEBUG_CES_BREAKPOINTS:
return on_breakpoint_modified(args) on_breakpoint_modified(args)
return S_OK
elif args[0] == DbgEng.DEBUG_CES_RADIX: elif args[0] == DbgEng.DEBUG_CES_RADIX:
util.set_convenience_variable('output-radix', args[1]) util.set_convenience_variable('output-radix', args[1])
return S_OK return S_OK
@ -185,21 +189,24 @@ def on_state_changed(*args):
if proc in PROC_STATE: if proc in PROC_STATE:
# Process may have exited (so deleted) first. # Process may have exited (so deleted) first.
PROC_STATE[proc].waiting = False PROC_STATE[proc].waiting = False
trace = commands.STATE.trace trace = commands.STATE.require_trace()
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("State changed proc {}".format(proc)): with trace.open_tx("State changed proc {}".format(proc)):
commands.put_state(proc) commands.put_state(proc)
if args[1] == DbgEng.DEBUG_STATUS_BREAK: if args[1] == DbgEng.DEBUG_STATUS_BREAK:
return on_stop(args) on_stop(args)
return S_OK
elif args[1] == DbgEng.DEBUG_STATUS_NO_DEBUGGEE: elif args[1] == DbgEng.DEBUG_STATUS_NO_DEBUGGEE:
return on_exited(proc) on_exited(proc)
return S_OK
else: else:
return on_cont(args) on_cont(args)
return S_OK
return S_OK return S_OK
@log_errors @log_errors
def on_debuggee_changed(*args): def on_debuggee_changed(*args) -> int:
# print("ON_DEBUGGEE_CHANGED: args={}".format(args)) # print("ON_DEBUGGEE_CHANGED: args={}".format(args))
# sys.stdout.flush() # sys.stdout.flush()
trace = commands.STATE.trace trace = commands.STATE.trace
@ -213,20 +220,20 @@ def on_debuggee_changed(*args):
@log_errors @log_errors
def on_session_status_changed(*args): def on_session_status_changed(*args) -> None:
# print("ON_STATUS_CHANGED: args={}".format(args)) # print("ON_STATUS_CHANGED: args={}".format(args))
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
if args[0] == DbgEng.DEBUG_SESSION_ACTIVE or args[0] == DbgEng.DEBUG_SESSION_REBOOT: if args[0] == DbgEng.DEBUG_SESSION_ACTIVE or args[0] == DbgEng.DEBUG_SESSION_REBOOT:
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("New Session {}".format(util.selected_process())): with trace.open_tx("New Session {}".format(util.selected_process())):
commands.put_processes() commands.put_processes()
return DbgEng.DEBUG_STATUS_GO return DbgEng.DEBUG_STATUS_GO
@log_errors @log_errors
def on_symbol_state_changed(*args): def on_symbol_state_changed(*args) -> None:
# print("ON_SYMBOL_STATE_CHANGED") # print("ON_SYMBOL_STATE_CHANGED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -240,31 +247,31 @@ def on_symbol_state_changed(*args):
@log_errors @log_errors
def on_system_error(*args): def on_system_error(*args) -> None:
print("ON_SYSTEM_ERROR: args={}".format(args)) print("ON_SYSTEM_ERROR: args={}".format(args))
# print(hex(args[0])) # print(hex(args[0]))
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("System Error {}".format(util.selected_process())): with trace.open_tx("System Error {}".format(util.selected_process())):
commands.put_processes() commands.put_processes()
return DbgEng.DEBUG_STATUS_BREAK return DbgEng.DEBUG_STATUS_BREAK
@log_errors @log_errors
def on_new_process(*args): def on_new_process(*args) -> None:
# print("ON_NEW_PROCESS") # print("ON_NEW_PROCESS")
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("New Process {}".format(util.selected_process())): with trace.open_tx("New Process {}".format(util.selected_process())):
commands.put_processes() commands.put_processes()
return DbgEng.DEBUG_STATUS_BREAK return DbgEng.DEBUG_STATUS_BREAK
def on_process_selected(): def on_process_selected() -> None:
# print("PROCESS_SELECTED") # print("PROCESS_SELECTED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -272,14 +279,14 @@ def on_process_selected():
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Process {} selected".format(proc)): with trace.open_tx("Process {} selected".format(proc)):
PROC_STATE[proc].record() PROC_STATE[proc].record()
commands.activate() commands.activate()
@log_errors @log_errors
def on_process_deleted(*args): def on_process_deleted(*args) -> None:
# print("ON_PROCESS_DELETED") # print("ON_PROCESS_DELETED")
exit_code = args[0] exit_code = args[0]
proc = util.selected_process() proc = util.selected_process()
@ -289,14 +296,14 @@ def on_process_deleted(*args):
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Process {} deleted".format(proc)): with trace.open_tx("Process {} deleted".format(proc)):
commands.put_processes() # TODO: Could just delete the one.... commands.put_processes() # TODO: Could just delete the one....
return DbgEng.DEBUG_STATUS_BREAK return DbgEng.DEBUG_STATUS_BREAK
@log_errors @log_errors
def on_threads_changed(*args): def on_threads_changed(*args) -> None:
# print("ON_THREADS_CHANGED") # print("ON_THREADS_CHANGED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -305,7 +312,7 @@ def on_threads_changed(*args):
return DbgEng.DEBUG_STATUS_GO return DbgEng.DEBUG_STATUS_GO
def on_thread_selected(*args): def on_thread_selected(*args) -> None:
# print("THREAD_SELECTED: args={}".format(args)) # print("THREAD_SELECTED: args={}".format(args))
# sys.stdout.flush() # sys.stdout.flush()
nthrd = args[0][1] nthrd = args[0][1]
@ -315,7 +322,7 @@ def on_thread_selected(*args):
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Thread {}.{} selected".format(nproc, nthrd)): with trace.open_tx("Thread {}.{} selected".format(nproc, nthrd)):
commands.put_state(nproc) commands.put_state(nproc)
state = PROC_STATE[nproc] state = PROC_STATE[nproc]
@ -326,7 +333,7 @@ def on_thread_selected(*args):
commands.activate() commands.activate()
def on_register_changed(regnum): def on_register_changed(regnum) -> None:
# print("REGISTER_CHANGED") # print("REGISTER_CHANGED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -334,13 +341,13 @@ def on_register_changed(regnum):
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Register {} changed".format(regnum)): with trace.open_tx("Register {} changed".format(regnum)):
commands.putreg() commands.putreg()
commands.activate() commands.activate()
def on_memory_changed(space): def on_memory_changed(space) -> None:
if space != DbgEng.DEBUG_DATA_SPACE_VIRTUAL: if space != DbgEng.DEBUG_DATA_SPACE_VIRTUAL:
return return
proc = util.selected_process() proc = util.selected_process()
@ -352,12 +359,12 @@ def on_memory_changed(space):
# Not great, but invalidate the whole space # Not great, but invalidate the whole space
# UI will only re-fetch what it needs # UI will only re-fetch what it needs
# But, some observations will not be recovered # But, some observations will not be recovered
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Memory changed"): with trace.open_tx("Memory changed"):
commands.putmem_state(0, 2**64, 'unknown') commands.putmem_state(0, 2**64, 'unknown')
def on_cont(*args): def on_cont(*args) -> None:
# print("ON CONT") # print("ON CONT")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -366,56 +373,55 @@ def on_cont(*args):
if trace is None: if trace is None:
return return
state = PROC_STATE[proc] state = PROC_STATE[proc]
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Continued"): with trace.open_tx("Continued"):
state.record_continued() state.record_continued()
return DbgEng.DEBUG_STATUS_GO return DbgEng.DEBUG_STATUS_GO
def on_stop(*args): def on_stop(*args) -> None:
# print("ON STOP")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
# print("not in state")
return return
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
# print("no trace")
return return
state = PROC_STATE[proc] state = PROC_STATE[proc]
state.visited.clear() state.visited.clear()
snap = update_position() time = update_position()
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Stopped"): with trace.open_tx("Stopped"):
state.record("Stopped", snap) description = util.compute_description(time, "Stopped")
state.record(description, time)
commands.put_event_thread() commands.put_event_thread()
commands.activate() commands.activate()
def update_position(): def update_position() -> Optional[Schedule]:
"""Update the position""" """Update the position."""
cursor = util.get_cursor() posobj = util.get_object("State.DebuggerVariables.curthread.TTD.Position")
if cursor is None: if posobj is None:
return None return None
pos = cursor.get_position() pos = util.pos2split(posobj)
lpos = util.get_last_position() lpos = util.get_last_position()
rng = range(pos.major, lpos.major) if lpos is None:
if pos.major > lpos.major: return util.split2schedule(pos)
rng = range(lpos.major, pos.major)
for i in rng: minpos, maxpos = (lpos, pos) if lpos < pos else (pos, lpos)
type = util.get_event_type(i) evts = list(util.ttd.evttypes.keys())
if type == "modload" or type == "modunload": minidx = bisect_left(evts, minpos)
maxidx = bisect_right(evts, maxpos)
types = set(util.ttd.evttypes[p] for p in evts[minidx:maxidx])
if "modload" in types or "modunload" in types:
on_modules_changed() on_modules_changed()
break if "threadcreated" in types or "threadterm" in types:
for i in rng:
type = util.get_event_type(i)
if type == "threadcreated" or type == "threadterm":
on_threads_changed() on_threads_changed()
util.set_last_position(pos) util.set_last_position(pos)
return util.pos2snap(pos) return util.split2schedule(pos)
def on_exited(proc): def on_exited(proc) -> None:
# print("ON EXITED") # print("ON EXITED")
if proc not in PROC_STATE: if proc not in PROC_STATE:
# print("not in state") # print("not in state")
@ -427,14 +433,14 @@ def on_exited(proc):
state.visited.clear() state.visited.clear()
exit_code = util.GetExitCode() exit_code = util.GetExitCode()
description = "Exited with code {}".format(exit_code) description = "Exited with code {}".format(exit_code)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx(description): with trace.open_tx(description):
state.record_exited(exit_code, description) state.record_exited(exit_code, description)
commands.activate() commands.activate()
@log_errors @log_errors
def on_modules_changed(*args): def on_modules_changed(*args) -> None:
# print("ON_MODULES_CHANGED") # print("ON_MODULES_CHANGED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -443,7 +449,7 @@ def on_modules_changed(*args):
return DbgEng.DEBUG_STATUS_GO return DbgEng.DEBUG_STATUS_GO
def on_breakpoint_created(bp): def on_breakpoint_created(bp) -> None:
# print("ON_BREAKPOINT_CREATED") # print("ON_BREAKPOINT_CREATED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -453,15 +459,14 @@ def on_breakpoint_created(bp):
if trace is None: if trace is None:
return return
ibpath = commands.PROC_BREAKS_PATTERN.format(procnum=proc) ibpath = commands.PROC_BREAKS_PATTERN.format(procnum=proc)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} created".format(bp.GetId())): with trace.open_tx("Breakpoint {} created".format(bp.GetId())):
ibobj = trace.create_object(ibpath) ibobj = trace.create_object(ibpath)
# Do not use retain_values or it'll remove other locs
commands.put_single_breakpoint(bp, ibobj, proc, []) commands.put_single_breakpoint(bp, ibobj, proc, [])
ibobj.insert() ibobj.insert()
def on_breakpoint_modified(*args): def on_breakpoint_modified(*args) -> None:
# print("BREAKPOINT_MODIFIED") # print("BREAKPOINT_MODIFIED")
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
@ -481,7 +486,7 @@ def on_breakpoint_modified(*args):
return on_breakpoint_created(bp) return on_breakpoint_created(bp)
def on_breakpoint_deleted(bpid): def on_breakpoint_deleted(bpid) -> None:
proc = util.selected_process() proc = util.selected_process()
if proc not in PROC_STATE: if proc not in PROC_STATE:
return return
@ -490,25 +495,25 @@ def on_breakpoint_deleted(bpid):
if trace is None: if trace is None:
return return
bpath = commands.PROC_BREAK_PATTERN.format(procnum=proc, breaknum=bpid) bpath = commands.PROC_BREAK_PATTERN.format(procnum=proc, breaknum=bpid)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} deleted".format(bpid)): with trace.open_tx("Breakpoint {} deleted".format(bpid)):
trace.proxy_object_path(bpath).remove(tree=True) trace.proxy_object_path(bpath).remove(tree=True)
@log_errors @log_errors
def on_breakpoint_hit(*args): def on_breakpoint_hit(*args) -> None:
# print("ON_BREAKPOINT_HIT: args={}".format(args)) # print("ON_BREAKPOINT_HIT: args={}".format(args))
return DbgEng.DEBUG_STATUS_BREAK return DbgEng.DEBUG_STATUS_BREAK
@log_errors @log_errors
def on_exception(*args): def on_exception(*args) -> None:
# print("ON_EXCEPTION: args={}".format(args)) # print("ON_EXCEPTION: args={}".format(args))
return DbgEng.DEBUG_STATUS_BREAK return DbgEng.DEBUG_STATUS_BREAK
@util.dbg.eng_thread @util.dbg.eng_thread
def install_hooks(): def install_hooks() -> None:
# print("Installing hooks") # print("Installing hooks")
if HOOK_STATE.installed: if HOOK_STATE.installed:
return return
@ -551,7 +556,7 @@ def install_hooks():
@util.dbg.eng_thread @util.dbg.eng_thread
def remove_hooks(): def remove_hooks() -> None:
# print("Removing hooks") # print("Removing hooks")
if not HOOK_STATE.installed: if not HOOK_STATE.installed:
return return
@ -559,14 +564,14 @@ def remove_hooks():
util.dbg._base._reset_callbacks() util.dbg._base._reset_callbacks()
def enable_current_process(): def enable_current_process() -> None:
# print("Enable current process") # print("Enable current process")
proc = util.selected_process() proc = util.selected_process()
# print("proc: {}".format(proc)) # print("proc: {}".format(proc))
PROC_STATE[proc] = ProcessState() PROC_STATE[proc] = ProcessState()
def disable_current_process(): def disable_current_process() -> None:
proc = util.selected_process() proc = util.selected_process()
if proc in PROC_STATE: if proc in PROC_STATE:
# Silently ignore already disabled # Silently ignore already disabled
@ -574,56 +579,55 @@ def disable_current_process():
@log_errors @log_errors
def on_state_changed_async(*args): def on_state_changed_async(*args) -> None:
util.dbg.run_async(on_state_changed, *args) util.dbg.run_async(on_state_changed, *args)
@log_errors @log_errors
def on_debuggee_changed_async(*args): def on_debuggee_changed_async(*args) -> None:
util.dbg.run_async(on_debuggee_changed, *args) util.dbg.run_async(on_debuggee_changed, *args)
@log_errors @log_errors
def on_session_status_changed_async(*args): def on_session_status_changed_async(*args) -> None:
util.dbg.run_async(on_session_status_changed, *args) util.dbg.run_async(on_session_status_changed, *args)
@log_errors @log_errors
def on_symbol_state_changed_async(*args): def on_symbol_state_changed_async(*args) -> None:
util.dbg.run_async(on_symbol_state_changed, *args) util.dbg.run_async(on_symbol_state_changed, *args)
@log_errors @log_errors
def on_system_error_async(*args): def on_system_error_async(*args) -> None:
util.dbg.run_async(on_system_error, *args) util.dbg.run_async(on_system_error, *args)
@log_errors @log_errors
def on_new_process_async(*args): def on_new_process_async(*args) -> None:
util.dbg.run_async(on_new_process, *args) util.dbg.run_async(on_new_process, *args)
@log_errors @log_errors
def on_process_deleted_async(*args): def on_process_deleted_async(*args) -> None:
util.dbg.run_async(on_process_deleted, *args) util.dbg.run_async(on_process_deleted, *args)
@log_errors @log_errors
def on_threads_changed_async(*args): def on_threads_changed_async(*args) -> None:
util.dbg.run_async(on_threads_changed, *args) util.dbg.run_async(on_threads_changed, *args)
@log_errors @log_errors
def on_modules_changed_async(*args): def on_modules_changed_async(*args) -> None:
util.dbg.run_async(on_modules_changed, *args) util.dbg.run_async(on_modules_changed, *args)
@log_errors @log_errors
def on_breakpoint_hit_async(*args): def on_breakpoint_hit_async(*args) -> None:
util.dbg.run_async(on_breakpoint_hit, *args) util.dbg.run_async(on_breakpoint_hit, *args)
@log_errors @log_errors
def on_exception_async(*args): def on_exception_async(*args) -> None:
util.dbg.run_async(on_exception, *args) util.dbg.run_async(on_exception, *args)

View file

@ -18,44 +18,48 @@ from contextlib import redirect_stdout
from io import StringIO from io import StringIO
import re import re
import sys import sys
from typing import Annotated, Any, Dict, Optional
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import MethodRegistry, ParamDesc, Address, AddressRange from ghidratrace.client import (MethodRegistry, ParamDesc, Address,
AddressRange, Schedule, TraceObject)
from pybag import pydbg from pybag import pydbg
from pybag.dbgeng import core as DbgEng, exception from pybag.dbgeng import core as DbgEng, exception
from . import util, commands from . import util, commands
REGISTRY = MethodRegistry(ThreadPoolExecutor( REGISTRY = MethodRegistry(ThreadPoolExecutor(
max_workers=1, thread_name_prefix='MethodRegistry')) max_workers=1, thread_name_prefix='MethodRegistry'))
def extre(base, ext): def extre(base: re.Pattern, ext: str) -> re.Pattern:
return re.compile(base.pattern + ext) return re.compile(base.pattern + ext)
AVAILABLE_PATTERN = re.compile('Available\[(?P<pid>\\d*)\]') WATCHPOINT_PATTERN = re.compile('Watchpoints\\[(?P<watchnum>\\d*)\\]')
WATCHPOINT_PATTERN = re.compile('Watchpoints\[(?P<watchnum>\\d*)\]') BREAKPOINT_PATTERN = re.compile('Breakpoints\\[(?P<breaknum>\\d*)\\]')
BREAKPOINT_PATTERN = re.compile('Breakpoints\[(?P<breaknum>\\d*)\]') BREAK_LOC_PATTERN = extre(BREAKPOINT_PATTERN, '\\[(?P<locnum>\\d*)\\]')
BREAK_LOC_PATTERN = extre(BREAKPOINT_PATTERN, '\[(?P<locnum>\\d*)\]')
SESSIONS_PATTERN = re.compile('Sessions') SESSIONS_PATTERN = re.compile('Sessions')
SESSION_PATTERN = extre(SESSIONS_PATTERN, '\[(?P<snum>\\d*)\]') SESSION_PATTERN = extre(SESSIONS_PATTERN, '\\[(?P<snum>\\d*)\\]')
PROCESSES_PATTERN = extre(SESSION_PATTERN, '\.Processes') AVAILABLE_PATTERN = extre(SESSION_PATTERN, '\\.Available\\[(?P<pid>\\d*)\\]')
PROCESS_PATTERN = extre(PROCESSES_PATTERN, '\[(?P<procnum>\\d*)\]') PROCESSES_PATTERN = extre(SESSION_PATTERN, '\\.Processes')
PROC_BREAKS_PATTERN = extre(PROCESS_PATTERN, '\.Debug.Breakpoints') PROCESS_PATTERN = extre(PROCESSES_PATTERN, '\\[(?P<procnum>\\d*)\\]')
PROC_BREAKBPT_PATTERN = extre(PROC_BREAKS_PATTERN, '\[(?P<breaknum>\\d*)\]') PROC_BREAKS_PATTERN = extre(PROCESS_PATTERN, '\\.Debug.Breakpoints')
ENV_PATTERN = extre(PROCESS_PATTERN, '\.Environment') PROC_BREAKBPT_PATTERN = extre(PROC_BREAKS_PATTERN, '\\[(?P<breaknum>\\d*)\\]')
THREADS_PATTERN = extre(PROCESS_PATTERN, '\.Threads') ENV_PATTERN = extre(PROCESS_PATTERN, '\\.Environment')
THREAD_PATTERN = extre(THREADS_PATTERN, '\[(?P<tnum>\\d*)\]') THREADS_PATTERN = extre(PROCESS_PATTERN, '\\.Threads')
STACK_PATTERN = extre(THREAD_PATTERN, '\.Stack.Frames') THREAD_PATTERN = extre(THREADS_PATTERN, '\\[(?P<tnum>\\d*)\\]')
FRAME_PATTERN = extre(STACK_PATTERN, '\[(?P<level>\\d*)\]') STACK_PATTERN = extre(THREAD_PATTERN, '\\.Stack.Frames')
REGS_PATTERN0 = extre(THREAD_PATTERN, '.Registers') FRAME_PATTERN = extre(STACK_PATTERN, '\\[(?P<level>\\d*)\\]')
REGS_PATTERN = extre(FRAME_PATTERN, '.Registers') REGS_PATTERN0 = extre(THREAD_PATTERN, '\\.Registers')
MEMORY_PATTERN = extre(PROCESS_PATTERN, '\.Memory') REGS_PATTERN = extre(FRAME_PATTERN, '\\.Registers')
MODULES_PATTERN = extre(PROCESS_PATTERN, '\.Modules') MEMORY_PATTERN = extre(PROCESS_PATTERN, '\\.Memory')
MODULES_PATTERN = extre(PROCESS_PATTERN, '\\.Modules')
def find_availpid_by_pattern(pattern, object, err_msg): def find_availpid_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> int:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -63,17 +67,18 @@ def find_availpid_by_pattern(pattern, object, err_msg):
return pid return pid
def find_availpid_by_obj(object): def find_availpid_by_obj(object: TraceObject) -> int:
return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Available") return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Attachable")
def find_proc_by_num(id): def find_proc_by_num(id: int) -> int:
if id != util.selected_process(): if id != util.selected_process():
util.select_process(id) util.select_process(id)
return util.selected_process() return util.selected_process()
def find_proc_by_pattern(object, pattern, err_msg): def find_proc_by_pattern(object: TraceObject, pattern: re.Pattern,
err_msg: str) -> int:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -81,43 +86,39 @@ def find_proc_by_pattern(object, pattern, err_msg):
return find_proc_by_num(procnum) return find_proc_by_num(procnum)
def find_proc_by_obj(object): def find_proc_by_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, PROCESS_PATTERN, "an Process") return find_proc_by_pattern(object, PROCESS_PATTERN, "an Process")
def find_proc_by_procbreak_obj(object): def find_proc_by_procbreak_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, PROC_BREAKS_PATTERN, return find_proc_by_pattern(object, PROC_BREAKS_PATTERN,
"a BreakpointLocationContainer") "a BreakpointLocationContainer")
def find_proc_by_procwatch_obj(object): def find_proc_by_env_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, PROC_WATCHES_PATTERN,
"a WatchpointContainer")
def find_proc_by_env_obj(object):
return find_proc_by_pattern(object, ENV_PATTERN, "an Environment") return find_proc_by_pattern(object, ENV_PATTERN, "an Environment")
def find_proc_by_threads_obj(object): def find_proc_by_threads_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, THREADS_PATTERN, "a ThreadContainer") return find_proc_by_pattern(object, THREADS_PATTERN, "a ThreadContainer")
def find_proc_by_mem_obj(object): def find_proc_by_mem_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, MEMORY_PATTERN, "a Memory") return find_proc_by_pattern(object, MEMORY_PATTERN, "a Memory")
def find_proc_by_modules_obj(object): def find_proc_by_modules_obj(object: TraceObject) -> int:
return find_proc_by_pattern(object, MODULES_PATTERN, "a ModuleContainer") return find_proc_by_pattern(object, MODULES_PATTERN, "a ModuleContainer")
def find_thread_by_num(id): def find_thread_by_num(id: int) -> Optional[int]:
if id != util.selected_thread(): if id != util.selected_thread():
util.select_thread(id) util.select_thread(id)
return util.selected_thread() return util.selected_thread()
def find_thread_by_pattern(pattern, object, err_msg): def find_thread_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> Optional[int]:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -127,27 +128,29 @@ def find_thread_by_pattern(pattern, object, err_msg):
return find_thread_by_num(tnum) return find_thread_by_num(tnum)
def find_thread_by_obj(object): def find_thread_by_obj(object: TraceObject) -> Optional[int]:
return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread") return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread")
def find_thread_by_stack_obj(object): def find_thread_by_stack_obj(object: TraceObject) -> Optional[int]:
return find_thread_by_pattern(STACK_PATTERN, object, "a Stack") return find_thread_by_pattern(STACK_PATTERN, object, "a Stack")
def find_thread_by_regs_obj(object): def find_thread_by_regs_obj(object: TraceObject) -> Optional[int]:
return find_thread_by_pattern(REGS_PATTERN0, object, "a RegisterValueContainer") return find_thread_by_pattern(REGS_PATTERN0, object,
"a RegisterValueContainer")
@util.dbg.eng_thread @util.dbg.eng_thread
def find_frame_by_level(level): def find_frame_by_level(level: int) -> DbgEng._DEBUG_STACK_FRAME:
for f in util.dbg._base.backtrace_list(): for f in util.dbg._base.backtrace_list():
if f.FrameNumber == level: if f.FrameNumber == level:
return f return f
# return dbg().backtrace_list()[level] # return dbg().backtrace_list()[level]
def find_frame_by_pattern(pattern, object, err_msg): def find_frame_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> DbgEng._DEBUG_STACK_FRAME:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -159,11 +162,11 @@ def find_frame_by_pattern(pattern, object, err_msg):
return find_frame_by_level(level) return find_frame_by_level(level)
def find_frame_by_obj(object): def find_frame_by_obj(object: TraceObject) -> DbgEng._DEBUG_STACK_FRAME:
return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame") return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame")
def find_bpt_by_number(breaknum): def find_bpt_by_number(breaknum: int) -> DbgEng.IDebugBreakpoint:
try: try:
bp = dbg()._control.GetBreakpointById(breaknum) bp = dbg()._control.GetBreakpointById(breaknum)
return bp return bp
@ -171,7 +174,8 @@ def find_bpt_by_number(breaknum):
raise KeyError(f"Breakpoints[{breaknum}] does not exist") raise KeyError(f"Breakpoints[{breaknum}] does not exist")
def find_bpt_by_pattern(pattern, object, err_msg): def find_bpt_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> DbgEng.IDebugBreakpoint:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -179,14 +183,78 @@ def find_bpt_by_pattern(pattern, object, err_msg):
return find_bpt_by_number(breaknum) return find_bpt_by_number(breaknum)
def find_bpt_by_obj(object): def find_bpt_by_obj(object: TraceObject) -> DbgEng.IDebugBreakpoint:
return find_bpt_by_pattern(PROC_BREAKBPT_PATTERN, object, "a BreakpointSpec") return find_bpt_by_pattern(PROC_BREAKBPT_PATTERN, object, "a BreakpointSpec")
shared_globals = dict() shared_globals: Dict[str, Any] = dict()
@REGISTRY.method class Session(TraceObject):
pass
class AvailableContainer(TraceObject):
pass
class BreakpointContainer(TraceObject):
pass
class ProcessContainer(TraceObject):
pass
class Environment(TraceObject):
pass
class ThreadContainer(TraceObject):
pass
class Stack(TraceObject):
pass
class RegisterValueContainer(TraceObject):
pass
class Memory(TraceObject):
pass
class ModuleContainer(TraceObject):
pass
class State(TraceObject):
pass
class Process(TraceObject):
pass
class Thread(TraceObject):
pass
class StackFrame(TraceObject):
pass
class Attachable(TraceObject):
pass
class BreakpointSpec(TraceObject):
pass
@REGISTRY.method()
# @util.dbg.eng_thread # @util.dbg.eng_thread
def execute(cmd: str, to_string: bool = False): def execute(cmd: str, to_string: bool = False):
"""Execute a Python3 command or script.""" """Execute a Python3 command or script."""
@ -205,59 +273,58 @@ def execute(cmd: str, to_string: bool=False):
@REGISTRY.method(action='evaluate', display='Evaluate') @REGISTRY.method(action='evaluate', display='Evaluate')
# @util.dbg.eng_thread # @util.dbg.eng_thread
def evaluate( def evaluate(
session: sch.Schema('Session'), session: Session,
expr: ParamDesc(str, display='Expr')): expr: Annotated[str, ParamDesc(display='Expr')]) -> str:
"""Evaluate a Python3 expression.""" """Evaluate a Python3 expression."""
return str(eval(expr, shared_globals)) return str(eval(expr, shared_globals))
@REGISTRY.method(action='refresh', display="Refresh", condition=util.dbg.use_generics) @REGISTRY.method(action='refresh', display="Refresh",
def refresh_generic(node: sch.OBJECT): condition=util.dbg.use_generics)
def refresh_generic(node: TraceObject) -> None:
"""List the children for a generic node.""" """List the children for a generic node."""
with commands.open_tracked_tx('Refresh Generic'): with commands.open_tracked_tx('Refresh Generic'):
commands.ghidra_trace_put_generic(node) commands.ghidra_trace_put_generic(node)
@REGISTRY.method(action='refresh', display='Refresh Available') @REGISTRY.method(action='refresh', display='Refresh Available')
def refresh_available(node: sch.Schema('AvailableContainer')): def refresh_available(node: AvailableContainer) -> None:
"""List processes on pydbg's host system.""" """List processes on pydbg's host system."""
with commands.open_tracked_tx('Refresh Available'): with commands.open_tracked_tx('Refresh Available'):
commands.ghidra_trace_put_available() commands.ghidra_trace_put_available()
@REGISTRY.method(action='refresh', display='Refresh Breakpoints') @REGISTRY.method(action='refresh', display='Refresh Breakpoints')
def refresh_breakpoints(node: sch.Schema('BreakpointContainer')): def refresh_breakpoints(node: BreakpointContainer) -> None:
""" """Refresh the list of breakpoints (including locations for the current
Refresh the list of breakpoints (including locations for the current process)."""
process).
"""
with commands.open_tracked_tx('Refresh Breakpoints'): with commands.open_tracked_tx('Refresh Breakpoints'):
commands.ghidra_trace_put_breakpoints() commands.ghidra_trace_put_breakpoints()
@REGISTRY.method(action='refresh', display='Refresh Processes') @REGISTRY.method(action='refresh', display='Refresh Processes')
def refresh_processes(node: sch.Schema('ProcessContainer')): def refresh_processes(node: ProcessContainer) -> None:
"""Refresh the list of processes.""" """Refresh the list of processes."""
with commands.open_tracked_tx('Refresh Processes'): with commands.open_tracked_tx('Refresh Processes'):
commands.ghidra_trace_put_processes() commands.ghidra_trace_put_processes()
@REGISTRY.method(action='refresh', display='Refresh Environment') @REGISTRY.method(action='refresh', display='Refresh Environment')
def refresh_environment(node: sch.Schema('Environment')): def refresh_environment(node: Environment) -> None:
"""Refresh the environment descriptors (arch, os, endian).""" """Refresh the environment descriptors (arch, os, endian)."""
with commands.open_tracked_tx('Refresh Environment'): with commands.open_tracked_tx('Refresh Environment'):
commands.ghidra_trace_put_environment() commands.ghidra_trace_put_environment()
@REGISTRY.method(action='refresh', display='Refresh Threads') @REGISTRY.method(action='refresh', display='Refresh Threads')
def refresh_threads(node: sch.Schema('ThreadContainer')): def refresh_threads(node: ThreadContainer) -> None:
"""Refresh the list of threads in the process.""" """Refresh the list of threads in the process."""
with commands.open_tracked_tx('Refresh Threads'): with commands.open_tracked_tx('Refresh Threads'):
commands.ghidra_trace_put_threads() commands.ghidra_trace_put_threads()
@REGISTRY.method(action='refresh', display='Refresh Stack') @REGISTRY.method(action='refresh', display='Refresh Stack')
def refresh_stack(node: sch.Schema('Stack')): def refresh_stack(node: Stack) -> None:
"""Refresh the backtrace for the thread.""" """Refresh the backtrace for the thread."""
tnum = find_thread_by_stack_obj(node) tnum = find_thread_by_stack_obj(node)
util.reset_frames() util.reset_frames()
@ -268,55 +335,67 @@ def refresh_stack(node: sch.Schema('Stack')):
@REGISTRY.method(action='refresh', display='Refresh Registers') @REGISTRY.method(action='refresh', display='Refresh Registers')
def refresh_registers(node: sch.Schema('RegisterValueContainer')): def refresh_registers(node: RegisterValueContainer) -> None:
"""Refresh the register values for the selected frame""" """Refresh the register values for the selected frame."""
tnum = find_thread_by_regs_obj(node) tnum = find_thread_by_regs_obj(node)
with commands.open_tracked_tx('Refresh Registers'): with commands.open_tracked_tx('Refresh Registers'):
commands.ghidra_trace_putreg() commands.ghidra_trace_putreg()
@REGISTRY.method(action='refresh', display='Refresh Memory') @REGISTRY.method(action='refresh', display='Refresh Memory')
def refresh_mappings(node: sch.Schema('Memory')): def refresh_mappings(node: Memory) -> None:
"""Refresh the list of memory regions for the process.""" """Refresh the list of memory regions for the process."""
with commands.open_tracked_tx('Refresh Memory Regions'): with commands.open_tracked_tx('Refresh Memory Regions'):
commands.ghidra_trace_put_regions() commands.ghidra_trace_put_regions()
@REGISTRY.method(action='refresh', display='Refresh Modules') @REGISTRY.method(action='refresh', display='Refresh Modules')
def refresh_modules(node: sch.Schema('ModuleContainer')): def refresh_modules(node: ModuleContainer) -> None:
""" """Refresh the modules and sections list for the process.
Refresh the modules and sections list for the process.
This will refresh the sections for all modules, not just the selected one. This will refresh the sections for all modules, not just the
selected one.
""" """
with commands.open_tracked_tx('Refresh Modules'): with commands.open_tracked_tx('Refresh Modules'):
commands.ghidra_trace_put_modules() commands.ghidra_trace_put_modules()
@REGISTRY.method(action='refresh', display='Refresh Events') @REGISTRY.method(action='refresh', display='Refresh Events')
def refresh_events(node: sch.Schema('State')): def refresh_events(node: State) -> None:
""" """Refresh the events list for a trace."""
Refresh the events list for a trace.
"""
with commands.open_tracked_tx('Refresh Events'): with commands.open_tracked_tx('Refresh Events'):
commands.ghidra_trace_put_events(node) commands.ghidra_trace_put_events()
@util.dbg.eng_thread
def do_maybe_activate_time(time: Optional[str]) -> None:
if time is not None:
sch: Schedule = Schedule.parse(time)
dbg().cmd(f"!tt " + util.schedule2ss(sch), quiet=False)
dbg().wait()
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_process(process: sch.Schema('Process')): def activate_process(process: Process,
time: Optional[str] = None) -> None:
"""Switch to the process.""" """Switch to the process."""
do_maybe_activate_time(time)
find_proc_by_obj(process) find_proc_by_obj(process)
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_thread(thread: sch.Schema('Thread')): def activate_thread(thread: Thread,
time: Optional[str] = None) -> None:
"""Switch to the thread.""" """Switch to the thread."""
do_maybe_activate_time(time)
find_thread_by_obj(thread) find_thread_by_obj(thread)
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_frame(frame: sch.Schema('StackFrame')): def activate_frame(frame: StackFrame,
time: Optional[str] = None) -> None:
"""Select the frame.""" """Select the frame."""
do_maybe_activate_time(time)
f = find_frame_by_obj(frame) f = find_frame_by_obj(frame)
util.select_frame(f.FrameNumber) util.select_frame(f.FrameNumber)
with commands.open_tracked_tx('Refresh Stack'): with commands.open_tracked_tx('Refresh Stack'):
@ -327,7 +406,7 @@ def activate_frame(frame: sch.Schema('StackFrame')):
@REGISTRY.method(action='delete') @REGISTRY.method(action='delete')
@util.dbg.eng_thread @util.dbg.eng_thread
def remove_process(process: sch.Schema('Process')): def remove_process(process: Process) -> None:
"""Remove the process.""" """Remove the process."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().detach_proc() dbg().detach_proc()
@ -336,15 +415,15 @@ def remove_process(process: sch.Schema('Process')):
@REGISTRY.method(action='connect', display='Connect') @REGISTRY.method(action='connect', display='Connect')
@util.dbg.eng_thread @util.dbg.eng_thread
def target( def target(
session: sch.Schema('Session'), session: Session,
cmd: ParamDesc(str, display='Command')): cmd: Annotated[str, ParamDesc(display='Command')]) -> None:
"""Connect to a target machine or process.""" """Connect to a target machine or process."""
dbg().attach_kernel(cmd) dbg().attach_kernel(cmd)
@REGISTRY.method(action='attach', display='Attach') @REGISTRY.method(action='attach', display='Attach')
@util.dbg.eng_thread @util.dbg.eng_thread
def attach_obj(target: sch.Schema('Attachable')): def attach_obj(target: Attachable) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
pid = find_availpid_by_obj(target) pid = find_availpid_by_obj(target)
dbg().attach_proc(pid) dbg().attach_proc(pid)
@ -353,82 +432,90 @@ def attach_obj(target: sch.Schema('Attachable')):
@REGISTRY.method(action='attach', display='Attach by pid') @REGISTRY.method(action='attach', display='Attach by pid')
@util.dbg.eng_thread @util.dbg.eng_thread
def attach_pid( def attach_pid(
session: sch.Schema('Session'), session: Session,
pid: ParamDesc(str, display='PID')): pid: Annotated[int, ParamDesc(display='PID')]) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
dbg().attach_proc(int(pid)) dbg().attach_proc(pid)
@REGISTRY.method(action='attach', display='Attach by name') @REGISTRY.method(action='attach', display='Attach by name')
@util.dbg.eng_thread @util.dbg.eng_thread
def attach_name( def attach_name(
session: sch.Schema('Session'), session: Session,
name: ParamDesc(str, display='Name')): name: Annotated[str, ParamDesc(display='Name')]) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
dbg().attach_proc(name) dbg().attach_proc(name)
@REGISTRY.method(action='detach', display='Detach') @REGISTRY.method(action='detach', display='Detach')
@util.dbg.eng_thread @util.dbg.eng_thread
def detach(process: sch.Schema('Process')): def detach(process: Process) -> None:
"""Detach the process's target.""" """Detach the process's target."""
dbg().detach_proc() dbg().detach_proc()
@REGISTRY.method(action='launch', display='Launch') @REGISTRY.method(action='launch', display='Launch')
def launch_loader( def launch_loader(
session: sch.Schema('Session'), session: Session,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '',
""" timeout: Annotated[int, ParamDesc(display='Timeout')] = -1,
Start a native process with the given command line, stopping at the ntdll initial breakpoint. wait: Annotated[bool, ParamDesc(
""" display='Wait',
description='Perform the initial WaitForEvents')] = False) -> None:
"""Start a native process with the given command line, stopping at the
ntdll initial breakpoint."""
command = file command = file
if args != None: if args != None:
command += " " + args command += " " + args
commands.ghidra_trace_create(command=file, start_trace=False) commands.ghidra_trace_create(command=command, start_trace=False,
timeout=timeout, wait=wait)
@REGISTRY.method(action='launch', display='LaunchEx') @REGISTRY.method(action='launch', display='LaunchEx')
def launch( def launch(
session: sch.Schema('Session'), session: Session,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')='', args: Annotated[str, ParamDesc(display='Arguments')] = '',
initial_break: ParamDesc(bool, display='Initial Break')=True, initial_break: Annotated[bool, ParamDesc(
timeout: ParamDesc(int, display='Timeout')=-1): display='Initial Break')] = True,
""" timeout: Annotated[int, ParamDesc(display='Timeout')] = -1,
Run a native process with the given command line. wait: Annotated[bool, ParamDesc(
""" display='Wait',
description='Perform the initial WaitForEvents')] = False) -> None:
"""Run a native process with the given command line."""
command = file command = file
if args != None: if args != None:
command += " " + args command += " " + args
commands.ghidra_trace_create( commands.ghidra_trace_create(command=command, start_trace=False,
command, initial_break=initial_break, timeout=timeout, start_trace=False) initial_break=initial_break,
timeout=timeout, wait=wait)
@REGISTRY.method @REGISTRY.method()
@util.dbg.eng_thread @util.dbg.eng_thread
def kill(process: sch.Schema('Process')): def kill(process: Process) -> None:
"""Kill execution of the process.""" """Kill execution of the process."""
commands.ghidra_trace_kill() commands.ghidra_trace_kill()
@REGISTRY.method(action='resume', display="Go") @REGISTRY.method(action='resume', display="Go")
def go(process: sch.Schema('Process')): def go(process: Process) -> None:
"""Continue execution of the process.""" """Continue execution of the process."""
util.dbg.run_async(lambda: dbg().go()) util.dbg.run_async(lambda: dbg().go())
@REGISTRY.method(action='step_ext', display='Go (backwards)', icon='icon.debugger.resume.back', condition=util.dbg.IS_TRACE) @REGISTRY.method(action='step_ext', display='Go (backwards)',
icon='icon.debugger.resume.back', condition=util.dbg.IS_TRACE)
@util.dbg.eng_thread @util.dbg.eng_thread
def go_back(thread: sch.Schema('Process')): def go_back(process: Process) -> None:
"""Continue execution of the process backwards.""" """Continue execution of the process backwards."""
dbg().cmd("g-") dbg().cmd("g-")
dbg().wait() dbg().wait()
@REGISTRY.method @REGISTRY.method()
def interrupt(process: sch.Schema('Process')): def interrupt(process: Process) -> None:
"""Interrupt the execution of the debugged program.""" """Interrupt the execution of the debugged program."""
# SetInterrupt is reentrant, so bypass the thread checks # SetInterrupt is reentrant, so bypass the thread checks
util.dbg._protected_base._control.SetInterrupt( util.dbg._protected_base._control.SetInterrupt(
@ -436,53 +523,64 @@ def interrupt(process: sch.Schema('Process')):
@REGISTRY.method(action='step_into') @REGISTRY.method(action='step_into')
def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction exactly.""" """Step one instruction exactly."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
util.dbg.run_async(lambda: dbg().stepi(n)) util.dbg.run_async(lambda: dbg().stepi(n))
@REGISTRY.method(action='step_over') @REGISTRY.method(action='step_over')
def step_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_over(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction, but proceed through subroutine calls.""" """Step one instruction, but proceed through subroutine calls."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
util.dbg.run_async(lambda: dbg().stepo(n)) util.dbg.run_async(lambda: dbg().stepo(n))
@REGISTRY.method(action='step_ext', display='Step Into (backwards)', icon='icon.debugger.step.back.into', condition=util.dbg.IS_TRACE) @REGISTRY.method(action='step_ext', display='Step Into (backwards)',
icon='icon.debugger.step.back.into',
condition=util.dbg.IS_TRACE)
@util.dbg.eng_thread @util.dbg.eng_thread
def step_back_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_back_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction backward exactly.""" """Step one instruction backward exactly."""
dbg().cmd("t- " + str(n)) dbg().cmd("t- " + str(n))
dbg().wait() dbg().wait()
@REGISTRY.method(action='step_ext', display='Step Over (backwards)', icon='icon.debugger.step.back.over', condition=util.dbg.IS_TRACE) @REGISTRY.method(action='step_ext', display='Step Over (backwards)',
icon='icon.debugger.step.back.over',
condition=util.dbg.IS_TRACE)
@util.dbg.eng_thread @util.dbg.eng_thread
def step_back_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_back_over(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction backward, but proceed through subroutine calls.""" """Step one instruction backward, but proceed through subroutine calls."""
dbg().cmd("p- " + str(n)) dbg().cmd("p- " + str(n))
dbg().wait() dbg().wait()
@REGISTRY.method(action='step_out') @REGISTRY.method(action='step_out')
def step_out(thread: sch.Schema('Thread')): def step_out(thread: Thread) -> None:
"""Execute until the current stack frame returns.""" """Execute until the current stack frame returns."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
util.dbg.run_async(lambda: dbg().stepout()) util.dbg.run_async(lambda: dbg().stepout())
@REGISTRY.method(action='step_to', display='Step To') @REGISTRY.method(action='step_to', display='Step To')
def step_to(thread: sch.Schema('Thread'), address: Address, max=None): def step_to(thread: Thread, address: Address,
max: Optional[int] = None) -> None:
"""Continue execution up to the given address.""" """Continue execution up to the given address."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
# TODO: The address may need mapping. # TODO: The address may need mapping.
util.dbg.run_async(lambda: dbg().stepto(address.offset, max)) util.dbg.run_async(lambda: dbg().stepto(address.offset, max))
@REGISTRY.method(action='go_to_time', display='Go To (event)', condition=util.dbg.IS_TRACE) @REGISTRY.method(action='go_to_time', display='Go To (event)',
condition=util.dbg.IS_TRACE)
@util.dbg.eng_thread @util.dbg.eng_thread
def go_to_time(node: sch.Schema('State'), evt: ParamDesc(str, display='Event')): def go_to_time(node: State,
evt: Annotated[str, ParamDesc(display='Event')]) -> None:
"""Reset the trace to a specific time.""" """Reset the trace to a specific time."""
dbg().cmd("!tt " + evt) dbg().cmd("!tt " + evt)
dbg().wait() dbg().wait()
@ -490,7 +588,7 @@ def go_to_time(node: sch.Schema('State'), evt: ParamDesc(str, display='Event')):
@REGISTRY.method(action='break_sw_execute') @REGISTRY.method(action='break_sw_execute')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_address(process: sch.Schema('Process'), address: Address): def break_address(process: Process, address: Address) -> None:
"""Set a breakpoint.""" """Set a breakpoint."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().bp(expr=address.offset) dbg().bp(expr=address.offset)
@ -498,7 +596,7 @@ def break_address(process: sch.Schema('Process'), address: Address):
@REGISTRY.method(action='break_ext', display='Set Breakpoint') @REGISTRY.method(action='break_ext', display='Set Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_expression(expression: str): def break_expression(expression: str) -> None:
"""Set a breakpoint.""" """Set a breakpoint."""
# TODO: Escape? # TODO: Escape?
dbg().bp(expr=expression) dbg().bp(expr=expression)
@ -506,7 +604,7 @@ def break_expression(expression: str):
@REGISTRY.method(action='break_hw_execute') @REGISTRY.method(action='break_hw_execute')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_hw_address(process: sch.Schema('Process'), address: Address): def break_hw_address(process: Process, address: Address) -> None:
"""Set a hardware-assisted breakpoint.""" """Set a hardware-assisted breakpoint."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().ba(expr=address.offset) dbg().ba(expr=address.offset)
@ -514,44 +612,46 @@ def break_hw_address(process: sch.Schema('Process'), address: Address):
@REGISTRY.method(action='break_ext', display='Set Hardware Breakpoint') @REGISTRY.method(action='break_ext', display='Set Hardware Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_hw_expression(expression: str): def break_hw_expression(expression: str) -> None:
"""Set a hardware-assisted breakpoint.""" """Set a hardware-assisted breakpoint."""
dbg().ba(expr=expression) dbg().ba(expr=expression)
@REGISTRY.method(action='break_read') @REGISTRY.method(action='break_read')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_read_range(process: sch.Schema('Process'), range: AddressRange): def break_read_range(process: Process, range: AddressRange) -> None:
"""Set a read breakpoint.""" """Set a read breakpoint."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().ba(expr=range.min, size=range.length(), access=DbgEng.DEBUG_BREAK_READ) dbg().ba(expr=range.min, size=range.length(),
access=DbgEng.DEBUG_BREAK_READ)
@REGISTRY.method(action='break_ext', display='Set Read Breakpoint') @REGISTRY.method(action='break_ext', display='Set Read Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_read_expression(expression: str): def break_read_expression(expression: str) -> None:
"""Set a read breakpoint.""" """Set a read breakpoint."""
dbg().ba(expr=expression, access=DbgEng.DEBUG_BREAK_READ) dbg().ba(expr=expression, access=DbgEng.DEBUG_BREAK_READ)
@REGISTRY.method(action='break_write') @REGISTRY.method(action='break_write')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_write_range(process: sch.Schema('Process'), range: AddressRange): def break_write_range(process: Process, range: AddressRange) -> None:
"""Set a write breakpoint.""" """Set a write breakpoint."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().ba(expr=range.min, size=range.length(), access=DbgEng.DEBUG_BREAK_WRITE) dbg().ba(expr=range.min, size=range.length(),
access=DbgEng.DEBUG_BREAK_WRITE)
@REGISTRY.method(action='break_ext', display='Set Write Breakpoint') @REGISTRY.method(action='break_ext', display='Set Write Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_write_expression(expression: str): def break_write_expression(expression: str) -> None:
"""Set a write breakpoint.""" """Set a write breakpoint."""
dbg().ba(expr=expression, access=DbgEng.DEBUG_BREAK_WRITE) dbg().ba(expr=expression, access=DbgEng.DEBUG_BREAK_WRITE)
@REGISTRY.method(action='break_access') @REGISTRY.method(action='break_access')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_access_range(process: sch.Schema('Process'), range: AddressRange): def break_access_range(process: Process, range: AddressRange) -> None:
"""Set an access breakpoint.""" """Set an access breakpoint."""
find_proc_by_obj(process) find_proc_by_obj(process)
dbg().ba(expr=range.min, size=range.length(), dbg().ba(expr=range.min, size=range.length(),
@ -560,14 +660,15 @@ def break_access_range(process: sch.Schema('Process'), range: AddressRange):
@REGISTRY.method(action='break_ext', display='Set Access Breakpoint') @REGISTRY.method(action='break_ext', display='Set Access Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def break_access_expression(expression: str): def break_access_expression(expression: str) -> None:
"""Set an access breakpoint.""" """Set an access breakpoint."""
dbg().ba(expr=expression, access=DbgEng.DEBUG_BREAK_READ | DbgEng.DEBUG_BREAK_WRITE) dbg().ba(expr=expression,
access=DbgEng.DEBUG_BREAK_READ | DbgEng.DEBUG_BREAK_WRITE)
@REGISTRY.method(action='toggle', display='Toggle Breakpoint') @REGISTRY.method(action='toggle', display='Toggle Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool): def toggle_breakpoint(breakpoint: BreakpointSpec, enabled: bool) -> None:
"""Toggle a breakpoint.""" """Toggle a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
if enabled: if enabled:
@ -578,50 +679,53 @@ def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool):
@REGISTRY.method(action='delete', display='Delete Breakpoint') @REGISTRY.method(action='delete', display='Delete Breakpoint')
@util.dbg.eng_thread @util.dbg.eng_thread
def delete_breakpoint(breakpoint: sch.Schema('BreakpointSpec')): def delete_breakpoint(breakpoint: BreakpointSpec) -> None:
"""Delete a breakpoint.""" """Delete a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
dbg().cmd("bc {}".format(bpt.GetId())) dbg().cmd("bc {}".format(bpt.GetId()))
@REGISTRY.method @REGISTRY.method()
@util.dbg.eng_thread @util.dbg.eng_thread
def read_mem(process: sch.Schema('Process'), range: AddressRange): def read_mem(process: Process, range: AddressRange) -> None:
"""Read memory.""" """Read memory."""
# print("READ_MEM: process={}, range={}".format(process, range)) # print("READ_MEM: process={}, range={}".format(process, range))
nproc = find_proc_by_obj(process) nproc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
nproc, Address(range.space, range.min)) nproc, Address(range.space, range.min))
with commands.open_tracked_tx('Read Memory'): with commands.open_tracked_tx('Read Memory'):
result = commands.put_bytes( result = commands.put_bytes(
offset_start, offset_start + range.length() - 1, pages=True, display_result=False) offset_start, offset_start + range.length() - 1, pages=True,
display_result=False)
if result['count'] == 0: if result['count'] == 0:
commands.putmem_state( commands.putmem_state(
offset_start, offset_start + range.length() - 1, 'error') offset_start, offset_start + range.length() - 1, 'error')
@REGISTRY.method @REGISTRY.method()
@util.dbg.eng_thread @util.dbg.eng_thread
def write_mem(process: sch.Schema('Process'), address: Address, data: bytes): def write_mem(process: Process, address: Address, data: bytes) -> None:
"""Write memory.""" """Write memory."""
nproc = find_proc_by_obj(process) nproc = find_proc_by_obj(process)
offset = process.trace.memory_mapper.map_back(nproc, address) offset = process.trace.extra.required_mm().map_back(nproc, address)
dbg().write(offset, data) dbg().write(offset, data)
@REGISTRY.method @REGISTRY.method()
@util.dbg.eng_thread @util.dbg.eng_thread
def write_reg(frame: sch.Schema('StackFrame'), name: str, value: bytes): def write_reg(frame: StackFrame, name: str, value: bytes) -> None:
"""Write a register.""" """Write a register."""
util.select_frame() f = find_frame_by_obj(frame)
util.select_frame(f.FrameNumber)
nproc = pydbg.selected_process() nproc = pydbg.selected_process()
dbg().reg._set_register(name, value) dbg().reg._set_register(name, value)
@REGISTRY.method(display='Refresh Events (custom)', condition=util.dbg.IS_TRACE) @REGISTRY.method(display='Refresh Events (custom)', condition=util.dbg.IS_TRACE)
@util.dbg.eng_thread @util.dbg.eng_thread
def refresh_events_custom(node: sch.Schema('State'), cmd: ParamDesc(str, display='Cmd'), def refresh_events_custom(node: State,
prefix: ParamDesc(str, display='Prefix')="dx -r2 @$cursession.TTD"): cmd: Annotated[str, ParamDesc(display='Cmd')],
prefix: Annotated[str, ParamDesc(display='Prefix')] = "dx -r2 @$cursession.TTD") -> None:
"""Parse TTD objects generated from a LINQ command.""" """Parse TTD objects generated from a LINQ command."""
with commands.open_tracked_tx('Put Events (custom)'): with commands.open_tracked_tx('Put Events (custom)'):
commands.ghidra_trace_put_events_custom(prefix, cmd) commands.ghidra_trace_put_events_custom(prefix, cmd)

View file

@ -1,5 +1,6 @@
<context> <context>
<schema name="DbgRoot" canonical="yes" elementResync="NEVER" attributeResync="NEVER"> <schema name="DbgRoot" canonical="yes" elementResync="NEVER" attributeResync="NEVER">
<interface name="EventScope" />
<attribute name="Sessions" schema="SessionContainer" required="yes" fixed="yes" /> <attribute name="Sessions" schema="SessionContainer" required="yes" fixed="yes" />
<attribute name="Settings" schema="ANY" /> <attribute name="Settings" schema="ANY" />
<attribute name="State" schema="State" /> <attribute name="State" schema="State" />
@ -16,7 +17,6 @@
</schema> </schema>
<schema name="Session" elementResync="NEVER" attributeResync="NEVER"> <schema name="Session" elementResync="NEVER" attributeResync="NEVER">
<interface name="Activatable" /> <interface name="Activatable" />
<interface name="EventScope" />
<interface name="FocusScope" /> <interface name="FocusScope" />
<interface name="Aggregate" /> <interface name="Aggregate" />
<element schema="VOID" /> <element schema="VOID" />

View file

@ -1,5 +1,6 @@
<context> <context>
<schema name="DbgRoot" canonical="yes" elementResync="NEVER" attributeResync="NEVER"> <schema name="DbgRoot" canonical="yes" elementResync="NEVER" attributeResync="NEVER">
<interface name="EventScope" />
<attribute name="Sessions" schema="SessionContainer" required="yes" fixed="yes" /> <attribute name="Sessions" schema="SessionContainer" required="yes" fixed="yes" />
<attribute name="Settings" schema="ANY" /> <attribute name="Settings" schema="ANY" />
<attribute name="State" schema="ANY" /> <attribute name="State" schema="ANY" />
@ -16,7 +17,6 @@
</schema> </schema>
<schema name="Session" elementResync="NEVER" attributeResync="NEVER"> <schema name="Session" elementResync="NEVER" attributeResync="NEVER">
<interface name="Activatable" /> <interface name="Activatable" />
<interface name="EventScope" />
<interface name="FocusScope" /> <interface name="FocusScope" />
<interface name="Aggregate" /> <interface name="Aggregate" />
<element schema="VOID" /> <element schema="VOID" />

View file

@ -13,10 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from comtypes.automation import VARIANT
from ghidratrace.client import Schedule
from .dbgmodel.imodelobject import ModelObject
from capstone import CsInsn
from _winapi import STILL_ACTIVE
from collections import namedtuple from collections import namedtuple
from concurrent.futures import Future from concurrent.futures import Future
import concurrent.futures import concurrent.futures
from ctypes import * from ctypes import POINTER, byref, c_ulong, c_ulonglong, create_string_buffer
import functools import functools
import io import io
import os import os
@ -25,11 +31,14 @@ import re
import sys import sys
import threading import threading
import traceback import traceback
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union, cast
from comtypes import CoClass, GUID from comtypes import CoClass, GUID
import comtypes import comtypes
from comtypes.gen import DbgMod from comtypes.gen import DbgMod
from comtypes.hresult import S_OK, S_FALSE from comtypes.hresult import S_OK, S_FALSE
from ghidradbg.dbgmodel.ihostdatamodelaccess import HostDataModelAccess
from ghidradbg.dbgmodel.imodelmethod import ModelMethod
from pybag import pydbg, userdbg, kerneldbg, crashdbg from pybag import pydbg, userdbg, kerneldbg, crashdbg
from pybag.dbgeng import core as DbgEng from pybag.dbgeng import core as DbgEng
from pybag.dbgeng import exception from pybag.dbgeng import exception
@ -37,9 +46,7 @@ from pybag.dbgeng import util as DbgUtil
from pybag.dbgeng.callbacks import DbgEngCallbacks from pybag.dbgeng.callbacks import DbgEngCallbacks
from pybag.dbgeng.idebugclient import DebugClient from pybag.dbgeng.idebugclient import DebugClient
from ghidradbg.dbgmodel.ihostdatamodelaccess import HostDataModelAccess DESCRIPTION_PATTERN = '[{major:X}:{minor:X}] {type}'
from ghidradbg.dbgmodel.imodelmethod import ModelMethod
from _winapi import STILL_ACTIVE
DbgVersion = namedtuple('DbgVersion', ['full', 'name', 'dotted', 'arch']) DbgVersion = namedtuple('DbgVersion', ['full', 'name', 'dotted', 'arch'])
@ -132,23 +139,27 @@ class DebuggeeRunningException(BaseException):
pass pass
T = TypeVar('T')
class DbgExecutor(object): class DbgExecutor(object):
def __init__(self, ghidra_dbg): def __init__(self, ghidra_dbg: 'GhidraDbg') -> None:
self._ghidra_dbg = ghidra_dbg self._ghidra_dbg = ghidra_dbg
self._work_queue = queue.SimpleQueue() self._work_queue: queue.SimpleQueue = queue.SimpleQueue()
self._thread = _Worker(ghidra_dbg._new_base, self._thread = _Worker(ghidra_dbg._new_base,
self._work_queue, ghidra_dbg._dispatch_events) self._work_queue, ghidra_dbg._dispatch_events)
self._thread.start() self._thread.start()
self._executing = False self._executing = False
def submit(self, fn, / , *args, **kwargs): def submit(self, fn: Callable[..., T], /, *args, **kwargs) -> Future[T]:
f = self._submit_no_exit(fn, *args, **kwargs) f = self._submit_no_exit(fn, *args, **kwargs)
self._ghidra_dbg.exit_dispatch() self._ghidra_dbg.exit_dispatch()
return f return f
def _submit_no_exit(self, fn, / , *args, **kwargs): def _submit_no_exit(self, fn: Callable[..., T], /,
f = Future() *args, **kwargs) -> Future[T]:
f: Future[T] = Future()
if self._executing and self._ghidra_dbg.IS_REMOTE == False: if self._executing and self._ghidra_dbg.IS_REMOTE == False:
f.set_exception(DebuggeeRunningException("Debuggee is Running")) f.set_exception(DebuggeeRunningException("Debuggee is Running"))
return f return f
@ -156,7 +167,7 @@ class DbgExecutor(object):
self._work_queue.put(w) self._work_queue.put(w)
return f return f
def _clear_queue(self): def _clear_queue(self) -> None:
while True: while True:
try: try:
work_item = self._work_queue.get_nowait() work_item = self._work_queue.get_nowait()
@ -165,12 +176,12 @@ class DbgExecutor(object):
work_item.future.set_exception( work_item.future.set_exception(
DebuggeeRunningException("Debuggee is Running")) DebuggeeRunningException("Debuggee is Running"))
def _state_execute(self): def _state_execute(self) -> None:
self._executing = True self._executing = True
if self._ghidra_dbg.IS_REMOTE == False: if self._ghidra_dbg.IS_REMOTE == False:
self._clear_queue() self._clear_queue()
def _state_break(self): def _state_break(self) -> None:
self._executing = False self._executing = False
@ -201,9 +212,12 @@ class AllDbg(pydbg.DebuggerBase):
load_dump = crashdbg.CrashDbg.load_dump load_dump = crashdbg.CrashDbg.load_dump
C = TypeVar('C', bound=Callable[..., Any])
class GhidraDbg(object): class GhidraDbg(object):
def __init__(self): def __init__(self) -> None:
self._queue = DbgExecutor(self) self._queue = DbgExecutor(self)
self._thread = self._queue._thread self._thread = self._queue._thread
# Wait for the executor to be operational before getting base # Wait for the executor to be operational before getting base
@ -245,10 +259,10 @@ class GhidraDbg(object):
setattr(self, name, self.eng_thread(getattr(base, name))) setattr(self, name, self.eng_thread(getattr(base, name)))
self.IS_KERNEL = False self.IS_KERNEL = False
self.IS_EXDI = False self.IS_EXDI = False
self.IS_REMOTE = os.getenv('OPT_CONNECT_STRING') is not None self.IS_REMOTE: bool = os.getenv('OPT_CONNECT_STRING') is not None
self.IS_TRACE = os.getenv('USE_TTD') == "true" self.IS_TRACE: bool = os.getenv('USE_TTD') == "true"
def _new_base(self): def _new_base(self) -> None:
remote = os.getenv('OPT_CONNECT_STRING') remote = os.getenv('OPT_CONNECT_STRING')
if remote is not None: if remote is not None:
remote_client = DbgEng.DebugConnect(remote) remote_client = DbgEng.DebugConnect(remote)
@ -257,7 +271,7 @@ class GhidraDbg(object):
else: else:
self._protected_base = AllDbg() self._protected_base = AllDbg()
def _generate_client(self, original): def _generate_client(self, original: DebugClient) -> DebugClient:
cli = POINTER(DbgEng.IDebugClient)() cli = POINTER(DbgEng.IDebugClient)()
cliptr = POINTER(POINTER(DbgEng.IDebugClient))(cli) cliptr = POINTER(POINTER(DbgEng.IDebugClient))(cli)
hr = original.CreateClient(cliptr) hr = original.CreateClient(cliptr)
@ -265,13 +279,13 @@ class GhidraDbg(object):
return DebugClient(client=cli) return DebugClient(client=cli)
@property @property
def _base(self): def _base(self) -> AllDbg:
if threading.current_thread() is not self._thread: if threading.current_thread() is not self._thread:
raise WrongThreadException("Was {}. Want {}".format( raise WrongThreadException("Was {}. Want {}".format(
threading.current_thread(), self._thread)) threading.current_thread(), self._thread))
return self._protected_base return self._protected_base
def run(self, fn, *args, **kwargs): def run(self, fn: Callable[..., T], *args, **kwargs) -> T:
# TODO: Remove this check? # TODO: Remove this check?
if hasattr(self, '_thread') and threading.current_thread() is self._thread: if hasattr(self, '_thread') and threading.current_thread() is self._thread:
raise WrongThreadException() raise WrongThreadException()
@ -283,64 +297,60 @@ class GhidraDbg(object):
except concurrent.futures.TimeoutError: except concurrent.futures.TimeoutError:
pass pass
def run_async(self, fn, *args, **kwargs): def run_async(self, fn: Callable[..., T], *args, **kwargs) -> Future[T]:
return self._queue.submit(fn, *args, **kwargs) return self._queue.submit(fn, *args, **kwargs)
@staticmethod @staticmethod
def check_thread(func): def check_thread(func: C) -> C:
''' """For methods inside of GhidraDbg, ensure it runs on the dbgeng
For methods inside of GhidraDbg, ensure it runs on the dbgeng thread."""
thread
'''
@functools.wraps(func) @functools.wraps(func)
def _func(self, *args, **kwargs): def _func(self, *args, **kwargs) -> Any:
if threading.current_thread() is self._thread: if threading.current_thread() is self._thread:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
else: else:
return self.run(func, self, *args, **kwargs) return self.run(func, self, *args, **kwargs)
return _func return cast(C, _func)
def eng_thread(self, func): def eng_thread(self, func: C) -> C:
''' """For methods and functions outside of GhidraDbg, ensure it runs on
For methods and functions outside of GhidraDbg, ensure it this GhidraDbg's dbgeng thread."""
runs on this GhidraDbg's dbgeng thread
'''
@functools.wraps(func) @functools.wraps(func)
def _func(*args, **kwargs): def _func(*args, **kwargs) -> Any:
if threading.current_thread() is self._thread: if threading.current_thread() is self._thread:
return func(*args, **kwargs) return func(*args, **kwargs)
else: else:
return self.run(func, *args, **kwargs) return self.run(func, *args, **kwargs)
return _func return cast(C, _func)
def _ces_exec_status(self, argument): def _ces_exec_status(self, argument: int):
if argument & 0xfffffff == DbgEng.DEBUG_STATUS_BREAK: if argument & 0xfffffff == DbgEng.DEBUG_STATUS_BREAK:
self._queue._state_break() self._queue._state_break()
else: else:
self._queue._state_execute() self._queue._state_execute()
@check_thread @check_thread
def _install_stdin(self): def _install_stdin(self) -> None:
self.input = StdInputCallbacks(self) self.input = StdInputCallbacks(self)
self._base._client.SetInputCallbacks(self.input) self._base._client.SetInputCallbacks(self.input)
# Manually decorated to preserve undecorated # Manually decorated to preserve undecorated
def _dispatch_events(self, timeout=DbgEng.WAIT_INFINITE): def _dispatch_events(self, timeout: int = DbgEng.WAIT_INFINITE) -> None:
# NB: pybag's impl doesn't heed standalone # NB: pybag's impl doesn't heed standalone
self._protected_base._client.DispatchCallbacks(timeout) self._protected_base._client.DispatchCallbacks(timeout)
dispatch_events = check_thread(_dispatch_events) dispatch_events = check_thread(_dispatch_events)
# no check_thread. Must allow reentry # no check_thread. Must allow reentry
def exit_dispatch(self): def exit_dispatch(self) -> None:
self._protected_base._client.ExitDispatch() self._protected_base._client.ExitDispatch()
@check_thread @check_thread
def cmd(self, cmdline, quiet=True): def cmd(self, cmdline: str, quiet: bool = True) -> str:
# NB: pybag's impl always captures. # NB: pybag's impl always captures.
# Here, we let it print without capture if quiet is False # Here, we let it print without capture if quiet is False
if quiet: if quiet:
@ -356,20 +366,20 @@ class GhidraDbg(object):
return self._base._control.Execute(cmdline) return self._base._control.Execute(cmdline)
@check_thread @check_thread
def return_input(self, input): def return_input(self, input: str) -> None:
# TODO: Contribute fix upstream (check_hr -> check_err) # TODO: Contribute fix upstream (check_hr -> check_err)
# return self._base._control.ReturnInput(input.encode()) # return self._base._control.ReturnInput(input.encode())
hr = self._base._control._ctrl.ReturnInput(input.encode()) hr = self._base._control._ctrl.ReturnInput(input.encode())
exception.check_err(hr) exception.check_err(hr)
def interrupt(self): def interrupt(self) -> None:
# Contribute upstream? # Contribute upstream?
# NOTE: This can be called from any thread # NOTE: This can be called from any thread
self._protected_base._control.SetInterrupt( self._protected_base._control.SetInterrupt(
DbgEng.DEBUG_INTERRUPT_ACTIVE) DbgEng.DEBUG_INTERRUPT_ACTIVE)
@check_thread @check_thread
def get_current_system_id(self): def get_current_system_id(self) -> int:
# TODO: upstream? # TODO: upstream?
sys_id = c_ulong() sys_id = c_ulong()
hr = self._base._systems._sys.GetCurrentSystemId(byref(sys_id)) hr = self._base._systems._sys.GetCurrentSystemId(byref(sys_id))
@ -377,7 +387,7 @@ class GhidraDbg(object):
return sys_id.value return sys_id.value
@check_thread @check_thread
def get_prompt_text(self): def get_prompt_text(self) -> str:
# TODO: upstream? # TODO: upstream?
size = c_ulong() size = c_ulong()
hr = self._base._control._ctrl.GetPromptText(None, 0, byref(size)) hr = self._base._control._ctrl.GetPromptText(None, 0, byref(size))
@ -386,12 +396,12 @@ class GhidraDbg(object):
return prompt_buf.value.decode() return prompt_buf.value.decode()
@check_thread @check_thread
def get_actual_processor_type(self): def get_actual_processor_type(self) -> int:
return self._base._control.GetActualProcessorType() return self._base._control.GetActualProcessorType()
@property @property
@check_thread @check_thread
def pid(self): def pid(self) -> Optional[int]:
try: try:
if is_kernel(): if is_kernel():
return 0 return 0
@ -403,17 +413,12 @@ class GhidraDbg(object):
class TTDState(object): class TTDState(object):
def __init__(self): def __init__(self) -> None:
self._cursor = None self._first: Optional[Tuple[int, int]] = None
self._first = None self._last: Optional[Tuple[int, int]] = None
self._last = None self._lastpos: Optional[Tuple[int, int]] = None
self._lastmajor = None self.evttypes: Dict[Tuple[int, int], str] = {}
self._lastpos = None self.MAX_STEP: int
self.breakpoints = []
self.events = {}
self.evttypes = {}
self.starts = {}
self.stops = {}
dbg = GhidraDbg() dbg = GhidraDbg()
@ -421,16 +426,16 @@ ttd = TTDState()
@dbg.eng_thread @dbg.eng_thread
def compute_pydbg_ver(): def compute_pydbg_ver() -> DbgVersion:
pat = re.compile( pat = re.compile(
'(?P<name>.*Debugger.*) Version (?P<dotted>[\\d\\.]*) (?P<arch>\\w*)') '(?P<name>.*Debugger.*) Version (?P<dotted>[\\d\\.]*) (?P<arch>\\w*)')
blurb = dbg.cmd('version') blurb = dbg.cmd('version')
matches = [pat.match(l) for l in blurb.splitlines() if pat.match(l)] matches_opt = [pat.match(l) for l in blurb.splitlines()]
matches = [m for m in matches_opt if m is not None]
if len(matches) == 0: if len(matches) == 0:
return DbgVersion('Unknown', 'Unknown', '0', 'Unknown') return DbgVersion('Unknown', 'Unknown', '0', 'Unknown')
m = matches[0] m = matches[0]
return DbgVersion(full=m.group(), **m.groupdict()) return DbgVersion(full=m.group(), **m.groupdict())
name, dotted_and_arch = full.split(' Version ')
DBG_VERSION = compute_pydbg_ver() DBG_VERSION = compute_pydbg_ver()
@ -441,26 +446,27 @@ def get_target():
@dbg.eng_thread @dbg.eng_thread
def disassemble1(addr): def disassemble1(addr: int) -> CsInsn:
return DbgUtil.disassemble_instruction(dbg._base.bitness(), addr, dbg.read(addr, 15)) data = dbg.read(addr, 15) # type:ignore
return DbgUtil.disassemble_instruction(dbg._base.bitness(), addr, data)
def get_inst(addr): def get_inst(addr: int) -> str:
return str(disassemble1(addr)) return str(disassemble1(addr))
def get_inst_sz(addr): def get_inst_sz(addr: int) -> int:
return int(disassemble1(addr).size) return int(disassemble1(addr).size)
@dbg.eng_thread @dbg.eng_thread
def get_breakpoints(): def get_breakpoints() -> Iterable[Tuple[str, str, str, str, str]]:
ids = [bpid for bpid in dbg._base.breakpoints] ids = [bpid for bpid in dbg._base.breakpoints]
offset_set = [] offset_set: List[str] = []
expr_set = [] expr_set: List[str] = []
prot_set = [] prot_set: List[str] = []
width_set = [] width_set: List[str] = []
stat_set = [] stat_set: List[str] = []
for bpid in ids: for bpid in ids:
try: try:
bp = dbg._base._control.GetBreakpointById(bpid) bp = dbg._base._control.GetBreakpointById(bpid)
@ -496,7 +502,7 @@ def get_breakpoints():
@dbg.eng_thread @dbg.eng_thread
def selected_process(): def selected_process() -> int:
try: try:
if is_exdi(): if is_exdi():
return 0 return 0
@ -504,7 +510,8 @@ def selected_process():
do = dbg._base._systems.GetCurrentProcessDataOffset() do = dbg._base._systems.GetCurrentProcessDataOffset()
id = c_ulong() id = c_ulong()
offset = c_ulonglong(do) offset = c_ulonglong(do)
nproc = dbg._base._systems._sys.GetProcessIdByDataOffset(offset, byref(id)) nproc = dbg._base._systems._sys.GetProcessIdByDataOffset(
offset, byref(id))
return id.value return id.value
if dbg.use_generics: if dbg.use_generics:
return dbg._base._systems.GetCurrentProcessSystemId() return dbg._base._systems.GetCurrentProcessSystemId()
@ -515,7 +522,7 @@ def selected_process():
@dbg.eng_thread @dbg.eng_thread
def selected_process_space(): def selected_process_space() -> int:
try: try:
if is_exdi(): if is_exdi():
return 0 return 0
@ -528,7 +535,7 @@ def selected_process_space():
@dbg.eng_thread @dbg.eng_thread
def selected_thread(): def selected_thread() -> Optional[int]:
try: try:
if is_kernel(): if is_kernel():
return 0 return 0
@ -540,7 +547,7 @@ def selected_thread():
@dbg.eng_thread @dbg.eng_thread
def selected_frame(): def selected_frame() -> Optional[int]:
try: try:
line = dbg.cmd('.frame').strip() line = dbg.cmd('.frame').strip()
if not line: if not line:
@ -553,40 +560,47 @@ def selected_frame():
return None return None
def require(t: Optional[T]) -> T:
if t is None:
raise ValueError("Unexpected None")
return t
@dbg.eng_thread @dbg.eng_thread
def select_process(id: int): def select_process(id: int) -> None:
if is_kernel(): if is_kernel():
# TODO: Ideally this should get the data offset from the id and then call # TODO: Ideally this should get the data offset from the id and then call
# SetImplicitProcessDataOffset # SetImplicitProcessDataOffset
return return
if dbg.use_generics: if dbg.use_generics:
id = get_proc_id(id) id = require(get_proc_id(id))
return dbg._base._systems.SetCurrentProcessId(id) return dbg._base._systems.SetCurrentProcessId(id)
@dbg.eng_thread @dbg.eng_thread
def select_thread(id: int): def select_thread(id: int) -> None:
if is_kernel(): if is_kernel():
# TODO: Ideally this should get the data offset from the id and then call # TODO: Ideally this should get the data offset from the id and then call
# SetImplicitThreadDataOffset # SetImplicitThreadDataOffset
return return
if dbg.use_generics: if dbg.use_generics:
id = get_thread_id(id) id = require(get_thread_id(id))
return dbg._base._systems.SetCurrentThreadId(id) return dbg._base._systems.SetCurrentThreadId(id)
@dbg.eng_thread @dbg.eng_thread
def select_frame(id: int): def select_frame(id: int) -> str:
return dbg.cmd('.frame /c {}'.format(id)) return dbg.cmd('.frame /c {}'.format(id))
@dbg.eng_thread @dbg.eng_thread
def reset_frames(): def reset_frames() -> str:
return dbg.cmd('.cxr') return dbg.cmd('.cxr')
@dbg.eng_thread @dbg.eng_thread
def parse_and_eval(expr, type=None): def parse_and_eval(expr: Union[str, int],
type: Optional[int] = None) -> Union[int, float, bytes]:
if isinstance(expr, int): if isinstance(expr, int):
return expr return expr
# TODO: This could be contributed upstream # TODO: This could be contributed upstream
@ -617,20 +631,22 @@ def parse_and_eval(expr, type=None):
return value.u.F82Bytes return value.u.F82Bytes
if type == DbgEng.DEBUG_VALUE_FLOAT128: if type == DbgEng.DEBUG_VALUE_FLOAT128:
return value.u.F128Bytes return value.u.F128Bytes
raise ValueError(
f"Could not evaluate '{expr}' or convert result '{value}'")
@dbg.eng_thread @dbg.eng_thread
def get_pc(): def get_pc() -> int:
return dbg._base.reg.get_pc() return dbg._base.reg.get_pc()
@dbg.eng_thread @dbg.eng_thread
def get_sp(): def get_sp() -> int:
return dbg._base.reg.get_sp() return dbg._base.reg.get_sp()
@dbg.eng_thread @dbg.eng_thread
def GetProcessIdsByIndex(count=0): def GetProcessIdsByIndex(count: int = 0) -> Tuple[List[int], List[int]]:
# TODO: This could be contributed upstream? # TODO: This could be contributed upstream?
if count == 0: if count == 0:
try: try:
@ -643,11 +659,11 @@ def GetProcessIdsByIndex(count=0):
hr = dbg._base._systems._sys.GetProcessIdsByIndex( hr = dbg._base._systems._sys.GetProcessIdsByIndex(
0, count, ids, sysids) 0, count, ids, sysids)
exception.check_err(hr) exception.check_err(hr)
return (tuple(ids), tuple(sysids)) return (list(ids), list(sysids))
@dbg.eng_thread @dbg.eng_thread
def GetCurrentProcessExecutableName(): def GetCurrentProcessExecutableName() -> str:
# TODO: upstream? # TODO: upstream?
_dbg = dbg._base _dbg = dbg._base
size = c_ulong() size = c_ulong()
@ -659,17 +675,15 @@ def GetCurrentProcessExecutableName():
size = exesize size = exesize
hr = _dbg._systems._sys.GetCurrentProcessExecutableName(buffer, size, None) hr = _dbg._systems._sys.GetCurrentProcessExecutableName(buffer, size, None)
exception.check_err(hr) exception.check_err(hr)
buffer = buffer[:size.value] return buffer.value.decode()
buffer = buffer.rstrip(b'\x00')
return buffer
@dbg.eng_thread @dbg.eng_thread
def GetCurrentProcessPeb(): def GetCurrentProcessPeb() -> int:
# TODO: upstream? # TODO: upstream?
_dbg = dbg._base _dbg = dbg._base
offset = c_ulonglong() offset = c_ulonglong()
if dbg.is_kernel(): if is_kernel():
hr = _dbg._systems._sys.GetCurrentProcessDataOffset(byref(offset)) hr = _dbg._systems._sys.GetCurrentProcessDataOffset(byref(offset))
else: else:
hr = _dbg._systems._sys.GetCurrentProcessPeb(byref(offset)) hr = _dbg._systems._sys.GetCurrentProcessPeb(byref(offset))
@ -678,7 +692,7 @@ def GetCurrentProcessPeb():
@dbg.eng_thread @dbg.eng_thread
def GetCurrentThreadTeb(): def GetCurrentThreadTeb() -> int:
# TODO: upstream? # TODO: upstream?
_dbg = dbg._base _dbg = dbg._base
offset = c_ulonglong() offset = c_ulonglong()
@ -691,7 +705,7 @@ def GetCurrentThreadTeb():
@dbg.eng_thread @dbg.eng_thread
def GetExitCode(): def GetExitCode() -> int:
# TODO: upstream? # TODO: upstream?
if is_kernel(): if is_kernel():
return STILL_ACTIVE return STILL_ACTIVE
@ -704,8 +718,9 @@ def GetExitCode():
@dbg.eng_thread @dbg.eng_thread
def process_list(running=False): def process_list(running: bool = False) -> Union[
"""Get the list of all processes""" Iterable[Tuple[int, str, int]], Iterable[Tuple[int]]]:
"""Get the list of all processes."""
_dbg = dbg._base _dbg = dbg._base
ids, sysids = GetProcessIdsByIndex() ids, sysids = GetProcessIdsByIndex()
pebs = [] pebs = []
@ -725,12 +740,16 @@ def process_list(running=False):
return zip(sysids) return zip(sysids)
finally: finally:
if not running and curid is not None: if not running and curid is not None:
try:
_dbg._systems.SetCurrentProcessId(curid) _dbg._systems.SetCurrentProcessId(curid)
except Exception as e:
print(f"Couldn't restore current process: {e}")
@dbg.eng_thread @dbg.eng_thread
def thread_list(running=False): def thread_list(running: bool = False) -> Union[
"""Get the list of all threads""" Iterable[Tuple[int, int, str]], Iterable[Tuple[int]]]:
"""Get the list of all threads."""
_dbg = dbg._base _dbg = dbg._base
try: try:
ids, sysids = _dbg._systems.GetThreadIdsByIndex() ids, sysids = _dbg._systems.GetThreadIdsByIndex()
@ -758,8 +777,8 @@ def thread_list(running=False):
@dbg.eng_thread @dbg.eng_thread
def get_proc_id(pid): def get_proc_id(pid: int) -> Optional[int]:
"""Get the list of all processes""" """Get the id for the given system process id."""
# TODO: Implement GetProcessIdBySystemId and replace this logic # TODO: Implement GetProcessIdBySystemId and replace this logic
_dbg = dbg._base _dbg = dbg._base
map = {} map = {}
@ -773,7 +792,7 @@ def get_proc_id(pid):
return None return None
def full_mem(): def full_mem() -> List[DbgEng._MEMORY_BASIC_INFORMATION64]:
info = DbgEng._MEMORY_BASIC_INFORMATION64() info = DbgEng._MEMORY_BASIC_INFORMATION64()
info.BaseAddress = 0 info.BaseAddress = 0
info.RegionSize = (1 << 64) - 1 info.RegionSize = (1 << 64) - 1
@ -783,8 +802,8 @@ def full_mem():
@dbg.eng_thread @dbg.eng_thread
def get_thread_id(tid): def get_thread_id(tid: int) -> Optional[int]:
"""Get the list of all threads""" """Get the id for the given system thread id."""
# TODO: Implement GetThreadIdBySystemId and replace this logic # TODO: Implement GetThreadIdBySystemId and replace this logic
_dbg = dbg._base _dbg = dbg._base
map = {} map = {}
@ -799,8 +818,8 @@ def get_thread_id(tid):
@dbg.eng_thread @dbg.eng_thread
def open_trace_or_dump(filename): def open_trace_or_dump(filename: Union[str, bytes]) -> None:
"""Open a trace or dump file""" """Open a trace or dump file."""
_cli = dbg._base._client._cli _cli = dbg._base._client._cli
if isinstance(filename, str): if isinstance(filename, str):
filename = filename.encode() filename = filename.encode()
@ -808,7 +827,7 @@ def open_trace_or_dump(filename):
exception.check_err(hr) exception.check_err(hr)
def split_path(pathString): def split_path(pathString: str) -> List[str]:
list = [] list = []
segs = pathString.split(".") segs = pathString.split(".")
for s in segs: for s in segs:
@ -823,23 +842,23 @@ def split_path(pathString):
return list return list
def IHostDataModelAccess(): def IHostDataModelAccess() -> HostDataModelAccess:
return HostDataModelAccess( return HostDataModelAccess(dbg._base._client._cli.QueryInterface(
dbg._base._client._cli.QueryInterface(interface=DbgMod.IHostDataModelAccess)) interface=DbgMod.IHostDataModelAccess))
def IModelMethod(method_ptr): def IModelMethod(method_ptr) -> ModelMethod:
return ModelMethod( return ModelMethod(method_ptr.GetIntrinsicValue().value.QueryInterface(
method_ptr.GetIntrinsicValue().value.QueryInterface(interface=DbgMod.IModelMethod)) interface=DbgMod.IModelMethod))
@dbg.eng_thread @dbg.eng_thread
def get_object(relpath): def get_object(relpath: str) -> Optional[ModelObject]:
"""Get the list of all threads""" """Get the model object at the given path."""
_cli = dbg._base._client._cli _cli = dbg._base._client._cli
access = HostDataModelAccess(_cli.QueryInterface( access = HostDataModelAccess(_cli.QueryInterface(
interface=DbgMod.IHostDataModelAccess)) interface=DbgMod.IHostDataModelAccess))
(mgr, host) = access.GetDataModel() mgr, host = access.GetDataModel()
root = mgr.GetRootNamespace() root = mgr.GetRootNamespace()
pathstr = "Debugger" pathstr = "Debugger"
if relpath != '': if relpath != '':
@ -850,11 +869,13 @@ def get_object(relpath):
@dbg.eng_thread @dbg.eng_thread
def get_method(context_path, method_name): def get_method(context_path: str, method_name: str) -> Optional[ModelMethod]:
"""Get the list of all threads""" """Get method for the given object (path) and name."""
obj = get_object(context_path) obj = get_object(context_path)
if obj is None:
return None
keys = obj.EnumerateKeys() keys = obj.EnumerateKeys()
(k, v) = keys.GetNext() k, v = keys.GetNext()
while k is not None: while k is not None:
if k.value == method_name: if k.value == method_name:
break break
@ -865,24 +886,24 @@ def get_method(context_path, method_name):
@dbg.eng_thread @dbg.eng_thread
def get_attributes(obj): def get_attributes(obj: ModelObject) -> Dict[str, ModelObject]:
"""Get the list of attributes""" """Get the list of attributes."""
if obj is None: if obj is None:
return None return None
return obj.GetAttributes() return obj.GetAttributes()
@dbg.eng_thread @dbg.eng_thread
def get_elements(obj): def get_elements(obj: ModelObject) -> List[Tuple[int, ModelObject]]:
"""Get the list of elements""" """Get the list of elements."""
if obj is None: if obj is None:
return None return None
return obj.GetElements() return obj.GetElements()
@dbg.eng_thread @dbg.eng_thread
def get_kind(obj): def get_kind(obj) -> Optional[int]:
"""Get the kind""" """Get the kind."""
if obj is None: if obj is None:
return None return None
kind = obj.GetKind() kind = obj.GetKind()
@ -891,65 +912,66 @@ def get_kind(obj):
return obj.GetKind().value return obj.GetKind().value
@dbg.eng_thread # DOESN'T WORK YET
def get_type(obj): # @dbg.eng_thread
"""Get the type""" # def get_type(obj: ModelObject):
if obj is None: # """Get the type."""
return None # if obj is None:
return obj.GetTypeKind() # return None
# return obj.GetTypeKind()
@dbg.eng_thread @dbg.eng_thread
def get_value(obj): def get_value(obj: ModelObject) -> Any:
"""Get the value""" """Get the value."""
if obj is None: if obj is None:
return None return None
return obj.GetValue() return obj.GetValue()
@dbg.eng_thread @dbg.eng_thread
def get_intrinsic_value(obj): def get_intrinsic_value(obj: ModelObject) -> VARIANT:
"""Get the intrinsic value""" """Get the intrinsic value."""
if obj is None: if obj is None:
return None return None
return obj.GetIntrinsicValue() return obj.GetIntrinsicValue()
@dbg.eng_thread @dbg.eng_thread
def get_target_info(obj): def get_target_info(obj: ModelObject) -> ModelObject:
"""Get the target info""" """Get the target info."""
if obj is None: if obj is None:
return None return None
return obj.GetTargetInfo() return obj.GetTargetInfo()
@dbg.eng_thread @dbg.eng_thread
def get_type_info(obj): def get_type_info(obj: ModelObject) -> ModelObject:
"""Get the type info""" """Get the type info."""
if obj is None: if obj is None:
return None return None
return obj.GetTypeInfo() return obj.GetTypeInfo()
@dbg.eng_thread @dbg.eng_thread
def get_name(obj): def get_name(obj: ModelObject) -> str:
"""Get the name""" """Get the name."""
if obj is None: if obj is None:
return None return None
return obj.GetName().value return obj.GetName().value
@dbg.eng_thread @dbg.eng_thread
def to_display_string(obj): def to_display_string(obj: ModelObject) -> str:
"""Get the display string""" """Get the display string."""
if obj is None: if obj is None:
return None return None
return obj.ToDisplayString() return obj.ToDisplayString()
@dbg.eng_thread @dbg.eng_thread
def get_location(obj): def get_location(obj: ModelObject) -> Optional[str]:
"""Get the location""" """Get the location."""
if obj is None: if obj is None:
return None return None
try: try:
@ -961,10 +983,10 @@ def get_location(obj):
return None return None
conv_map = {} conv_map: Dict[str, str] = {}
def get_convenience_variable(id): def get_convenience_variable(id: str) -> Any:
if id not in conv_map: if id not in conv_map:
return "auto" return "auto"
val = conv_map[id] val = conv_map[id]
@ -973,77 +995,89 @@ def get_convenience_variable(id):
return val return val
def get_cursor(): def get_last_position() -> Optional[Tuple[int, int]]:
return ttd._cursor
def get_last_position():
return ttd._lastpos return ttd._lastpos
def set_last_position(pos): def set_last_position(pos: Tuple[int, int]) -> None:
ttd._lastpos = pos ttd._lastpos = pos
def get_event_type(rng): def get_event_type(pos: Tuple[int, int]) -> Optional[str]:
if ttd.evttypes.__contains__(rng): if ttd.evttypes.__contains__(pos):
return ttd.evttypes[rng] return ttd.evttypes[pos]
return None
def pos2snap(pos): def split2schedule(pos: Tuple[int, int]) -> Schedule:
pmap = get_attributes(pos) major, minor = pos
major = get_value(pmap["Sequence"]) return mm2schedule(major, minor)
minor = get_value(pmap["Steps"])
return mm2snap(major, minor)
def mm2snap(major, minor): def schedule2split(time: Schedule) -> Tuple[int, int]:
return time.snap, time.steps
def mm2schedule(major: int, minor: int) -> Schedule:
index = int(major) index = int(major)
if index < 0 or index >= ttd.MAX_STEP: if index < 0 or hasattr(ttd, 'MAX_STEP') and index >= ttd.MAX_STEP:
return int(ttd._lastmajor) # << 32 return Schedule(require(ttd._last)[0])
snap = index # << 32 + int(minor) if index >= 1 << 63:
return snap return Schedule((1 << 63) - 1)
return Schedule(index, minor)
def pos2split(pos): def pos2split(pos: ModelObject) -> Tuple[int, int]:
pmap = get_attributes(pos) pmap = get_attributes(pos)
major = get_value(pmap["Sequence"]) major = get_value(pmap["Sequence"])
minor = get_value(pmap["Steps"]) minor = get_value(pmap["Steps"])
return (major, minor) return (major, minor)
def set_convenience_variable(id, value): def schedule2ss(time: Schedule) -> str:
return f'{time.snap:x}:{time.steps:x}'
def compute_description(time: Optional[Schedule], fallback: str) -> str:
if time is None:
return fallback
evt_type = get_event_type(schedule2split(time))
evt_str = evt_type or fallback
return DESCRIPTION_PATTERN.format(major=time.snap, minor=time.steps,
type=evt_str)
def set_convenience_variable(id: str, value: Any) -> None:
conv_map[id] = value conv_map[id] = value
def set_kernel(value): def set_kernel(value: bool) -> None:
dbg.IS_KERNEL = value dbg.IS_KERNEL = value
def is_kernel(): def is_kernel() -> bool:
return dbg.IS_KERNEL return dbg.IS_KERNEL
def set_exdi(value): def set_exdi(value: bool) -> None:
dbg.IS_EXDI = value dbg.IS_EXDI = value
def is_exdi(): def is_exdi() -> bool:
return dbg.IS_EXDI return dbg.IS_EXDI
def set_remote(value): def set_remote(value: bool) -> None:
dbg.IS_REMOTE = value dbg.IS_REMOTE = value
def is_remote(): def is_remote() -> bool:
return dbg.IS_REMOTE return dbg.IS_REMOTE
def set_trace(value): def set_trace(value: bool) -> None:
dbg.IS_TRACE = value dbg.IS_TRACE = value
def is_trace(): def is_trace() -> bool:
return dbg.IS_TRACE return dbg.IS_TRACE

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "ghidradrgn" name = "ghidradrgn"
version = "11.3" version = "11.4"
authors = [ authors = [
{ name="Ghidra Development Team" }, { name="Ghidra Development Team" },
] ]
@ -17,7 +17,7 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"ghidratrace==11.3", "ghidratrace==11.4",
] ]
[project.urls] [project.urls]

View file

@ -21,11 +21,14 @@ import re
import socket import socket
import sys import sys
import time import time
from typing import Union
import drgn import drgn
import drgn.cli import drgn.cli
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import Client, Address, AddressRange, TraceObject from ghidratrace.client import (
Client, Address, AddressRange, TraceObject, Schedule)
from ghidratrace.display import print_tabular_values, wait
from . import util, arch, methods, hooks from . import util, arch, methods, hooks
@ -62,9 +65,10 @@ SYMBOL_PATTERN = SYMBOLS_PATTERN + SYMBOL_KEY_PATTERN
PROGRAMS = {} PROGRAMS = {}
class ErrorWithCode(Exception): class ErrorWithCode(Exception):
def __init__(self, code): def __init__(self, code: int) -> None:
self.code = code self.code = code
def __str__(self) -> str: def __str__(self) -> str:
@ -73,10 +77,10 @@ class ErrorWithCode(Exception):
class State(object): class State(object):
def __init__(self): def __init__(self) -> None:
self.reset_client() self.reset_client()
def require_client(self): def require_client(self) -> Client:
if self.client is None: if self.client is None:
raise RuntimeError("Not connected") raise RuntimeError("Not connected")
return self.client return self.client
@ -190,7 +194,8 @@ def start_trace(name):
language, compiler = arch.compute_ghidra_lcsp() language, compiler = arch.compute_ghidra_lcsp()
if name is None: if name is None:
name = 'drgn/noname' name = 'drgn/noname'
STATE.trace = STATE.client.create_trace(name, language, compiler) STATE.trace = STATE.client.create_trace(
name, language, compiler, extra=None)
# TODO: Is adding an attribute like this recommended in Python? # TODO: Is adding an attribute like this recommended in Python?
STATE.trace.memory_mapper = arch.compute_memory_mapper(language) STATE.trace.memory_mapper = arch.compute_memory_mapper(language)
STATE.trace.register_mapper = arch.compute_register_mapper(language) STATE.trace.register_mapper = arch.compute_register_mapper(language)
@ -230,7 +235,6 @@ def ghidra_trace_restart(name=None):
start_trace(name) start_trace(name)
def ghidra_trace_create(start_trace=True): def ghidra_trace_create(start_trace=True):
""" """
Create a session. Create a session.
@ -343,7 +347,7 @@ def ghidra_trace_save():
STATE.require_trace().save() STATE.require_trace().save()
def ghidra_trace_new_snap(description=None): def ghidra_trace_new_snap(description=None, time: Union[int, Schedule, None] = None):
""" """
Create a new snapshot Create a new snapshot
@ -351,25 +355,16 @@ def ghidra_trace_new_snap(description=None):
""" """
description = str(description) description = str(description)
if isinstance(time, int):
time = Schedule(time)
STATE.require_tx() STATE.require_tx()
return {'snap': STATE.require_trace().snapshot(description)} return {'snap': STATE.require_trace().snapshot(description, time=time)}
def ghidra_trace_set_snap(snap=None):
"""
Go to a snapshot
Subsequent modifications to machine state will affect the given snapshot.
"""
STATE.require_trace().set_snap(int(snap))
def quantize_pages(start, end): def quantize_pages(start, end):
return (start // PAGE_SIZE * PAGE_SIZE, (end + PAGE_SIZE - 1) // PAGE_SIZE * PAGE_SIZE) return (start // PAGE_SIZE * PAGE_SIZE, (end + PAGE_SIZE - 1) // PAGE_SIZE * PAGE_SIZE)
def put_bytes(start, end, pages, display_result): def put_bytes(start, end, pages, display_result):
trace = STATE.require_trace() trace = STATE.require_trace()
if pages: if pages:
@ -560,7 +555,8 @@ def put_object(lpath, key, value):
lobj.set_value('Kind', str(vkind)) lobj.set_value('Kind', str(vkind))
lobj.set_value('Type', str(vtype)) lobj.set_value('Type', str(vtype))
else: else:
lobj.set_value('_display', '{} [{}:{}]'.format(key, type(value), str(value))) lobj.set_value('_display', '{} [{}:{}]'.format(
key, type(value), str(value)))
lobj.set_value('Value', str(value)) lobj.set_value('Value', str(value))
return return
@ -581,7 +577,8 @@ def put_object(lpath, key, value):
lobj.set_value('Address', addr) lobj.set_value('Address', addr)
return return
if vkind is drgn.TypeKind.TYPEDEF: if vkind is drgn.TypeKind.TYPEDEF:
lobj.set_value('_display', '{} [{}:{}]'.format(key, type(vvalue), str(vvalue))) lobj.set_value('_display', '{} [{}:{}]'.format(
key, type(vvalue), str(vvalue)))
lobj.set_value('Value', str(vvalue)) lobj.set_value('Value', str(vvalue))
return return
if vkind is drgn.TypeKind.UNION or vkind is drgn.TypeKind.STRUCT or vkind is drgn.TypeKind.CLASS: if vkind is drgn.TypeKind.UNION or vkind is drgn.TypeKind.STRUCT or vkind is drgn.TypeKind.CLASS:
@ -589,7 +586,8 @@ def put_object(lpath, key, value):
put_object(lobj.path+'.Members', k, vvalue[k]) put_object(lobj.path+'.Members', k, vvalue[k])
return return
lobj.set_value('_display', '{} [{}:{}]'.format(key, type(vvalue), str(vvalue))) lobj.set_value('_display', '{} [{}:{}]'.format(
key, type(vvalue), str(vvalue)))
lobj.set_value('Value', str(vvalue)) lobj.set_value('Value', str(vvalue))
@ -631,7 +629,6 @@ def ghidra_trace_put_locals():
put_locals() put_locals()
def ghidra_trace_create_obj(path=None): def ghidra_trace_create_obj(path=None):
""" """
Create an object in the Ghidra trace. Create an object in the Ghidra trace.
@ -798,69 +795,14 @@ def ghidra_trace_get_obj(path):
print("{}\t{}".format(object.id, object.path)) print("{}\t{}".format(object.id, object.path))
class TableColumn(object):
def __init__(self, head):
self.head = head
self.contents = [head]
self.is_last = False
def add_data(self, data):
self.contents.append(str(data))
def finish(self):
self.width = max(len(d) for d in self.contents) + 1
def print_cell(self, i):
print(
self.contents[i] if self.is_last else self.contents[i].ljust(self.width), end='')
class Tabular(object):
def __init__(self, heads):
self.columns = [TableColumn(h) for h in heads]
self.columns[-1].is_last = True
self.num_rows = 1
def add_row(self, datas):
for c, d in zip(self.columns, datas):
c.add_data(d)
self.num_rows += 1
def print_table(self):
for c in self.columns:
c.finish()
for rn in range(self.num_rows):
for c in self.columns:
c.print_cell(rn)
print('')
def val_repr(value):
if isinstance(value, TraceObject):
return value.path
elif isinstance(value, Address):
return '{}:{:08x}'.format(value.space, value.offset)
return repr(value)
def print_values(values):
table = Tabular(['Parent', 'Key', 'Span', 'Value', 'Type'])
for v in values:
table.add_row(
[v.parent.path, v.key, v.span, val_repr(v.value), v.schema])
table.print_table()
def ghidra_trace_get_values(pattern): def ghidra_trace_get_values(pattern):
""" """
List all values matching a given path pattern. List all values matching a given path pattern.
""" """
trace = STATE.require_trace() trace = STATE.require_trace()
values = trace.get_values(pattern) values = wait(trace.get_values(pattern))
print_values(values) print_tabular_values(values, print)
def ghidra_trace_get_values_rng(address, length): def ghidra_trace_get_values_rng(address, length):
@ -873,8 +815,8 @@ def ghidra_trace_get_values_rng(address, length):
nproc = util.selected_process() nproc = util.selected_process()
base, addr = trace.memory_mapper.map(nproc, start) base, addr = trace.memory_mapper.map(nproc, start)
# Do not create the space. We're querying. No tx. # Do not create the space. We're querying. No tx.
values = trace.get_values_intersecting(addr.extend(end - start)) values = wait(trace.get_values_intersecting(addr.extend(end - start)))
print_values(values) print_tabular_values(values, print)
def activate(path=None): def activate(path=None):
@ -892,7 +834,8 @@ def activate(path=None):
if frame is None: if frame is None:
path = THREAD_PATTERN.format(procnum=nproc, tnum=nthrd) path = THREAD_PATTERN.format(procnum=nproc, tnum=nthrd)
else: else:
path = FRAME_PATTERN.format(procnum=nproc, tnum=nthrd, level=frame) path = FRAME_PATTERN.format(
procnum=nproc, tnum=nthrd, level=frame)
trace.proxy_object_path(path).activate() trace.proxy_object_path(path).activate()
@ -951,7 +894,6 @@ def ghidra_trace_put_processes():
put_processes() put_processes()
def put_environment(): def put_environment():
nproc = util.selected_process() nproc = util.selected_process()
epath = ENV_PATTERN.format(procnum=nproc) epath = ENV_PATTERN.format(procnum=nproc)
@ -1009,7 +951,6 @@ if hasattr(drgn, 'RelocatableModule'):
STATE.trace.proxy_object_path( STATE.trace.proxy_object_path(
MEMORY_PATTERN.format(procnum=nproc)).retain_values(keys) MEMORY_PATTERN.format(procnum=nproc)).retain_values(keys)
def ghidra_trace_put_regions(): def ghidra_trace_put_regions():
""" """
Read the memory map, if applicable, and write to the trace's Regions Read the memory map, if applicable, and write to the trace's Regions
@ -1020,7 +961,6 @@ if hasattr(drgn, 'RelocatableModule'):
put_regions() put_regions()
# Detect whether this is supported before defining the command # Detect whether this is supported before defining the command
if hasattr(drgn, 'RelocatableModule'): if hasattr(drgn, 'RelocatableModule'):
def put_modules(): def put_modules():
@ -1066,7 +1006,6 @@ if hasattr(drgn, 'RelocatableModule'):
STATE.trace.proxy_object_path(MODULES_PATTERN.format( STATE.trace.proxy_object_path(MODULES_PATTERN.format(
procnum=nproc)).retain_values(mod_keys) procnum=nproc)).retain_values(mod_keys)
def ghidra_trace_put_modules(): def ghidra_trace_put_modules():
""" """
Gather object files, if applicable, and write to the trace's Modules Gather object files, if applicable, and write to the trace's Modules
@ -1090,9 +1029,11 @@ if hasattr(drgn, 'RelocatableModule'):
maddr = hex(m.address_range[0]) maddr = hex(m.address_range[0])
for key in sections.keys(): for key in sections.keys():
value = sections[key] value = sections[key]
spath = SECTION_PATTERN.format(procnum=nproc, modpath=maddr, secname=key) spath = SECTION_PATTERN.format(
procnum=nproc, modpath=maddr, secname=key)
sobj = STATE.trace.create_object(spath) sobj = STATE.trace.create_object(spath)
section_keys.append(SECTION_KEY_PATTERN.format(modpath=maddr, secname=key)) section_keys.append(SECTION_KEY_PATTERN.format(
modpath=maddr, secname=key))
base_base, base_addr = mapper.map(nproc, value) base_base, base_addr = mapper.map(nproc, value)
if base_base != base_addr.space: if base_base != base_addr.space:
STATE.trace.create_overlay_space(base_base, base_addr.space) STATE.trace.create_overlay_space(base_base, base_addr.space)
@ -1104,7 +1045,6 @@ if hasattr(drgn, 'RelocatableModule'):
procnum=nproc, modpath=maddr)).retain_values(section_keys) procnum=nproc, modpath=maddr)).retain_values(section_keys)
def convert_state(t): def convert_state(t):
if t.IsSuspended(): if t.IsSuspended():
return 'SUSPENDED' return 'SUSPENDED'
@ -1131,7 +1071,8 @@ def put_threads(running=False):
short = '{:d} {:x}:{:x}'.format(i, nproc, nthrd) short = '{:d} {:x}:{:x}'.format(i, nproc, nthrd)
tobj.set_value('_short_display', short) tobj.set_value('_short_display', short)
if hasattr(t, 'name'): if hasattr(t, 'name'):
tobj.set_value('_display', '{:x} {:x}:{:x} {}'.format(i, nproc, nthrd, t.name)) tobj.set_value('_display', '{:x} {:x}:{:x} {}'.format(
i, nproc, nthrd, t.name))
tobj.set_value('Name', t.name) tobj.set_value('Name', t.name)
else: else:
tobj.set_value('_display', short) tobj.set_value('_display', short)
@ -1153,7 +1094,6 @@ def ghidra_trace_put_threads():
put_threads() put_threads()
def put_frames(): def put_frames():
nproc = util.selected_process() nproc = util.selected_process()
if nproc < 0: if nproc < 0:
@ -1186,7 +1126,8 @@ def put_frames():
fobj.set_value('PC', offset_inst) fobj.set_value('PC', offset_inst)
fobj.set_value('SP', offset_stack) fobj.set_value('SP', offset_stack)
fobj.set_value('Name', f.name) fobj.set_value('Name', f.name)
fobj.set_value('_display', "#{} {} {}".format(i, hex(offset_inst.offset), f.name)) fobj.set_value('_display', "#{} {} {}".format(
i, hex(offset_inst.offset), f.name))
fobj.insert() fobj.insert()
aobj = STATE.trace.create_object(fpath+".Attributes") aobj = STATE.trace.create_object(fpath+".Attributes")
aobj.insert() aobj.insert()
@ -1220,7 +1161,6 @@ def ghidra_trace_put_frames():
put_frames() put_frames()
def put_symbols(pattern=None): def put_symbols(pattern=None):
nproc = util.selected_process() nproc = util.selected_process()
if nproc is None: if nproc is None:
@ -1264,7 +1204,6 @@ def ghidra_trace_put_symbols():
put_symbols() put_symbols()
def set_display(key, value, obj): def set_display(key, value, obj):
kind = util.get_kind(value) kind = util.get_kind(value)
vstr = util.get_value(value) vstr = util.get_value(value)
@ -1404,4 +1343,3 @@ def get_sp():
frame = stack[util.selected_frame()] frame = stack[util.selected_frame()]
return frame.sp return frame.sp

View file

@ -19,12 +19,14 @@ from io import StringIO
import re import re
import sys import sys
import time import time
from typing import Annotated, Any, Dict, Optional
import drgn import drgn
import drgn.cli import drgn.cli
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import MethodRegistry, ParamDesc, Address, AddressRange from ghidratrace.client import (
MethodRegistry, ParamDesc, Address, AddressRange, TraceObject)
from . import util, commands, hooks from . import util, commands, hooks
@ -185,11 +187,59 @@ def find_module_by_obj(object):
return find_module_by_pattern(MODULE_PATTERN, object, "a Module") return find_module_by_pattern(MODULE_PATTERN, object, "a Module")
shared_globals = dict() shared_globals: Dict[str, Any] = dict()
@REGISTRY.method class Environment(TraceObject):
def execute(cmd: str, to_string: bool=False): pass
class LocalsContainer(TraceObject):
pass
class Memory(TraceObject):
pass
class ModuleContainer(TraceObject):
pass
class Process(TraceObject):
pass
class ProcessContainer(TraceObject):
pass
class Stack(TraceObject):
pass
class RegisterValueContainer(TraceObject):
pass
class StackFrame(TraceObject):
pass
class SymbolContainer(TraceObject):
pass
class Thread(TraceObject):
pass
class ThreadContainer(TraceObject):
pass
@REGISTRY.method()
def execute(cmd: str, to_string: bool = False) -> Optional[str]:
"""Execute a Python3 command or script.""" """Execute a Python3 command or script."""
if to_string: if to_string:
data = StringIO() data = StringIO()
@ -198,31 +248,32 @@ def execute(cmd: str, to_string: bool=False):
return data.getvalue() return data.getvalue()
else: else:
exec(cmd, shared_globals) exec(cmd, shared_globals)
return None
@REGISTRY.method(action='refresh', display='Refresh Processes') @REGISTRY.method(action='refresh', display='Refresh Processes')
def refresh_processes(node: sch.Schema('ProcessContainer')): def refresh_processes(node: ProcessContainer) -> None:
"""Refresh the list of processes.""" """Refresh the list of processes."""
with commands.open_tracked_tx('Refresh Processes'): with commands.open_tracked_tx('Refresh Processes'):
commands.ghidra_trace_put_processes() commands.ghidra_trace_put_processes()
@REGISTRY.method(action='refresh', display='Refresh Environment') @REGISTRY.method(action='refresh', display='Refresh Environment')
def refresh_environment(node: sch.Schema('Environment')): def refresh_environment(node: Environment) -> None:
"""Refresh the environment descriptors (arch, os, endian).""" """Refresh the environment descriptors (arch, os, endian)."""
with commands.open_tracked_tx('Refresh Environment'): with commands.open_tracked_tx('Refresh Environment'):
commands.ghidra_trace_put_environment() commands.ghidra_trace_put_environment()
@REGISTRY.method(action='refresh', display='Refresh Threads') @REGISTRY.method(action='refresh', display='Refresh Threads')
def refresh_threads(node: sch.Schema('ThreadContainer')): def refresh_threads(node: ThreadContainer) -> None:
"""Refresh the list of threads in the process.""" """Refresh the list of threads in the process."""
with commands.open_tracked_tx('Refresh Threads'): with commands.open_tracked_tx('Refresh Threads'):
commands.ghidra_trace_put_threads() commands.ghidra_trace_put_threads()
# @REGISTRY.method(action='refresh', display='Refresh Symbols') # @REGISTRY.method(action='refresh', display='Refresh Symbols')
# def refresh_symbols(node: sch.Schema('SymbolContainer')): # def refresh_symbols(node: SymbolContainer) -> None:
# """Refresh the list of symbols in the process.""" # """Refresh the list of symbols in the process."""
# with commands.open_tracked_tx('Refresh Symbols'): # with commands.open_tracked_tx('Refresh Symbols'):
# commands.ghidra_trace_put_symbols() # commands.ghidra_trace_put_symbols()
@ -230,17 +281,15 @@ def refresh_threads(node: sch.Schema('ThreadContainer')):
@REGISTRY.method(action='show_symbol', display='Retrieve Symbols') @REGISTRY.method(action='show_symbol', display='Retrieve Symbols')
def retrieve_symbols( def retrieve_symbols(
session: sch.Schema('SymbolContainer'), conainer: SymbolContainer,
pattern: ParamDesc(str, display='Pattern')): pattern: Annotated[str, ParamDesc(display='Pattern')]) -> None:
""" """Load the symbol set matching the pattern."""
Load the symbol set matching the pattern.
"""
with commands.open_tracked_tx('Retrieve Symbols'): with commands.open_tracked_tx('Retrieve Symbols'):
commands.put_symbols(pattern) commands.put_symbols(pattern)
@REGISTRY.method(action='refresh', display='Refresh Stack') @REGISTRY.method(action='refresh', display='Refresh Stack')
def refresh_stack(node: sch.Schema('Stack')): def refresh_stack(node: Stack) -> None:
"""Refresh the backtrace for the thread.""" """Refresh the backtrace for the thread."""
tnum = find_thread_by_stack_obj(node) tnum = find_thread_by_stack_obj(node)
with commands.open_tracked_tx('Refresh Stack'): with commands.open_tracked_tx('Refresh Stack'):
@ -248,53 +297,51 @@ def refresh_stack(node: sch.Schema('Stack')):
@REGISTRY.method(action='refresh', display='Refresh Registers') @REGISTRY.method(action='refresh', display='Refresh Registers')
def refresh_registers(node: sch.Schema('RegisterValueContainer')): def refresh_registers(node: RegisterValueContainer) -> None:
"""Refresh the register values for the selected frame""" """Refresh the register values for the selected frame."""
level = find_frame_by_regs_obj(node) level = find_frame_by_regs_obj(node)
with commands.open_tracked_tx('Refresh Registers'): with commands.open_tracked_tx('Refresh Registers'):
commands.ghidra_trace_putreg() commands.ghidra_trace_putreg()
@REGISTRY.method(action='refresh', display='Refresh Locals') @REGISTRY.method(action='refresh', display='Refresh Locals')
def refresh_locals(node: sch.Schema('LocalsContainer')): def refresh_locals(node: LocalsContainer) -> None:
"""Refresh the local values for the selected frame""" """Refresh the local values for the selected frame."""
level = find_frame_by_locals_obj(node) level = find_frame_by_locals_obj(node)
with commands.open_tracked_tx('Refresh Registers'): with commands.open_tracked_tx('Refresh Registers'):
commands.ghidra_trace_put_locals() commands.ghidra_trace_put_locals()
if hasattr(drgn, 'RelocatableModule'): @REGISTRY.method(action='refresh', display='Refresh Memory',
@REGISTRY.method(action='refresh', display='Refresh Memory') condition=hasattr(drgn, 'RelocatableModule'))
def refresh_mappings(node: sch.Schema('Memory')): def refresh_mappings(node: Memory) -> None:
"""Refresh the list of memory regions for the process.""" """Refresh the list of memory regions for the process."""
with commands.open_tracked_tx('Refresh Memory Regions'): with commands.open_tracked_tx('Refresh Memory Regions'):
commands.ghidra_trace_put_regions() commands.ghidra_trace_put_regions()
if hasattr(drgn, 'RelocatableModule'): @REGISTRY.method(action='refresh', display='Refresh Modules',
@REGISTRY.method(action='refresh', display='Refresh Modules') condition=hasattr(drgn, 'RelocatableModule'))
def refresh_modules(node: sch.Schema('ModuleContainer')): def refresh_modules(node: ModuleContainer) -> None:
""" """Refresh the modules list for the process."""
Refresh the modules list for the process.
"""
with commands.open_tracked_tx('Refresh Modules'): with commands.open_tracked_tx('Refresh Modules'):
commands.ghidra_trace_put_modules() commands.ghidra_trace_put_modules()
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_process(process: sch.Schema('Process')): def activate_process(process: Process) -> None:
"""Switch to the process.""" """Switch to the process."""
find_proc_by_obj(process) find_proc_by_obj(process)
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_thread(thread: sch.Schema('Thread')): def activate_thread(thread: Thread) -> None:
"""Switch to the thread.""" """Switch to the thread."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
@REGISTRY.method(action='activate') @REGISTRY.method(action='activate')
def activate_frame(frame: sch.Schema('StackFrame')): def activate_frame(frame: StackFrame) -> None:
"""Select the frame.""" """Select the frame."""
i, f = find_frame_by_obj(frame) i, f = find_frame_by_obj(frame)
util.select_frame(i) util.select_frame(i)
@ -304,12 +351,12 @@ def activate_frame(frame: sch.Schema('StackFrame')):
commands.ghidra_trace_putreg() commands.ghidra_trace_putreg()
@REGISTRY.method @REGISTRY.method()
def read_mem(process: sch.Schema('Process'), range: AddressRange): def read_mem(process: Process, range: AddressRange) -> None:
"""Read memory.""" """Read memory."""
# print("READ_MEM: process={}, range={}".format(process, range)) # print("READ_MEM: process={}, range={}".format(process, range))
nproc = find_proc_by_obj(process) nproc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
nproc, Address(range.space, range.min)) nproc, Address(range.space, range.min))
with commands.open_tracked_tx('Read Memory'): with commands.open_tracked_tx('Read Memory'):
result = commands.put_bytes( result = commands.put_bytes(
@ -320,9 +367,8 @@ def read_mem(process: sch.Schema('Process'), range: AddressRange):
@REGISTRY.method(action='attach', display='Attach by pid') @REGISTRY.method(action='attach', display='Attach by pid')
def attach_pid( def attach_pid(processes: ProcessContainer,
processes: sch.Schema('ProcessContainer'), pid: Annotated[str, ParamDesc(display='PID')]) -> None:
pid: ParamDesc(str, display='PID')):
"""Attach the process to the given target.""" """Attach the process to the given target."""
prog = drgn.Program() prog = drgn.Program()
prog.set_pid(int(pid)) prog.set_pid(int(pid))
@ -341,9 +387,8 @@ def attach_pid(
@REGISTRY.method(action='attach', display='Attach core dump') @REGISTRY.method(action='attach', display='Attach core dump')
def attach_core( def attach_core(processes: ProcessContainer,
processes: sch.Schema('ProcessContainer'), core: Annotated[str, ParamDesc(display='Core dump')]) -> None:
core: ParamDesc(str, display='Core dump')):
"""Attach the process to the given target.""" """Attach the process to the given target."""
prog = drgn.Program() prog = drgn.Program()
prog.set_core_dump(core) prog.set_core_dump(core)
@ -361,7 +406,8 @@ def attach_core(
@REGISTRY.method(action='step_into') @REGISTRY.method(action='step_into')
def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction exactly.""" """Step one instruction exactly."""
find_thread_by_obj(thread) find_thread_by_obj(thread)
time.sleep(1) time.sleep(1)
@ -369,22 +415,20 @@ def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1):
# @REGISTRY.method # @REGISTRY.method
# def kill(process: sch.Schema('Process')): # def kill(process: Process) -> None:
# """Kill execution of the process.""" # """Kill execution of the process."""
# commands.ghidra_trace_kill() # commands.ghidra_trace_kill()
# @REGISTRY.method(action='resume') # @REGISTRY.method(action='resume')
# def go(process: sch.Schema('Process')): # def go(process: Process) -> None:
# """Continue execution of the process.""" # """Continue execution of the process."""
# util.dbg.run_async(lambda: dbg().go()) # util.dbg.run_async(lambda: dbg().go())
# @REGISTRY.method # @REGISTRY.method
# def interrupt(process: sch.Schema('Process')): # def interrupt(process: Process) -> None:
# """Interrupt the execution of the debugged program.""" # """Interrupt the execution of the debugged program."""
# # SetInterrupt is reentrant, so bypass the thread checks # # SetInterrupt is reentrant, so bypass the thread checks
# util.dbg._protected_base._control.SetInterrupt( # util.dbg._protected_base._control.SetInterrupt(
# DbgEng.DEBUG_INTERRUPT_ACTIVE) # DbgEng.DEBUG_INTERRUPT_ACTIVE)

View file

@ -18,5 +18,6 @@ src/main/py/LICENSE||GHIDRA||||END|
src/main/py/MANIFEST.in||GHIDRA||||END| src/main/py/MANIFEST.in||GHIDRA||||END|
src/main/py/README.md||GHIDRA||||END| src/main/py/README.md||GHIDRA||||END|
src/main/py/pyproject.toml||GHIDRA||||END| src/main/py/pyproject.toml||GHIDRA||||END|
src/main/py/src/ghidragdb/py.typed||GHIDRA||||END|
src/main/py/src/ghidragdb/schema.xml||GHIDRA||||END| src/main/py/src/ghidragdb/schema.xml||GHIDRA||||END|
src/main/py/tests/EMPTY||GHIDRA||||END| src/main/py/tests/EMPTY||GHIDRA||||END|

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "ghidragdb" name = "ghidragdb"
version = "11.3" version = "11.4"
authors = [ authors = [
{ name="Ghidra Development Team" }, { name="Ghidra Development Team" },
] ]
@ -17,9 +17,12 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"ghidratrace==11.3", "ghidratrace==11.4",
] ]
[project.urls] [project.urls]
"Homepage" = "https://github.com/NationalSecurityAgency/ghidra" "Homepage" = "https://github.com/NationalSecurityAgency/ghidra"
"Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues" "Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues"
[tool.setuptools.package-data]
ghidragdb = ["py.typed"]

View file

@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from ghidratrace.client import Address, RegVal from ghidratrace.client import Address, RegVal
import gdb import gdb
# NOTE: This map is derived from the ldefs using a script # NOTE: This map is derived from the ldefs using a script
# i386 is hand-patched # i386 is hand-patched
language_map = { language_map: Dict[str, List[str]] = {
'aarch64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon', 'AARCH64:LE:64:v8A'], 'aarch64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon',
'aarch64:ilp32': ['AARCH64:BE:32:ilp32', 'AARCH64:LE:32:ilp32', 'AARCH64:LE:64:AppleSilicon'], 'AARCH64:LE:64:v8A'],
'aarch64:ilp32': ['AARCH64:BE:32:ilp32', 'AARCH64:LE:32:ilp32',
'AARCH64:LE:64:AppleSilicon'],
'arm': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'], 'arm': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'],
'arm_any': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'], 'arm_any': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8',
'ARM:LE:32:v8T'],
'armv2': ['ARM:BE:32:v4', 'ARM:LE:32:v4'], 'armv2': ['ARM:BE:32:v4', 'ARM:LE:32:v4'],
'armv2a': ['ARM:BE:32:v4', 'ARM:LE:32:v4'], 'armv2a': ['ARM:BE:32:v4', 'ARM:LE:32:v4'],
'armv3': ['ARM:BE:32:v4', 'ARM:LE:32:v4'], 'armv3': ['ARM:BE:32:v4', 'ARM:LE:32:v4'],
@ -55,7 +59,8 @@ language_map = {
'i386:x86-64': ['x86:LE:64:default'], 'i386:x86-64': ['x86:LE:64:default'],
'i386:x86-64:intel': ['x86:LE:64:default'], 'i386:x86-64:intel': ['x86:LE:64:default'],
'i8086': ['x86:LE:16:Protected Mode', 'x86:LE:16:Real Mode'], 'i8086': ['x86:LE:16:Protected Mode', 'x86:LE:16:Real Mode'],
'iwmmxt': ['ARM:BE:32:v7', 'ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v7', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'], 'iwmmxt': ['ARM:BE:32:v7', 'ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v7',
'ARM:LE:32:v8', 'ARM:LE:32:v8T'],
'm68hc12': ['HC-12:BE:16:default'], 'm68hc12': ['HC-12:BE:16:default'],
'm68k': ['68000:BE:32:default'], 'm68k': ['68000:BE:32:default'],
'm68k:68020': ['68000:BE:32:MC68020'], 'm68k:68020': ['68000:BE:32:MC68020'],
@ -63,41 +68,51 @@ language_map = {
'm9s12x': ['HCS-12:BE:24:default', 'HCS-12X:BE:24:default'], 'm9s12x': ['HCS-12:BE:24:default', 'HCS-12X:BE:24:default'],
'mips:3000': ['MIPS:BE:32:default', 'MIPS:LE:32:default'], 'mips:3000': ['MIPS:BE:32:default', 'MIPS:LE:32:default'],
'mips:4000': ['MIPS:BE:32:default', 'MIPS:LE:32:default'], 'mips:4000': ['MIPS:BE:32:default', 'MIPS:LE:32:default'],
'mips:5000': ['MIPS:BE:64:64-32addr', 'MIPS:BE:64:default', 'MIPS:LE:64:64-32addr', 'MIPS:LE:64:default'], 'mips:5000': ['MIPS:BE:64:64-32addr', 'MIPS:BE:64:default',
'MIPS:LE:64:64-32addr', 'MIPS:LE:64:default'],
'mips:micromips': ['MIPS:BE:32:micro'], 'mips:micromips': ['MIPS:BE:32:micro'],
'msp:430X': ['TI_MSP430:LE:16:default'], 'msp:430X': ['TI_MSP430:LE:16:default'],
'powerpc:403': ['PowerPC:BE:32:4xx', 'PowerPC:LE:32:4xx'], 'powerpc:403': ['PowerPC:BE:32:4xx', 'PowerPC:LE:32:4xx'],
'powerpc:MPC8XX': ['PowerPC:BE:32:MPC8270', 'PowerPC:BE:32:QUICC', 'PowerPC:LE:32:QUICC'], 'powerpc:MPC8XX': ['PowerPC:BE:32:MPC8270', 'PowerPC:BE:32:QUICC',
'PowerPC:LE:32:QUICC'],
'powerpc:common': ['PowerPC:BE:32:default', 'PowerPC:LE:32:default'], 'powerpc:common': ['PowerPC:BE:32:default', 'PowerPC:LE:32:default'],
'powerpc:common64': ['PowerPC:BE:64:64-32addr', 'PowerPC:BE:64:default', 'PowerPC:LE:64:64-32addr', 'PowerPC:LE:64:default'], 'powerpc:common64': ['PowerPC:BE:64:64-32addr', 'PowerPC:BE:64:default',
'PowerPC:LE:64:64-32addr', 'PowerPC:LE:64:default'],
'powerpc:e500': ['PowerPC:BE:32:e500', 'PowerPC:LE:32:e500'], 'powerpc:e500': ['PowerPC:BE:32:e500', 'PowerPC:LE:32:e500'],
'powerpc:e500mc': ['PowerPC:BE:64:A2ALT', 'PowerPC:LE:64:A2ALT'], 'powerpc:e500mc': ['PowerPC:BE:64:A2ALT', 'PowerPC:LE:64:A2ALT'],
'powerpc:e500mc64': ['PowerPC:BE:64:A2-32addr', 'PowerPC:BE:64:A2ALT-32addr', 'PowerPC:LE:64:A2-32addr', 'PowerPC:LE:64:A2ALT-32addr'], 'powerpc:e500mc64': ['PowerPC:BE:64:A2-32addr',
'riscv:rv32': ['RISCV:LE:32:RV32G', 'RISCV:LE:32:RV32GC', 'RISCV:LE:32:RV32I', 'RISCV:LE:32:RV32IC', 'RISCV:LE:32:RV32IMC', 'RISCV:LE:32:default'], 'PowerPC:BE:64:A2ALT-32addr',
'riscv:rv64': ['RISCV:LE:64:RV64G', 'RISCV:LE:64:RV64GC', 'RISCV:LE:64:RV64I', 'RISCV:LE:64:RV64IC', 'RISCV:LE:64:default'], 'PowerPC:LE:64:A2-32addr',
'PowerPC:LE:64:A2ALT-32addr'],
'riscv:rv32': ['RISCV:LE:32:RV32G', 'RISCV:LE:32:RV32GC',
'RISCV:LE:32:RV32I', 'RISCV:LE:32:RV32IC',
'RISCV:LE:32:RV32IMC', 'RISCV:LE:32:default'],
'riscv:rv64': ['RISCV:LE:64:RV64G', 'RISCV:LE:64:RV64GC',
'RISCV:LE:64:RV64I', 'RISCV:LE:64:RV64IC',
'RISCV:LE:64:default'],
'sh4': ['SuperH4:BE:32:default', 'SuperH4:LE:32:default'], 'sh4': ['SuperH4:BE:32:default', 'SuperH4:LE:32:default'],
'sparc:v9b': ['sparc:BE:32:default', 'sparc:BE:64:default'], 'sparc:v9b': ['sparc:BE:32:default', 'sparc:BE:64:default'],
'xscale': ['ARM:BE:32:v6', 'ARM:LE:32:v6'], 'xscale': ['ARM:BE:32:v6', 'ARM:LE:32:v6'],
'z80': ['z80:LE:16:default', 'z8401x:LE:16:default'] 'z80': ['z80:LE:16:default', 'z8401x:LE:16:default']
} }
data64_compiler_map = { data64_compiler_map: Dict[Optional[str], str] = {
None: 'pointer64', None: 'pointer64',
} }
x86_compiler_map = { x86_compiler_map: Dict[Optional[str], str] = {
'GNU/Linux': 'gcc', 'GNU/Linux': 'gcc',
'Windows': 'windows', 'Windows': 'windows',
# This may seem wrong, but Ghidra cspecs really describe the ABI # This may seem wrong, but Ghidra cspecs really describe the ABI
'Cygwin': 'windows', 'Cygwin': 'windows',
} }
riscv_compiler_map = { riscv_compiler_map: Dict[Optional[str], str] = {
'GNU/Linux': 'gcc', 'GNU/Linux': 'gcc',
'Cygwin': 'gcc', 'Cygwin': 'gcc',
} }
compiler_map = { compiler_map: Dict[str, Dict[Optional[str], str]] = {
'DATA:BE:64:default': data64_compiler_map, 'DATA:BE:64:default': data64_compiler_map,
'DATA:LE:64:default': data64_compiler_map, 'DATA:LE:64:default': data64_compiler_map,
'x86:LE:32:default': x86_compiler_map, 'x86:LE:32:default': x86_compiler_map,
@ -107,14 +122,14 @@ compiler_map = {
} }
def get_arch(): def get_arch() -> str:
return gdb.selected_inferior().architecture().name() return gdb.selected_inferior().architecture().name()
def get_endian(): def get_endian() -> str:
parm = gdb.parameter('endian') parm = gdb.parameter('endian')
if not parm in ['', 'auto', 'default']: if not parm in ['', 'auto', 'default']:
return parm return str(parm)
# Once again, we have to hack using the human-readable 'show' # Once again, we have to hack using the human-readable 'show'
show = gdb.execute('show endian', to_string=True) show = gdb.execute('show endian', to_string=True)
if 'little' in show: if 'little' in show:
@ -124,10 +139,10 @@ def get_endian():
return 'unrecognized' return 'unrecognized'
def get_osabi(): def get_osabi() -> str:
parm = gdb.parameter('osabi') parm = gdb.parameter('osabi')
if not parm in ['', 'auto', 'default']: if not parm in ['', 'auto', 'default']:
return parm return str(parm)
# We have to hack around the fact the GDB won't give us the current OS ABI # We have to hack around the fact the GDB won't give us the current OS ABI
# via the API if it is "auto" or "default". Using "show", we can get it, but # via the API if it is "auto" or "default". Using "show", we can get it, but
# we have to parse output meant for a human. The current value will be on # we have to parse output meant for a human. The current value will be on
@ -138,11 +153,11 @@ def get_osabi():
return line.split('"')[-2] return line.split('"')[-2]
def compute_ghidra_language(): def compute_ghidra_language() -> str:
# First, check if the parameter is set # First, check if the parameter is set
lang = gdb.parameter('ghidra-language') lang = gdb.parameter('ghidra-language')
if not lang in ['', 'auto', 'default']: if not lang in ['', 'auto', 'default']:
return lang return str(lang)
# Get the list of possible languages for the arch. We'll need to sift # Get the list of possible languages for the arch. We'll need to sift
# through them by endian and probably prefer default/simpler variants. The # through them by endian and probably prefer default/simpler variants. The
@ -163,11 +178,11 @@ def compute_ghidra_language():
return 'DATA' + lebe + '64:default' return 'DATA' + lebe + '64:default'
def compute_ghidra_compiler(lang): def compute_ghidra_compiler(lang: str) -> str:
# First, check if the parameter is set # First, check if the parameter is set
comp = gdb.parameter('ghidra-compiler') comp = gdb.parameter('ghidra-compiler')
if not comp in ['', 'auto', 'default']: if not comp in ['', 'auto', 'default']:
return comp return str(comp)
# Check if the selected lang has specific compiler recommendations # Check if the selected lang has specific compiler recommendations
if not lang in compiler_map: if not lang in compiler_map:
@ -185,7 +200,7 @@ def compute_ghidra_compiler(lang):
return 'default' return 'default'
def compute_ghidra_lcsp(): def compute_ghidra_lcsp() -> Tuple[str, str]:
lang = compute_ghidra_language() lang = compute_ghidra_language()
comp = compute_ghidra_compiler(lang) comp = compute_ghidra_compiler(lang)
return lang, comp return lang, comp
@ -193,10 +208,10 @@ def compute_ghidra_lcsp():
class DefaultMemoryMapper(object): class DefaultMemoryMapper(object):
def __init__(self, defaultSpace): def __init__(self, defaultSpace: str) -> None:
self.defaultSpace = defaultSpace self.defaultSpace = defaultSpace
def map(self, inf: gdb.Inferior, offset: int): def map(self, inf: gdb.Inferior, offset: int) -> Tuple[str, Address]:
if inf.num == 1: if inf.num == 1:
space = self.defaultSpace space = self.defaultSpace
else: else:
@ -213,10 +228,10 @@ class DefaultMemoryMapper(object):
DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram') DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram')
memory_mappers = {} memory_mappers: Dict[str, DefaultMemoryMapper] = {}
def compute_memory_mapper(lang): def compute_memory_mapper(lang: str) -> DefaultMemoryMapper:
if not lang in memory_mappers: if not lang in memory_mappers:
return DEFAULT_MEMORY_MAPPER return DEFAULT_MEMORY_MAPPER
return memory_mappers[lang] return memory_mappers[lang]
@ -224,16 +239,16 @@ def compute_memory_mapper(lang):
class DefaultRegisterMapper(object): class DefaultRegisterMapper(object):
def __init__(self, byte_order): def __init__(self, byte_order: str) -> None:
if not byte_order in ['big', 'little']: if not byte_order in ['big', 'little']:
raise ValueError("Invalid byte_order: {}".format(byte_order)) raise ValueError(f"Invalid byte_order: {byte_order}")
self.byte_order = byte_order self.byte_order = byte_order
self.union_winners = {}
def map_name(self, inf, name): def map_name(self, inf: gdb.Inferior, name: str):
return name return name
def convert_value(self, value, type=None): def convert_value(self, value: gdb.Value,
type: Optional[gdb.Type] = None) -> bytes:
if type is None: if type is None:
type = value.dynamic_type.strip_typedefs() type = value.dynamic_type.strip_typedefs()
l = type.sizeof l = type.sizeof
@ -241,39 +256,43 @@ class DefaultRegisterMapper(object):
# NOTE: Might like to pre-lookup 'unsigned char', but it depends on the # NOTE: Might like to pre-lookup 'unsigned char', but it depends on the
# architecture *at the time of lookup*. # architecture *at the time of lookup*.
cv = value.cast(gdb.lookup_type('unsigned char').array(l - 1)) cv = value.cast(gdb.lookup_type('unsigned char').array(l - 1))
rng = range(l) rng: Sequence[int] = range(l)
if self.byte_order == 'little': it = reversed(rng) if self.byte_order == 'little' else rng
rng = reversed(rng) result = bytes(cv[i] for i in it)
return bytes(cv[i] for i in rng) return result
def map_value(self, inf, name, value): def map_value(self, inf: gdb.Inferior, name: str,
value: gdb.Value) -> RegVal:
try: try:
av = self.convert_value(value) av = self.convert_value(value)
except gdb.error as e: except gdb.error as e:
raise gdb.GdbError("Cannot convert {}'s value: '{}', type: '{}'" raise gdb.GdbError(
.format(name, value, value.type)) f"Cannot convert {name}'s value: '{value}', type: '{value.type}'")
return RegVal(self.map_name(inf, name), av) return RegVal(self.map_name(inf, name), av)
def convert_value_back(self, value, size=None): def convert_value_back(self, value: bytes,
size: Optional[int] = None) -> bytes:
if size is not None: if size is not None:
value = value[-size:].rjust(size, b'\0') value = value[-size:].rjust(size, b'\0')
if self.byte_order == 'little': if self.byte_order == 'little':
value = bytes(reversed(value)) value = bytes(reversed(value))
return value return value
def map_name_back(self, inf, name): def map_name_back(self, inf: gdb.Inferior, name: str) -> str:
return name return name
def map_value_back(self, inf, name, value): def map_value_back(self, inf: gdb.Inferior, name: str,
return RegVal(self.map_name_back(inf, name), self.convert_value_back(value)) value: bytes) -> RegVal:
return RegVal(
self.map_name_back(inf, name), self.convert_value_back(value))
class Intel_x86_64_RegisterMapper(DefaultRegisterMapper): class Intel_x86_64_RegisterMapper(DefaultRegisterMapper):
def __init__(self): def __init__(self) -> None:
super().__init__('little') super().__init__('little')
def map_name(self, inf, name): def map_name(self, inf: gdb.Inferior, name: str) -> str:
if name == 'eflags': if name == 'eflags':
return 'rflags' return 'rflags'
if name.startswith('zmm'): if name.startswith('zmm'):
@ -281,13 +300,14 @@ class Intel_x86_64_RegisterMapper(DefaultRegisterMapper):
return 'ymm' + name[3:] return 'ymm' + name[3:]
return super().map_name(inf, name) return super().map_name(inf, name)
def map_value(self, inf, name, value): def map_value(self, inf: gdb.Inferior, name: str,
value: gdb.Value) -> RegVal:
rv = super().map_value(inf, name, value) rv = super().map_value(inf, name, value)
if rv.name.startswith('ymm') and len(rv.value) > 32: if rv.name.startswith('ymm') and len(rv.value) > 32:
return RegVal(rv.name, rv.value[-32:]) return RegVal(rv.name, rv.value[-32:])
return rv return rv
def map_name_back(self, inf, name): def map_name_back(self, inf: gdb.Inferior, name: str) -> str:
if name == 'rflags': if name == 'rflags':
return 'eflags' return 'eflags'
return name return name
@ -296,12 +316,12 @@ class Intel_x86_64_RegisterMapper(DefaultRegisterMapper):
DEFAULT_BE_REGISTER_MAPPER = DefaultRegisterMapper('big') DEFAULT_BE_REGISTER_MAPPER = DefaultRegisterMapper('big')
DEFAULT_LE_REGISTER_MAPPER = DefaultRegisterMapper('little') DEFAULT_LE_REGISTER_MAPPER = DefaultRegisterMapper('little')
register_mappers = { register_mappers: Dict[str, DefaultRegisterMapper] = {
'x86:LE:64:default': Intel_x86_64_RegisterMapper() 'x86:LE:64:default': Intel_x86_64_RegisterMapper()
} }
def compute_register_mapper(lang): def compute_register_mapper(lang: str) -> DefaultRegisterMapper:
if not lang in register_mappers: if not lang in register_mappers:
if ':BE:' in lang: if ':BE:' in lang:
return DEFAULT_BE_REGISTER_MAPPER return DEFAULT_BE_REGISTER_MAPPER

View file

@ -13,68 +13,69 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from dataclasses import dataclass, field
import functools import functools
import time import time
import traceback import traceback
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, cast
import gdb import gdb
from ghidratrace.client import Batch
from . import commands, util from . import commands, util
class GhidraHookPrefix(gdb.Command): class GhidraHookPrefix(gdb.Command):
"""Commands for exporting data to a Ghidra trace""" """Commands for exporting data to a Ghidra trace."""
def __init__(self): def __init__(self) -> None:
super().__init__('hooks-ghidra', gdb.COMMAND_NONE, prefix=True) super().__init__('hooks-ghidra', gdb.COMMAND_NONE, prefix=True)
GhidraHookPrefix() GhidraHookPrefix()
@dataclass(frozen=False)
class HookState(object): class HookState(object):
__slots__ = ('installed', 'batch', 'skip_continue', 'in_break_w_cont') installed = False
batch: Optional[Batch] = None
skip_continue = False
in_break_w_cont = False
def __init__(self): def ensure_batch(self) -> None:
self.installed = False
self.batch = None
self.skip_continue = False
self.in_break_w_cont = False
def ensure_batch(self):
if self.batch is None: if self.batch is None:
self.batch = commands.STATE.client.start_batch() self.batch = commands.STATE.require_client().start_batch()
def end_batch(self): def end_batch(self) -> None:
if self.batch is None: if self.batch is None:
return return
self.batch = None self.batch = None
commands.STATE.client.end_batch() commands.STATE.require_client().end_batch()
def check_skip_continue(self): def check_skip_continue(self) -> bool:
skip = self.skip_continue skip = self.skip_continue
self.skip_continue = False self.skip_continue = False
return skip return skip
@dataclass(frozen=False)
class InferiorState(object): class InferiorState(object):
__slots__ = ('first', 'regions', 'modules', 'threads', 'breaks', 'visited') first = True
def __init__(self):
self.first = True
# For things we can detect changes to between stops # For things we can detect changes to between stops
self.regions = [] regions: List[util.Region] = field(default_factory=list)
self.modules = False modules = False
self.threads = False threads = False
self.breaks = False breaks = False
# For frames and threads that have already been synced since last stop # For frames and threads that have already been synced since last stop
self.visited = set() visited: set[Any] = field(default_factory=set)
def record(self, description=None): def record(self, description: Optional[str] = None) -> None:
first = self.first first = self.first
self.first = False self.first = False
trace = commands.STATE.require_trace()
if description is not None: if description is not None:
commands.STATE.trace.snapshot(description) trace.snapshot(description)
if first: if first:
commands.put_inferiors() commands.put_inferiors()
commands.put_environment() commands.put_environment()
@ -106,7 +107,8 @@ class InferiorState(object):
print(f"Couldn't record page with SP: {e}") print(f"Couldn't record page with SP: {e}")
self.visited.add(hashable_frame) self.visited.add(hashable_frame)
# NB: These commands (put_modules/put_regions) will fail if the process is running # NB: These commands (put_modules/put_regions) will fail if the process is running
regions_changed, regions = util.REGION_INFO_READER.have_changed(self.regions) regions_changed, regions = util.REGION_INFO_READER.have_changed(
self.regions)
if regions_changed: if regions_changed:
self.regions = commands.put_regions(regions) self.regions = commands.put_regions(regions)
if first or self.modules: if first or self.modules:
@ -116,31 +118,29 @@ class InferiorState(object):
commands.put_breakpoints() commands.put_breakpoints()
self.breaks = False self.breaks = False
def record_continued(self): def record_continued(self) -> None:
commands.put_inferiors() commands.put_inferiors()
commands.put_threads() commands.put_threads()
def record_exited(self, exit_code): def record_exited(self, exit_code: int) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
ipath = commands.INFERIOR_PATTERN.format(infnum=inf.num) ipath = commands.INFERIOR_PATTERN.format(infnum=inf.num)
infobj = commands.STATE.trace.proxy_object_path(ipath) infobj = commands.STATE.require_trace().proxy_object_path(ipath)
infobj.set_value('Exit Code', exit_code) infobj.set_value('Exit Code', exit_code)
infobj.set_value('State', 'TERMINATED') infobj.set_value('State', 'TERMINATED')
@dataclass(frozen=False)
class BrkState(object): class BrkState(object):
__slots__ = ('break_loc_counts',) break_loc_counts: Dict[gdb.Breakpoint, int] = field(default_factory=dict)
def __init__(self): def update_brkloc_count(self, b: gdb.Breakpoint, count: int) -> None:
self.break_loc_counts = {}
def update_brkloc_count(self, b, count):
self.break_loc_counts[b] = count self.break_loc_counts[b] = count
def get_brkloc_count(self, b): def get_brkloc_count(self, b: gdb.Breakpoint) -> int:
return self.break_loc_counts.get(b, 0) return self.break_loc_counts.get(b, 0)
def del_brkloc_count(self, b): def del_brkloc_count(self, b: gdb.Breakpoint) -> int:
if b not in self.break_loc_counts: if b not in self.break_loc_counts:
return 0 # TODO: Print a warning? return 0 # TODO: Print a warning?
count = self.break_loc_counts[b] count = self.break_loc_counts[b]
@ -150,40 +150,41 @@ class BrkState(object):
HOOK_STATE = HookState() HOOK_STATE = HookState()
BRK_STATE = BrkState() BRK_STATE = BrkState()
INF_STATES = {} INF_STATES: Dict[int, InferiorState] = {}
def log_errors(func): C = TypeVar('C', bound=Callable)
'''
Wrap a function in a try-except that prints and reraises the
exception. def log_errors(func: C) -> C:
"""Wrap a function in a try-except that prints and reraises the exception.
This is needed because pybag and/or the COM wrappers do not print This is needed because pybag and/or the COM wrappers do not print
exceptions that occur during event callbacks. exceptions that occur during event callbacks.
''' """
@functools.wraps(func) @functools.wraps(func)
def _func(*args, **kwargs): def _func(*args, **kwargs) -> Any:
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except: except:
traceback.print_exc() traceback.print_exc()
raise raise
return _func return cast(C, _func)
@log_errors @log_errors
def on_new_inferior(event): def on_new_inferior(event: gdb.NewInferiorEvent) -> None:
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("New Inferior {}".format(event.inferior.num)): with trace.open_tx(f"New Inferior {event.inferior.num}"):
commands.put_inferiors() # TODO: Could put just the one.... commands.put_inferiors() # TODO: Could put just the one....
def on_inferior_selected(): def on_inferior_selected() -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -191,25 +192,25 @@ def on_inferior_selected():
if trace is None: if trace is None:
return return
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Inferior {} selected".format(inf.num)): with trace.open_tx(f"Inferior {inf.num} selected"):
INF_STATES[inf.num].record() INF_STATES[inf.num].record()
commands.activate() commands.activate()
@log_errors @log_errors
def on_inferior_deleted(event): def on_inferior_deleted(event: gdb.InferiorDeletedEvent) -> None:
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
if event.inferior.num in INF_STATES: if event.inferior.num in INF_STATES:
del INF_STATES[event.inferior.num] del INF_STATES[event.inferior.num]
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Inferior {} deleted".format(event.inferior.num)): with trace.open_tx(f"Inferior {event.inferior.num} deleted"):
commands.put_inferiors() # TODO: Could just delete the one.... commands.put_inferiors() # TODO: Could just delete the one....
@log_errors @log_errors
def on_new_thread(event): def on_new_thread(event: gdb.ThreadEvent) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -217,7 +218,7 @@ def on_new_thread(event):
# TODO: Syscall clone/exit to detect thread destruction? # TODO: Syscall clone/exit to detect thread destruction?
def on_thread_selected(): def on_thread_selected(event: Optional[gdb.ThreadEvent]) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -226,12 +227,12 @@ def on_thread_selected():
return return
t = gdb.selected_thread() t = gdb.selected_thread()
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Thread {}.{} selected".format(inf.num, t.num)): with trace.open_tx(f"Thread {inf.num}.{t.num} selected"):
INF_STATES[inf.num].record() INF_STATES[inf.num].record()
commands.activate() commands.activate()
def on_frame_selected(): def on_frame_selected() -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -243,13 +244,13 @@ def on_frame_selected():
if f is None: if f is None:
return return
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Frame {}.{}.{} selected".format(inf.num, t.num, util.get_level(f))): with trace.open_tx(f"Frame {inf.num}.{t.num}.{util.get_level(f)} selected"):
INF_STATES[inf.num].record() INF_STATES[inf.num].record()
commands.activate() commands.activate()
@log_errors @log_errors
def on_memory_changed(event): def on_memory_changed(event: gdb.MemoryChangedEvent) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -257,13 +258,15 @@ def on_memory_changed(event):
if trace is None: if trace is None:
return return
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Memory *0x{:08x} changed".format(event.address)): address = int(event.address)
commands.put_bytes(event.address, event.address + event.length, length = int(event.length)
with trace.open_tx(f"Memory *0x{address:08x} changed"):
commands.put_bytes(address, address + length,
pages=False, is_mi=False, from_tty=False) pages=False, is_mi=False, from_tty=False)
@log_errors @log_errors
def on_register_changed(event): def on_register_changed(event: gdb.RegisterChangedEvent) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -274,7 +277,7 @@ def on_register_changed(event):
# TODO: How do I get the descriptor from the number? # TODO: How do I get the descriptor from the number?
# For now, just record the lot # For now, just record the lot
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Register {} changed".format(event.regnum)): with trace.open_tx(f"Register {event.regnum} changed"):
commands.putreg(event.frame, util.get_register_descs( commands.putreg(event.frame, util.get_register_descs(
event.frame.architecture())) event.frame.architecture()))
@ -300,8 +303,9 @@ def on_cont(event):
state.record_continued() state.record_continued()
def check_for_continue(event): def check_for_continue(event: Optional[gdb.StopEvent]) -> bool:
if hasattr(event, 'breakpoints'): # Attribute check because of version differences
if isinstance(event, gdb.StopEvent) and hasattr(event, 'breakpoints'):
if HOOK_STATE.in_break_w_cont: if HOOK_STATE.in_break_w_cont:
return True return True
for brk in event.breakpoints: for brk in event.breakpoints:
@ -315,7 +319,7 @@ def check_for_continue(event):
@log_errors @log_errors
def on_stop(event): def on_stop(event: Optional[gdb.StopEvent]) -> None:
if check_for_continue(event): if check_for_continue(event):
HOOK_STATE.skip_continue = True HOOK_STATE.skip_continue = True
return return
@ -336,7 +340,7 @@ def on_stop(event):
@log_errors @log_errors
def on_exited(event): def on_exited(event: gdb.ExitedEvent) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
return return
@ -358,13 +362,13 @@ def on_exited(event):
HOOK_STATE.end_batch() HOOK_STATE.end_batch()
def notify_others_breaks(inf): def notify_others_breaks(inf: gdb.Inferior) -> None:
for num, state in INF_STATES.items(): for num, state in INF_STATES.items():
if num != inf.num: if num != inf.num:
state.breaks = True state.breaks = True
def modules_changed(): def modules_changed() -> None:
# Assumption: affects the current inferior # Assumption: affects the current inferior
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
@ -373,22 +377,22 @@ def modules_changed():
@log_errors @log_errors
def on_clear_objfiles(event): def on_clear_objfiles(event: gdb.ClearObjFilesEvent) -> None:
modules_changed() modules_changed()
@log_errors @log_errors
def on_new_objfile(event): def on_new_objfile(event: gdb.NewObjFileEvent) -> None:
modules_changed() modules_changed()
@log_errors @log_errors
def on_free_objfile(event): def on_free_objfile(event: gdb.FreeObjFileEvent) -> None:
modules_changed() modules_changed()
@log_errors @log_errors
def on_breakpoint_created(b): def on_breakpoint_created(b: gdb.Breakpoint) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
notify_others_breaks(inf) notify_others_breaks(inf)
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
@ -398,7 +402,7 @@ def on_breakpoint_created(b):
return return
ibpath = commands.INF_BREAKS_PATTERN.format(infnum=inf.num) ibpath = commands.INF_BREAKS_PATTERN.format(infnum=inf.num)
HOOK_STATE.ensure_batch() HOOK_STATE.ensure_batch()
with trace.open_tx("Breakpoint {} created".format(b.number)): with trace.open_tx(f"Breakpoint {b.number} created"):
ibobj = trace.create_object(ibpath) ibobj = trace.create_object(ibpath)
# Do not use retain_values or it'll remove other locs # Do not use retain_values or it'll remove other locs
commands.put_single_breakpoint(b, ibobj, inf, []) commands.put_single_breakpoint(b, ibobj, inf, [])
@ -406,7 +410,7 @@ def on_breakpoint_created(b):
@log_errors @log_errors
def on_breakpoint_modified(b): def on_breakpoint_modified(b: gdb.Breakpoint) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
notify_others_breaks(inf) notify_others_breaks(inf)
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
@ -429,7 +433,7 @@ def on_breakpoint_modified(b):
@log_errors @log_errors
def on_breakpoint_deleted(b): def on_breakpoint_deleted(b: gdb.Breakpoint) -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
notify_others_breaks(inf) notify_others_breaks(inf)
if inf.num not in INF_STATES: if inf.num not in INF_STATES:
@ -451,17 +455,28 @@ def on_breakpoint_deleted(b):
@log_errors @log_errors
def on_before_prompt(): def on_before_prompt(n: None) -> object:
HOOK_STATE.end_batch() HOOK_STATE.end_batch()
return None
def cmd_hook(name): @dataclass(frozen=True)
class HookFunc(object):
wrapped: Callable[[], None]
hook: Type[gdb.Command]
unhook: Callable[[], None]
def _cmd_hook(func): def __call__(self) -> None:
self.wrapped()
def cmd_hook(name: str):
def _cmd_hook(func: Callable[[], None]) -> HookFunc:
class _ActiveCommand(gdb.Command): class _ActiveCommand(gdb.Command):
def __init__(self): def __init__(self) -> None:
# It seems we can't hook commands using the Python API.... # It seems we can't hook commands using the Python API....
super().__init__(f"hooks-ghidra def-{name}", gdb.COMMAND_USER) super().__init__(f"hooks-ghidra def-{name}", gdb.COMMAND_USER)
gdb.execute(f""" gdb.execute(f"""
@ -470,50 +485,48 @@ def cmd_hook(name):
end end
""") """)
def invoke(self, argument, from_tty): def invoke(self, argument: str, from_tty: bool) -> None:
self.dont_repeat() self.dont_repeat()
func() func()
def _unhook_command(): def _unhook_command() -> None:
gdb.execute(f""" gdb.execute(f"""
define {name} define {name}
end end
""") """)
func.hook = _ActiveCommand return HookFunc(func, _ActiveCommand, _unhook_command)
func.unhook = _unhook_command
return func
return _cmd_hook return _cmd_hook
@cmd_hook('hookpost-inferior') @cmd_hook('hookpost-inferior')
def hook_inferior(): def hook_inferior() -> None:
on_inferior_selected() on_inferior_selected()
@cmd_hook('hookpost-thread') @cmd_hook('hookpost-thread')
def hook_thread(): def hook_thread() -> None:
on_thread_selected() on_thread_selected(None)
@cmd_hook('hookpost-frame') @cmd_hook('hookpost-frame')
def hook_frame(): def hook_frame() -> None:
on_frame_selected() on_frame_selected()
@cmd_hook('hookpost-up') @cmd_hook('hookpost-up')
def hook_frame_up(): def hook_frame_up() -> None:
on_frame_selected() on_frame_selected()
@cmd_hook('hookpost-down') @cmd_hook('hookpost-down')
def hook_frame_down(): def hook_frame_down() -> None:
on_frame_selected() on_frame_selected()
# TODO: Checks and workarounds for events missing in gdb 9 # TODO: Checks and workarounds for events missing in gdb 9
def install_hooks(): def install_hooks() -> None:
if HOOK_STATE.installed: if HOOK_STATE.installed:
return return
HOOK_STATE.installed = True HOOK_STATE.installed = True
@ -548,7 +561,7 @@ def install_hooks():
gdb.events.before_prompt.connect(on_before_prompt) gdb.events.before_prompt.connect(on_before_prompt)
def remove_hooks(): def remove_hooks() -> None:
if not HOOK_STATE.installed: if not HOOK_STATE.installed:
return return
HOOK_STATE.installed = False HOOK_STATE.installed = False
@ -582,12 +595,12 @@ def remove_hooks():
gdb.events.before_prompt.disconnect(on_before_prompt) gdb.events.before_prompt.disconnect(on_before_prompt)
def enable_current_inferior(): def enable_current_inferior() -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
INF_STATES[inf.num] = InferiorState() INF_STATES[inf.num] = InferiorState()
def disable_current_inferior(): def disable_current_inferior() -> None:
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
if inf.num in INF_STATES: if inf.num in INF_STATES:
# Silently ignore already disabled # Silently ignore already disabled

View file

@ -16,28 +16,30 @@
from concurrent.futures import Future, Executor from concurrent.futures import Future, Executor
from contextlib import contextmanager from contextlib import contextmanager
import re import re
from typing import Annotated, Generator, Optional, Tuple, Union
import gdb import gdb
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import MethodRegistry, ParamDesc, Address, AddressRange from ghidratrace.client import (MethodRegistry, ParamDesc, Address,
AddressRange, Trace, TraceObject)
from . import commands, hooks, util from . import commands, hooks, util
@contextmanager @contextmanager
def no_pagination(): def no_pagination() -> Generator[None, None, None]:
before = gdb.parameter('pagination') before = gdb.parameter('pagination')
util.set_bool_param('pagination', False) util.set_bool_param('pagination', False)
yield yield
util.set_bool_param('pagination', before) util.set_bool_param('pagination', bool(before))
@contextmanager @contextmanager
def no_confirm(): def no_confirm() -> Generator[None, None, None]:
before = gdb.parameter('confirm') before = gdb.parameter('confirm')
util.set_bool_param('confirm', False) util.set_bool_param('confirm', False)
yield yield
util.set_bool_param('confirm', before) util.set_bool_param('confirm', bool(before))
class GdbExecutor(Executor): class GdbExecutor(Executor):
@ -60,27 +62,28 @@ class GdbExecutor(Executor):
REGISTRY = MethodRegistry(GdbExecutor()) REGISTRY = MethodRegistry(GdbExecutor())
def extre(base, ext): def extre(base: re.Pattern, ext: str) -> re.Pattern:
return re.compile(base.pattern + ext) return re.compile(base.pattern + ext)
AVAILABLE_PATTERN = re.compile('Available\[(?P<pid>\\d*)\]') AVAILABLE_PATTERN = re.compile('Available\\[(?P<pid>\\d*)\\]')
BREAKPOINT_PATTERN = re.compile('Breakpoints\[(?P<breaknum>\\d*)\]') BREAKPOINT_PATTERN = re.compile('Breakpoints\\[(?P<breaknum>\\d*)\\]')
BREAK_LOC_PATTERN = extre(BREAKPOINT_PATTERN, '\[(?P<locnum>\\d*)\]') BREAK_LOC_PATTERN = extre(BREAKPOINT_PATTERN, '\\[(?P<locnum>\\d*)\\]')
INFERIOR_PATTERN = re.compile('Inferiors\[(?P<infnum>\\d*)\]') INFERIOR_PATTERN = re.compile('Inferiors\\[(?P<infnum>\\d*)\\]')
INF_BREAKS_PATTERN = extre(INFERIOR_PATTERN, '\.Breakpoints') INF_BREAKS_PATTERN = extre(INFERIOR_PATTERN, '\\.Breakpoints')
ENV_PATTERN = extre(INFERIOR_PATTERN, '\.Environment') ENV_PATTERN = extre(INFERIOR_PATTERN, '\\.Environment')
THREADS_PATTERN = extre(INFERIOR_PATTERN, '\.Threads') THREADS_PATTERN = extre(INFERIOR_PATTERN, '\\.Threads')
THREAD_PATTERN = extre(THREADS_PATTERN, '\[(?P<tnum>\\d*)\]') THREAD_PATTERN = extre(THREADS_PATTERN, '\\[(?P<tnum>\\d*)\\]')
STACK_PATTERN = extre(THREAD_PATTERN, '\.Stack') STACK_PATTERN = extre(THREAD_PATTERN, '\\.Stack')
FRAME_PATTERN = extre(STACK_PATTERN, '\[(?P<level>\\d*)\]') FRAME_PATTERN = extre(STACK_PATTERN, '\\[(?P<level>\\d*)\\]')
REGS_PATTERN = extre(FRAME_PATTERN, '\.Registers') REGS_PATTERN = extre(FRAME_PATTERN, '\\.Registers')
MEMORY_PATTERN = extre(INFERIOR_PATTERN, '\.Memory') MEMORY_PATTERN = extre(INFERIOR_PATTERN, '\\.Memory')
MODULES_PATTERN = extre(INFERIOR_PATTERN, '\.Modules') MODULES_PATTERN = extre(INFERIOR_PATTERN, '\\.Modules')
MODULE_PATTERN = extre(MODULES_PATTERN, '\[(?P<modname>.*)\]') MODULE_PATTERN = extre(MODULES_PATTERN, '\\[(?P<modname>.*)\\]')
def find_availpid_by_pattern(pattern, object, err_msg): def find_availpid_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> int:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -88,18 +91,19 @@ def find_availpid_by_pattern(pattern, object, err_msg):
return pid return pid
def find_availpid_by_obj(object): def find_availpid_by_obj(object: TraceObject) -> int:
return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Available") return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Available")
def find_inf_by_num(infnum): def find_inf_by_num(infnum: int) -> gdb.Inferior:
for inf in gdb.inferiors(): for inf in gdb.inferiors():
if inf.num == infnum: if inf.num == infnum:
return inf return inf
raise KeyError(f"Inferiors[{infnum}] does not exist") raise KeyError(f"Inferiors[{infnum}] does not exist")
def find_inf_by_pattern(object, pattern, err_msg): def find_inf_by_pattern(object: TraceObject, pattern: re.Pattern,
err_msg: str) -> gdb.Inferior:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -107,50 +111,51 @@ def find_inf_by_pattern(object, pattern, err_msg):
return find_inf_by_num(infnum) return find_inf_by_num(infnum)
def find_inf_by_obj(object): def find_inf_by_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, INFERIOR_PATTERN, "an Inferior") return find_inf_by_pattern(object, INFERIOR_PATTERN, "an Inferior")
def find_inf_by_infbreak_obj(object): def find_inf_by_infbreak_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, INF_BREAKS_PATTERN, return find_inf_by_pattern(object, INF_BREAKS_PATTERN,
"a BreakpointLocationContainer") "a BreakpointLocationContainer")
def find_inf_by_env_obj(object): def find_inf_by_env_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, ENV_PATTERN, "an Environment") return find_inf_by_pattern(object, ENV_PATTERN, "an Environment")
def find_inf_by_threads_obj(object): def find_inf_by_threads_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, THREADS_PATTERN, "a ThreadContainer") return find_inf_by_pattern(object, THREADS_PATTERN, "a ThreadContainer")
def find_inf_by_mem_obj(object): def find_inf_by_mem_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, MEMORY_PATTERN, "a Memory") return find_inf_by_pattern(object, MEMORY_PATTERN, "a Memory")
def find_inf_by_modules_obj(object): def find_inf_by_modules_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, MODULES_PATTERN, "a ModuleContainer") return find_inf_by_pattern(object, MODULES_PATTERN, "a ModuleContainer")
def find_inf_by_mod_obj(object): def find_inf_by_mod_obj(object: TraceObject) -> gdb.Inferior:
return find_inf_by_pattern(object, MODULE_PATTERN, "a Module") return find_inf_by_pattern(object, MODULE_PATTERN, "a Module")
def find_module_name_by_mod_obj(object): def find_module_name_by_mod_obj(object: TraceObject) -> str:
mat = MODULE_PATTERN.fullmatch(object.path) mat = MODULE_PATTERN.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not a Module") raise TypeError(f"{object} is not a Module")
return mat['modname'] return mat['modname']
def find_thread_by_num(inf, tnum): def find_thread_by_num(inf: gdb.Inferior, tnum: int) -> gdb.InferiorThread:
for t in inf.threads(): for t in inf.threads():
if t.num == tnum: if t.num == tnum:
return t return t
raise KeyError(f"Inferiors[{inf.num}].Threads[{tnum}] does not exist") raise KeyError(f"Inferiors[{inf.num}].Threads[{tnum}] does not exist")
def find_thread_by_pattern(pattern, object, err_msg): def find_thread_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> gdb.InferiorThread:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -160,15 +165,16 @@ def find_thread_by_pattern(pattern, object, err_msg):
return find_thread_by_num(inf, tnum) return find_thread_by_num(inf, tnum)
def find_thread_by_obj(object): def find_thread_by_obj(object: TraceObject) -> gdb.InferiorThread:
return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread") return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread")
def find_thread_by_stack_obj(object): def find_thread_by_stack_obj(object: TraceObject) -> gdb.InferiorThread:
return find_thread_by_pattern(STACK_PATTERN, object, "a Stack") return find_thread_by_pattern(STACK_PATTERN, object, "a Stack")
def find_frame_by_level(thread, level): def find_frame_by_level(thread: gdb.InferiorThread,
level: int) -> Optional[gdb.Frame]:
# Because threads don't have any attribute to get at frames # Because threads don't have any attribute to get at frames
thread.switch() thread.switch()
f = util.selected_frame() f = util.selected_frame()
@ -192,7 +198,8 @@ def find_frame_by_level(thread, level):
return f return f
def find_frame_by_pattern(pattern, object, err_msg): def find_frame_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> Optional[gdb.Frame]:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -204,17 +211,18 @@ def find_frame_by_pattern(pattern, object, err_msg):
return find_frame_by_level(t, level) return find_frame_by_level(t, level)
def find_frame_by_obj(object): def find_frame_by_obj(object: TraceObject) -> Optional[gdb.Frame]:
return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame") return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame")
def find_frame_by_regs_obj(object): def find_frame_by_regs_obj(object: TraceObject) -> Optional[gdb.Frame]:
return find_frame_by_pattern(REGS_PATTERN, object, return find_frame_by_pattern(REGS_PATTERN, object,
"a RegisterValueContainer") "a RegisterValueContainer")
# Because there's no method to get a register by name.... # Because there's no method to get a register by name....
def find_reg_by_name(f, name): def find_reg_by_name(f: gdb.Frame, name: str) -> Union[gdb.RegisterDescriptor,
util.RegisterDesc]:
for reg in util.get_register_descs(f.architecture()): for reg in util.get_register_descs(f.architecture()):
# TODO: gdb appears to be case sensitive, but until we encounter a # TODO: gdb appears to be case sensitive, but until we encounter a
# situation where case matters, we'll be insensitive # situation where case matters, we'll be insensitive
@ -225,7 +233,7 @@ def find_reg_by_name(f, name):
# Oof. no gdb/Python method to get breakpoint by number # Oof. no gdb/Python method to get breakpoint by number
# I could keep my own cache in a dict, but why? # I could keep my own cache in a dict, but why?
def find_bpt_by_number(breaknum): def find_bpt_by_number(breaknum: int) -> gdb.Breakpoint:
# TODO: If len exceeds some threshold, use binary search? # TODO: If len exceeds some threshold, use binary search?
for b in gdb.breakpoints(): for b in gdb.breakpoints():
if b.number == breaknum: if b.number == breaknum:
@ -233,7 +241,8 @@ def find_bpt_by_number(breaknum):
raise KeyError(f"Breakpoints[{breaknum}] does not exist") raise KeyError(f"Breakpoints[{breaknum}] does not exist")
def find_bpt_by_pattern(pattern, object, err_msg): def find_bpt_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> gdb.Breakpoint:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -241,74 +250,140 @@ def find_bpt_by_pattern(pattern, object, err_msg):
return find_bpt_by_number(breaknum) return find_bpt_by_number(breaknum)
def find_bpt_by_obj(object): def find_bpt_by_obj(object: TraceObject) -> gdb.Breakpoint:
return find_bpt_by_pattern(BREAKPOINT_PATTERN, object, "a BreakpointSpec") return find_bpt_by_pattern(BREAKPOINT_PATTERN, object, "a BreakpointSpec")
def find_bptlocnum_by_pattern(pattern, object, err_msg): def find_bptlocnum_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> Tuple[int, int]:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
breaknum = int(mat['breaknum']) breaknum = int(mat['breaknum'])
locnum = int(mat['locnum']) locnum = int(mat['locnum'])
return breaknum, locnum return breaknum, locnum
def find_bptlocnum_by_obj(object): def find_bptlocnum_by_obj(object: TraceObject) -> Tuple[int, int]:
return find_bptlocnum_by_pattern(BREAK_LOC_PATTERN, object, return find_bptlocnum_by_pattern(BREAK_LOC_PATTERN, object,
"a BreakpointLocation") "a BreakpointLocation")
def find_bpt_loc_by_obj(object): def find_bpt_loc_by_obj(object: TraceObject) -> gdb.BreakpointLocation:
breaknum, locnum = find_bptlocnum_by_obj(object) breaknum, locnum = find_bptlocnum_by_obj(object)
bpt = find_bpt_by_number(breaknum) bpt = find_bpt_by_number(breaknum)
# Requires gdb-13.1 or later # Requires gdb-13.1 or later
return bpt.locations[locnum - 1] # Display is 1-up return bpt.locations[locnum - 1] # Display is 1-up
def switch_inferior(inferior): def switch_inferior(inferior: gdb.Inferior) -> None:
if gdb.selected_inferior().num == inferior.num: if gdb.selected_inferior().num == inferior.num:
return return
gdb.execute(f'inferior {inferior.num}') gdb.execute(f'inferior {inferior.num}')
@REGISTRY.method class Attachable(TraceObject):
def execute(cmd: str, to_string: bool=False): pass
class AvailableContainer(TraceObject):
pass
class BreakpointContainer(TraceObject):
pass
class BreakpointLocation(TraceObject):
pass
class BreakpointLocationContainer(TraceObject):
pass
class BreakpointSpec(TraceObject):
pass
class Environment(TraceObject):
pass
class Inferior(TraceObject):
pass
class InferiorContainer(TraceObject):
pass
class Memory(TraceObject):
pass
class Module(TraceObject):
pass
class ModuleContainer(TraceObject):
pass
class RegisterValueContainer(TraceObject):
pass
class Stack(TraceObject):
pass
class StackFrame(TraceObject):
pass
class Thread(TraceObject):
pass
class ThreadContainer(TraceObject):
pass
@REGISTRY.method()
def execute(cmd: str, to_string: bool = False) -> Optional[str]:
"""Execute a CLI command.""" """Execute a CLI command."""
return gdb.execute(cmd, to_string=to_string) return gdb.execute(cmd, to_string=to_string)
@REGISTRY.method(action='refresh', display='Refresh Available') @REGISTRY.method(action='refresh', display='Refresh Available')
def refresh_available(node: sch.Schema('AvailableContainer')): def refresh_available(node: AvailableContainer) -> None:
"""List processes on gdb's host system.""" """List processes on gdb's host system."""
with commands.open_tracked_tx('Refresh Available'): with commands.open_tracked_tx('Refresh Available'):
gdb.execute('ghidra trace put-available') gdb.execute('ghidra trace put-available')
@REGISTRY.method(action='refresh', display='Refresh Breakpoints') @REGISTRY.method(action='refresh', display='Refresh Breakpoints')
def refresh_breakpoints(node: sch.Schema('BreakpointContainer')): def refresh_breakpoints(node: BreakpointContainer) -> None:
""" """Refresh the list of breakpoints (including locations for the current
Refresh the list of breakpoints (including locations for the current inferior)."""
inferior).
"""
with commands.open_tracked_tx('Refresh Breakpoints'): with commands.open_tracked_tx('Refresh Breakpoints'):
gdb.execute('ghidra trace put-breakpoints') gdb.execute('ghidra trace put-breakpoints')
@REGISTRY.method(action='refresh', display='Refresh Inferiors') @REGISTRY.method(action='refresh', display='Refresh Inferiors')
def refresh_inferiors(node: sch.Schema('InferiorContainer')): def refresh_inferiors(node: InferiorContainer) -> None:
"""Refresh the list of inferiors.""" """Refresh the list of inferiors."""
with commands.open_tracked_tx('Refresh Inferiors'): with commands.open_tracked_tx('Refresh Inferiors'):
gdb.execute('ghidra trace put-inferiors') gdb.execute('ghidra trace put-inferiors')
@REGISTRY.method(action='refresh', display='Refresh Breakpoint Locations') @REGISTRY.method(action='refresh', display='Refresh Breakpoint Locations')
def refresh_inf_breakpoints(node: sch.Schema('BreakpointLocationContainer')): def refresh_inf_breakpoints(node: BreakpointLocationContainer) -> None:
""" """Refresh the breakpoint locations for the inferior.
Refresh the breakpoint locations for the inferior.
In the course of refreshing the locations, the breakpoint list will also be In the course of refreshing the locations, the breakpoint list will
refreshed. also be refreshed.
""" """
switch_inferior(find_inf_by_infbreak_obj(node)) switch_inferior(find_inf_by_infbreak_obj(node))
with commands.open_tracked_tx('Refresh Breakpoint Locations'): with commands.open_tracked_tx('Refresh Breakpoint Locations'):
@ -316,7 +391,7 @@ def refresh_inf_breakpoints(node: sch.Schema('BreakpointLocationContainer')):
@REGISTRY.method(action='refresh', display='Refresh Environment') @REGISTRY.method(action='refresh', display='Refresh Environment')
def refresh_environment(node: sch.Schema('Environment')): def refresh_environment(node: Environment) -> None:
"""Refresh the environment descriptors (arch, os, endian).""" """Refresh the environment descriptors (arch, os, endian)."""
switch_inferior(find_inf_by_env_obj(node)) switch_inferior(find_inf_by_env_obj(node))
with commands.open_tracked_tx('Refresh Environment'): with commands.open_tracked_tx('Refresh Environment'):
@ -324,7 +399,7 @@ def refresh_environment(node: sch.Schema('Environment')):
@REGISTRY.method(action='refresh', display='Refresh Threads') @REGISTRY.method(action='refresh', display='Refresh Threads')
def refresh_threads(node: sch.Schema('ThreadContainer')): def refresh_threads(node: ThreadContainer) -> None:
"""Refresh the list of threads in the inferior.""" """Refresh the list of threads in the inferior."""
switch_inferior(find_inf_by_threads_obj(node)) switch_inferior(find_inf_by_threads_obj(node))
with commands.open_tracked_tx('Refresh Threads'): with commands.open_tracked_tx('Refresh Threads'):
@ -332,7 +407,7 @@ def refresh_threads(node: sch.Schema('ThreadContainer')):
@REGISTRY.method(action='refresh', display='Refresh Stack') @REGISTRY.method(action='refresh', display='Refresh Stack')
def refresh_stack(node: sch.Schema('Stack')): def refresh_stack(node: Stack) -> None:
"""Refresh the backtrace for the thread.""" """Refresh the backtrace for the thread."""
find_thread_by_stack_obj(node).switch() find_thread_by_stack_obj(node).switch()
with commands.open_tracked_tx('Refresh Stack'): with commands.open_tracked_tx('Refresh Stack'):
@ -340,7 +415,7 @@ def refresh_stack(node: sch.Schema('Stack')):
@REGISTRY.method(action='refresh', display='Refresh Registers') @REGISTRY.method(action='refresh', display='Refresh Registers')
def refresh_registers(node: sch.Schema('RegisterValueContainer')): def refresh_registers(node: RegisterValueContainer) -> None:
"""Refresh the register values for the frame.""" """Refresh the register values for the frame."""
f = find_frame_by_regs_obj(node) f = find_frame_by_regs_obj(node)
if f is None: if f is None:
@ -352,7 +427,7 @@ def refresh_registers(node: sch.Schema('RegisterValueContainer')):
@REGISTRY.method(action='refresh', display='Refresh Memory') @REGISTRY.method(action='refresh', display='Refresh Memory')
def refresh_mappings(node: sch.Schema('Memory')): def refresh_mappings(node: Memory) -> None:
"""Refresh the list of memory regions for the inferior.""" """Refresh the list of memory regions for the inferior."""
switch_inferior(find_inf_by_mem_obj(node)) switch_inferior(find_inf_by_mem_obj(node))
with commands.open_tracked_tx('Refresh Memory Regions'): with commands.open_tracked_tx('Refresh Memory Regions'):
@ -360,10 +435,8 @@ def refresh_mappings(node: sch.Schema('Memory')):
@REGISTRY.method(action='refresh', display="Refresh Modules") @REGISTRY.method(action='refresh', display="Refresh Modules")
def refresh_modules(node: sch.Schema('ModuleContainer')): def refresh_modules(node: ModuleContainer) -> None:
""" """Refresh the modules list for the inferior."""
Refresh the modules list for the inferior.
"""
switch_inferior(find_inf_by_modules_obj(node)) switch_inferior(find_inf_by_modules_obj(node))
with commands.open_tracked_tx('Refresh Modules'): with commands.open_tracked_tx('Refresh Modules'):
gdb.execute('ghidra trace put-modules') gdb.execute('ghidra trace put-modules')
@ -371,20 +444,16 @@ def refresh_modules(node: sch.Schema('ModuleContainer')):
# node is Module so this appears in Modules panel # node is Module so this appears in Modules panel
@REGISTRY.method(display='Refresh all Modules and all Sections') @REGISTRY.method(display='Refresh all Modules and all Sections')
def load_all_sections(node: sch.Schema('Module')): def load_all_sections(node: Module) -> None:
""" """Load/refresh all modules and all sections."""
Load/refresh all modules and all sections.
"""
switch_inferior(find_inf_by_mod_obj(node)) switch_inferior(find_inf_by_mod_obj(node))
with commands.open_tracked_tx('Refresh all Modules and all Sections'): with commands.open_tracked_tx('Refresh all Modules and all Sections'):
gdb.execute('ghidra trace put-sections -all-objects') gdb.execute('ghidra trace put-sections -all-objects')
@REGISTRY.method(action='refresh', display="Refresh Module and Sections") @REGISTRY.method(action='refresh', display="Refresh Module and Sections")
def refresh_sections(node: sch.Schema('Module')): def refresh_sections(node: Module) -> None:
""" """Load/refresh the module and its sections."""
Load/refresh the module and its sections.
"""
switch_inferior(find_inf_by_mod_obj(node)) switch_inferior(find_inf_by_mod_obj(node))
with commands.open_tracked_tx('Refresh Module and Sections'): with commands.open_tracked_tx('Refresh Module and Sections'):
modname = find_module_name_by_mod_obj(node) modname = find_module_name_by_mod_obj(node)
@ -392,31 +461,33 @@ def refresh_sections(node: sch.Schema('Module')):
@REGISTRY.method(action='activate', display="Activate Inferior") @REGISTRY.method(action='activate', display="Activate Inferior")
def activate_inferior(inferior: sch.Schema('Inferior')): def activate_inferior(inferior: Inferior) -> None:
"""Switch to the inferior.""" """Switch to the inferior."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
@REGISTRY.method(action='activate', display="Activate Thread") @REGISTRY.method(action='activate', display="Activate Thread")
def activate_thread(thread: sch.Schema('Thread')): def activate_thread(thread: Thread) -> None:
"""Switch to the thread.""" """Switch to the thread."""
find_thread_by_obj(thread).switch() find_thread_by_obj(thread).switch()
@REGISTRY.method(action='activate', display="Activate Frame") @REGISTRY.method(action='activate', display="Activate Frame")
def activate_frame(frame: sch.Schema('StackFrame')): def activate_frame(frame: StackFrame) -> None:
"""Select the frame.""" """Select the frame."""
find_frame_by_obj(frame).select() f = find_frame_by_obj(frame)
if not f is None:
f.select()
@REGISTRY.method(display='Add Inferior') @REGISTRY.method(display='Add Inferior')
def add_inferior(container: sch.Schema('InferiorContainer')): def add_inferior(container: InferiorContainer) -> None:
"""Add a new inferior.""" """Add a new inferior."""
gdb.execute('add-inferior') gdb.execute('add-inferior')
@REGISTRY.method(action='delete', display="Delete Inferior") @REGISTRY.method(action='delete', display="Delete Inferior")
def delete_inferior(inferior: sch.Schema('Inferior')): def delete_inferior(inferior: Inferior) -> None:
"""Remove the inferior.""" """Remove the inferior."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
gdb.execute(f'remove-inferior {inf.num}') gdb.execute(f'remove-inferior {inf.num}')
@ -424,14 +495,14 @@ def delete_inferior(inferior: sch.Schema('Inferior')):
# TODO: Separate method for each of core, exec, remote, etc...? # TODO: Separate method for each of core, exec, remote, etc...?
@REGISTRY.method(display='Connect Target') @REGISTRY.method(display='Connect Target')
def connect(inferior: sch.Schema('Inferior'), spec: str): def connect(inferior: Inferior, spec: str) -> None:
"""Connect to a target machine or process.""" """Connect to a target machine or process."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute(f'target {spec}') gdb.execute(f'target {spec}')
@REGISTRY.method(action='attach', display='Attach') @REGISTRY.method(action='attach', display='Attach')
def attach_obj(target: sch.Schema('Attachable')): def attach_obj(target: Attachable) -> None:
"""Attach the inferior to the given target.""" """Attach the inferior to the given target."""
# switch_inferior(find_inf_by_obj(inferior)) # switch_inferior(find_inf_by_obj(inferior))
pid = find_availpid_by_obj(target) pid = find_availpid_by_obj(target)
@ -439,25 +510,24 @@ def attach_obj(target: sch.Schema('Attachable')):
@REGISTRY.method(action='attach', display='Attach by PID') @REGISTRY.method(action='attach', display='Attach by PID')
def attach_pid(inferior: sch.Schema('Inferior'), pid: int): def attach_pid(inferior: Inferior, pid: int) -> None:
"""Attach the inferior to the given target.""" """Attach the inferior to the given target."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute(f'attach {pid}') gdb.execute(f'attach {pid}')
@REGISTRY.method(display='Detach') @REGISTRY.method(display='Detach')
def detach(inferior: sch.Schema('Inferior')): def detach(inferior: Inferior) -> None:
"""Detach the inferior's target.""" """Detach the inferior's target."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute('detach') gdb.execute('detach')
@REGISTRY.method(action='launch', display='Launch at main') @REGISTRY.method(action='launch', display='Launch at main')
def launch_main(inferior: sch.Schema('Inferior'), def launch_main(inferior: Inferior,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '') -> None:
""" """Start a native process with the given command line, stopping at 'main'
Start a native process with the given command line, stopping at 'main'
(start). (start).
If 'main' is not defined in the file, this behaves like 'run'. If 'main' is not defined in the file, this behaves like 'run'.
@ -472,13 +542,11 @@ def launch_main(inferior: sch.Schema('Inferior'),
@REGISTRY.method(action='launch', display='Launch at Loader', @REGISTRY.method(action='launch', display='Launch at Loader',
condition=util.GDB_VERSION.major >= 9) condition=util.GDB_VERSION.major >= 9)
def launch_loader(inferior: sch.Schema('Inferior'), def launch_loader(inferior: Inferior,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '') -> None:
""" """Start a native process with the given command line, stopping at first
Start a native process with the given command line, stopping at first instruction (starti)."""
instruction (starti).
"""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute(f''' gdb.execute(f'''
file {file} file {file}
@ -488,14 +556,13 @@ def launch_loader(inferior: sch.Schema('Inferior'),
@REGISTRY.method(action='launch', display='Launch and Run') @REGISTRY.method(action='launch', display='Launch and Run')
def launch_run(inferior: sch.Schema('Inferior'), def launch_run(inferior: Inferior,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '') -> None:
""" """Run a native process with the given command line (run).
Run a native process with the given command line (run).
The process will not stop until it hits one of your breakpoints, or it is The process will not stop until it hits one of your breakpoints, or
signaled. it is signaled.
""" """
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute(f''' gdb.execute(f'''
@ -505,23 +572,24 @@ def launch_run(inferior: sch.Schema('Inferior'),
''') ''')
@REGISTRY.method @REGISTRY.method()
def kill(inferior: sch.Schema('Inferior')): def kill(inferior: Inferior) -> None:
"""Kill execution of the inferior.""" """Kill execution of the inferior."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
with no_confirm(): with no_confirm():
gdb.execute('kill') gdb.execute('kill')
@REGISTRY.method @REGISTRY.method()
def resume(inferior: sch.Schema('Inferior')): def resume(inferior: Inferior) -> None:
"""Continue execution of the inferior.""" """Continue execution of the inferior."""
switch_inferior(find_inf_by_obj(inferior)) switch_inferior(find_inf_by_obj(inferior))
gdb.execute('continue') gdb.execute('continue')
@REGISTRY.method(action='step_ext', icon='icon.debugger.resume.back', condition=util.IS_TRACE) @REGISTRY.method(action='step_ext', icon='icon.debugger.resume.back',
def resume_back(thread: sch.Schema('Inferior')): condition=util.IS_TRACE)
def resume_back(inferior: Inferior) -> None:
"""Continue execution of the inferior backwards.""" """Continue execution of the inferior backwards."""
gdb.execute('reverse-continue') gdb.execute('reverse-continue')
@ -529,44 +597,46 @@ def resume_back(thread: sch.Schema('Inferior')):
# Technically, inferior is not required, but it hints that the affected object # Technically, inferior is not required, but it hints that the affected object
# is the current inferior. This in turn queues the UI to enable or disable the # is the current inferior. This in turn queues the UI to enable or disable the
# button appropriately # button appropriately
@REGISTRY.method @REGISTRY.method()
def interrupt(inferior: sch.Schema('Inferior')): def interrupt(inferior: Inferior) -> None:
"""Interrupt the execution of the debugged program.""" """Interrupt the execution of the debugged program."""
gdb.execute('interrupt') gdb.execute('interrupt')
@REGISTRY.method @REGISTRY.method()
def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction exactly (stepi).""" """Step one instruction exactly (stepi)."""
find_thread_by_obj(thread).switch() find_thread_by_obj(thread).switch()
gdb.execute('stepi') gdb.execute('stepi')
@REGISTRY.method @REGISTRY.method()
def step_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_over(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction, but proceed through subroutine calls (nexti).""" """Step one instruction, but proceed through subroutine calls (nexti)."""
find_thread_by_obj(thread).switch() find_thread_by_obj(thread).switch()
gdb.execute('nexti') gdb.execute('nexti')
@REGISTRY.method @REGISTRY.method()
def step_out(thread: sch.Schema('Thread')): def step_out(thread: Thread) -> None:
"""Execute until the current stack frame returns (finish).""" """Execute until the current stack frame returns (finish)."""
find_thread_by_obj(thread).switch() find_thread_by_obj(thread).switch()
gdb.execute('finish') gdb.execute('finish')
@REGISTRY.method(action='step_ext', display='Advance') @REGISTRY.method(action='step_ext', display='Advance')
def step_advance(thread: sch.Schema('Thread'), address: Address): def step_advance(thread: Thread, address: Address) -> None:
"""Continue execution up to the given address (advance).""" """Continue execution up to the given address (advance)."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.switch() t.switch()
offset = thread.trace.memory_mapper.map_back(t.inferior, address) offset = thread.trace.extra.require_mm().map_back(t.inferior, address)
gdb.execute(f'advance *0x{offset:x}') gdb.execute(f'advance *0x{offset:x}')
@REGISTRY.method(action='step_ext', display='Return') @REGISTRY.method(action='step_ext', display='Return')
def step_return(thread: sch.Schema('Thread'), value: int=None): def step_return(thread: Thread, value: Optional[int] = None) -> None:
"""Skip the remainder of the current function (return).""" """Skip the remainder of the current function (return)."""
find_thread_by_obj(thread).switch() find_thread_by_obj(thread).switch()
if value is None: if value is None:
@ -575,104 +645,109 @@ def step_return(thread: sch.Schema('Thread'), value: int=None):
gdb.execute(f'return {value}') gdb.execute(f'return {value}')
@REGISTRY.method(action='step_ext', icon='icon.debugger.step.back.into', condition=util.IS_TRACE) @REGISTRY.method(action='step_ext', icon='icon.debugger.step.back.into',
def step_back_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): condition=util.IS_TRACE)
def step_back_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step backwards one instruction exactly (reverse-stepi).""" """Step backwards one instruction exactly (reverse-stepi)."""
gdb.execute('reverse-stepi') gdb.execute('reverse-stepi')
@REGISTRY.method(action='step_ext', icon='icon.debugger.step.back.over', condition=util.IS_TRACE) @REGISTRY.method(action='step_ext', icon='icon.debugger.step.back.over',
def step_back_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): condition=util.IS_TRACE)
"""Step one instruction backwards, but proceed through subroutine calls (reverse-nexti).""" def step_back_over(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction backwards, but proceed through subroutine calls
(reverse-nexti)."""
gdb.execute('reverse-nexti') gdb.execute('reverse-nexti')
@REGISTRY.method(action='break_sw_execute') @REGISTRY.method(action='break_sw_execute')
def break_sw_execute_address(inferior: sch.Schema('Inferior'), address: Address): def break_sw_execute_address(inferior: Inferior, address: Address) -> None:
"""Set a breakpoint (break).""" """Set a breakpoint (break)."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset = inferior.trace.memory_mapper.map_back(inf, address) offset = inferior.trace.extra.require_mm().map_back(inf, address)
gdb.execute(f'break *0x{offset:x}') gdb.execute(f'break *0x{offset:x}')
@REGISTRY.method(action='break_ext', display="Set Breakpoint") @REGISTRY.method(action='break_ext', display="Set Breakpoint")
def break_sw_execute_expression(expression: str): def break_sw_execute_expression(expression: str) -> None:
"""Set a breakpoint (break).""" """Set a breakpoint (break)."""
# TODO: Escape? # TODO: Escape?
gdb.execute(f'break {expression}') gdb.execute(f'break {expression}')
@REGISTRY.method(action='break_hw_execute') @REGISTRY.method(action='break_hw_execute')
def break_hw_execute_address(inferior: sch.Schema('Inferior'), address: Address): def break_hw_execute_address(inferior: Inferior, address: Address) -> None:
"""Set a hardware-assisted breakpoint (hbreak).""" """Set a hardware-assisted breakpoint (hbreak)."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset = inferior.trace.memory_mapper.map_back(inf, address) offset = inferior.trace.extra.require_mm().map_back(inf, address)
gdb.execute(f'hbreak *0x{offset:x}') gdb.execute(f'hbreak *0x{offset:x}')
@REGISTRY.method(action='break_ext', display="Set Hardware Breakpoint") @REGISTRY.method(action='break_ext', display="Set Hardware Breakpoint")
def break_hw_execute_expression(expression: str): def break_hw_execute_expression(expression: str) -> None:
"""Set a hardware-assisted breakpoint (hbreak).""" """Set a hardware-assisted breakpoint (hbreak)."""
# TODO: Escape? # TODO: Escape?
gdb.execute(f'hbreak {expression}') gdb.execute(f'hbreak {expression}')
@REGISTRY.method(action='break_read') @REGISTRY.method(action='break_read')
def break_read_range(inferior: sch.Schema('Inferior'), range: AddressRange): def break_read_range(inferior: Inferior, range: AddressRange) -> None:
"""Set a read watchpoint (rwatch).""" """Set a read watchpoint (rwatch)."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset_start = inferior.trace.memory_mapper.map_back( offset_start = inferior.trace.extra.require_mm().map_back(
inf, Address(range.space, range.min)) inf, Address(range.space, range.min))
gdb.execute( gdb.execute(
f'rwatch -location *((char(*)[{range.length()}]) 0x{offset_start:x})') f'rwatch -location *((char(*)[{range.length()}]) 0x{offset_start:x})')
@REGISTRY.method(action='break_ext', display="Set Read Watchpoint") @REGISTRY.method(action='break_ext', display="Set Read Watchpoint")
def break_read_expression(expression: str): def break_read_expression(expression: str) -> None:
"""Set a read watchpoint (rwatch).""" """Set a read watchpoint (rwatch)."""
gdb.execute(f'rwatch {expression}') gdb.execute(f'rwatch {expression}')
@REGISTRY.method(action='break_write') @REGISTRY.method(action='break_write')
def break_write_range(inferior: sch.Schema('Inferior'), range: AddressRange): def break_write_range(inferior: Inferior, range: AddressRange) -> None:
"""Set a watchpoint (watch).""" """Set a watchpoint (watch)."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset_start = inferior.trace.memory_mapper.map_back( offset_start = inferior.trace.extra.require_mm().map_back(
inf, Address(range.space, range.min)) inf, Address(range.space, range.min))
gdb.execute( gdb.execute(
f'watch -location *((char(*)[{range.length()}]) 0x{offset_start:x})') f'watch -location *((char(*)[{range.length()}]) 0x{offset_start:x})')
@REGISTRY.method(action='break_ext', display="Set Watchpoint") @REGISTRY.method(action='break_ext', display="Set Watchpoint")
def break_write_expression(expression: str): def break_write_expression(expression: str) -> None:
"""Set a watchpoint (watch).""" """Set a watchpoint (watch)."""
gdb.execute(f'watch {expression}') gdb.execute(f'watch {expression}')
@REGISTRY.method(action='break_access') @REGISTRY.method(action='break_access')
def break_access_range(inferior: sch.Schema('Inferior'), range: AddressRange): def break_access_range(inferior: Inferior, range: AddressRange) -> None:
"""Set an access watchpoint (awatch).""" """Set an access watchpoint (awatch)."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset_start = inferior.trace.memory_mapper.map_back( offset_start = inferior.trace.extra.require_mm().map_back(
inf, Address(range.space, range.min)) inf, Address(range.space, range.min))
gdb.execute( gdb.execute(
f'awatch -location *((char(*)[{range.length()}]) 0x{offset_start:x})') f'awatch -location *((char(*)[{range.length()}]) 0x{offset_start:x})')
@REGISTRY.method(action='break_ext', display="Set Access Watchpoint") @REGISTRY.method(action='break_ext', display="Set Access Watchpoint")
def break_access_expression(expression: str): def break_access_expression(expression: str) -> None:
"""Set an access watchpoint (awatch).""" """Set an access watchpoint (awatch)."""
gdb.execute(f'awatch {expression}') gdb.execute(f'awatch {expression}')
@REGISTRY.method(action='break_ext', display='Catch Event') @REGISTRY.method(action='break_ext', display='Catch Event')
def break_event(inferior: sch.Schema('Inferior'), spec: str): def break_event(inferior: Inferior, spec: str) -> None:
"""Set a catchpoint (catch).""" """Set a catchpoint (catch)."""
gdb.execute(f'catch {spec}') gdb.execute(f'catch {spec}')
@REGISTRY.method(action='toggle', display="Toggle Breakpoint") @REGISTRY.method(action='toggle', display="Toggle Breakpoint")
def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool): def toggle_breakpoint(breakpoint: BreakpointSpec, enabled: bool) -> None:
"""Toggle a breakpoint.""" """Toggle a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
bpt.enabled = enabled bpt.enabled = enabled
@ -680,7 +755,8 @@ def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool):
@REGISTRY.method(action='toggle', display="Toggle Breakpoint Location", @REGISTRY.method(action='toggle', display="Toggle Breakpoint Location",
condition=util.GDB_VERSION.major >= 13) condition=util.GDB_VERSION.major >= 13)
def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabled: bool): def toggle_breakpoint_location(location: BreakpointLocation,
enabled: bool) -> None:
"""Toggle a breakpoint location.""" """Toggle a breakpoint location."""
loc = find_bpt_loc_by_obj(location) loc = find_bpt_loc_by_obj(location)
loc.enabled = enabled loc.enabled = enabled
@ -688,7 +764,8 @@ def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabl
@REGISTRY.method(action='toggle', display="Toggle Breakpoint Location", @REGISTRY.method(action='toggle', display="Toggle Breakpoint Location",
condition=util.GDB_VERSION.major < 13) condition=util.GDB_VERSION.major < 13)
def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabled: bool): def toggle_breakpoint_location_pre13(location: BreakpointLocation,
enabled: bool) -> None:
"""Toggle a breakpoint location.""" """Toggle a breakpoint location."""
bptnum, locnum = find_bptlocnum_by_obj(location) bptnum, locnum = find_bptlocnum_by_obj(location)
cmd = 'enable' if enabled else 'disable' cmd = 'enable' if enabled else 'disable'
@ -696,17 +773,17 @@ def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabl
@REGISTRY.method(action='delete', display="Delete Breakpoint") @REGISTRY.method(action='delete', display="Delete Breakpoint")
def delete_breakpoint(breakpoint: sch.Schema('BreakpointSpec')): def delete_breakpoint(breakpoint: BreakpointSpec) -> None:
"""Delete a breakpoint.""" """Delete a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
bpt.delete() bpt.delete()
@REGISTRY.method @REGISTRY.method()
def read_mem(inferior: sch.Schema('Inferior'), range: AddressRange): def read_mem(inferior: Inferior, range: AddressRange) -> None:
"""Read memory.""" """Read memory."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset_start = inferior.trace.memory_mapper.map_back( offset_start = inferior.trace.extra.require_mm().map_back(
inf, Address(range.space, range.min)) inf, Address(range.space, range.min))
with commands.open_tracked_tx('Read Memory'): with commands.open_tracked_tx('Read Memory'):
try: try:
@ -717,22 +794,25 @@ def read_mem(inferior: sch.Schema('Inferior'), range: AddressRange):
f'ghidra trace putmem-state 0x{offset_start:x} {range.length()} error') f'ghidra trace putmem-state 0x{offset_start:x} {range.length()} error')
@REGISTRY.method @REGISTRY.method()
def write_mem(inferior: sch.Schema('Inferior'), address: Address, data: bytes): def write_mem(inferior: Inferior, address: Address, data: bytes) -> None:
"""Write memory.""" """Write memory."""
inf = find_inf_by_obj(inferior) inf = find_inf_by_obj(inferior)
offset = inferior.trace.memory_mapper.map_back(inf, address) offset = inferior.trace.extra.require_mm().map_back(inf, address)
inf.write_memory(offset, data) inf.write_memory(offset, data)
@REGISTRY.method @REGISTRY.method()
def write_reg(frame: sch.Schema('StackFrame'), name: str, value: bytes): def write_reg(frame: StackFrame, name: str, value: bytes) -> None:
"""Write a register.""" """Write a register."""
f = find_frame_by_obj(frame) f = find_frame_by_obj(frame)
if f is None:
raise gdb.GdbError(f"Frame {frame.path} no longer exists")
f.select() f.select()
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
mname, mval = frame.trace.register_mapper.map_value_back(inf, name, value) trace: Trace[commands.Extra] = frame.trace
reg = find_reg_by_name(f, mname) rv = trace.extra.require_rm().map_value_back(inf, name, value)
reg = find_reg_by_name(f, rv.name)
size = int(gdb.parse_and_eval(f'sizeof(${reg.name})')) size = int(gdb.parse_and_eval(f'sizeof(${reg.name})'))
arr = '{' + ','.join(str(b) for b in mval) + '}' arr = '{' + ','.join(str(b) for b in rv.value) + '}'
gdb.execute(f'set ((unsigned char[{size}])${reg.name}) = {arr}') gdb.execute(f'set ((unsigned char[{size}])${reg.name}) = {arr}')

View file

@ -26,9 +26,11 @@ class GhidraLanguageParameter(gdb.Parameter):
LanguageID. LanguageID.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__('ghidra-language', gdb.COMMAND_DATA, gdb.PARAM_STRING) super().__init__('ghidra-language', gdb.COMMAND_DATA, gdb.PARAM_STRING)
self.value = 'auto' self.value = 'auto'
GhidraLanguageParameter() GhidraLanguageParameter()
@ -39,8 +41,9 @@ class GhidraCompilerParameter(gdb.Parameter):
that valid compiler spec ids depend on the language id. that valid compiler spec ids depend on the language id.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__('ghidra-compiler', gdb.COMMAND_DATA, gdb.PARAM_STRING) super().__init__('ghidra-compiler', gdb.COMMAND_DATA, gdb.PARAM_STRING)
self.value = 'auto' self.value = 'auto'
GhidraCompilerParameter()
GhidraCompilerParameter()

View file

@ -13,17 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from abc import abstractmethod
from collections import namedtuple from collections import namedtuple
import bisect import bisect
from dataclasses import dataclass
import re import re
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import gdb import gdb
GdbVersion = namedtuple('GdbVersion', ['full', 'major', 'minor']) @dataclass(frozen=True)
class GdbVersion:
full: str
major: int
minor: int
def _compute_gdb_ver(): def _compute_gdb_ver() -> GdbVersion:
blurb = gdb.execute('show version', to_string=True) blurb = gdb.execute('show version', to_string=True)
top = blurb.split('\n')[0] top = blurb.split('\n')[0]
full = top.split(' ')[-1] full = top.split(' ')[-1]
@ -57,19 +64,49 @@ OBJFILE_SECTION_PATTERN_V9 = re.compile("\\s*" +
GNU_DEBUGDATA_PREFIX = ".gnu_debugdata for " GNU_DEBUGDATA_PREFIX = ".gnu_debugdata for "
class Module(namedtuple('BaseModule', ['name', 'base', 'max', 'sections'])): @dataclass(frozen=True)
pass class Region:
start: int
end: int
offset: int
perms: Optional[str]
objfile: str
@dataclass(frozen=True)
class Section:
name: str
start: int
end: int
offset: int
attrs: List[str]
def better(self, other: 'Section') -> 'Section':
start = self.start if self.start != 0 else other.start
end = self.end if self.end != 0 else other.end
offset = self.offset if self.offset != 0 else other.offset
attrs = dict.fromkeys(self.attrs)
attrs.update(dict.fromkeys(other.attrs))
return Section(self.name, start, end, offset, list(attrs))
@dataclass(frozen=True)
class Module:
name: str
base: int
max: int
sections: Dict[str, Section]
class Index: class Index:
def __init__(self, regions): def __init__(self, regions: List[Region]) -> None:
self.regions = {} self.regions: Dict[int, Region] = {}
self.bases = [] self.bases: List[int] = []
for r in regions: for r in regions:
self.regions[r.start] = r self.regions[r.start] = r
self.bases.append(r.start) self.bases.append(r.start)
def compute_base(self, address): def compute_base(self, address: int) -> int:
index = bisect.bisect_right(self.bases, address) - 1 index = bisect.bisect_right(self.bases, address) - 1
if index == -1: if index == -1:
return address return address
@ -84,34 +121,28 @@ class Index:
return region.start return region.start
class Section(namedtuple('BaseSection', ['name', 'start', 'end', 'offset', 'attrs'])): def try_hexint(val: str, name: str) -> int:
def better(self, other):
start = self.start if self.start != 0 else other.start
end = self.end if self.end != 0 else other.end
offset = self.offset if self.offset != 0 else other.offset
attrs = dict.fromkeys(self.attrs)
attrs.update(dict.fromkeys(other.attrs))
return Section(self.name, start, end, offset, list(attrs))
def try_hexint(val, name):
try: try:
return int(val, 16) return int(val, 16)
except ValueError: except ValueError:
gdb.write("Invalid {}: {}".format(name, val), stream=gdb.STDERR) gdb.write(f"Invalid {name}: {val}\n", stream=gdb.STDERR)
return 0 return 0
# AFAICT, Objfile does not give info about load addresses :( # AFAICT, Objfile does not give info about load addresses :(
class ModuleInfoReader(object): class ModuleInfoReader(object):
def name_from_line(self, line): cmd: str
objfile_pattern: re.Pattern
section_pattern: re.Pattern
def name_from_line(self, line: str) -> Optional[str]:
mat = self.objfile_pattern.fullmatch(line) mat = self.objfile_pattern.fullmatch(line)
if mat is None: if mat is None:
return None return None
n = mat['name'] n = mat['name']
return None if mat is None else mat['name'] return None if mat is None else mat['name']
def section_from_line(self, line, max_addr): def section_from_line(self, line: str, max_addr: int) -> Optional[Section]:
mat = self.section_pattern.fullmatch(line) mat = self.section_pattern.fullmatch(line)
if mat is None: if mat is None:
return None return None
@ -122,7 +153,8 @@ class ModuleInfoReader(object):
attrs = [a for a in mat['attrs'].split(' ') if a != ''] attrs = [a for a in mat['attrs'].split(' ') if a != '']
return Section(name, start, end, offset, attrs) return Section(name, start, end, offset, attrs)
def finish_module(self, name, sections, index): def finish_module(self, name: str, sections: Dict[str, Section],
index: Index) -> Module:
alloc = {k: s for k, s in sections.items() if 'ALLOC' in s.attrs} alloc = {k: s for k, s in sections.items() if 'ALLOC' in s.attrs}
if len(alloc) == 0: if len(alloc) == 0:
return Module(name, 0, 0, alloc) return Module(name, 0, 0, alloc)
@ -130,13 +162,13 @@ class ModuleInfoReader(object):
max_addr = max(s.end for s in alloc.values()) max_addr = max(s.end for s in alloc.values())
return Module(name, base_addr, max_addr, alloc) return Module(name, base_addr, max_addr, alloc)
def get_modules(self): def get_modules(self) -> Dict[str, Module]:
modules = {} modules = {}
index = Index(REGION_INFO_READER.get_regions()) index = Index(REGION_INFO_READER.get_regions())
out = gdb.execute(self.cmd, to_string=True) out = gdb.execute(self.cmd, to_string=True)
max_addr = compute_max_addr() max_addr = compute_max_addr()
name = None name = None
sections = None sections: Dict[str, Section] = {}
for line in out.split('\n'): for line in out.split('\n'):
n = self.name_from_line(line) n = self.name_from_line(line)
if n is not None: if n is not None:
@ -176,7 +208,7 @@ class ModuleInfoReaderV11(ModuleInfoReader):
section_pattern = OBJFILE_SECTION_PATTERN_V9 section_pattern = OBJFILE_SECTION_PATTERN_V9
def _choose_module_info_reader(): def _choose_module_info_reader() -> ModuleInfoReader:
if GDB_VERSION.major == 8: if GDB_VERSION.major == 8:
return ModuleInfoReaderV8() return ModuleInfoReaderV8()
elif GDB_VERSION.major == 9: elif GDB_VERSION.major == 9:
@ -207,15 +239,11 @@ REGION_PATTERN = re.compile("\\s*" +
"(?P<objfile>.*)") "(?P<objfile>.*)")
class Region(namedtuple('BaseRegion', ['start', 'end', 'offset', 'perms', 'objfile'])):
pass
class RegionInfoReader(object): class RegionInfoReader(object):
cmd = REGIONS_CMD cmd = REGIONS_CMD
region_pattern = REGION_PATTERN region_pattern = REGION_PATTERN
def region_from_line(self, line, max_addr): def region_from_line(self, line: str, max_addr: int) -> Optional[Region]:
mat = self.region_pattern.fullmatch(line) mat = self.region_pattern.fullmatch(line)
if mat is None: if mat is None:
return None return None
@ -226,8 +254,8 @@ class RegionInfoReader(object):
objfile = mat['objfile'] objfile = mat['objfile']
return Region(start, end, offset, perms, objfile) return Region(start, end, offset, perms, objfile)
def get_regions(self): def get_regions(self) -> List[Region]:
regions = [] regions: List[Region] = []
try: try:
out = gdb.execute(self.cmd, to_string=True) out = gdb.execute(self.cmd, to_string=True)
max_addr = compute_max_addr() max_addr = compute_max_addr()
@ -240,12 +268,12 @@ class RegionInfoReader(object):
regions.append(r) regions.append(r)
return regions return regions
def full_mem(self): def full_mem(self) -> Region:
# TODO: This may not work for Harvard architectures # TODO: This may not work for Harvard architectures
max_addr = compute_max_addr() max_addr = compute_max_addr()
return Region(0, max_addr+1, 0, None, 'full memory') return Region(0, max_addr+1, 0, None, 'full memory')
def have_changed(self, regions): def have_changed(self, regions: List[Region]) -> Tuple[bool, Optional[List[Region]]]:
if len(regions) == 1 and regions[0].objfile == 'full memory': if len(regions) == 1 and regions[0].objfile == 'full memory':
return False, None return False, None
new_regions = self.get_regions() new_regions = self.get_regions()
@ -257,7 +285,7 @@ class RegionInfoReader(object):
return mat['perms'] return mat['perms']
def _choose_region_info_reader(): def _choose_region_info_reader() -> RegionInfoReader:
if 8 <= GDB_VERSION.major: if 8 <= GDB_VERSION.major:
return RegionInfoReader() return RegionInfoReader()
else: else:
@ -273,18 +301,23 @@ BREAK_PATTERN = re.compile('')
BREAK_LOC_PATTERN = re.compile('') BREAK_LOC_PATTERN = re.compile('')
class BreakpointLocation(namedtuple('BaseBreakpointLocation', ['address', 'enabled', 'thread_groups'])): @dataclass(frozen=True)
class BreakpointLocation:
address: int
enabled: bool
thread_groups: List[int]
class BreakpointLocationInfoReader(object):
@abstractmethod
def get_locations(self, breakpoint: gdb.Breakpoint) -> List[Union[
BreakpointLocation, gdb.BreakpointLocation]]:
pass pass
class BreakpointLocationInfoReaderV8(object): class BreakpointLocationInfoReaderV8(BreakpointLocationInfoReader):
def breakpoint_from_line(self, line): def get_locations(self, breakpoint: gdb.Breakpoint) -> List[Union[
pass BreakpointLocation, gdb.BreakpointLocation]]:
def location_from_line(self, line):
pass
def get_locations(self, breakpoint):
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
thread_groups = [inf.num] thread_groups = [inf.num]
if breakpoint.location is not None and breakpoint.location.startswith("*0x"): if breakpoint.location is not None and breakpoint.location.startswith("*0x"):
@ -295,20 +328,16 @@ class BreakpointLocationInfoReaderV8(object):
return [] return []
class BreakpointLocationInfoReaderV9(object): class BreakpointLocationInfoReaderV9(BreakpointLocationInfoReader):
def breakpoint_from_line(self, line): def get_locations(self, breakpoint: gdb.Breakpoint) -> List[Union[
pass BreakpointLocation, gdb.BreakpointLocation]]:
def location_from_line(self, line):
pass
def get_locations(self, breakpoint):
inf = gdb.selected_inferior() inf = gdb.selected_inferior()
thread_groups = [inf.num] thread_groups = [inf.num]
if breakpoint.location is None: if breakpoint.location is None:
return [] return []
try: try:
address = gdb.parse_and_eval(breakpoint.location).address address = int(gdb.parse_and_eval(
breakpoint.location).address) & compute_max_addr()
loc = BreakpointLocation( loc = BreakpointLocation(
address, breakpoint.enabled, thread_groups) address, breakpoint.enabled, thread_groups)
return [loc] return [loc]
@ -317,12 +346,13 @@ class BreakpointLocationInfoReaderV9(object):
return [] return []
class BreakpointLocationInfoReaderV13(object): class BreakpointLocationInfoReaderV13(BreakpointLocationInfoReader):
def get_locations(self, breakpoint): def get_locations(self, breakpoint: gdb.Breakpoint) -> List[Union[
BreakpointLocation, gdb.BreakpointLocation]]:
return breakpoint.locations return breakpoint.locations
def _choose_breakpoint_location_info_reader(): def _choose_breakpoint_location_info_reader() -> BreakpointLocationInfoReader:
if GDB_VERSION.major >= 13: if GDB_VERSION.major >= 13:
return BreakpointLocationInfoReaderV13() return BreakpointLocationInfoReaderV13()
if GDB_VERSION.major >= 9: if GDB_VERSION.major >= 9:
@ -337,16 +367,16 @@ def _choose_breakpoint_location_info_reader():
BREAKPOINT_LOCATION_INFO_READER = _choose_breakpoint_location_info_reader() BREAKPOINT_LOCATION_INFO_READER = _choose_breakpoint_location_info_reader()
def set_bool_param_by_api(name, value): def set_bool_param_by_api(name: str, value: bool) -> None:
gdb.set_parameter(name, value) gdb.set_parameter(name, value)
def set_bool_param_by_cmd(name, value): def set_bool_param_by_cmd(name: str, value: bool) -> None:
val = 'on' if value else 'off' val = 'on' if value else 'off'
gdb.execute(f'set {name} {val}') gdb.execute(f'set {name} {val}')
def choose_set_parameter(): def choose_set_parameter() -> Callable[[str, bool], None]:
if GDB_VERSION.major >= 13: if GDB_VERSION.major >= 13:
return set_bool_param_by_api return set_bool_param_by_api
else: else:
@ -356,30 +386,32 @@ def choose_set_parameter():
set_bool_param = choose_set_parameter() set_bool_param = choose_set_parameter()
def get_level(frame): def get_level(frame: gdb.Frame) -> int:
if hasattr(frame, "level"): if hasattr(frame, "level"):
return frame.level() return frame.level()
else: else:
level = -1 level = -1
f = frame f: Optional[gdb.Frame] = frame
while f is not None: while f is not None:
level += 1 level += 1
f = f.newer() f = f.newer()
return level return level
class RegisterDesc(namedtuple('BaseRegisterDesc', ['name'])): @dataclass(frozen=True)
pass class RegisterDesc:
name: str
def get_register_descs(arch, group='all'): def get_register_descs(arch: gdb.Architecture, group: str = 'all') -> List[
Union[RegisterDesc, gdb.RegisterDescriptor]]:
if hasattr(arch, "registers"): if hasattr(arch, "registers"):
try: try:
return arch.registers(group) return list(arch.registers(group))
except ValueError: # No such group, or version too old except ValueError: # No such group, or version too old
return arch.registers() return list(arch.registers())
else: else:
descs = [] descs: List[Union[RegisterDesc, gdb.RegisterDescriptor]] = []
try: try:
regset = gdb.execute( regset = gdb.execute(
f"info registers {group}", to_string=True).strip().split('\n') f"info registers {group}", to_string=True).strip().split('\n')
@ -393,7 +425,7 @@ def get_register_descs(arch, group='all'):
return descs return descs
def selected_frame(): def selected_frame() -> Optional[gdb.Frame]:
try: try:
return gdb.selected_frame() return gdb.selected_frame()
except Exception as e: except Exception as e:
@ -401,5 +433,5 @@ def selected_frame():
return None return None
def compute_max_addr(): def compute_max_addr() -> int:
return (1 << (int(gdb.parse_and_eval("sizeof(void*)")) * 8)) - 1 return (1 << (int(gdb.parse_and_eval("sizeof(void*)")) * 8)) - 1

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from dataclasses import dataclass
from typing import Dict, List, Optional
import gdb import gdb
from . import util from . import util
@ -23,22 +25,23 @@ from .commands import install, cmd
class GhidraWinePrefix(gdb.Command): class GhidraWinePrefix(gdb.Command):
"""Commands for tracing Wine processes""" """Commands for tracing Wine processes"""
def __init__(self): def __init__(self) -> None:
super().__init__('ghidra wine', gdb.COMMAND_SUPPORT, prefix=True) super().__init__('ghidra wine', gdb.COMMAND_SUPPORT, prefix=True)
def is_mapped(pe_file): def is_mapped(pe_file: str) -> bool:
return pe_file in gdb.execute("info proc mappings", to_string=True) return pe_file in gdb.execute("info proc mappings", to_string=True)
def set_break(command): def set_break(command: str) -> gdb.Breakpoint:
breaks_before = set(gdb.breakpoints()) breaks_before = set(gdb.breakpoints())
gdb.execute(command) gdb.execute(command)
return (set(gdb.breakpoints()) - breaks_before).pop() return (set(gdb.breakpoints()) - breaks_before).pop()
@cmd('ghidra wine run-to-image', '-ghidra-wine-run-to-image', gdb.COMMAND_SUPPORT, False) @cmd('ghidra wine run-to-image', '-ghidra-wine-run-to-image',
def ghidra_wine_run_to_image(pe_file, *, is_mi, **kwargs): gdb.COMMAND_SUPPORT, False)
def ghidra_wine_run_to_image(pe_file: str, *, is_mi: bool, **kwargs) -> None:
mprot_catchpoint = set_break(""" mprot_catchpoint = set_break("""
catch syscall mprotect catch syscall mprotect
commands commands
@ -53,13 +56,16 @@ end
ORIG_MODULE_INFO_READER = util.MODULE_INFO_READER ORIG_MODULE_INFO_READER = util.MODULE_INFO_READER
@dataclass(frozen=False)
class Range(object): class Range(object):
min: int
max: int
def expand(self, region): @staticmethod
if not hasattr(self, 'min'): def from_region(region: util.Region) -> 'Range':
self.min = region.start return Range(region.start, region.end)
self.max = region.end
else: def expand(self, region: util.Region):
self.min = min(self.min, region.start) self.min = min(self.min, region.start)
self.max = max(self.max, region.end) self.max = max(self.max, region.end)
return self return self
@ -69,14 +75,14 @@ class Range(object):
MODULE_SUFFIXES = (".exe", ".dll") MODULE_SUFFIXES = (".exe", ".dll")
class WineModuleInfoReader(object): class WineModuleInfoReader(util.ModuleInfoReader):
def get_modules(self): def get_modules(self) -> Dict[str, util.Module]:
modules = ORIG_MODULE_INFO_READER.get_modules() modules = ORIG_MODULE_INFO_READER.get_modules()
ranges = dict() ranges = dict()
for region in util.REGION_INFO_READER.get_regions(): for region in util.REGION_INFO_READER.get_regions():
if not region.objfile in ranges: if not region.objfile in ranges:
ranges[region.objfile] = Range().expand(region) ranges[region.objfile] = Range.from_region(region)
else: else:
ranges[region.objfile].expand(region) ranges[region.objfile].expand(region)
for k, v in ranges.items(): for k, v in ranges.items():

View file

@ -15,4 +15,5 @@ src/main/py/LICENSE||GHIDRA||||END|
src/main/py/MANIFEST.in||GHIDRA||||END| src/main/py/MANIFEST.in||GHIDRA||||END|
src/main/py/README.md||GHIDRA||||END| src/main/py/README.md||GHIDRA||||END|
src/main/py/pyproject.toml||GHIDRA||||END| src/main/py/pyproject.toml||GHIDRA||||END|
src/main/py/src/ghidralldb/py.typed||GHIDRA||||END|
src/main/py/src/ghidralldb/schema.xml||GHIDRA||||END| src/main/py/src/ghidralldb/schema.xml||GHIDRA||||END|

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "ghidralldb" name = "ghidralldb"
version = "11.3" version = "11.4"
authors = [ authors = [
{ name="Ghidra Development Team" }, { name="Ghidra Development Team" },
] ]
@ -17,9 +17,12 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"ghidratrace==11.3", "ghidratrace==11.4",
] ]
[project.urls] [project.urls]
"Homepage" = "https://github.com/NationalSecurityAgency/ghidra" "Homepage" = "https://github.com/NationalSecurityAgency/ghidra"
"Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues" "Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues"
[tool.setuptools.package-data]
ghidralldb = ["py.typed"]

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from typing import Dict, List, Optional, Tuple
from ghidratrace.client import Address, RegVal from ghidratrace.client import Address, RegVal
import lldb import lldb
@ -20,8 +21,9 @@ from . import util
# NOTE: This map is derived from the ldefs using a script # NOTE: This map is derived from the ldefs using a script
language_map = { language_map: Dict[str, List[str]] = {
'aarch64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon', 'AARCH64:LE:64:v8A'], 'aarch64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon',
'AARCH64:LE:64:v8A'],
'arm': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'], 'arm': ['ARM:BE:32:v8', 'ARM:BE:32:v8T', 'ARM:LE:32:v8', 'ARM:LE:32:v8T'],
'armv4': ['ARM:BE:32:v4', 'ARM:LE:32:v4'], 'armv4': ['ARM:BE:32:v4', 'ARM:LE:32:v4'],
'armv4t': ['ARM:BE:32:v4t', 'ARM:LE:32:v4t'], 'armv4t': ['ARM:BE:32:v4t', 'ARM:LE:32:v4t'],
@ -50,8 +52,10 @@ language_map = {
'thumbv7em': ['ARM:BE:32:Cortex', 'ARM:LE:32:Cortex'], 'thumbv7em': ['ARM:BE:32:Cortex', 'ARM:LE:32:Cortex'],
'armv8': ['ARM:BE:32:v8', 'ARM:LE:32:v8'], 'armv8': ['ARM:BE:32:v8', 'ARM:LE:32:v8'],
'armv8l': ['ARM:BE:32:v8', 'ARM:LE:32:v8'], 'armv8l': ['ARM:BE:32:v8', 'ARM:LE:32:v8'],
'arm64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon', 'AARCH64:LE:64:v8A'], 'arm64': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon',
'arm64e': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon', 'AARCH64:LE:64:v8A'], 'AARCH64:LE:64:v8A'],
'arm64e': ['AARCH64:BE:64:v8A', 'AARCH64:LE:64:AppleSilicon',
'AARCH64:LE:64:v8A'],
'arm64_32': ['ARM:BE:32:v8', 'ARM:LE:32:v8'], 'arm64_32': ['ARM:BE:32:v8', 'ARM:LE:32:v8'],
'mips': ['MIPS:BE:32:default', 'MIPS:LE:32:default'], 'mips': ['MIPS:BE:32:default', 'MIPS:LE:32:default'],
'mipsr2': ['MIPS:BE:32:default', 'MIPS:LE:32:default'], 'mipsr2': ['MIPS:BE:32:default', 'MIPS:LE:32:default'],
@ -102,8 +106,11 @@ language_map = {
'hexagon': [], 'hexagon': [],
'hexagonv4': [], 'hexagonv4': [],
'hexagonv5': [], 'hexagonv5': [],
'riscv32': ['RISCV:LE:32:RV32G', 'RISCV:LE:32:RV32GC', 'RISCV:LE:32:RV32I', 'RISCV:LE:32:RV32IC', 'RISCV:LE:32:RV32IMC', 'RISCV:LE:32:default'], 'riscv32': ['RISCV:LE:32:RV32G', 'RISCV:LE:32:RV32GC', 'RISCV:LE:32:RV32I',
'riscv64': ['RISCV:LE:64:RV64G', 'RISCV:LE:64:RV64GC', 'RISCV:LE:64:RV64I', 'RISCV:LE:64:RV64IC', 'RISCV:LE:64:default'], 'RISCV:LE:32:RV32IC', 'RISCV:LE:32:RV32IMC',
'RISCV:LE:32:default'],
'riscv64': ['RISCV:LE:64:RV64G', 'RISCV:LE:64:RV64GC', 'RISCV:LE:64:RV64I',
'RISCV:LE:64:RV64IC', 'RISCV:LE:64:default'],
'unknown-mach-32': ['DATA:LE:32:default', 'DATA:LE:32:default'], 'unknown-mach-32': ['DATA:LE:32:default', 'DATA:LE:32:default'],
'unknown-mach-64': ['DATA:LE:64:default', 'DATA:LE:64:default'], 'unknown-mach-64': ['DATA:LE:64:default', 'DATA:LE:64:default'],
'arc': [], 'arc': [],
@ -111,19 +118,20 @@ language_map = {
'wasm32': ['x86:LE:32:default'], 'wasm32': ['x86:LE:32:default'],
} }
data64_compiler_map = { data64_compiler_map: Dict[Optional[str], str] = {
None: 'pointer64', None: 'pointer64',
} }
x86_compiler_map = { x86_compiler_map: Dict[Optional[str], str] = {
'windows': 'windows', 'windows': 'windows',
'Cygwin': 'windows', 'Cygwin': 'windows',
'linux': 'gcc', 'linux': 'gcc',
'default': 'gcc', 'default': 'gcc',
'unknown': 'gcc', 'unknown': 'gcc',
None: 'gcc',
} }
default_compiler_map = { default_compiler_map: Dict[Optional[str], str] = {
'freebsd': 'gcc', 'freebsd': 'gcc',
'linux': 'gcc', 'linux': 'gcc',
'netbsd': 'gcc', 'netbsd': 'gcc',
@ -138,7 +146,7 @@ default_compiler_map = {
'unknown': 'default', 'unknown': 'default',
} }
compiler_map = { compiler_map: Dict[str, Dict[Optional[str], str]] = {
'DATA:BE:64:': data64_compiler_map, 'DATA:BE:64:': data64_compiler_map,
'DATA:LE:64:': data64_compiler_map, 'DATA:LE:64:': data64_compiler_map,
'x86:LE:32:': x86_compiler_map, 'x86:LE:32:': x86_compiler_map,
@ -148,7 +156,7 @@ compiler_map = {
} }
def find_host_triple(): def find_host_triple() -> str:
dbg = util.get_debugger() dbg = util.get_debugger()
for i in range(dbg.GetNumPlatforms()): for i in range(dbg.GetNumPlatforms()):
platform = dbg.GetPlatformAtIndex(i) platform = dbg.GetPlatformAtIndex(i)
@ -157,19 +165,19 @@ def find_host_triple():
return 'unrecognized' return 'unrecognized'
def find_triple(): def find_triple() -> str:
triple = util.get_target().triple triple = util.get_target().triple
if triple is not None: if triple is not None:
return triple return triple
return find_host_triple() return find_host_triple()
def get_arch(): def get_arch() -> str:
triple = find_triple() triple = find_triple()
return triple.split('-')[0] return triple.split('-')[0]
def get_endian(): def get_endian() -> str:
parm = util.get_convenience_variable('endian') parm = util.get_convenience_variable('endian')
if parm != 'auto': if parm != 'auto':
return parm return parm
@ -183,7 +191,7 @@ def get_endian():
return 'unrecognized' return 'unrecognized'
def get_osabi(): def get_osabi() -> str:
parm = util.get_convenience_variable('osabi') parm = util.get_convenience_variable('osabi')
if not parm in ['auto', 'default']: if not parm in ['auto', 'default']:
return parm return parm
@ -195,7 +203,7 @@ def get_osabi():
return triple.split('-')[2] return triple.split('-')[2]
def compute_ghidra_language(): def compute_ghidra_language() -> str:
# First, check if the parameter is set # First, check if the parameter is set
lang = util.get_convenience_variable('ghidra-language') lang = util.get_convenience_variable('ghidra-language')
if lang != 'auto': if lang != 'auto':
@ -223,37 +231,33 @@ def compute_ghidra_language():
return 'DATA' + lebe + '64:default' return 'DATA' + lebe + '64:default'
def compute_ghidra_compiler(lang): def compute_ghidra_compiler(lang: str) -> str:
# First, check if the parameter is set # First, check if the parameter is set
comp = util.get_convenience_variable('ghidra-compiler') comp = util.get_convenience_variable('ghidra-compiler')
if comp != 'auto': if comp != 'auto':
return comp return comp
# Check if the selected lang has specific compiler recommendations # Check if the selected lang has specific compiler recommendations
matched_lang = sorted( # NOTE: Unlike other agents, we put prefixes in map keys
(l for l in compiler_map if l in lang), matches = [l for l in compiler_map if lang.startswith(l)]
key=lambda l: compiler_map[l] if len(matches) == 0:
)
if len(matched_lang) == 0:
print(f"{lang} not found in compiler map - using default compiler") print(f"{lang} not found in compiler map - using default compiler")
return 'default' return 'default'
comp_map = compiler_map[matches[0]]
comp_map = compiler_map[matched_lang[0]]
if comp_map == data64_compiler_map: if comp_map == data64_compiler_map:
print(f"Using the DATA64 compiler map") print(f"Using the DATA64 compiler map")
osabi = get_osabi() osabi = get_osabi()
if osabi in comp_map: if osabi in comp_map:
return comp_map[osabi] return comp_map[osabi]
if lang.startswith("x86:"):
print(f"{osabi} not found in compiler map - using gcc")
return 'gcc'
if None in comp_map: if None in comp_map:
return comp_map[None] def_comp = comp_map[None]
print(f"{osabi} not found in compiler map - using {def_comp} compiler")
return def_comp
print(f"{osabi} not found in compiler map - using default compiler") print(f"{osabi} not found in compiler map - using default compiler")
return 'default' return 'default'
def compute_ghidra_lcsp(): def compute_ghidra_lcsp() -> Tuple[str, str]:
lang = compute_ghidra_language() lang = compute_ghidra_language()
comp = compute_ghidra_compiler(lang) comp = compute_ghidra_compiler(lang)
return lang, comp return lang, comp
@ -261,10 +265,10 @@ def compute_ghidra_lcsp():
class DefaultMemoryMapper(object): class DefaultMemoryMapper(object):
def __init__(self, defaultSpace): def __init__(self, defaultSpace: str) -> None:
self.defaultSpace = defaultSpace self.defaultSpace = defaultSpace
def map(self, proc: lldb.SBProcess, offset: int): def map(self, proc: lldb.SBProcess, offset: int) -> Tuple[str, Address]:
space = self.defaultSpace space = self.defaultSpace
return self.defaultSpace, Address(space, offset) return self.defaultSpace, Address(space, offset)
@ -277,10 +281,10 @@ class DefaultMemoryMapper(object):
DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram') DEFAULT_MEMORY_MAPPER = DefaultMemoryMapper('ram')
memory_mappers = {} memory_mappers: Dict[str, DefaultMemoryMapper] = {}
def compute_memory_mapper(lang): def compute_memory_mapper(lang: str) -> DefaultMemoryMapper:
if not lang in memory_mappers: if not lang in memory_mappers:
return DEFAULT_MEMORY_MAPPER return DEFAULT_MEMORY_MAPPER
return memory_mappers[lang] return memory_mappers[lang]
@ -288,31 +292,31 @@ def compute_memory_mapper(lang):
class DefaultRegisterMapper(object): class DefaultRegisterMapper(object):
def __init__(self, byte_order): def __init__(self, byte_order: str) -> None:
if not byte_order in ['big', 'little']: if not byte_order in ['big', 'little']:
raise ValueError("Invalid byte_order: {}".format(byte_order)) raise ValueError("Invalid byte_order: {}".format(byte_order))
self.byte_order = byte_order self.byte_order = byte_order
self.union_winners = {}
def map_name(self, proc, name): def map_name(self, proc: lldb.SBProcess, name: str) -> str:
return name return name
def map_value(self, proc, name, value): def map_value(self, proc: lldb.SBProcess, name: str, value: bytes) -> RegVal:
return RegVal(self.map_name(proc, name), value) return RegVal(self.map_name(proc, name), value)
def map_name_back(self, proc, name): def map_name_back(self, proc: lldb.SBProcess, name: str) -> str:
return name return name
def map_value_back(self, proc, name, value): def map_value_back(self, proc: lldb.SBProcess, name: str,
value: bytes) -> RegVal:
return RegVal(self.map_name_back(proc, name), value) return RegVal(self.map_name_back(proc, name), value)
class Intel_x86_64_RegisterMapper(DefaultRegisterMapper): class Intel_x86_64_RegisterMapper(DefaultRegisterMapper):
def __init__(self): def __init__(self) -> None:
super().__init__('little') super().__init__('little')
def map_name(self, proc, name): def map_name(self, proc: lldb.SBProcess, name: str) -> str:
if name is None: if name is None:
return 'UNKNOWN' return 'UNKNOWN'
if name == 'eflags': if name == 'eflags':
@ -322,26 +326,27 @@ class Intel_x86_64_RegisterMapper(DefaultRegisterMapper):
return 'ymm' + name[3:] return 'ymm' + name[3:]
return super().map_name(proc, name) return super().map_name(proc, name)
def map_value(self, proc, name, value): def map_value(self, proc: lldb.SBProcess, name: str, value: bytes) -> RegVal:
rv = super().map_value(proc, name, value) rv = super().map_value(proc, name, value)
if rv.name.startswith('ymm') and len(rv.value) > 32: if rv.name.startswith('ymm') and len(rv.value) > 32:
return RegVal(rv.name, rv.value[-32:]) return RegVal(rv.name, rv.value[-32:])
return rv return rv
def map_name_back(self, proc, name): def map_name_back(self, proc: lldb.SBProcess, name: str) -> str:
if name == 'rflags': if name == 'rflags':
return 'eflags' return 'eflags'
return super().map_name_back(proc, name)
DEFAULT_BE_REGISTER_MAPPER = DefaultRegisterMapper('big') DEFAULT_BE_REGISTER_MAPPER = DefaultRegisterMapper('big')
DEFAULT_LE_REGISTER_MAPPER = DefaultRegisterMapper('little') DEFAULT_LE_REGISTER_MAPPER = DefaultRegisterMapper('little')
register_mappers = { register_mappers: Dict[str, DefaultRegisterMapper] = {
'x86:LE:64:default': Intel_x86_64_RegisterMapper() 'x86:LE:64:default': Intel_x86_64_RegisterMapper()
} }
def compute_register_mapper(lang): def compute_register_mapper(lang: str) -> DefaultRegisterMapper:
if not lang in register_mappers: if not lang in register_mappers:
if ':BE:' in lang: if ':BE:' in lang:
return DEFAULT_BE_REGISTER_MAPPER return DEFAULT_BE_REGISTER_MAPPER

View file

@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from dataclasses import dataclass, field
import threading import threading
import time import time
from typing import Any, Optional, Union
import lldb import lldb
@ -24,34 +26,41 @@ from . import commands, util
ALL_EVENTS = 0xFFFF ALL_EVENTS = 0xFFFF
@dataclass(frozen=False)
class HookState(object): class HookState(object):
__slots__ = ('installed', 'mem_catchpoint') installed = False
def __init__(self): def __init__(self) -> None:
self.installed = False self.installed = False
self.mem_catchpoint = None
@dataclass(frozen=False)
class ProcessState(object): class ProcessState(object):
__slots__ = ('first', 'regions', 'modules', 'threads', first = True
'breaks', 'watches', 'visited')
def __init__(self):
self.first = True
# For things we can detect changes to between stops # For things we can detect changes to between stops
regions = False
modules = False
threads = False
breaks = False
watches = False
# For frames and threads that have already been synced since last stop
visited: set[Any] = field(default_factory=set)
def __init__(self) -> None:
self.first = True
self.regions = False self.regions = False
self.modules = False self.modules = False
self.threads = False self.threads = False
self.breaks = False self.breaks = False
self.watches = False self.watches = False
# For frames and threads that have already been synced since last stop
self.visited = set() self.visited = set()
def record(self, description=None): def record(self, description: Optional[str] = None) -> None:
first = self.first first = self.first
self.first = False self.first = False
trace = commands.STATE.require_trace()
if description is not None: if description is not None:
commands.STATE.trace.snapshot(description) trace.snapshot(description)
if first: if first:
commands.put_processes() commands.put_processes()
commands.put_environment() commands.put_environment()
@ -121,7 +130,8 @@ class QuitSentinel(object):
QUIT = QuitSentinel() QUIT = QuitSentinel()
def process_event(self, listener, event): def process_event(self, listener: lldb.SBListener,
event: lldb.SBEvent) -> Union[QuitSentinel, bool]:
try: try:
desc = util.get_description(event) desc = util.get_description(event)
# print(f"Event: {desc}") # print(f"Event: {desc}")
@ -130,7 +140,7 @@ def process_event(self, listener, event):
# LLDB may crash on event.GetBroadcasterClass, otherwise # LLDB may crash on event.GetBroadcasterClass, otherwise
# All the checks below, e.g. SBTarget.EventIsTargetEvent, call this # All the checks below, e.g. SBTarget.EventIsTargetEvent, call this
print(f"Ignoring {desc} because target is invalid") print(f"Ignoring {desc} because target is invalid")
return return False
event_process = util.get_process() event_process = util.get_process()
if event_process.IsValid() and event_process.GetProcessID() not in PROC_STATE: if event_process.IsValid() and event_process.GetProcessID() not in PROC_STATE:
PROC_STATE[event_process.GetProcessID()] = ProcessState() PROC_STATE[event_process.GetProcessID()] = ProcessState()
@ -260,13 +270,14 @@ def process_event(self, listener, event):
return True return True
except BaseException as e: except BaseException as e:
print(e) print(e)
return False
class EventThread(threading.Thread): class EventThread(threading.Thread):
func = process_event func = process_event
event = lldb.SBEvent() event = lldb.SBEvent()
def run(self): def run(self) -> None:
# Let's only try at most 4 times to retrieve any kind of event. # Let's only try at most 4 times to retrieve any kind of event.
# After that, the thread exits. # After that, the thread exits.
listener = lldb.SBListener('eventlistener') listener = lldb.SBListener('eventlistener')
@ -365,40 +376,40 @@ class EventThread(threading.Thread):
""" """
def on_new_process(event): def on_new_process(event: lldb.SBEvent) -> None:
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("New Process {}".format(event.process.num)): with trace.open_tx(f"New Process {event.process.num}"):
commands.put_processes() # TODO: Could put just the one.... commands.put_processes() # TODO: Could put just the one....
def on_process_selected(): def on_process_selected() -> None:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Process {} selected".format(proc.GetProcessID())): with trace.open_tx(f"Process {proc.GetProcessID()} selected"):
PROC_STATE[proc.GetProcessID()].record() PROC_STATE[proc.GetProcessID()].record()
commands.activate() commands.activate()
def on_process_deleted(event): def on_process_deleted(event: lldb.SBEvent) -> None:
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return
if event.process.num in PROC_STATE: if event.process.num in PROC_STATE:
del PROC_STATE[event.process.num] del PROC_STATE[event.process.num]
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Process {} deleted".format(event.process.num)): with trace.open_tx(f"Process {event.process.num} deleted"):
commands.put_processes() # TODO: Could just delete the one.... commands.put_processes() # TODO: Could just delete the one....
def on_new_thread(event): def on_new_thread(event: lldb.SBEvent) -> None:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return
@ -406,224 +417,237 @@ def on_new_thread(event):
# TODO: Syscall clone/exit to detect thread destruction? # TODO: Syscall clone/exit to detect thread destruction?
def on_thread_selected(): def on_thread_selected() -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
t = util.selected_thread() t = util.selected_thread()
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Thread {}.{} selected".format(proc.GetProcessID(), t.GetThreadID())): with trace.open_tx(f"Thread {proc.GetProcessID()}.{t.GetThreadID()} selected"):
PROC_STATE[proc.GetProcessID()].record() PROC_STATE[proc.GetProcessID()].record()
commands.put_threads() commands.put_threads()
commands.activate() commands.activate()
return True
def on_frame_selected(): def on_frame_selected() -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
f = util.selected_frame() f = util.selected_frame()
t = f.GetThread() t = f.GetThread()
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Frame {}.{}.{} selected".format(proc.GetProcessID(), t.GetThreadID(), f.GetFrameID())): with trace.open_tx(f"Frame {proc.GetProcessID()}.{t.GetThreadID()}.{f.GetFrameID()} selected"):
PROC_STATE[proc.GetProcessID()].record() PROC_STATE[proc.GetProcessID()].record()
commands.put_threads() commands.put_threads()
commands.put_frames() commands.put_frames()
commands.activate() commands.activate()
return True
def on_syscall_memory(): def on_syscall_memory() -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
PROC_STATE[proc.GetProcessID()].regions = True PROC_STATE[proc.GetProcessID()].regions = True
return True
def on_memory_changed(event): def on_memory_changed(event: lldb.SBEvent) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Memory *0x{:08x} changed".format(event.address)): with trace.open_tx(f"Memory *0x{event.address:08x} changed"):
commands.put_bytes(event.address, event.address + event.length, commands.put_bytes(event.address, event.address + event.length,
pages=False, is_mi=False, result=None) pages=False, result=None)
return True
def on_register_changed(event): def on_register_changed(event: lldb.SBEvent) -> bool:
# print("Register changed: {}".format(dir(event)))
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
# I'd rather have a descriptor! with trace.client.batch():
# TODO: How do I get the descriptor from the number? with trace.open_tx(f"Register {event.regnum} changed"):
# For now, just record the lot
with commands.STATE.client.batch():
with trace.open_tx("Register {} changed".format(event.regnum)):
banks = event.frame.GetRegisters() banks = event.frame.GetRegisters()
commands.putreg( commands.putreg(
event.frame, banks.GetFirstValueByName(commands.DEFAULT_REGISTER_BANK)) event.frame, banks.GetFirstValueByName(commands.DEFAULT_REGISTER_BANK))
return True
def on_cont(event): def on_cont(event: lldb.SBEvent) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
state = PROC_STATE[proc.GetProcessID()] state = PROC_STATE[proc.GetProcessID()]
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Continued"): with trace.open_tx("Continued"):
state.record_continued() state.record_continued()
return True
def on_stop(event): def on_stop(event: lldb.SBEvent) -> bool:
proc = lldb.SBProcess.GetProcessFromEvent( proc = lldb.SBProcess.GetProcessFromEvent(
event) if event is not None else util.get_process() event) if event is not None else util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
print("not in state") print("not in state")
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
print("no trace") print("no trace")
return return False
state = PROC_STATE[proc.GetProcessID()] state = PROC_STATE[proc.GetProcessID()]
state.visited.clear() state.visited.clear()
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Stopped"): with trace.open_tx("Stopped"):
state.record("Stopped") state.record("Stopped")
commands.put_event_thread() commands.put_event_thread()
commands.put_threads() commands.put_threads()
commands.put_frames() commands.put_frames()
commands.activate() commands.activate()
return True
def on_exited(event): def on_exited(event: lldb.SBEvent) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
state = PROC_STATE[proc.GetProcessID()] state = PROC_STATE[proc.GetProcessID()]
state.visited.clear() state.visited.clear()
exit_code = proc.GetExitStatus() exit_code = proc.GetExitStatus()
description = "Exited with code {}".format(exit_code) description = "Exited with code {}".format(exit_code)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx(description): with trace.open_tx(description):
state.record(description) state.record(description)
state.record_exited(exit_code) state.record_exited(exit_code)
commands.put_event_thread() commands.put_event_thread()
commands.activate() commands.activate()
return False
def modules_changed(): def modules_changed() -> bool:
# Assumption: affects the current process # Assumption: affects the current process
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
PROC_STATE[proc.GetProcessID()].modules = True PROC_STATE[proc.GetProcessID()].modules = True
return True
def on_new_objfile(event): def on_new_objfile(event: lldb.SBEvent) -> bool:
modules_changed() modules_changed()
return True
def on_free_objfile(event): def on_free_objfile(event: lldb.SBEvent) -> bool:
modules_changed() modules_changed()
return True
def on_breakpoint_created(b): def on_breakpoint_created(b: lldb.SBBreakpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} created".format(b.GetID())): with trace.open_tx("Breakpoint {} created".format(b.GetID())):
commands.put_single_breakpoint(b, proc) commands.put_single_breakpoint(b, proc)
return True
def on_breakpoint_modified(b): def on_breakpoint_modified(b: lldb.SBBreakpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} modified".format(b.GetID())): with trace.open_tx("Breakpoint {} modified".format(b.GetID())):
commands.put_single_breakpoint(b, proc) commands.put_single_breakpoint(b, proc)
return True
def on_breakpoint_deleted(b): def on_breakpoint_deleted(b: lldb.SBBreakpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
bpt_path = commands.PROC_BREAK_PATTERN.format( bpt_path = commands.PROC_BREAK_PATTERN.format(
procnum=proc.GetProcessID(), breaknum=b.GetID()) procnum=proc.GetProcessID(), breaknum=b.GetID())
bpt_obj = trace.proxy_object_path(bpt_path) bpt_obj = trace.proxy_object_path(bpt_path)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} deleted".format(b.GetID())): with trace.open_tx("Breakpoint {} deleted".format(b.GetID())):
bpt_obj.remove(tree=True) bpt_obj.remove(tree=True)
return True
def on_watchpoint_created(b): def on_watchpoint_created(b: lldb.SBWatchpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Breakpoint {} created".format(b.GetID())): with trace.open_tx("Breakpoint {} created".format(b.GetID())):
commands.put_single_watchpoint(b, proc) commands.put_single_watchpoint(b, proc)
return True
def on_watchpoint_modified(b): def on_watchpoint_modified(b: lldb.SBWatchpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Watchpoint {} modified".format(b.GetID())): with trace.open_tx("Watchpoint {} modified".format(b.GetID())):
commands.put_single_watchpoint(b, proc) commands.put_single_watchpoint(b, proc)
return True
def on_watchpoint_deleted(b): def on_watchpoint_deleted(b: lldb.SBWatchpoint) -> bool:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() not in PROC_STATE: if proc.GetProcessID() not in PROC_STATE:
return return False
trace = commands.STATE.trace trace = commands.STATE.trace
if trace is None: if trace is None:
return return False
wpt_path = commands.PROC_WATCH_PATTERN.format( wpt_path = commands.PROC_WATCH_PATTERN.format(
procnum=proc.GetProcessID(), watchnum=b.GetID()) procnum=proc.GetProcessID(), watchnum=b.GetID())
wpt_obj = trace.proxy_object_path(wpt_path) wpt_obj = trace.proxy_object_path(wpt_path)
with commands.STATE.client.batch(): with trace.client.batch():
with trace.open_tx("Watchpoint {} deleted".format(b.GetID())): with trace.open_tx("Watchpoint {} deleted".format(b.GetID())):
wpt_obj.remove(tree=True) wpt_obj.remove(tree=True)
return True
def install_hooks(): def install_hooks() -> None:
if HOOK_STATE.installed: if HOOK_STATE.installed:
return return
HOOK_STATE.installed = True HOOK_STATE.installed = True
@ -632,18 +656,18 @@ def install_hooks():
event_thread.start() event_thread.start()
def remove_hooks(): def remove_hooks() -> None:
if not HOOK_STATE.installed: if not HOOK_STATE.installed:
return return
HOOK_STATE.installed = False HOOK_STATE.installed = False
def enable_current_process(): def enable_current_process() -> None:
proc = util.get_process() proc = util.get_process()
PROC_STATE[proc.GetProcessID()] = ProcessState() PROC_STATE[proc.GetProcessID()] = ProcessState()
def disable_current_process(): def disable_current_process() -> None:
proc = util.get_process() proc = util.get_process()
if proc.GetProcessID() in PROC_STATE: if proc.GetProcessID() in PROC_STATE:
# Silently ignore already disabled # Silently ignore already disabled

View file

@ -16,11 +16,13 @@
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
import re import re
import sys import sys
from typing import Annotated, Any, Optional, Tuple
import lldb import lldb
from ghidratrace import sch from ghidratrace import sch
from ghidratrace.client import MethodRegistry, ParamDesc, Address, AddressRange from ghidratrace.client import (
MethodRegistry, ParamDesc, Address, AddressRange, TraceObject)
from . import commands, util from . import commands, util
@ -28,7 +30,7 @@ from . import commands, util
REGISTRY = MethodRegistry(ThreadPoolExecutor(max_workers=1)) REGISTRY = MethodRegistry(ThreadPoolExecutor(max_workers=1))
def extre(base, ext): def extre(base: re.Pattern, ext: str) -> re.Pattern:
return re.compile(base.pattern + ext) return re.compile(base.pattern + ext)
@ -49,7 +51,7 @@ MEMORY_PATTERN = extre(PROCESS_PATTERN, '\.Memory')
MODULES_PATTERN = extre(PROCESS_PATTERN, '\.Modules') MODULES_PATTERN = extre(PROCESS_PATTERN, '\.Modules')
def find_availpid_by_pattern(pattern, object, err_msg): def find_availpid_by_pattern(pattern: re.Pattern, object: TraceObject, err_msg: str) -> int:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -57,15 +59,16 @@ def find_availpid_by_pattern(pattern, object, err_msg):
return pid return pid
def find_availpid_by_obj(object): def find_availpid_by_obj(object: TraceObject) -> int:
return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Available") return find_availpid_by_pattern(AVAILABLE_PATTERN, object, "an Available")
def find_proc_by_num(procnum): def find_proc_by_num(procnum: int) -> lldb.SBProcess:
return util.get_process() return util.get_process()
def find_proc_by_pattern(object, pattern, err_msg): def find_proc_by_pattern(object: TraceObject, pattern: re.Pattern,
err_msg: str) -> lldb.SBProcess:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -73,37 +76,37 @@ def find_proc_by_pattern(object, pattern, err_msg):
return find_proc_by_num(procnum) return find_proc_by_num(procnum)
def find_proc_by_obj(object): def find_proc_by_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, PROCESS_PATTERN, "a Process") return find_proc_by_pattern(object, PROCESS_PATTERN, "a Process")
def find_proc_by_procbreak_obj(object): def find_proc_by_procbreak_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, PROC_BREAKS_PATTERN, return find_proc_by_pattern(object, PROC_BREAKS_PATTERN,
"a BreakpointLocationContainer") "a BreakpointLocationContainer")
def find_proc_by_procwatch_obj(object): def find_proc_by_procwatch_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, PROC_WATCHES_PATTERN, return find_proc_by_pattern(object, PROC_WATCHES_PATTERN,
"a WatchpointContainer") "a WatchpointContainer")
def find_proc_by_env_obj(object): def find_proc_by_env_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, ENV_PATTERN, "an Environment") return find_proc_by_pattern(object, ENV_PATTERN, "an Environment")
def find_proc_by_threads_obj(object): def find_proc_by_threads_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, THREADS_PATTERN, "a ThreadContainer") return find_proc_by_pattern(object, THREADS_PATTERN, "a ThreadContainer")
def find_proc_by_mem_obj(object): def find_proc_by_mem_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, MEMORY_PATTERN, "a Memory") return find_proc_by_pattern(object, MEMORY_PATTERN, "a Memory")
def find_proc_by_modules_obj(object): def find_proc_by_modules_obj(object: TraceObject) -> lldb.SBProcess:
return find_proc_by_pattern(object, MODULES_PATTERN, "a ModuleContainer") return find_proc_by_pattern(object, MODULES_PATTERN, "a ModuleContainer")
def find_thread_by_num(proc, tnum): def find_thread_by_num(proc: lldb.SBThread, tnum: int) -> lldb.SBThread:
for t in proc.threads: for t in proc.threads:
if t.GetThreadID() == tnum: if t.GetThreadID() == tnum:
return t return t
@ -111,7 +114,8 @@ def find_thread_by_num(proc, tnum):
f"Processes[{proc.GetProcessID()}].Threads[{tnum}] does not exist") f"Processes[{proc.GetProcessID()}].Threads[{tnum}] does not exist")
def find_thread_by_pattern(pattern, object, err_msg): def find_thread_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> lldb.SBThread:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -121,19 +125,19 @@ def find_thread_by_pattern(pattern, object, err_msg):
return find_thread_by_num(proc, tnum) return find_thread_by_num(proc, tnum)
def find_thread_by_obj(object): def find_thread_by_obj(object: TraceObject) -> lldb.SBThread:
return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread") return find_thread_by_pattern(THREAD_PATTERN, object, "a Thread")
def find_thread_by_stack_obj(object): def find_thread_by_stack_obj(object: TraceObject) -> lldb.SBThread:
return find_thread_by_pattern(STACK_PATTERN, object, "a Stack") return find_thread_by_pattern(STACK_PATTERN, object, "a Stack")
def find_frame_by_level(thread, level): def find_frame_by_level(thread: lldb.SBThread, level: int) -> lldb.SBFrame:
return thread.GetFrameAtIndex(level) return thread.GetFrameAtIndex(level)
def find_frame_by_pattern(pattern, object, err_msg): def find_frame_by_pattern(pattern: re.Pattern, object: TraceObject, err_msg: str) -> lldb.SBFrame:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -145,26 +149,18 @@ def find_frame_by_pattern(pattern, object, err_msg):
return find_frame_by_level(t, level) return find_frame_by_level(t, level)
def find_frame_by_obj(object): def find_frame_by_obj(object: TraceObject) -> lldb.SBFrame:
return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame") return find_frame_by_pattern(FRAME_PATTERN, object, "a StackFrame")
def find_frame_by_regs_obj(object): def find_frame_by_regs_obj(object: TraceObject) -> lldb.SBFrame:
return find_frame_by_pattern(REGS_PATTERN, object, return find_frame_by_pattern(REGS_PATTERN, object,
"a RegisterValueContainer") "a RegisterValueContainer")
# Because there's no method to get a register by name....
def find_reg_by_name(f, name):
for reg in f.architecture().registers():
if reg.name == name:
return reg
raise KeyError(f"No such register: {name}")
# Oof. no lldb/Python method to get breakpoint by number # Oof. no lldb/Python method to get breakpoint by number
# I could keep my own cache in a dict, but why? # I could keep my own cache in a dict, but why?
def find_bpt_by_number(breaknum): def find_bpt_by_number(breaknum: int) -> lldb.SBBreakpoint:
# TODO: If len exceeds some threshold, use binary search? # TODO: If len exceeds some threshold, use binary search?
for i in range(0, util.get_target().GetNumBreakpoints()): for i in range(0, util.get_target().GetNumBreakpoints()):
b = util.get_target().GetBreakpointAtIndex(i) b = util.get_target().GetBreakpointAtIndex(i)
@ -173,7 +169,8 @@ def find_bpt_by_number(breaknum):
raise KeyError(f"Breakpoints[{breaknum}] does not exist") raise KeyError(f"Breakpoints[{breaknum}] does not exist")
def find_bpt_by_pattern(pattern, object, err_msg): def find_bpt_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> lldb.SBBreakpoint:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -181,13 +178,13 @@ def find_bpt_by_pattern(pattern, object, err_msg):
return find_bpt_by_number(breaknum) return find_bpt_by_number(breaknum)
def find_bpt_by_obj(object): def find_bpt_by_obj(object: TraceObject) -> lldb.SBBreakpoint:
return find_bpt_by_pattern(PROC_BREAK_PATTERN, object, "a BreakpointSpec") return find_bpt_by_pattern(PROC_BREAK_PATTERN, object, "a BreakpointSpec")
# Oof. no lldb/Python method to get breakpoint by number # Oof. no lldb/Python method to get breakpoint by number
# I could keep my own cache in a dict, but why? # I could keep my own cache in a dict, but why?
def find_wpt_by_number(watchnum): def find_wpt_by_number(watchnum: int) -> lldb.SBWatchpoint:
# TODO: If len exceeds some threshold, use binary search? # TODO: If len exceeds some threshold, use binary search?
for i in range(0, util.get_target().GetNumWatchpoints()): for i in range(0, util.get_target().GetNumWatchpoints()):
w = util.get_target().GetWatchpointAtIndex(i) w = util.get_target().GetWatchpointAtIndex(i)
@ -196,7 +193,8 @@ def find_wpt_by_number(watchnum):
raise KeyError(f"Watchpoints[{watchnum}] does not exist") raise KeyError(f"Watchpoints[{watchnum}] does not exist")
def find_wpt_by_pattern(pattern, object, err_msg): def find_wpt_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> lldb.SBWatchpoint:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypeError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
@ -204,32 +202,33 @@ def find_wpt_by_pattern(pattern, object, err_msg):
return find_wpt_by_number(watchnum) return find_wpt_by_number(watchnum)
def find_wpt_by_obj(object): def find_wpt_by_obj(object: TraceObject) -> lldb.SBWatchpoint:
return find_wpt_by_pattern(PROC_WATCH_PATTERN, object, "a WatchpointSpec") return find_wpt_by_pattern(PROC_WATCH_PATTERN, object, "a WatchpointSpec")
def find_bptlocnum_by_pattern(pattern, object, err_msg): def find_bptlocnum_by_pattern(pattern: re.Pattern, object: TraceObject,
err_msg: str) -> Tuple[int, int]:
mat = pattern.fullmatch(object.path) mat = pattern.fullmatch(object.path)
if mat is None: if mat is None:
raise TypError(f"{object} is not {err_msg}") raise TypeError(f"{object} is not {err_msg}")
breaknum = int(mat['breaknum']) breaknum = int(mat['breaknum'])
locnum = int(mat['locnum']) locnum = int(mat['locnum'])
return breaknum, locnum return breaknum, locnum
def find_bptlocnum_by_obj(object): def find_bptlocnum_by_obj(object: TraceObject) -> Tuple[int, int]:
return find_bptlocnum_by_pattern(PROC_BREAKLOC_PATTERN, object, return find_bptlocnum_by_pattern(PROC_BREAKLOC_PATTERN, object,
"a BreakpointLocation") "a BreakpointLocation")
def find_bpt_loc_by_obj(object): def find_bpt_loc_by_obj(object: TraceObject) -> lldb.SBBreakpointLocation:
breaknum, locnum = find_bptlocnum_by_obj(object) breaknum, locnum = find_bptlocnum_by_obj(object)
bpt = find_bpt_by_number(breaknum) bpt = find_bpt_by_number(breaknum)
# Requires lldb-13.1 or later # Requires lldb-13.1 or later
return bpt.locations[locnum - 1] # Display is 1-up return bpt.locations[locnum - 1] # Display is 1-up
def exec_convert_errors(cmd, to_string=False): def exec_convert_errors(cmd: str, to_string: bool = False) -> Optional[str]:
res = lldb.SBCommandReturnObject() res = lldb.SBCommandReturnObject()
util.get_debugger().GetCommandInterpreter().HandleCommand(cmd, res) util.get_debugger().GetCommandInterpreter().HandleCommand(cmd, res)
if not res.Succeeded(): if not res.Succeeded():
@ -239,77 +238,142 @@ def exec_convert_errors(cmd, to_string=False):
if to_string: if to_string:
return res.GetOutput() return res.GetOutput()
print(res.GetOutput(), end="") print(res.GetOutput(), end="")
return None
@REGISTRY.method class Attachable(TraceObject):
def execute(cmd: str, to_string: bool=False): pass
class AvailableContainer(TraceObject):
pass
class BreakpointContainer(TraceObject):
pass
class BreakpointLocation(TraceObject):
pass
class BreakpointSpec(TraceObject):
pass
class Environment(TraceObject):
pass
class Memory(TraceObject):
pass
class ModuleContainer(TraceObject):
pass
class Process(TraceObject):
pass
class ProcessContainer(TraceObject):
pass
class RegisterValueContainer(TraceObject):
pass
class Stack(TraceObject):
pass
class StackFrame(TraceObject):
pass
class Thread(TraceObject):
pass
class ThreadContainer(TraceObject):
pass
class WatchpointContainer(TraceObject):
pass
class WatchpointSpec(TraceObject):
pass
@REGISTRY.method()
def execute(cmd: str, to_string: bool = False) -> Optional[str]:
"""Execute a CLI command.""" """Execute a CLI command."""
# TODO: Check for eCommandInterpreterResultQuitRequested? # TODO: Check for eCommandInterpreterResultQuitRequested?
return exec_convert_errors(cmd, to_string) return exec_convert_errors(cmd, to_string)
@REGISTRY.method(display='Evaluate') @REGISTRY.method(display='Evaluate')
def evaluate(expr: str): def evaluate(expr: str) -> Any:
"""Evaluate an expression.""" """Evaluate an expression."""
value = util.get_target().EvaluateExpression(expr) value = util.get_target().EvaluateExpression(expr)
if value.GetError().Fail(): if value.GetError().Fail():
raise RuntimeError(value.GetError().GetCString()) raise RuntimeError(value.GetError().GetCString())
return commands.convert_value(value) return commands.eval_value(value)
@REGISTRY.method(display="Python Evaluate") @REGISTRY.method(display="Python Evaluate")
def pyeval(expr: str): def pyeval(expr: str) -> Any:
return eval(expr) return eval(expr)
@REGISTRY.method(action='refresh', display="Refresh Available") @REGISTRY.method(action='refresh', display="Refresh Available")
def refresh_available(node: sch.Schema('AvailableContainer')): def refresh_available(node: AvailableContainer) -> None:
"""List processes on lldb's host system.""" """List processes on lldb's host system."""
with commands.open_tracked_tx('Refresh Available'): with commands.open_tracked_tx('Refresh Available'):
exec_convert_errors('ghidra trace put-available') exec_convert_errors('ghidra trace put-available')
@REGISTRY.method(action='refresh', display="Refresh Processes") @REGISTRY.method(action='refresh', display="Refresh Processes")
def refresh_processes(node: sch.Schema('ProcessContainer')): def refresh_processes(node: ProcessContainer) -> None:
"""Refresh the list of processes.""" """Refresh the list of processes."""
with commands.open_tracked_tx('Refresh Processes'): with commands.open_tracked_tx('Refresh Processes'):
exec_convert_errors('ghidra trace put-threads') exec_convert_errors('ghidra trace put-threads')
@REGISTRY.method(action='refresh', display="Refresh Breakpoints") @REGISTRY.method(action='refresh', display="Refresh Breakpoints")
def refresh_proc_breakpoints(node: sch.Schema('BreakpointContainer')): def refresh_proc_breakpoints(node: BreakpointContainer) -> None:
""" """Refresh the breakpoints for the process."""
Refresh the breakpoints for the process.
"""
with commands.open_tracked_tx('Refresh Breakpoint Locations'): with commands.open_tracked_tx('Refresh Breakpoint Locations'):
exec_convert_errors('ghidra trace put-breakpoints') exec_convert_errors('ghidra trace put-breakpoints')
@REGISTRY.method(action='refresh', display="Refresh Watchpoints") @REGISTRY.method(action='refresh', display="Refresh Watchpoints")
def refresh_proc_watchpoints(node: sch.Schema('WatchpointContainer')): def refresh_proc_watchpoints(node: WatchpointContainer) -> None:
""" """Refresh the watchpoints for the process."""
Refresh the watchpoints for the process.
"""
with commands.open_tracked_tx('Refresh Watchpoint Locations'): with commands.open_tracked_tx('Refresh Watchpoint Locations'):
exec_convert_errors('ghidra trace put-watchpoints') exec_convert_errors('ghidra trace put-watchpoints')
@REGISTRY.method(action='refresh', display="Refresh Environment") @REGISTRY.method(action='refresh', display="Refresh Environment")
def refresh_environment(node: sch.Schema('Environment')): def refresh_environment(node: Environment) -> None:
"""Refresh the environment descriptors (arch, os, endian).""" """Refresh the environment descriptors (arch, os, endian)."""
with commands.open_tracked_tx('Refresh Environment'): with commands.open_tracked_tx('Refresh Environment'):
exec_convert_errors('ghidra trace put-environment') exec_convert_errors('ghidra trace put-environment')
@REGISTRY.method(action='refresh', display="Refresh Threads") @REGISTRY.method(action='refresh', display="Refresh Threads")
def refresh_threads(node: sch.Schema('ThreadContainer')): def refresh_threads(node: ThreadContainer) -> None:
"""Refresh the list of threads in the process.""" """Refresh the list of threads in the process."""
with commands.open_tracked_tx('Refresh Threads'): with commands.open_tracked_tx('Refresh Threads'):
exec_convert_errors('ghidra trace put-threads') exec_convert_errors('ghidra trace put-threads')
@REGISTRY.method(action='refresh', display="Refresh Stack") @REGISTRY.method(action='refresh', display="Refresh Stack")
def refresh_stack(node: sch.Schema('Stack')): def refresh_stack(node: Stack) -> None:
"""Refresh the backtrace for the thread.""" """Refresh the backtrace for the thread."""
t = find_thread_by_stack_obj(node) t = find_thread_by_stack_obj(node)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
@ -318,7 +382,7 @@ def refresh_stack(node: sch.Schema('Stack')):
@REGISTRY.method(action='refresh', display="Refresh Registers") @REGISTRY.method(action='refresh', display="Refresh Registers")
def refresh_registers(node: sch.Schema('RegisterValueContainer')): def refresh_registers(node: RegisterValueContainer) -> None:
"""Refresh the register values for the frame.""" """Refresh the register values for the frame."""
f = find_frame_by_regs_obj(node) f = find_frame_by_regs_obj(node)
f.thread.SetSelectedFrame(f.GetFrameID()) f.thread.SetSelectedFrame(f.GetFrameID())
@ -328,83 +392,83 @@ def refresh_registers(node: sch.Schema('RegisterValueContainer')):
@REGISTRY.method(action='refresh', display="Refresh Memory") @REGISTRY.method(action='refresh', display="Refresh Memory")
def refresh_mappings(node: sch.Schema('Memory')): def refresh_mappings(node: Memory) -> None:
"""Refresh the list of memory regions for the process.""" """Refresh the list of memory regions for the process."""
with commands.open_tracked_tx('Refresh Memory Regions'): with commands.open_tracked_tx('Refresh Memory Regions'):
exec_convert_errors('ghidra trace put-regions') exec_convert_errors('ghidra trace put-regions')
@REGISTRY.method(action='refresh', display="Refresh Modules") @REGISTRY.method(action='refresh', display="Refresh Modules")
def refresh_modules(node: sch.Schema('ModuleContainer')): def refresh_modules(node: ModuleContainer) -> None:
""" """Refresh the modules and sections list for the process.
Refresh the modules and sections list for the process.
This will refresh the sections for all modules, not just the selected one. This will refresh the sections for all modules, not just the
selected one.
""" """
with commands.open_tracked_tx('Refresh Modules'): with commands.open_tracked_tx('Refresh Modules'):
exec_convert_errors('ghidra trace put-modules') exec_convert_errors('ghidra trace put-modules')
@REGISTRY.method(action='activate', display='Activate Process') @REGISTRY.method(action='activate', display='Activate Process')
def activate_process(process: sch.Schema('Process')): def activate_process(process: Process) -> None:
"""Switch to the process.""" """Switch to the process."""
# TODO # TODO
return return
@REGISTRY.method(action='activate', display='Activate Thread') @REGISTRY.method(action='activate', display='Activate Thread')
def activate_thread(thread: sch.Schema('Thread')): def activate_thread(thread: Thread) -> None:
"""Switch to the thread.""" """Switch to the thread."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
@REGISTRY.method(action='activate', display='Activate Frame') @REGISTRY.method(action='activate', display='Activate Frame')
def activate_frame(frame: sch.Schema('StackFrame')): def activate_frame(frame: StackFrame) -> None:
"""Select the frame.""" """Select the frame."""
f = find_frame_by_obj(frame) f = find_frame_by_obj(frame)
f.thread.SetSelectedFrame(f.GetFrameID()) f.thread.SetSelectedFrame(f.GetFrameID())
@REGISTRY.method(action='delete', display='Remove Process') @REGISTRY.method(action='delete', display='Remove Process')
def remove_process(process: sch.Schema('Process')): def remove_process(process: Process) -> None:
"""Remove the process.""" """Remove the process."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
exec_convert_errors(f'target delete 0') exec_convert_errors(f'target delete 0')
@REGISTRY.method(action='connect', display="Connect Target") @REGISTRY.method(action='connect', display="Connect Target")
def target(process: sch.Schema('Process'), spec: str): def target(process: Process, spec: str) -> None:
"""Connect to a target machine or process.""" """Connect to a target machine or process."""
exec_convert_errors(f'target select {spec}') exec_convert_errors(f'target select {spec}')
@REGISTRY.method(action='attach', display="Attach by Attachable") @REGISTRY.method(action='attach', display="Attach by Attachable")
def attach_obj(process: sch.Schema('Process'), target: sch.Schema('Attachable')): def attach_obj(process: Process, target: Attachable) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
pid = find_availpid_by_obj(target) pid = find_availpid_by_obj(target)
exec_convert_errors(f'process attach -p {pid}') exec_convert_errors(f'process attach -p {pid}')
@REGISTRY.method(action='attach', display="Attach by PID") @REGISTRY.method(action='attach', display="Attach by PID")
def attach_pid(process: sch.Schema('Process'), pid: int): def attach_pid(process: Process, pid: int) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
exec_convert_errors(f'process attach -p {pid}') exec_convert_errors(f'process attach -p {pid}')
@REGISTRY.method(action='attach', display="Attach by Name") @REGISTRY.method(action='attach', display="Attach by Name")
def attach_name(process: sch.Schema('Process'), name: str): def attach_name(process: Process, name: str) -> None:
"""Attach the process to the given target.""" """Attach the process to the given target."""
exec_convert_errors(f'process attach -n {name}') exec_convert_errors(f'process attach -n {name}')
@REGISTRY.method(display="Detach") @REGISTRY.method(display="Detach")
def detach(process: sch.Schema('Process')): def detach(process: Process) -> None:
"""Detach the process's target.""" """Detach the process's target."""
exec_convert_errors(f'process detach') exec_convert_errors(f'process detach')
def do_launch(process, file, args, cmd): def do_launch(process: Process, file: str, args: str, cmd: str):
exec_convert_errors(f'file {file}') exec_convert_errors(f'file {file}')
if args != '': if args != '':
exec_convert_errors(f'settings set target.run-args {args}') exec_convert_errors(f'settings set target.run-args {args}')
@ -412,11 +476,10 @@ def do_launch(process, file, args, cmd):
@REGISTRY.method(action='launch', display="Launch at Entry") @REGISTRY.method(action='launch', display="Launch at Entry")
def launch_loader(process: sch.Schema('Process'), def launch_loader(process: Process,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '') -> None:
""" """Start a native process with the given command line, stopping at 'main'.
Start a native process with the given command line, stopping at 'main'.
If 'main' is not defined in the file, this behaves like 'run'. If 'main' is not defined in the file, this behaves like 'run'.
""" """
@ -424,32 +487,31 @@ def launch_loader(process: sch.Schema('Process'),
@REGISTRY.method(action='launch', display="Launch and Run") @REGISTRY.method(action='launch', display="Launch and Run")
def launch(process: sch.Schema('Process'), def launch(process: Process,
file: ParamDesc(str, display='File'), file: Annotated[str, ParamDesc(display='File')],
args: ParamDesc(str, display='Arguments')=''): args: Annotated[str, ParamDesc(display='Arguments')] = '') -> None:
""" """Run a native process with the given command line.
Run a native process with the given command line.
The process will not stop until it hits one of your breakpoints, or it is The process will not stop until it hits one of your breakpoints, or
signaled. it is signaled.
""" """
do_launch(process, file, args, 'run') do_launch(process, file, args, 'run')
@REGISTRY.method @REGISTRY.method()
def kill(process: sch.Schema('Process')): def kill(process: Process) -> None:
"""Kill execution of the process.""" """Kill execution of the process."""
exec_convert_errors('process kill') exec_convert_errors('process kill')
@REGISTRY.method(name='continue', action='resume', display="Continue") @REGISTRY.method(name='continue', action='resume', display="Continue")
def _continue(process: sch.Schema('Process')): def _continue(process: Process):
"""Continue execution of the process.""" """Continue execution of the process."""
exec_convert_errors('process continue') exec_convert_errors('process continue')
@REGISTRY.method @REGISTRY.method()
def interrupt(process: sch.Schema('Process')): def interrupt(process: Process):
"""Interrupt the execution of the debugged program.""" """Interrupt the execution of the debugged program."""
exec_convert_errors('process interrupt') exec_convert_errors('process interrupt')
# util.get_process().SendAsyncInterrupt() # util.get_process().SendAsyncInterrupt()
@ -458,7 +520,8 @@ def interrupt(process: sch.Schema('Process')):
@REGISTRY.method(action='step_into') @REGISTRY.method(action='step_into')
def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_into(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step on instruction exactly.""" """Step on instruction exactly."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
@ -466,7 +529,8 @@ def step_into(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1):
@REGISTRY.method(action='step_over') @REGISTRY.method(action='step_over')
def step_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1): def step_over(thread: Thread,
n: Annotated[int, ParamDesc(display='N')] = 1) -> None:
"""Step one instruction, but proceed through subroutine calls.""" """Step one instruction, but proceed through subroutine calls."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
@ -474,7 +538,7 @@ def step_over(thread: sch.Schema('Thread'), n: ParamDesc(int, display='N')=1):
@REGISTRY.method(action='step_out') @REGISTRY.method(action='step_out')
def step_out(thread: sch.Schema('Thread')): def step_out(thread: Thread) -> None:
"""Execute until the current stack frame returns.""" """Execute until the current stack frame returns."""
if thread is not None: if thread is not None:
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
@ -483,16 +547,16 @@ def step_out(thread: sch.Schema('Thread')):
@REGISTRY.method(action='step_ext', display="Advance") @REGISTRY.method(action='step_ext', display="Advance")
def step_advance(thread: sch.Schema('Thread'), address: Address): def step_advance(thread: Thread, address: Address) -> None:
"""Continue execution up to the given address.""" """Continue execution up to the given address."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
offset = thread.trace.memory_mapper.map_back(t.process, address) offset = thread.trace.extra.require_mm().map_back(t.process, address)
exec_convert_errors(f'thread until -a {offset}') exec_convert_errors(f'thread until -a {offset}')
@REGISTRY.method(action='step_ext', display="Return") @REGISTRY.method(action='step_ext', display="Return")
def step_return(thread: sch.Schema('Thread'), value: int=None): def step_return(thread: Thread, value: Optional[int] = None) -> None:
"""Skip the remainder of the current function.""" """Skip the remainder of the current function."""
t = find_thread_by_obj(thread) t = find_thread_by_obj(thread)
t.process.SetSelectedThread(t) t.process.SetSelectedThread(t)
@ -503,10 +567,10 @@ def step_return(thread: sch.Schema('Thread'), value: int=None):
@REGISTRY.method(action='break_sw_execute') @REGISTRY.method(action='break_sw_execute')
def break_address(process: sch.Schema('Process'), address: Address): def break_address(process: Process, address: Address) -> None:
"""Set a breakpoint.""" """Set a breakpoint."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset = process.trace.memory_mapper.map_back(proc, address) offset = process.trace.extra.require_mm().map_back(proc, address)
exec_convert_errors(f'breakpoint set -a 0x{offset:x}') exec_convert_errors(f'breakpoint set -a 0x{offset:x}')
@ -518,25 +582,25 @@ def break_expression(expression: str):
@REGISTRY.method(action='break_hw_execute') @REGISTRY.method(action='break_hw_execute')
def break_hw_address(process: sch.Schema('Process'), address: Address): def break_hw_address(process: Process, address: Address) -> None:
"""Set a hardware-assisted breakpoint.""" """Set a hardware-assisted breakpoint."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset = process.trace.memory_mapper.map_back(proc, address) offset = process.trace.extra.require_mm().map_back(proc, address)
exec_convert_errors(f'breakpoint set -H -a 0x{offset:x}') exec_convert_errors(f'breakpoint set -H -a 0x{offset:x}')
@REGISTRY.method(action='break_ext', display='Set Hardware Breakpoint') @REGISTRY.method(action='break_ext', display='Set Hardware Breakpoint')
def break_hw_expression(expression: str): def break_hw_expression(expression: str) -> None:
"""Set a hardware-assisted breakpoint.""" """Set a hardware-assisted breakpoint."""
# TODO: Escape? # TODO: Escape?
exec_convert_errors(f'breakpoint set -H -name {expression}') exec_convert_errors(f'breakpoint set -H -name {expression}')
@REGISTRY.method(action='break_read') @REGISTRY.method(action='break_read')
def break_read_range(process: sch.Schema('Process'), range: AddressRange): def break_read_range(process: Process, range: AddressRange) -> None:
"""Set a read watchpoint.""" """Set a read watchpoint."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
proc, Address(range.space, range.min)) proc, Address(range.space, range.min))
sz = range.length() sz = range.length()
exec_convert_errors( exec_convert_errors(
@ -544,7 +608,7 @@ def break_read_range(process: sch.Schema('Process'), range: AddressRange):
@REGISTRY.method(action='break_ext', display='Set Read Watchpoint') @REGISTRY.method(action='break_ext', display='Set Read Watchpoint')
def break_read_expression(expression: str, size=None): def break_read_expression(expression: str, size: Optional[str] = None) -> None:
"""Set a read watchpoint.""" """Set a read watchpoint."""
size_part = '' if size is None else f'-s {size}' size_part = '' if size is None else f'-s {size}'
exec_convert_errors( exec_convert_errors(
@ -552,10 +616,10 @@ def break_read_expression(expression: str, size=None):
@REGISTRY.method(action='break_write') @REGISTRY.method(action='break_write')
def break_write_range(process: sch.Schema('Process'), range: AddressRange): def break_write_range(process: Process, range: AddressRange) -> None:
"""Set a watchpoint.""" """Set a watchpoint."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
proc, Address(range.space, range.min)) proc, Address(range.space, range.min))
sz = range.length() sz = range.length()
exec_convert_errors( exec_convert_errors(
@ -563,7 +627,7 @@ def break_write_range(process: sch.Schema('Process'), range: AddressRange):
@REGISTRY.method(action='break_ext', display='Set Watchpoint') @REGISTRY.method(action='break_ext', display='Set Watchpoint')
def break_write_expression(expression: str, size=None): def break_write_expression(expression: str, size: Optional[str] = None) -> None:
"""Set a watchpoint.""" """Set a watchpoint."""
size_part = '' if size is None else f'-s {size}' size_part = '' if size is None else f'-s {size}'
exec_convert_errors( exec_convert_errors(
@ -571,10 +635,10 @@ def break_write_expression(expression: str, size=None):
@REGISTRY.method(action='break_access') @REGISTRY.method(action='break_access')
def break_access_range(process: sch.Schema('Process'), range: AddressRange): def break_access_range(process: Process, range: AddressRange) -> None:
"""Set a read/write watchpoint.""" """Set a read/write watchpoint."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
proc, Address(range.space, range.min)) proc, Address(range.space, range.min))
sz = range.length() sz = range.length()
exec_convert_errors( exec_convert_errors(
@ -582,7 +646,8 @@ def break_access_range(process: sch.Schema('Process'), range: AddressRange):
@REGISTRY.method(action='break_ext', display='Set Read/Write Watchpoint') @REGISTRY.method(action='break_ext', display='Set Read/Write Watchpoint')
def break_access_expression(expression: str, size=None): def break_access_expression(expression: str,
size: Optional[str] = None) -> None:
"""Set a read/write watchpoint.""" """Set a read/write watchpoint."""
size_part = '' if size is None else f'-s {size}' size_part = '' if size is None else f'-s {size}'
exec_convert_errors( exec_convert_errors(
@ -590,13 +655,13 @@ def break_access_expression(expression: str, size=None):
@REGISTRY.method(action='break_ext', display="Break on Exception") @REGISTRY.method(action='break_ext', display="Break on Exception")
def break_exception(lang: str): def break_exception(lang: str) -> None:
"""Set a catchpoint.""" """Set a catchpoint."""
exec_convert_errors(f'breakpoint set -E {lang}') exec_convert_errors(f'breakpoint set -E {lang}')
@REGISTRY.method(action='toggle', display='Toggle Watchpoint') @REGISTRY.method(action='toggle', display='Toggle Watchpoint')
def toggle_watchpoint(watchpoint: sch.Schema('WatchpointSpec'), enabled: bool): def toggle_watchpoint(watchpoint: WatchpointSpec, enabled: bool) -> None:
"""Toggle a watchpoint.""" """Toggle a watchpoint."""
wpt = find_wpt_by_obj(watchpoint) wpt = find_wpt_by_obj(watchpoint)
wpt.enabled = enabled wpt.enabled = enabled
@ -605,7 +670,7 @@ def toggle_watchpoint(watchpoint: sch.Schema('WatchpointSpec'), enabled: bool):
@REGISTRY.method(action='toggle', display='Toggle Breakpoint') @REGISTRY.method(action='toggle', display='Toggle Breakpoint')
def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool): def toggle_breakpoint(breakpoint: BreakpointSpec, enabled: bool) -> None:
"""Toggle a breakpoint.""" """Toggle a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
cmd = 'enable' if enabled else 'disable' cmd = 'enable' if enabled else 'disable'
@ -613,7 +678,8 @@ def toggle_breakpoint(breakpoint: sch.Schema('BreakpointSpec'), enabled: bool):
@REGISTRY.method(action='toggle', display='Toggle Breakpoint Location') @REGISTRY.method(action='toggle', display='Toggle Breakpoint Location')
def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabled: bool): def toggle_breakpoint_location(location: BreakpointLocation,
enabled: bool) -> None:
"""Toggle a breakpoint location.""" """Toggle a breakpoint location."""
bptnum, locnum = find_bptlocnum_by_obj(location) bptnum, locnum = find_bptlocnum_by_obj(location)
cmd = 'enable' if enabled else 'disable' cmd = 'enable' if enabled else 'disable'
@ -621,7 +687,7 @@ def toggle_breakpoint_location(location: sch.Schema('BreakpointLocation'), enabl
@REGISTRY.method(action='delete', display='Delete Watchpoint') @REGISTRY.method(action='delete', display='Delete Watchpoint')
def delete_watchpoint(watchpoint: sch.Schema('WatchpointSpec')): def delete_watchpoint(watchpoint: WatchpointSpec) -> None:
"""Delete a watchpoint.""" """Delete a watchpoint."""
wpt = find_wpt_by_obj(watchpoint) wpt = find_wpt_by_obj(watchpoint)
wptnum = wpt.GetID() wptnum = wpt.GetID()
@ -629,18 +695,18 @@ def delete_watchpoint(watchpoint: sch.Schema('WatchpointSpec')):
@REGISTRY.method(action='delete', display='Delete Breakpoint') @REGISTRY.method(action='delete', display='Delete Breakpoint')
def delete_breakpoint(breakpoint: sch.Schema('BreakpointSpec')): def delete_breakpoint(breakpoint: BreakpointSpec) -> None:
"""Delete a breakpoint.""" """Delete a breakpoint."""
bpt = find_bpt_by_obj(breakpoint) bpt = find_bpt_by_obj(breakpoint)
bptnum = bpt.GetID() bptnum = bpt.GetID()
exec_convert_errors(f'breakpoint delete {bptnum}') exec_convert_errors(f'breakpoint delete {bptnum}')
@REGISTRY.method @REGISTRY.method()
def read_mem(process: sch.Schema('Process'), range: AddressRange): def read_mem(process: Process, range: AddressRange) -> None:
"""Read memory.""" """Read memory."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset_start = process.trace.memory_mapper.map_back( offset_start = process.trace.extra.require_mm().map_back(
proc, Address(range.space, range.min)) proc, Address(range.space, range.min))
ci = util.get_debugger().GetCommandInterpreter() ci = util.get_debugger().GetCommandInterpreter()
with commands.open_tracked_tx('Read Memory'): with commands.open_tracked_tx('Read Memory'):
@ -654,22 +720,21 @@ def read_mem(process: sch.Schema('Process'), range: AddressRange):
f'ghidra trace putmem-state 0x{offset_start:x} {range.length()} error') f'ghidra trace putmem-state 0x{offset_start:x} {range.length()} error')
@REGISTRY.method @REGISTRY.method()
def write_mem(process: sch.Schema('Process'), address: Address, data: bytes): def write_mem(process: Process, address: Address, data: bytes) -> None:
"""Write memory.""" """Write memory."""
proc = find_proc_by_obj(process) proc = find_proc_by_obj(process)
offset = process.trace.memory_mapper.map_back(proc, address) offset = process.trace.extra.require_mm().map_back(proc, address)
proc.write_memory(offset, data) proc.write_memory(offset, data)
@REGISTRY.method @REGISTRY.method()
def write_reg(frame: sch.Schema('StackFrame'), name: str, value: bytes): def write_reg(frame: StackFrame, name: str, value: bytes) -> None:
"""Write a register.""" """Write a register."""
f = find_frame_by_obj(frame) f = find_frame_by_obj(frame)
f.select() f.select()
proc = lldb.selected_process() proc = lldb.selected_process()
mname, mval = frame.trace.register_mapper.map_value_back(proc, name, value) mname, mval = frame.trace.extra.require_rm().map_value_back(proc, name, value)
reg = find_reg_by_name(f, mname)
size = int(lldb.parse_and_eval(f'sizeof(${mname})')) size = int(lldb.parse_and_eval(f'sizeof(${mname})'))
arr = '{' + ','.join(str(b) for b in mval) + '}' arr = '{' + ','.join(str(b) for b in mval) + '}'
exec_convert_errors( exec_convert_errors(

View file

@ -14,17 +14,24 @@
# limitations under the License. # limitations under the License.
## ##
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass
import os import os
import re import re
import sys import sys
from typing import Any, Dict, List, Optional, Union
import lldb import lldb
LldbVersion = namedtuple('LldbVersion', ['display', 'full', 'major', 'minor']) @dataclass(frozen=True)
class LldbVersion:
display: str
full: str
major: int
minor: int
def _compute_lldb_ver(): def _compute_lldb_ver() -> LldbVersion:
blurb = lldb.debugger.GetVersionString() blurb = lldb.debugger.GetVersionString()
top = blurb.split('\n')[0] top = blurb.split('\n')[0]
if ' version ' in top: if ' version ' in top:
@ -40,12 +47,15 @@ LLDB_VERSION = _compute_lldb_ver()
GNU_DEBUGDATA_PREFIX = ".gnu_debugdata for " GNU_DEBUGDATA_PREFIX = ".gnu_debugdata for "
class Module(namedtuple('BaseModule', ['name', 'base', 'max', 'sections'])): @dataclass
pass class Section:
name: str
start: int
end: int
offset: int
attrs: List[str]
def better(self, other: 'Section') -> 'Section':
class Section(namedtuple('BaseSection', ['name', 'start', 'end', 'offset', 'attrs'])):
def better(self, other):
start = self.start if self.start != 0 else other.start start = self.start if self.start != 0 else other.start
end = self.end if self.end != 0 else other.end end = self.end if self.end != 0 else other.end
offset = self.offset if self.offset != 0 else other.offset offset = self.offset if self.offset != 0 else other.offset
@ -54,18 +64,17 @@ class Section(namedtuple('BaseSection', ['name', 'start', 'end', 'offset', 'attr
return Section(self.name, start, end, offset, list(attrs)) return Section(self.name, start, end, offset, list(attrs))
@dataclass(frozen=True)
class Module:
name: str
base: int
max: int
sections: Dict[str, Section]
# AFAICT, Objfile does not give info about load addresses :( # AFAICT, Objfile does not give info about load addresses :(
class ModuleInfoReader(object): class ModuleInfoReader(object):
def name_from_line(self, line): def section_from_sbsection(self, s: lldb.SBSection) -> Section:
mat = self.objfile_pattern.fullmatch(line)
if mat is None:
return None
n = mat['name']
if n.startswith(GNU_DEBUGDATA_PREFIX):
return None
return None if mat is None else mat['name']
def section_from_sbsection(self, s):
start = s.GetLoadAddress(get_target()) start = s.GetLoadAddress(get_target())
if start >= sys.maxsize*2: if start >= sys.maxsize*2:
start = 0 start = 0
@ -75,7 +84,7 @@ class ModuleInfoReader(object):
attrs = s.GetPermissions() attrs = s.GetPermissions()
return Section(name, start, end, offset, attrs) return Section(name, start, end, offset, attrs)
def finish_module(self, name, sections): def finish_module(self, name: str, sections: Dict[str, Section]) -> Module:
alloc = {k: s for k, s in sections.items()} alloc = {k: s for k, s in sections.items()}
if len(alloc) == 0: if len(alloc) == 0:
return Module(name, 0, 0, alloc) return Module(name, 0, 0, alloc)
@ -91,10 +100,10 @@ class ModuleInfoReader(object):
max_addr = max(s.end for s in alloc.values()) max_addr = max(s.end for s in alloc.values())
return Module(name, base_addr, max_addr, alloc) return Module(name, base_addr, max_addr, alloc)
def get_modules(self): def get_modules(self) -> Dict[str, Module]:
modules = {} modules = {}
name = None name = None
sections = {} sections: Dict[str, Section] = {}
for i in range(0, get_target().GetNumModules()): for i in range(0, get_target().GetNumModules()):
module = get_target().GetModuleAtIndex(i) module = get_target().GetModuleAtIndex(i)
fspec = module.GetFileSpec() fspec = module.GetFileSpec()
@ -108,19 +117,24 @@ class ModuleInfoReader(object):
return modules return modules
def _choose_module_info_reader(): def _choose_module_info_reader() -> ModuleInfoReader:
return ModuleInfoReader() return ModuleInfoReader()
MODULE_INFO_READER = _choose_module_info_reader() MODULE_INFO_READER = _choose_module_info_reader()
class Region(namedtuple('BaseRegion', ['start', 'end', 'offset', 'perms', 'objfile'])): @dataclass
pass class Region:
start: int
end: int
offset: int
perms: Optional[str]
objfile: str
class RegionInfoReader(object): class RegionInfoReader(object):
def region_from_sbmemreg(self, info): def region_from_sbmemreg(self, info: lldb.SBMemoryRegionInfo) -> Region:
start = info.GetRegionBase() start = info.GetRegionBase()
end = info.GetRegionEnd() end = info.GetRegionEnd()
offset = info.GetRegionBase() offset = info.GetRegionBase()
@ -136,7 +150,7 @@ class RegionInfoReader(object):
objfile = info.GetName() objfile = info.GetName()
return Region(start, end, offset, perms, objfile) return Region(start, end, offset, perms, objfile)
def get_regions(self): def get_regions(self) -> List[Region]:
regions = [] regions = []
reglist = get_process().GetMemoryRegions() reglist = get_process().GetMemoryRegions()
for i in range(0, reglist.GetSize()): for i in range(0, reglist.GetSize()):
@ -148,7 +162,7 @@ class RegionInfoReader(object):
regions.append(r) regions.append(r)
return regions return regions
def full_mem(self): def full_mem(self) -> Region:
# TODO: This may not work for Harvard architectures # TODO: This may not work for Harvard architectures
try: try:
sizeptr = int(parse_and_eval('sizeof(void*)')) * 8 sizeptr = int(parse_and_eval('sizeof(void*)')) * 8
@ -157,7 +171,7 @@ class RegionInfoReader(object):
return Region(0, 1 << 64, 0, None, 'full memory') return Region(0, 1 << 64, 0, None, 'full memory')
def _choose_region_info_reader(): def _choose_region_info_reader() -> RegionInfoReader:
return RegionInfoReader() return RegionInfoReader()
@ -169,68 +183,69 @@ BREAK_PATTERN = re.compile('')
BREAK_LOC_PATTERN = re.compile('') BREAK_LOC_PATTERN = re.compile('')
class BreakpointLocation(namedtuple('BaseBreakpointLocation', ['address', 'enabled', 'thread_groups'])):
pass
class BreakpointLocationInfoReader(object): class BreakpointLocationInfoReader(object):
def get_locations(self, breakpoint): def get_locations(self, breakpoint: lldb.SBBreakpoint) -> List[
lldb.SBBreakpointLocation]:
return breakpoint.locations return breakpoint.locations
def _choose_breakpoint_location_info_reader(): def _choose_breakpoint_location_info_reader() -> BreakpointLocationInfoReader:
return BreakpointLocationInfoReader() return BreakpointLocationInfoReader()
BREAKPOINT_LOCATION_INFO_READER = _choose_breakpoint_location_info_reader() BREAKPOINT_LOCATION_INFO_READER = _choose_breakpoint_location_info_reader()
def get_debugger(): def get_debugger() -> lldb.SBDebugger:
return lldb.SBDebugger.FindDebuggerWithID(1) return lldb.SBDebugger.FindDebuggerWithID(1)
def get_target(): def get_target() -> lldb.SBTarget:
return get_debugger().GetTargetAtIndex(0) return get_debugger().GetTargetAtIndex(0)
def get_process(): def get_process() -> lldb.SBProcess:
return get_target().GetProcess() return get_target().GetProcess()
def selected_thread(): def selected_thread() -> lldb.SBThread:
return get_process().GetSelectedThread() return get_process().GetSelectedThread()
def selected_frame(): def selected_frame() -> lldb.SBFrame:
return selected_thread().GetSelectedFrame() return selected_thread().GetSelectedFrame()
def parse_and_eval(expr, signed=False): def parse_and_eval(expr: str, signed: bool = False) -> int:
if signed is True: if signed is True:
return get_eval(expr).GetValueAsSigned() return get_eval(expr).GetValueAsSigned()
return get_eval(expr).GetValueAsUnsigned() return get_eval(expr).GetValueAsUnsigned()
def get_eval(expr): def get_eval(expr: str) -> lldb.SBValue:
eval = get_target().EvaluateExpression(expr) eval = get_target().EvaluateExpression(expr)
if eval.GetError().Fail(): if eval.GetError().Fail():
raise ValueError(eval.GetError().GetCString()) raise ValueError(eval.GetError().GetCString())
return eval return eval
def get_description(object, level=None): def get_description(object: Union[
lldb.SBThread, lldb.SBBreakpoint, lldb.SBWatchpoint, lldb.SBEvent],
level: Optional[int] = None) -> str:
stream = lldb.SBStream() stream = lldb.SBStream()
if level is None: if level is None:
object.GetDescription(stream) object.GetDescription(stream)
else: elif isinstance(object, lldb.SBWatchpoint):
object.GetDescription(stream, level) object.GetDescription(stream, level)
else:
raise ValueError(f"Object {object} does not support description level")
return escape_ansi(stream.GetData()) return escape_ansi(stream.GetData())
conv_map = {} conv_map: Dict[str, str] = {}
def get_convenience_variable(id): def get_convenience_variable(id: str) -> str:
# val = get_target().GetEnvironment().Get(id) # val = get_target().GetEnvironment().Get(id)
if id not in conv_map: if id not in conv_map:
return "auto" return "auto"
@ -240,21 +255,21 @@ def get_convenience_variable(id):
return val return val
def set_convenience_variable(id, value): def set_convenience_variable(id: str, value: str) -> None:
# env = get_target().GetEnvironment() # env = get_target().GetEnvironment()
# return env.Set(id, value, True) # return env.Set(id, value, True)
conv_map[id] = value conv_map[id] = value
def escape_ansi(line): def escape_ansi(line: str) -> str:
ansi_escape = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]') ansi_escape = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]')
return ansi_escape.sub('', line) return ansi_escape.sub('', line)
def debracket(init): def debracket(init: Optional[str]) -> str:
val = init
if init is None: if init is None:
return "" return ""
val = init
val = val.replace("[", "(") val = val.replace("[", "(")
val = val.replace("]", ")") val = val.replace("]", ")")
return val return val

View file

@ -43,6 +43,7 @@ import ghidra.trace.model.program.TraceVariableSnapProgramView;
import ghidra.trace.model.thread.TraceThread; import ghidra.trace.model.thread.TraceThread;
import ghidra.trace.model.time.schedule.PatchStep; import ghidra.trace.model.time.schedule.PatchStep;
import ghidra.trace.model.time.schedule.TraceSchedule; import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
import ghidra.trace.util.TraceRegisterUtils; import ghidra.trace.util.TraceRegisterUtils;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -389,22 +390,47 @@ public enum ControlMode {
*/ */
public DebuggerCoordinates validateCoordinates(PluginTool tool, public DebuggerCoordinates validateCoordinates(PluginTool tool,
DebuggerCoordinates coordinates, ActivationCause cause) { DebuggerCoordinates coordinates, ActivationCause cause) {
if (!followsPresent()) { if (!followsPresent() || cause != ActivationCause.USER) {
return coordinates; return coordinates;
} }
Target target = coordinates.getTarget(); Target target = coordinates.getTarget();
if (target == null) { if (target == null) {
return coordinates; return coordinates;
} }
if (cause == ActivationCause.USER &&
(!coordinates.getTime().isSnapOnly() || coordinates.getSnap() != target.getSnap())) { ScheduleForm form =
tool.setStatusInfo( target.getSupportedTimeForm(coordinates.getObject(), coordinates.getSnap());
"Cannot navigate time in %s mode. Switch to Trace or Emulate mode first." if (form == null) {
if (coordinates.getTime().isSnapOnly() &&
coordinates.getSnap() == target.getSnap()) {
return coordinates;
}
else {
tool.setStatusInfo("""
Cannot navigate time in %s mode. Switch to Trace or Emulate mode first."""
.formatted(name), .formatted(name),
true); true);
return null; return null;
} }
return coordinates; }
TraceSchedule norm = form.validate(coordinates.getTrace(), coordinates.getTime());
if (norm != null) {
return coordinates.time(norm);
}
String errMsg = switch (form) {
case SNAP_ONLY -> """
Target can only navigate to snapshots. Switch to Emulate mode first.""";
case SNAP_EVT_STEPS -> """
Target can only replay steps on the event thread. Switch to Emulate mode \
first.""";
case SNAP_ANY_STEPS -> """
Target cannot perform p-code steps. Switch to Emulate mode first.""";
case SNAP_ANY_STEPS_OPS -> throw new AssertionError();
};
tool.setStatusInfo(errMsg, true);
return null;
} }
protected TracePlatform platformFor(DebuggerCoordinates coordinates, Address address) { protected TracePlatform platformFor(DebuggerCoordinates coordinates, Address address) {

View file

@ -34,8 +34,12 @@ import ghidra.trace.model.breakpoint.TraceBreakpointKind;
import ghidra.trace.model.guest.TracePlatform; import ghidra.trace.model.guest.TracePlatform;
import ghidra.trace.model.memory.TraceMemoryState; import ghidra.trace.model.memory.TraceMemoryState;
import ghidra.trace.model.stack.TraceStackFrame; import ghidra.trace.model.stack.TraceStackFrame;
import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.path.KeyPath; import ghidra.trace.model.target.path.KeyPath;
import ghidra.trace.model.thread.TraceThread; import ghidra.trace.model.thread.TraceThread;
import ghidra.trace.model.time.TraceSnapshot;
import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
import ghidra.util.Swing; import ghidra.util.Swing;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -278,13 +282,51 @@ public interface Target {
* Get the current snapshot key for the target * Get the current snapshot key for the target
* *
* <p> * <p>
* For most targets, this is the most recently created snapshot. * For most targets, this is the most recently created snapshot. For time-traveling targets, if
* may not be. If this returns a negative number, then it refers to a scratch snapshot and
* almost certainly indicates time travel with instruction steps. Use {@link #getTime()} in that
* case to get a more precise schedule.
* *
* @return the snapshot * @return the snapshot
*/ */
// TODO: Should this be TraceSchedule getTime()?
long getSnap(); long getSnap();
/**
* Get the current time
*
* @return the current time
*/
default TraceSchedule getTime() {
long snap = getSnap();
if (snap >= 0) {
return TraceSchedule.snap(snap);
}
TraceSnapshot snapshot = getTrace().getTimeManager().getSnapshot(snap, false);
if (snapshot == null) {
return null;
}
return snapshot.getSchedule();
}
/**
* Get the form of schedules supported by "activate" on the back end
*
* <p>
* A non-null return value indicates the back end supports time travel. If it does, the return
* value indicates the form of schedules that can be activated, (i.e., via some "go to time"
* command). NOTE: Switching threads is considered an event by every time-traveling back end
* that we know of. Events are usually mapped to a Ghidra trace's snapshots, and so most back
* ends are constrained to schedules of the form {@link ScheduleForm#SNAP_EVT_STEPS}. A back-end
* based on emulation may support thread switching. To support p-code op stepping, the back-end
* will certainly have to be based on p-code emulation, and it must be using the same Sleigh
* language as Ghidra.
*
* @param obj the object (or an ancestor) that may support time travel
* @param snap the <em>destination</em> snapshot
* @return the form
*/
public ScheduleForm getSupportedTimeForm(TraceObject obj, long snap);
/** /**
* Collect all actions that implement the given common debugger command * Collect all actions that implement the given common debugger command
* *

View file

@ -402,7 +402,16 @@ public class DebuggerCoordinates {
return new DebuggerCoordinates(trace, platform, target, thread, view, newTime, frame, path); return new DebuggerCoordinates(trace, platform, target, thread, view, newTime, frame, path);
} }
/**
* Get these same coordinates with time replaced by the given schedule
*
* @param newTime the new schedule
* @return the new coordinates
*/
public DebuggerCoordinates time(TraceSchedule newTime) { public DebuggerCoordinates time(TraceSchedule newTime) {
if (Objects.equals(time, newTime)) {
return this;
}
if (trace == null) { if (trace == null) {
return NOWHERE; return NOWHERE;
} }

View file

@ -36,27 +36,36 @@ dependencies {
testImplementation project(path: ':Framework-TraceModeling', configuration: 'testArtifacts') testImplementation project(path: ':Framework-TraceModeling', configuration: 'testArtifacts')
} }
task generateProtoPy { task configureGenerateProtoPy {
ext.srcdir = file("src/main/proto")
ext.src = fileTree(srcdir) {
include "**/*.proto"
}
ext.outdir = file("build/generated/source/proto/main/py")
outputs.dir(outdir)
inputs.files(src)
dependsOn(configurations.protocArtifact) dependsOn(configurations.protocArtifact)
doLast { doLast {
def exe = configurations.protocArtifact.first() def exe = configurations.protocArtifact.first()
if (!isCurrentWindows()) { if (!isCurrentWindows()) {
exe.setExecutable(true) exe.setExecutable(true)
} }
providers.exec { generateProtoPy.commandLine exe
commandLine exe, "--python_out=$outdir", "-I$srcdir" generateProtoPy.args "--python_out=${generateProtoPy.outdir}"
args src generateProtoPy.args "--pyi_out=${generateProtoPy.stubsOutdir}"
}.result.get() generateProtoPy.args "-I${generateProtoPy.srcdir}"
generateProtoPy.args generateProtoPy.src
} }
} }
// Can't use providers.exec, or else we see no output
task generateProtoPy(type:Exec) {
dependsOn(configureGenerateProtoPy)
ext.srcdir = file("src/main/proto")
ext.src = fileTree(srcdir) {
include "**/*.proto"
}
ext.outdir = file("build/generated/source/proto/main/py")
ext.stubsOutdir = file("build/generated/source/proto/main/pyi/ghidratrace")
outputs.dir(outdir)
outputs.dir(stubsOutdir)
inputs.files(src)
}
tasks.assemblePyPackage { tasks.assemblePyPackage {
from(generateProtoPy) { from(generateProtoPy) {
into "src/ghidratrace" into "src/ghidratrace"

View file

@ -14,4 +14,5 @@ src/main/help/help/topics/TraceRmiLauncherServicePlugin/images/GdbTerminal.png||
src/main/py/LICENSE||GHIDRA||||END| src/main/py/LICENSE||GHIDRA||||END|
src/main/py/README.md||GHIDRA||||END| src/main/py/README.md||GHIDRA||||END|
src/main/py/pyproject.toml||GHIDRA||||END| src/main/py/pyproject.toml||GHIDRA||||END|
src/main/py/src/ghidratrace/py.typed||GHIDRA||||END|
src/main/py/tests/EMPTY||GHIDRA||||END| src/main/py/tests/EMPTY||GHIDRA||||END|

View file

@ -44,11 +44,16 @@ public abstract class AbstractTraceRmiConnection implements TraceRmiConnection {
protected void doActivate(TraceObject object, Trace trace, TraceSnapshot snapshot) { protected void doActivate(TraceObject object, Trace trace, TraceSnapshot snapshot) {
DebuggerCoordinates coords = getTraceManager().getCurrent(); DebuggerCoordinates coords = getTraceManager().getCurrent();
if (coords.getTrace() != trace) { if (coords.getTrace() != trace) {
coords = DebuggerCoordinates.NOWHERE; coords = DebuggerCoordinates.NOWHERE.trace(trace);
} }
if (snapshot != null && followsPresent(trace)) { if (snapshot != null && followsPresent(trace)) {
if (snapshot.getKey() > 0 || snapshot.getSchedule() == null) {
coords = coords.snap(snapshot.getKey()); coords = coords.snap(snapshot.getKey());
} }
else {
coords = coords.time(snapshot.getSchedule());
}
}
DebuggerCoordinates finalCoords = object == null ? coords : coords.object(object); DebuggerCoordinates finalCoords = object == null ? coords : coords.object(object);
Swing.runLater(() -> { Swing.runLater(() -> {
DebuggerTraceManagerService traceManager = getTraceManager(); DebuggerTraceManagerService traceManager = getTraceManager();
@ -68,5 +73,4 @@ public abstract class AbstractTraceRmiConnection implements TraceRmiConnection {
} }
}); });
} }
} }

View file

@ -26,6 +26,7 @@ import ghidra.rmi.trace.TraceRmi.*;
import ghidra.trace.model.Trace; import ghidra.trace.model.Trace;
import ghidra.trace.model.target.TraceObject; import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.time.TraceSnapshot; import ghidra.trace.model.time.TraceSnapshot;
import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.util.Msg; import ghidra.util.Msg;
class OpenTrace implements ValueDecoder { class OpenTrace implements ValueDecoder {
@ -79,9 +80,16 @@ class OpenTrace implements ValueDecoder {
trace.release(consumer); trace.release(consumer);
} }
public TraceSnapshot createSnapshot(Snap snap, String description) { public TraceSnapshot createSnapshot(long snap) {
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap.getSnap(), true); TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap, true);
snapshot.setDescription(description); return this.lastSnapshot = snapshot;
}
public TraceSnapshot createSnapshot(TraceSchedule schedule) {
if (schedule.isSnapOnly()) {
return createSnapshot(schedule.getSnap());
}
TraceSnapshot snapshot = trace.getTimeManager().findScratchSnapshot(schedule);
return this.lastSnapshot = snapshot; return this.lastSnapshot = snapshot;
} }

View file

@ -61,12 +61,13 @@ import ghidra.trace.model.target.path.*;
import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName; import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName;
import ghidra.trace.model.target.schema.XmlSchemaContext; import ghidra.trace.model.target.schema.XmlSchemaContext;
import ghidra.trace.model.time.TraceSnapshot; import ghidra.trace.model.time.TraceSnapshot;
import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.util.*; import ghidra.util.*;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.exception.DuplicateFileException; import ghidra.util.exception.DuplicateFileException;
public class TraceRmiHandler extends AbstractTraceRmiConnection { public class TraceRmiHandler extends AbstractTraceRmiConnection {
public static final String VERSION = "11.3"; public static final String VERSION = "11.4";
protected static class VersionMismatchError extends TraceRmiError { protected static class VersionMismatchError extends TraceRmiError {
public VersionMismatchError(String remote) { public VersionMismatchError(String remote) {
@ -740,77 +741,43 @@ public class TraceRmiHandler extends AbstractTraceRmiConnection {
} }
protected static Value makeValue(Object value) { protected static Value makeValue(Object value) {
if (value instanceof Void) { return switch (value) {
return Value.newBuilder().setNullValue(Null.getDefaultInstance()).build(); case Void v -> Value.newBuilder().setNullValue(Null.getDefaultInstance()).build();
} case Boolean b -> Value.newBuilder().setBoolValue(b).build();
if (value instanceof Boolean b) { case Byte b -> Value.newBuilder().setByteValue(b).build();
return Value.newBuilder().setBoolValue(b).build(); case Character c -> Value.newBuilder().setCharValue(c).build();
} case Short s -> Value.newBuilder().setShortValue(s).build();
if (value instanceof Byte b) { case Integer i -> Value.newBuilder().setIntValue(i).build();
return Value.newBuilder().setByteValue(b).build(); case Long l -> Value.newBuilder().setLongValue(l).build();
} case String s -> Value.newBuilder().setStringValue(s).build();
if (value instanceof Character c) { case boolean[] ba -> Value.newBuilder()
return Value.newBuilder().setCharValue(c).build();
}
if (value instanceof Short s) {
return Value.newBuilder().setShortValue(s).build();
}
if (value instanceof Integer i) {
return Value.newBuilder().setIntValue(i).build();
}
if (value instanceof Long l) {
return Value.newBuilder().setLongValue(l).build();
}
if (value instanceof String s) {
return Value.newBuilder().setStringValue(s).build();
}
if (value instanceof boolean[] ba) {
return Value.newBuilder()
.setBoolArrValue( .setBoolArrValue(
BoolArr.newBuilder().addAllArr(Arrays.asList(ArrayUtils.toObject(ba)))) BoolArr.newBuilder().addAllArr(Arrays.asList(ArrayUtils.toObject(ba))))
.build(); .build();
} case byte[] ba -> Value.newBuilder().setBytesValue(ByteString.copyFrom(ba)).build();
if (value instanceof byte[] ba) { case char[] ca -> Value.newBuilder().setCharArrValue(new String(ca)).build();
return Value.newBuilder().setBytesValue(ByteString.copyFrom(ba)).build(); case short[] sa -> Value.newBuilder()
}
if (value instanceof char[] ca) {
return Value.newBuilder().setCharArrValue(new String(ca)).build();
}
if (value instanceof short[] sa) {
return Value.newBuilder()
.setShortArrValue(ShortArr.newBuilder() .setShortArrValue(ShortArr.newBuilder()
.addAllArr( .addAllArr(
Stream.of(ArrayUtils.toObject(sa)).map(s -> (int) s).toList())) Stream.of(ArrayUtils.toObject(sa)).map(s -> (int) s).toList()))
.build(); .build();
} case int[] ia -> Value.newBuilder()
if (value instanceof int[] ia) {
return Value.newBuilder()
.setIntArrValue( .setIntArrValue(
IntArr.newBuilder().addAllArr(IntStream.of(ia).mapToObj(i -> i).toList())) IntArr.newBuilder().addAllArr(IntStream.of(ia).mapToObj(i -> i).toList()))
.build(); .build();
} case long[] la -> Value.newBuilder()
if (value instanceof long[] la) {
return Value.newBuilder()
.setLongArrValue( .setLongArrValue(
LongArr.newBuilder().addAllArr(LongStream.of(la).mapToObj(l -> l).toList())) LongArr.newBuilder().addAllArr(LongStream.of(la).mapToObj(l -> l).toList()))
.build(); .build();
} case String[] sa -> Value.newBuilder()
if (value instanceof String[] sa) {
return Value.newBuilder()
.setStringArrValue(StringArr.newBuilder().addAllArr(List.of(sa))) .setStringArrValue(StringArr.newBuilder().addAllArr(List.of(sa)))
.build(); .build();
} case Address a -> Value.newBuilder().setAddressValue(makeAddr(a)).build();
if (value instanceof Address a) { case AddressRange r -> Value.newBuilder().setRangeValue(makeAddrRange(r)).build();
return Value.newBuilder().setAddressValue(makeAddr(a)).build(); case TraceObject o -> Value.newBuilder().setChildDesc(makeObjDesc(o)).build();
} default -> throw new AssertionError(
if (value instanceof AddressRange r) {
return Value.newBuilder().setRangeValue(makeAddrRange(r)).build();
}
if (value instanceof TraceObject o) {
return Value.newBuilder().setChildDesc(makeObjDesc(o)).build();
}
throw new AssertionError(
"Cannot encode value: " + value + "(type=" + value.getClass() + ")"); "Cannot encode value: " + value + "(type=" + value.getClass() + ")");
};
} }
protected static MethodArgument makeArgument(String name, Object value) { protected static MethodArgument makeArgument(String name, Object value) {
@ -958,8 +925,9 @@ public class TraceRmiHandler extends AbstractTraceRmiConnection {
dis.applyTo(open.trace.getFixedProgramView(snap), monitor); dis.applyTo(open.trace.getFixedProgramView(snap), monitor);
} }
AddressSetView result = dis.getDisassembledAddressSet();
return ReplyDisassemble.newBuilder() return ReplyDisassemble.newBuilder()
.setLength(dis.getDisassembledAddressSet().getNumAddresses()) .setLength(result == null ? 0 : result.getNumAddresses())
.build(); .build();
} }
@ -1180,13 +1148,21 @@ public class TraceRmiHandler extends AbstractTraceRmiConnection {
protected ReplySnapshot handleSnapshot(RequestSnapshot req) { protected ReplySnapshot handleSnapshot(RequestSnapshot req) {
OpenTrace open = requireOpenTrace(req.getOid()); OpenTrace open = requireOpenTrace(req.getOid());
TraceSnapshot snapshot = open.createSnapshot(req.getSnap(), req.getDescription()); TraceSnapshot snapshot = switch (req.getTimeCase()) {
case TIME_NOT_SET -> throw new TraceRmiError("snap or time required");
case SNAP -> open.createSnapshot(req.getSnap().getSnap());
case SCHEDULE -> open
.createSnapshot(TraceSchedule.parse(req.getSchedule().getSchedule()));
};
snapshot.setDescription(req.getDescription());
if (!"".equals(req.getDatetime())) { if (!"".equals(req.getDatetime())) {
Instant instant = Instant instant =
DateTimeFormatter.ISO_INSTANT.parse(req.getDatetime()).query(Instant::from); DateTimeFormatter.ISO_INSTANT.parse(req.getDatetime()).query(Instant::from);
snapshot.setRealTime(instant.toEpochMilli()); snapshot.setRealTime(instant.toEpochMilli());
} }
return ReplySnapshot.getDefaultInstance(); return ReplySnapshot.newBuilder()
.setSnap(Snap.newBuilder().setSnap(snapshot.getKey()))
.build();
} }
protected ReplyStartTx handleStartTx(RequestStartTx req) { protected ReplyStartTx handleStartTx(RequestStartTx req) {

View file

@ -61,11 +61,11 @@ import ghidra.trace.model.target.schema.*;
import ghidra.trace.model.target.schema.PrimitiveTraceObjectSchema.MinimalSchemaContext; import ghidra.trace.model.target.schema.PrimitiveTraceObjectSchema.MinimalSchemaContext;
import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName; import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName;
import ghidra.trace.model.thread.*; import ghidra.trace.model.thread.*;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
import ghidra.util.Msg; import ghidra.util.Msg;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
public class TraceRmiTarget extends AbstractTarget { public class TraceRmiTarget extends AbstractTarget {
class TraceRmiActionEntry implements ActionEntry { class TraceRmiActionEntry implements ActionEntry {
private final RemoteMethod method; private final RemoteMethod method;
private final Map<String, Object> args; private final Map<String, Object> args;
@ -169,6 +169,48 @@ public class TraceRmiTarget extends AbstractTarget {
} }
} }
protected ScheduleForm getSupportedTimeFormByMethod(TraceObject obj) {
KeyPath path = obj.getCanonicalPath();
MatchedMethod activate = matches.getBest(ActivateMatcher.class, path, ActionName.ACTIVATE,
ActivateMatcher.makeBySpecificity(obj.getRoot().getSchema(), path));
if (activate == null) {
return null;
}
if (activate.params.get("time") != null) {
return ScheduleForm.SNAP_ANY_STEPS_OPS;
}
if (activate.params.get("snap") != null) {
return ScheduleForm.SNAP_ONLY;
}
return null;
}
protected ScheduleForm getSupportedTimeFormByAttribute(TraceObject obj, long snap) {
TraceObject eventScope = obj.findSuitableInterface(TraceObjectEventScope.class);
if (eventScope == null) {
return null;
}
TraceObjectValue timeSupportStr =
eventScope.getAttribute(snap, TraceObjectEventScope.KEY_TIME_SUPPORT);
if (timeSupportStr == null) {
return null;
}
return ScheduleForm.valueOf(timeSupportStr.castValue());
}
@Override
public ScheduleForm getSupportedTimeForm(TraceObject obj, long snap) {
ScheduleForm byMethod = getSupportedTimeFormByMethod(obj);
if (byMethod == null) {
return null;
}
ScheduleForm byAttr = getSupportedTimeFormByAttribute(obj, snap);
if (byAttr == null) {
return null;
}
return byMethod.intersect(byAttr);
}
@Override @Override
public TraceExecutionState getThreadExecutionState(TraceThread thread) { public TraceExecutionState getThreadExecutionState(TraceThread thread) {
if (!(thread instanceof TraceObjectThread tot)) { if (!(thread instanceof TraceObjectThread tot)) {
@ -385,7 +427,8 @@ public class TraceRmiTarget extends AbstractTarget {
.orElse(null); .orElse(null);
} }
record ParamAndObjectArg(RemoteParameter param, TraceObject obj) {} record ParamAndObjectArg(RemoteParameter param, TraceObject obj) {
}
protected ParamAndObjectArg getFirstObjectArgument(RemoteMethod method, protected ParamAndObjectArg getFirstObjectArgument(RemoteMethod method,
Map<String, Object> args) { Map<String, Object> args) {
@ -828,7 +871,8 @@ public class TraceRmiTarget extends AbstractTarget {
static final List<ToggleBreakMatcher> SPEC = matchers(HAS_SPEC); static final List<ToggleBreakMatcher> SPEC = matchers(HAS_SPEC);
} }
record MatchKey(Class<? extends MethodMatcher> cls, ActionName action, TraceObjectSchema sch) {} record MatchKey(Class<? extends MethodMatcher> cls, ActionName action, TraceObjectSchema sch) {
}
protected class Matches { protected class Matches {
private final Map<MatchKey, MatchedMethod> map = new HashMap<>(); private final Map<MatchKey, MatchedMethod> map = new HashMap<>();
@ -975,7 +1019,8 @@ public class TraceRmiTarget extends AbstractTarget {
@Override @Override
public CompletableFuture<Void> activateAsync(DebuggerCoordinates prev, public CompletableFuture<Void> activateAsync(DebuggerCoordinates prev,
DebuggerCoordinates coords) { DebuggerCoordinates coords) {
if (prev.getSnap() != coords.getSnap()) { boolean timeNeq = !Objects.equals(prev.getTime(), coords.getTime());
if (timeNeq) {
requestCaches.invalidate(); requestCaches.invalidate();
} }
TraceObject object = coords.getObject(); TraceObject object = coords.getObject();
@ -983,10 +1028,9 @@ public class TraceRmiTarget extends AbstractTarget {
return AsyncUtils.nil(); return AsyncUtils.nil();
} }
MatchedMethod activate = KeyPath path = object.getCanonicalPath();
matches.getBest(ActivateMatcher.class, object.getCanonicalPath(), ActionName.ACTIVATE, MatchedMethod activate = matches.getBest(ActivateMatcher.class, path, ActionName.ACTIVATE,
() -> ActivateMatcher.makeBySpecificity(trace.getObjectManager().getRootSchema(), ActivateMatcher.makeBySpecificity(object.getRoot().getSchema(), path));
object.getCanonicalPath()));
if (activate == null) { if (activate == null) {
return AsyncUtils.nil(); return AsyncUtils.nil();
} }
@ -996,11 +1040,11 @@ public class TraceRmiTarget extends AbstractTarget {
args.put(paramFocus.name(), args.put(paramFocus.name(),
object.findSuitableSchema(getSchemaContext().getSchema(paramFocus.type()))); object.findSuitableSchema(getSchemaContext().getSchema(paramFocus.type())));
RemoteParameter paramTime = activate.params.get("time"); RemoteParameter paramTime = activate.params.get("time");
if (paramTime != null) { if (paramTime != null && (paramTime.required() || timeNeq)) {
args.put(paramTime.name(), coords.getTime().toString()); args.put(paramTime.name(), coords.getTime().toString());
} }
RemoteParameter paramSnap = activate.params.get("snap"); RemoteParameter paramSnap = activate.params.get("snap");
if (paramSnap != null) { if (paramSnap != null && (paramSnap.required() || timeNeq)) {
args.put(paramSnap.name(), coords.getSnap()); args.put(paramSnap.name(), coords.getSnap());
} }
return activate.method.invokeAsync(args).toCompletableFuture().thenApply(__ -> null); return activate.method.invokeAsync(args).toCompletableFuture().thenApply(__ -> null);

View file

@ -56,6 +56,10 @@ message Snap {
int64 snap = 1; int64 snap = 1;
} }
message Schedule {
string schedule = 1;
}
message Span { message Span {
int64 min = 1; int64 min = 1;
int64 max = 2; int64 max = 2;
@ -392,10 +396,14 @@ message RequestSnapshot {
DomObjId oid = 1; DomObjId oid = 1;
string description = 2; string description = 2;
string datetime = 3; string datetime = 3;
oneof time {
Snap snap = 4; Snap snap = 4;
Schedule schedule = 5;
}
} }
message ReplySnapshot { message ReplySnapshot {
Snap snap = 1;
} }
// Client commands // Client commands

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "ghidratrace" name = "ghidratrace"
version = "11.3" version = "11.4"
authors = [ authors = [
{ name="Ghidra Development Team" }, { name="Ghidra Development Team" },
] ]
@ -23,3 +23,6 @@ dependencies = [
[project.urls] [project.urls]
"Homepage" = "https://github.com/NationalSecurityAgency/ghidra" "Homepage" = "https://github.com/NationalSecurityAgency/ghidra"
"Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues" "Bug Tracker" = "https://github.com/NationalSecurityAgency/ghidra/issues"
[tool.setuptools.package-data]
ghidratrace = ["py.typed"]

View file

@ -0,0 +1,114 @@
## ###
# IP: GHIDRA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from concurrent.futures import Future
from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union
from .client import Address, TraceObject, TraceObjectValue
T = TypeVar('T')
def wait_opt(val: Union[T, Future[T], None]) -> Optional[T]:
if val is None:
return None
if isinstance(val, Future):
return val.result()
return val
def wait(val: Union[T, Future[T]]) -> T:
if isinstance(val, Future):
return val.result()
return val
class TableColumn(object):
def __init__(self, head: str) -> None:
self.head = head
self.contents = [head]
self.is_last = False
def add_data(self, data: str) -> None:
self.contents.append(data)
def finish(self) -> None:
self.width = max(len(d) for d in self.contents) + 1
def format_cell(self, i: int) -> str:
return (self.contents[i] if self.is_last
else self.contents[i].ljust(self.width))
class Tabular(object):
def __init__(self, heads: List[str]) -> None:
self.columns = [TableColumn(h) for h in heads]
self.columns[-1].is_last = True
self.num_rows = 1
def add_row(self, datas: List[str]) -> None:
for c, d in zip(self.columns, datas):
c.add_data(d)
self.num_rows += 1
def print_table(self, println: Callable[[str], None]) -> None:
for c in self.columns:
c.finish()
for rn in range(self.num_rows):
println(''.join(c.format_cell(rn) for c in self.columns))
def repr_or_future(val: Union[T, Future[T]]) -> str:
if isinstance(val, Future):
if val.done():
return str(val.result())
else:
return "<Future>"
else:
return str(val)
def obj_repr(obj: TraceObject) -> str:
if obj.path is None:
if obj.id is None:
return "<ERR: no path nor id>"
else:
return f"<id={repr_or_future(obj.id)}>"
elif isinstance(obj.path, Future):
if obj.path.done():
return obj.path.result()
elif obj.id is None:
return "<path=<Future>>"
else:
return f"<id={repr_or_future(obj.id)}>"
else:
return obj.path
def val_repr(value: Any) -> str:
if isinstance(value, TraceObject):
return obj_repr(value)
elif isinstance(value, Address):
return f'{value.space}:{value.offset:08x}'
return repr(value)
def print_tabular_values(values: Sequence[TraceObjectValue],
println: Callable[[str], None]) -> None:
table = Tabular(['Parent', 'Key', 'Span', 'Value', 'Type'])
for v in values:
table.add_row([obj_repr(v.parent), v.key, str(v.span),
val_repr(v.value), str(v.schema)])
table.print_table(println)

View file

@ -16,7 +16,6 @@
from dataclasses import dataclass from dataclasses import dataclass
# Use instances as type annotations or as schema
@dataclass(frozen=True) @dataclass(frozen=True)
class Schema: class Schema:
name: str name: str
@ -25,6 +24,7 @@ class Schema:
return self.name return self.name
UNSPECIFIED = Schema('')
ANY = Schema('ANY') ANY = Schema('ANY')
OBJECT = Schema('OBJECT') OBJECT = Schema('OBJECT')
VOID = Schema('VOID') VOID = Schema('VOID')

View file

@ -13,21 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
## ##
from concurrent.futures import Future
import socket import socket
import traceback from typing import TypeVar
from google.protobuf import message as _message
M = TypeVar('M', bound=_message.Message)
def send_length(s, value): def send_length(s: socket.socket, value: int) -> None:
s.sendall(value.to_bytes(4, 'big')) s.sendall(value.to_bytes(4, 'big'))
def send_delimited(s, msg): def send_delimited(s: socket.socket, msg: _message.Message) -> None:
data = msg.SerializeToString() data = msg.SerializeToString()
send_length(s, len(data)) send_length(s, len(data))
s.sendall(data) s.sendall(data)
def recv_all(s, size): def recv_all(s, size: int) -> bytes:
buf = b'' buf = b''
while len(buf) < size: while len(buf) < size:
part = s.recv(size - len(buf)) part = s.recv(size - len(buf))
@ -38,14 +43,14 @@ def recv_all(s, size):
# return s.recv(size, socket.MSG_WAITALL) # return s.recv(size, socket.MSG_WAITALL)
def recv_length(s): def recv_length(s: socket.socket) -> int:
buf = recv_all(s, 4) buf = recv_all(s, 4)
if len(buf) < 4: if len(buf) < 4:
raise Exception("Socket closed") raise Exception("Socket closed")
return int.from_bytes(buf, 'big') return int.from_bytes(buf, 'big')
def recv_delimited(s, msg, dbg_seq): def recv_delimited(s: socket.socket, msg: M, dbg_seq: int) -> M:
size = recv_length(s) size = recv_length(s)
buf = recv_all(s, size) buf = recv_all(s, size)
if len(buf) < size: if len(buf) < size:

View file

@ -513,8 +513,9 @@ public class ObjectTableModel extends AbstractQueryTableModel<ValueRow> {
} }
protected Lifespan computeFullRange() { protected Lifespan computeFullRange() {
Long max = getTrace() == null ? null : getTrace().getTimeManager().getMaxSnap(); Long maxBoxed = getTrace() == null ? null : getTrace().getTimeManager().getMaxSnap();
return Lifespan.span(0L, max == null ? 1 : max + 1); long max = maxBoxed == null ? 0 : maxBoxed;
return Lifespan.span(0L, max == Lifespan.DOMAIN.lmax() ? max : (max + 1));
} }
protected void updateTimelineMax() { protected void updateTimelineMax() {

View file

@ -122,8 +122,9 @@ public class PathTableModel extends AbstractQueryTableModel<PathRow> {
} }
protected void updateTimelineMax() { protected void updateTimelineMax() {
Long max = getTrace() == null ? null : getTrace().getTimeManager().getMaxSnap(); Long maxBoxed = getTrace() == null ? null : getTrace().getTimeManager().getMaxSnap();
Lifespan fullRange = Lifespan.span(0L, max == null ? 1 : max + 1); long max = maxBoxed == null ? 0 : maxBoxed;
Lifespan fullRange = Lifespan.span(0L, max == Lifespan.DOMAIN.lmax() ? max : (max + 1));
lifespanPlotColumn.setFullRange(fullRange); lifespanPlotColumn.setFullRange(fullRange);
} }

View file

@ -17,8 +17,7 @@ package ghidra.app.plugin.core.debug.gui.time;
import java.awt.BorderLayout; import java.awt.BorderLayout;
import java.awt.Component; import java.awt.Component;
import java.util.Collection; import java.util.*;
import java.util.Objects;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -28,6 +27,7 @@ import javax.swing.table.*;
import docking.widgets.table.*; import docking.widgets.table.*;
import docking.widgets.table.DefaultEnumeratedColumnTableModel.EnumeratedTableColumn; import docking.widgets.table.DefaultEnumeratedColumnTableModel.EnumeratedTableColumn;
import ghidra.debug.api.tracemgr.DebuggerCoordinates;
import ghidra.docking.settings.Settings; import ghidra.docking.settings.Settings;
import ghidra.framework.model.DomainObjectEvent; import ghidra.framework.model.DomainObjectEvent;
import ghidra.framework.plugintool.PluginTool; import ghidra.framework.plugintool.PluginTool;
@ -112,9 +112,6 @@ public class DebuggerSnapshotTablePanel extends JPanel {
} }
SnapshotRow row = new SnapshotRow(currentTrace, snapshot); SnapshotRow row = new SnapshotRow(currentTrace, snapshot);
snapshotTableModel.add(row); snapshotTableModel.add(row);
if (currentSnap == snapshot.getKey()) {
snapshotFilterPanel.setSelectedItem(row);
}
} }
private void snapChanged(TraceSnapshot snapshot) { private void snapChanged(TraceSnapshot snapshot) {
@ -132,7 +129,7 @@ public class DebuggerSnapshotTablePanel extends JPanel {
} }
} }
final TableCellRenderer boldCurrentRenderer = new AbstractGColumnRenderer<Object>() { final TableCellRenderer styleCurrentRenderer = new AbstractGColumnRenderer<Object>() {
@Override @Override
public String getFilterString(Object t, Settings settings) { public String getFilterString(Object t, Settings settings) {
return t == null ? "<null>" : t.toString(); return t == null ? "<null>" : t.toString();
@ -142,9 +139,16 @@ public class DebuggerSnapshotTablePanel extends JPanel {
public Component getTableCellRendererComponent(GTableCellRenderingData data) { public Component getTableCellRendererComponent(GTableCellRenderingData data) {
super.getTableCellRendererComponent(data); super.getTableCellRendererComponent(data);
SnapshotRow row = (SnapshotRow) data.getRowObject(); SnapshotRow row = (SnapshotRow) data.getRowObject();
if (row != null && currentSnap != null && currentSnap.longValue() == row.getSnap()) { if (row == null || current == DebuggerCoordinates.NOWHERE) {
// When used in a dialog, only currentTrace is set
return this;
}
if (current.getViewSnap() == row.getSnap()) {
setBold(); setBold();
} }
else if (current.getSnap() == row.getSnap()) {
setItalic();
}
return this; return this;
} }
}; };
@ -155,7 +159,7 @@ public class DebuggerSnapshotTablePanel extends JPanel {
protected boolean hideScratch = true; protected boolean hideScratch = true;
private Trace currentTrace; private Trace currentTrace;
private volatile Long currentSnap; private volatile DebuggerCoordinates current = DebuggerCoordinates.NOWHERE;
protected final SnapshotListener listener = new SnapshotListener(); protected final SnapshotListener listener = new SnapshotListener();
@ -173,19 +177,19 @@ public class DebuggerSnapshotTablePanel extends JPanel {
TableColumnModel columnModel = snapshotTable.getColumnModel(); TableColumnModel columnModel = snapshotTable.getColumnModel();
TableColumn snapCol = columnModel.getColumn(SnapshotTableColumns.SNAP.ordinal()); TableColumn snapCol = columnModel.getColumn(SnapshotTableColumns.SNAP.ordinal());
snapCol.setPreferredWidth(40); snapCol.setPreferredWidth(40);
snapCol.setCellRenderer(boldCurrentRenderer); snapCol.setCellRenderer(styleCurrentRenderer);
TableColumn timeCol = columnModel.getColumn(SnapshotTableColumns.TIMESTAMP.ordinal()); TableColumn timeCol = columnModel.getColumn(SnapshotTableColumns.TIMESTAMP.ordinal());
timeCol.setPreferredWidth(200); timeCol.setPreferredWidth(200);
timeCol.setCellRenderer(boldCurrentRenderer); timeCol.setCellRenderer(styleCurrentRenderer);
TableColumn etCol = columnModel.getColumn(SnapshotTableColumns.EVENT_THREAD.ordinal()); TableColumn etCol = columnModel.getColumn(SnapshotTableColumns.EVENT_THREAD.ordinal());
etCol.setPreferredWidth(40); etCol.setPreferredWidth(40);
etCol.setCellRenderer(boldCurrentRenderer); etCol.setCellRenderer(styleCurrentRenderer);
TableColumn schdCol = columnModel.getColumn(SnapshotTableColumns.SCHEDULE.ordinal()); TableColumn schdCol = columnModel.getColumn(SnapshotTableColumns.SCHEDULE.ordinal());
schdCol.setPreferredWidth(60); schdCol.setPreferredWidth(60);
schdCol.setCellRenderer(boldCurrentRenderer); schdCol.setCellRenderer(styleCurrentRenderer);
TableColumn descCol = columnModel.getColumn(SnapshotTableColumns.DESCRIPTION.ordinal()); TableColumn descCol = columnModel.getColumn(SnapshotTableColumns.DESCRIPTION.ordinal());
descCol.setPreferredWidth(200); descCol.setPreferredWidth(200);
descCol.setCellRenderer(boldCurrentRenderer); descCol.setCellRenderer(styleCurrentRenderer);
} }
private void addNewListeners() { private void addNewListeners() {
@ -235,14 +239,18 @@ public class DebuggerSnapshotTablePanel extends JPanel {
return; return;
} }
TraceTimeManager manager = currentTrace.getTimeManager(); TraceTimeManager manager = currentTrace.getTimeManager();
Collection<? extends TraceSnapshot> snapshots =
hideScratch ? manager.getSnapshots(0, true, Long.MAX_VALUE, true) List<SnapshotRow> toAdd = new ArrayList<>();
: manager.getAllSnapshots(); for (TraceSnapshot snapshot : hideScratch
// Use .collect instead of .toList to avoid size/sync issues ? manager.getSnapshots(0, true, Long.MAX_VALUE, true)
// Even though access is synchronized, size may change during iteration : manager.getAllSnapshots()) {
snapshotTableModel.addAll(snapshots.stream() SnapshotRow row = new SnapshotRow(currentTrace, snapshot);
.map(s -> new SnapshotRow(currentTrace, s)) toAdd.add(row);
.collect(Collectors.toList())); if (current != DebuggerCoordinates.NOWHERE &&
snapshot.getKey() == current.getViewSnap()) {
}
}
snapshotTableModel.addAll(toAdd);
} }
protected void deleteScratchSnapshots() { protected void deleteScratchSnapshots() {
@ -270,10 +278,13 @@ public class DebuggerSnapshotTablePanel extends JPanel {
return row == null ? null : row.getSnap(); return row == null ? null : row.getSnap();
} }
public void setCurrentSnapshot(Long snap) { public void setCurrent(DebuggerCoordinates coords) {
currentSnap = snap; boolean fire = coords.getViewSnap() != current.getViewSnap();
current = coords;
if (fire) {
snapshotTableModel.fireTableDataChanged(); snapshotTableModel.fireTableDataChanged();
} }
}
public void setSelectedSnapshot(Long snap) { public void setSelectedSnapshot(Long snap) {
if (snap == null) { if (snap == null) {

View file

@ -19,7 +19,6 @@ import static ghidra.app.plugin.core.debug.gui.DebuggerResources.*;
import java.awt.event.*; import java.awt.event.*;
import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles;
import java.util.Objects;
import javax.swing.Icon; import javax.swing.Icon;
import javax.swing.JComponent; import javax.swing.JComponent;
@ -37,6 +36,8 @@ import ghidra.framework.plugintool.*;
import ghidra.framework.plugintool.AutoService.Wiring; import ghidra.framework.plugintool.AutoService.Wiring;
import ghidra.framework.plugintool.annotation.AutoConfigStateField; import ghidra.framework.plugintool.annotation.AutoConfigStateField;
import ghidra.framework.plugintool.annotation.AutoServiceConsumed; import ghidra.framework.plugintool.annotation.AutoServiceConsumed;
import ghidra.trace.model.Trace;
import ghidra.trace.model.time.TraceSnapshot;
import ghidra.trace.model.time.schedule.TraceSchedule; import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.util.HelpLocation; import ghidra.util.HelpLocation;
@ -62,16 +63,6 @@ public class DebuggerTimeProvider extends ComponentProviderAdapter {
} }
} }
protected static boolean sameCoordinates(DebuggerCoordinates a, DebuggerCoordinates b) {
if (!Objects.equals(a.getTrace(), b.getTrace())) {
return false;
}
if (!Objects.equals(a.getTime(), b.getTime())) {
return false;
}
return true;
}
protected final DebuggerTimePlugin plugin; protected final DebuggerTimePlugin plugin;
DebuggerCoordinates current = DebuggerCoordinates.NOWHERE; DebuggerCoordinates current = DebuggerCoordinates.NOWHERE;
@ -154,7 +145,7 @@ public class DebuggerTimeProvider extends ComponentProviderAdapter {
@Override @Override
public void mouseClicked(MouseEvent e) { public void mouseClicked(MouseEvent e) {
if (e.getClickCount() == 2 && e.getButton() == MouseEvent.BUTTON1) { if (e.getClickCount() == 2 && e.getButton() == MouseEvent.BUTTON1) {
activateSelectedSnapshot(); activateSelectedSnapshot(e);
} }
} }
}); });
@ -162,18 +153,44 @@ public class DebuggerTimeProvider extends ComponentProviderAdapter {
@Override @Override
public void keyPressed(KeyEvent e) { public void keyPressed(KeyEvent e) {
if (e.getKeyCode() == KeyEvent.VK_ENTER) { if (e.getKeyCode() == KeyEvent.VK_ENTER) {
activateSelectedSnapshot(); activateSelectedSnapshot(e);
e.consume(); // lest it select the next row down e.consume(); // lest it select the next row down
} }
} }
}); });
} }
private void activateSelectedSnapshot() { private TraceSchedule computeSelectedSchedule(InputEvent e, long snap) {
Long snap = mainPanel.getSelectedSnapshot(); if ((e.getModifiersEx() & InputEvent.SHIFT_DOWN_MASK) != 0) {
if (snap != null && traceManager != null) { return TraceSchedule.snap(snap);
traceManager.activateSnap(snap);
} }
if (snap >= 0) {
return TraceSchedule.snap(snap);
}
Trace trace = current.getTrace();
if (trace == null) {
return TraceSchedule.snap(snap);
}
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap, false);
if (snapshot == null) { // Really shouldn't happen, but okay
return TraceSchedule.snap(snap);
}
TraceSchedule schedule = snapshot.getSchedule();
if (schedule == null) {
return TraceSchedule.snap(snap);
}
return schedule;
}
private void activateSelectedSnapshot(InputEvent e) {
if (traceManager == null) {
return;
}
Long snap = mainPanel.getSelectedSnapshot();
if (snap == null) {
return;
}
traceManager.activateTime(computeSelectedSchedule(e, snap));
} }
protected void createActions() { protected void createActions() {
@ -202,14 +219,9 @@ public class DebuggerTimeProvider extends ComponentProviderAdapter {
} }
public void coordinatesActivated(DebuggerCoordinates coordinates) { public void coordinatesActivated(DebuggerCoordinates coordinates) {
if (sameCoordinates(current, coordinates)) {
current = coordinates; current = coordinates;
return;
}
current = coordinates;
mainPanel.setTrace(current.getTrace()); mainPanel.setTrace(current.getTrace());
mainPanel.setCurrentSnapshot(current.getSnap()); mainPanel.setCurrent(current);
} }
public void writeConfigState(SaveState saveState) { public void writeConfigState(SaveState saveState) {

View file

@ -634,27 +634,6 @@ public class DebuggerEmulationServicePlugin extends Plugin implements DebuggerEm
return task.future; return task.future;
} }
protected TraceSnapshot findScratch(Trace trace, TraceSchedule time) {
Collection<? extends TraceSnapshot> exist =
trace.getTimeManager().getSnapshotsWithSchedule(time);
if (!exist.isEmpty()) {
return exist.iterator().next();
}
/**
* TODO: This could be more sophisticated.... Does it need to be, though? Ideally, we'd only
* keep state around that has annotations, e.g., bookmarks and code units. That needs a new
* query (latestStartSince) on those managers, though. It must find the latest start tick
* since a given snap. We consider only start snaps because placed code units go "from now
* on out".
*/
TraceSnapshot last = trace.getTimeManager().getMostRecentSnapshot(-1);
long snap = last == null ? Long.MIN_VALUE : last.getKey() + 1;
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap, true);
snapshot.setDescription("Emulated");
snapshot.setSchedule(time);
return snapshot;
}
protected void installBreakpoints(Trace trace, long snap, DebuggerPcodeMachine<?> emu) { protected void installBreakpoints(Trace trace, long snap, DebuggerPcodeMachine<?> emu) {
Lifespan span = Lifespan.at(snap); Lifespan span = Lifespan.at(snap);
TraceBreakpointManager bm = trace.getBreakpointManager(); TraceBreakpointManager bm = trace.getBreakpointManager();
@ -753,7 +732,8 @@ public class DebuggerEmulationServicePlugin extends Plugin implements DebuggerEm
protected TraceSnapshot writeToScratch(CacheKey key, CachedEmulator ce) { protected TraceSnapshot writeToScratch(CacheKey key, CachedEmulator ce) {
TraceSnapshot destSnap; TraceSnapshot destSnap;
try (Transaction tx = key.trace.openTransaction("Emulate")) { try (Transaction tx = key.trace.openTransaction("Emulate")) {
destSnap = findScratch(key.trace, key.time); destSnap = key.trace.getTimeManager().findScratchSnapshot(key.time);
destSnap.setDescription("Emulated");
try { try {
ce.emulator().writeDown(key.platform, destSnap.getKey(), key.time.getSnap()); ce.emulator().writeDown(key.platform, destSnap.getKey(), key.time.getSnap());
} }

View file

@ -726,10 +726,42 @@ public class DebuggerTraceManagerServicePlugin extends Plugin
@Override @Override
public CompletableFuture<Long> materialize(DebuggerCoordinates coordinates) { public CompletableFuture<Long> materialize(DebuggerCoordinates coordinates) {
return materialize(DebuggerCoordinates.NOWHERE, coordinates, ActivationCause.USER);
}
protected CompletableFuture<Long> materialize(DebuggerCoordinates previous,
DebuggerCoordinates coordinates, ActivationCause cause) {
/**
* NOTE: If we're requested the snapshot, we don't care if we can find the snapshot already
* materialized. We're going to let the back end actually materialize and activate. When it
* activates (check the cause), we'll look for the materialized snapshot.
*
* If we go to a found snapshot on our request, the back-end may intermittently revert to
* the another snapshot, because it may not realize what we've done at the front end, or it
* may re-validate the request and go elsewhere, resulting in abrasive navigation. While we
* could attempt some bookkeeping on the back-end, we don't control how the native debugger
* issues events, so it's easier to just give it our request and then let it drive.
*/
ControlMode mode = getEffectiveControlMode(coordinates.getTrace());
Target target = coordinates.getTarget();
// NOTE: We've already validated at this point
if (mode.isTarget() && cause == ActivationCause.USER && target != null) {
// Yes, use getSnap for the materialized (view) snapshot
return target.activateAsync(previous, coordinates).thenApply(__ -> target.getSnap());
}
Long found = findSnapshot(coordinates); Long found = findSnapshot(coordinates);
if (found != null) { if (found != null) {
return CompletableFuture.completedFuture(found); return CompletableFuture.completedFuture(found);
} }
/**
* NOTE: We can actually reach this point in RO_TARGET mode, though ordinarily, it should
* only reach here in RW_EMULATOR mode. The reason is because during many automated tests,
* the "default" mode of RO_TARGET is taken as the effective mode, and the tests still
* expect emulation behavior. So do it.
*/
if (emulationService == null) { if (emulationService == null) {
Msg.warn(this, "Cannot navigate to coordinates with execution schedules, " + Msg.warn(this, "Cannot navigate to coordinates with execution schedules, " +
"because the emulation service is not available."); "because the emulation service is not available.");
@ -738,16 +770,20 @@ public class DebuggerTraceManagerServicePlugin extends Plugin
return emulationService.backgroundEmulate(coordinates.getPlatform(), coordinates.getTime()); return emulationService.backgroundEmulate(coordinates.getPlatform(), coordinates.getTime());
} }
protected CompletableFuture<Void> prepareViewAndFireEvent(DebuggerCoordinates coordinates, protected CompletableFuture<Void> prepareViewAndFireEvent(DebuggerCoordinates previous,
ActivationCause cause) { DebuggerCoordinates coordinates, ActivationCause cause) {
TraceVariableSnapProgramView varView = (TraceVariableSnapProgramView) coordinates.getView(); TraceVariableSnapProgramView varView = (TraceVariableSnapProgramView) coordinates.getView();
if (varView == null) { // Should only happen with NOWHERE if (varView == null) { // Should only happen with NOWHERE
fireLocationEvent(coordinates, cause); fireLocationEvent(coordinates, cause);
return AsyncUtils.nil(); return AsyncUtils.nil();
} }
return materialize(coordinates).thenAcceptAsync(snap -> { return materialize(previous, coordinates, cause).thenAcceptAsync(snap -> {
if (snap == null) { if (snap == null) {
return; // The tool is probably closing /**
* Either the tool is closing, or we're going to let the target materialize and
* activate the actual snap.
*/
return;
} }
if (!coordinates.equals(current)) { if (!coordinates.equals(current)) {
return; // We navigated elsewhere before emulation completed return; // We navigated elsewhere before emulation completed
@ -1150,22 +1186,13 @@ public class DebuggerTraceManagerServicePlugin extends Plugin
return AsyncUtils.nil(); return AsyncUtils.nil();
} }
CompletableFuture<Void> future = CompletableFuture<Void> future =
prepareViewAndFireEvent(resolved, cause).exceptionally(ex -> { prepareViewAndFireEvent(prev, resolved, cause).exceptionally(ex -> {
// Emulation service will already display error // Emulation service will already display error
doSetCurrent(prev); doSetCurrent(prev);
return null; return null;
}); });
if (cause != ActivationCause.USER) {
return future; return future;
} }
Target target = resolved.getTarget();
if (target == null) {
return future;
}
return future.thenCompose(__ -> target.activateAsync(prev, resolved));
}
@Override @Override
public void activate(DebuggerCoordinates coordinates, ActivationCause cause) { public void activate(DebuggerCoordinates coordinates, ActivationCause cause) {

View file

@ -24,6 +24,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -69,8 +70,7 @@ import ghidra.trace.database.ToyDBTraceBuilder;
import ghidra.trace.model.Trace; import ghidra.trace.model.Trace;
import ghidra.trace.model.target.schema.SchemaContext; import ghidra.trace.model.target.schema.SchemaContext;
import ghidra.trace.model.target.schema.XmlSchemaContext; import ghidra.trace.model.target.schema.XmlSchemaContext;
import ghidra.util.InvalidNameException; import ghidra.util.*;
import ghidra.util.NumericUtilities;
import ghidra.util.datastruct.TestDataStructureErrorHandlerInstaller; import ghidra.util.datastruct.TestDataStructureErrorHandlerInstaller;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.ConsoleTaskMonitor; import ghidra.util.task.ConsoleTaskMonitor;
@ -310,6 +310,25 @@ public abstract class AbstractGhidraHeadedDebuggerTest
}, () -> lastError.get().getMessage()); }, () -> lastError.get().getMessage());
} }
public static void waitForPass(Object originator, Runnable runnable, long duration,
TimeUnit unit) {
long start = System.currentTimeMillis();
while (System.currentTimeMillis() - start < unit.toMillis(duration)) {
try {
waitForPass(runnable);
break;
}
catch (Throwable e) {
Msg.warn(originator, "Long wait: " + e);
try {
Thread.sleep(500);
}
catch (InterruptedException e1) {
}
}
}
}
public static <T> T waitForPass(Supplier<T> supplier) { public static <T> T waitForPass(Supplier<T> supplier) {
var locals = new Object() { var locals = new Object() {
AssertionError lastError; AssertionError lastError;

View file

@ -33,8 +33,10 @@ import ghidra.trace.model.breakpoint.TraceBreakpoint;
import ghidra.trace.model.breakpoint.TraceBreakpointKind; import ghidra.trace.model.breakpoint.TraceBreakpointKind;
import ghidra.trace.model.guest.TracePlatform; import ghidra.trace.model.guest.TracePlatform;
import ghidra.trace.model.stack.TraceStackFrame; import ghidra.trace.model.stack.TraceStackFrame;
import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.path.KeyPath; import ghidra.trace.model.target.path.KeyPath;
import ghidra.trace.model.thread.TraceThread; import ghidra.trace.model.thread.TraceThread;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -70,6 +72,11 @@ public class MockTarget implements Target {
return snap; return snap;
} }
@Override
public ScheduleForm getSupportedTimeForm(TraceObject obj, long snap) {
return null;
}
@Override @Override
public Map<String, ActionEntry> collectActions(ActionName name, ActionContext context, public Map<String, ActionEntry> collectActions(ActionName name, ActionContext context,
ObjectArgumentPolicy policy) { ObjectArgumentPolicy policy) {

View file

@ -201,6 +201,9 @@ public class DBTraceTimeViewport implements TraceTimeViewport {
while (true) { while (true) {
TraceSnapshot fork = locateMostRecentFork(timeManager, curSnap); TraceSnapshot fork = locateMostRecentFork(timeManager, curSnap);
long prevSnap = fork == null ? Long.MIN_VALUE : fork.getKey(); long prevSnap = fork == null ? Long.MIN_VALUE : fork.getKey();
if (curSnap >= 0 && prevSnap < 0) {
prevSnap = 0;
}
if (!addSnapRange(prevSnap, curSnap, spanSet, ordered)) { if (!addSnapRange(prevSnap, curSnap, spanSet, ordered)) {
return; return;
} }

View file

@ -585,6 +585,9 @@ public class DBTraceMemorySpace
protected void doPutFutureBytes(OffsetSnap loc, ByteBuffer buf, int dstOffset, int maxLen, protected void doPutFutureBytes(OffsetSnap loc, ByteBuffer buf, int dstOffset, int maxLen,
OutSnap lastSnap, Set<TraceAddressSnapRange> changed) throws IOException { OutSnap lastSnap, Set<TraceAddressSnapRange> changed) throws IOException {
if (loc.snap == Lifespan.DOMAIN.lmax()) {
return;
}
// NOTE: Do not leave the buffer advanced from here // NOTE: Do not leave the buffer advanced from here
int pos = buf.position(); int pos = buf.position();
// exclusive? // exclusive?
@ -616,7 +619,7 @@ public class DBTraceMemorySpace
} }
} }
if (!remaining.isEmpty()) { if (!remaining.isEmpty()) {
lastSnap.snap = Long.MAX_VALUE; lastSnap.snap = Lifespan.DOMAIN.lmax();
for (AddressRange rng : remaining) { for (AddressRange rng : remaining) {
changed.add( changed.add(
new ImmutableTraceAddressSnapRange(rng, Lifespan.nowOnMaybeScratch(loc.snap))); new ImmutableTraceAddressSnapRange(rng, Lifespan.nowOnMaybeScratch(loc.snap)));

View file

@ -133,6 +133,26 @@ public class DBTraceTimeManager implements TraceTimeManager, DBTraceManager {
return snapshotsBySchedule.get(schedule.toString()); return snapshotsBySchedule.get(schedule.toString());
} }
@Override
public TraceSnapshot findScratchSnapshot(TraceSchedule schedule) {
Collection<? extends TraceSnapshot> exist = getSnapshotsWithSchedule(schedule);
if (!exist.isEmpty()) {
return exist.iterator().next();
}
/**
* TODO: This could be more sophisticated.... Does it need to be, though? Ideally, we'd only
* keep state around that has annotations, e.g., bookmarks and code units. That needs a new
* query (latestStartSince) on those managers, though. It must find the latest start tick
* since a given snap. We consider only start snaps because placed code units go "from now
* on out".
*/
TraceSnapshot last = getMostRecentSnapshot(-1);
long snap = last == null ? Long.MIN_VALUE : last.getKey() + 1;
TraceSnapshot snapshot = getSnapshot(snap, true);
snapshot.setSchedule(schedule);
return snapshot;
}
@Override @Override
public Collection<? extends DBTraceSnapshot> getAllSnapshots() { public Collection<? extends DBTraceSnapshot> getAllSnapshots() {
return Collections.unmodifiableCollection(snapshotStore.asMap().values()); return Collections.unmodifiableCollection(snapshotStore.asMap().values());

View file

@ -806,6 +806,6 @@ public interface TraceObject extends TraceUniqueObject {
if (stateVal == null) { if (stateVal == null) {
return TraceExecutionState.INACTIVE; return TraceExecutionState.INACTIVE;
} }
return TraceExecutionState.valueOf((String) stateVal.getValue()); return TraceExecutionState.valueOf(stateVal.castValue());
} }
} }

View file

@ -16,6 +16,7 @@
package ghidra.trace.model.target.iface; package ghidra.trace.model.target.iface;
import ghidra.trace.model.target.info.TraceObjectInfo; import ghidra.trace.model.target.info.TraceObjectInfo;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
/** /**
* An object that can emit events affecting itself and its successors * An object that can emit events affecting itself and its successors
@ -28,8 +29,11 @@ import ghidra.trace.model.target.info.TraceObjectInfo;
shortName = "event scope", shortName = "event scope",
attributes = { attributes = {
TraceObjectEventScope.KEY_EVENT_THREAD, TraceObjectEventScope.KEY_EVENT_THREAD,
TraceObjectEventScope.KEY_TIME_SUPPORT,
}, },
fixedKeys = {}) fixedKeys = {})
public interface TraceObjectEventScope extends TraceObjectInterface { public interface TraceObjectEventScope extends TraceObjectInterface {
String KEY_EVENT_THREAD = "_event_thread"; String KEY_EVENT_THREAD = "_event_thread";
/** See {@link ScheduleForm} */
String KEY_TIME_SUPPORT = "_time_support";
} }

View file

@ -52,10 +52,25 @@ public interface TraceTimeManager {
* at most one snapshot. * at most one snapshot.
* *
* @param schedule the schedule to find * @param schedule the schedule to find
* @return the snapshot, or {@code null} if no such snapshot exists * @return the snapshots
*/ */
Collection<? extends TraceSnapshot> getSnapshotsWithSchedule(TraceSchedule schedule); Collection<? extends TraceSnapshot> getSnapshotsWithSchedule(TraceSchedule schedule);
/**
* Find or create a the snapshot with the given schedule
*
* <p>
* If a snapshot with the given schedule already exists, this returns the first such snapshot
* found. Ideally, there is exactly one. If this method is consistently used for creating
* scratch snapshots, then that should always be the case. If no such snapshot exists, this
* creates a snapshot with the minimum available negative snapshot key, that is starting at
* {@link Long#MIN_VALUE} and increasing from there.
*
* @param schedule the schedule to find
* @return the snapshot
*/
TraceSnapshot findScratchSnapshot(TraceSchedule schedule);
/** /**
* List all snapshots in the trace * List all snapshots in the trace
* *

View file

@ -30,8 +30,153 @@ import ghidra.util.task.TaskMonitor;
* A sequence of emulator stepping commands, essentially comprising a "point in time." * A sequence of emulator stepping commands, essentially comprising a "point in time."
*/ */
public class TraceSchedule implements Comparable<TraceSchedule> { public class TraceSchedule implements Comparable<TraceSchedule> {
/**
* The initial snapshot (with no steps)
*/
public static final TraceSchedule ZERO = TraceSchedule.snap(0); public static final TraceSchedule ZERO = TraceSchedule.snap(0);
/**
* Specifies forms of a stepping schedule.
*
* <p>
* Each form defines a set of stepping schedules. It happens that each is a subset of the next.
* A {@link #SNAP_ONLY} schedule is also a {@link #SNAP_ANY_STEPS_OPS} schedule, but not
* necessarily vice versa.
*/
public enum ScheduleForm {
/**
* The schedule consists only of a snapshot. No stepping after.
*/
SNAP_ONLY {
@Override
public boolean contains(Trace trace, TraceSchedule schedule) {
return schedule.steps.isNop() && schedule.pSteps.isNop();
}
},
/**
* The schedule consists of a snapshot and some number of instruction steps on the event
* thread only.
*/
SNAP_EVT_STEPS {
@Override
public boolean contains(Trace trace, TraceSchedule schedule) {
if (!schedule.pSteps.isNop()) {
return false;
}
List<Step> steps = schedule.steps.getSteps();
if (steps.isEmpty()) {
return true;
}
if (steps.size() != 1) {
return false;
}
if (!(steps.getFirst() instanceof TickStep ticks)) {
return false;
}
if (ticks.getThreadKey() == -1) {
return true;
}
if (trace == null) {
return false;
}
TraceThread eventThread = schedule.getEventThread(trace);
TraceThread thread = ticks.getThread(trace.getThreadManager(), eventThread);
if (eventThread != thread) {
return false;
}
return true;
}
@Override
public TraceSchedule validate(Trace trace, TraceSchedule schedule) {
if (!schedule.pSteps.isNop()) {
return null;
}
List<Step> steps = schedule.steps.getSteps();
if (steps.isEmpty()) {
return schedule;
}
if (steps.size() != 1) {
return null;
}
if (!(steps.getFirst() instanceof TickStep ticks)) {
return null;
}
if (ticks.getThreadKey() == -1) {
return schedule;
}
if (trace == null) {
return null;
}
TraceThread eventThread = schedule.getEventThread(trace);
TraceThread thread = ticks.getThread(trace.getThreadManager(), eventThread);
if (eventThread != thread) {
return null;
}
return TraceSchedule.snap(schedule.snap).steppedForward(null, ticks.getTickCount());
}
},
/**
* The schedule consists of a snapshot and a sequence of instruction steps on any
* threads(s).
*/
SNAP_ANY_STEPS {
@Override
public boolean contains(Trace trace, TraceSchedule schedule) {
return schedule.pSteps.isNop();
}
},
/**
* The schedule consists of a snapshot and a sequence of instruction steps then p-code op
* steps on any thread(s).
*
* <p>
* This is the most capable form supported by {@link TraceSchedule}.
*/
SNAP_ANY_STEPS_OPS {
@Override
public boolean contains(Trace trace, TraceSchedule schedule) {
return true;
}
};
public static final List<ScheduleForm> VALUES = List.of(ScheduleForm.values());
/**
* Check if the given schedule conforms
*
* @param trace if available, a trace for determining the event thread
* @param schedule the schedule to test
* @return true if the schedule adheres to this form
*/
public abstract boolean contains(Trace trace, TraceSchedule schedule);
/**
* If the given schedule conforms, normalize the schedule to prove it does.
*
* @param trace if available, a trace for determining the event thread
* @param schedule the schedule to test
* @return the non-null normalized schedule, or null if the given schedule does not conform
*/
public TraceSchedule validate(Trace trace, TraceSchedule schedule) {
if (!contains(trace, schedule)) {
return null;
}
return schedule;
}
/**
* Get the more restrictive of this and the given form
*
* @param that the other form
* @return the more restrictive form
*/
public ScheduleForm intersect(ScheduleForm that) {
int ord = Math.min(this.ordinal(), that.ordinal());
return VALUES.get(ord);
}
}
/** /**
* Create a schedule that consists solely of a snapshot * Create a schedule that consists solely of a snapshot
* *
@ -256,7 +401,7 @@ public class TraceSchedule implements Comparable<TraceSchedule> {
* loading a snapshot * loading a snapshot
*/ */
public boolean isSnapOnly() { public boolean isSnapOnly() {
return steps.isNop() && pSteps.isNop(); return ScheduleForm.SNAP_ONLY.contains(null, this);
} }
/** /**

View file

@ -0,0 +1,593 @@
/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package agent;
import static org.junit.Assert.*;
import java.io.*;
import java.net.*;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.hamcrest.Matchers;
import org.junit.*;
import db.NoTransactionException;
import generic.Unique;
import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest;
import ghidra.app.plugin.core.debug.service.tracermi.TraceRmiPlugin;
import ghidra.app.plugin.core.debug.utils.ManagedDomainObject;
import ghidra.app.services.TraceRmiService;
import ghidra.debug.api.tracermi.*;
import ghidra.framework.*;
import ghidra.framework.main.ApplicationLevelOnlyPlugin;
import ghidra.framework.model.DomainFile;
import ghidra.framework.plugintool.Plugin;
import ghidra.framework.plugintool.PluginsConfiguration;
import ghidra.framework.plugintool.util.*;
import ghidra.pty.*;
import ghidra.pty.PtyChild.Echo;
import ghidra.pty.testutil.DummyProc;
import ghidra.trace.model.Trace;
import ghidra.trace.model.target.schema.PrimitiveTraceObjectSchema.MinimalSchemaContext;
import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName;
import ghidra.trace.model.time.TraceSnapshot;
import ghidra.util.Msg;
import ghidra.util.SystemUtilities;
public class TraceRmiPythonClientTest extends AbstractGhidraHeadedDebuggerTest {
public static final String PREAMBLE =
"""
import socket
from typing import Annotated, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from ghidratrace import sch
from ghidratrace.client import (Client, Address, AddressRange, TraceObject,
MethodRegistry, Schedule, TraceRmiError, ParamDesc)
registry = MethodRegistry(ThreadPoolExecutor())
def connect(addr):
cs = socket.socket()
cs.connect(addr)
return Client(cs, "test-client", registry)
""";
protected static final int CONNECT_TIMEOUT_MS = 3000;
protected static final int TIMEOUT_SECONDS = 10;
protected static final int QUIT_TIMEOUT_MS = 1000;
protected TraceRmiService traceRmi;
private Path pathToPython;
@BeforeClass
public static void setupPython() throws Throwable {
if (SystemUtilities.isInTestingBatchMode()) {
return; // gradle should already have done it
}
String gradle = switch (OperatingSystem.CURRENT_OPERATING_SYSTEM) {
case WINDOWS -> DummyProc.which("gradle.bat");
default -> "gradle";
};
assertEquals(0, new ProcessBuilder(gradle, "Debugger-rmi-trace:assemblePyPackage")
.directory(TestApplicationUtils.getInstallationDirectory())
.inheritIO()
.start()
.waitFor());
}
protected void setPythonPath(Map<String, String> env) throws IOException {
String sep =
OperatingSystem.CURRENT_OPERATING_SYSTEM == OperatingSystem.WINDOWS ? ";" : ":";
String rmiPyPkg = Application.getModuleSubDirectory("Debugger-rmi-trace",
"build/pypkg/src").getAbsolutePath();
String add = rmiPyPkg;
env.compute("PYTHONPATH", (k, v) -> v == null ? add : (v + sep + add));
}
protected Path getPathToPython() {
return Paths.get(DummyProc.which("python"));
}
@Before
public void setupTraceRmi() throws Throwable {
traceRmi = addPlugin(tool, TraceRmiPlugin.class);
pathToPython = getPathToPython();
}
protected void addAllDebuggerPlugins() throws PluginException {
PluginsConfiguration plugConf = new PluginsConfiguration() {
@Override
protected boolean accepts(Class<? extends Plugin> pluginClass) {
return !ApplicationLevelOnlyPlugin.class.isAssignableFrom(pluginClass);
}
};
for (PluginDescription pd : plugConf
.getPluginDescriptions(PluginPackage.getPluginPackage("Debugger"))) {
addPlugin(tool, pd.getPluginClass());
}
}
protected static String addrToStringForPython(InetAddress address) {
if (address.isAnyLocalAddress()) {
return "127.0.0.1"; // Can't connect to 0.0.0.0 as such. Choose localhost.
}
return address.getHostAddress();
}
protected static String sockToStringForPython(SocketAddress address) {
if (address instanceof InetSocketAddress tcp) {
return "('%s', %d)".formatted(addrToStringForPython(tcp.getAddress()), tcp.getPort());
}
throw new AssertionError("Unhandled address type " + address);
}
protected static class PyError extends RuntimeException {
public final int exitCode;
public final String out;
public PyError(int exitCode, String out) {
super("""
exitCode=%d:
----out----
%s
""".formatted(exitCode, out));
this.exitCode = exitCode;
this.out = out;
}
}
protected record PyResult(boolean timedOut, int exitCode, String out) {
protected String handle() {
if (0 != exitCode || out.contains("Traceback")) {
throw new PyError(exitCode, out);
}
return out;
}
}
protected record ExecInPy(PtySession session, PrintWriter stdin,
CompletableFuture<PyResult> future) {
}
@SuppressWarnings("resource") // Do not close stdin
protected ExecInPy execInPy(String script) throws IOException {
Map<String, String> env = new HashMap<>(System.getenv());
setPythonPath(env);
Pty pty = PtyFactory.local().openpty();
PtySession session =
pty.getChild().session(new String[] { pathToPython.toString() }, env, Echo.ON);
ByteArrayOutputStream out = new ByteArrayOutputStream();
new Thread(() -> {
InputStream is = pty.getParent().getInputStream();
byte[] buf = new byte[1024];
while (true) {
try {
int len = is.read(buf);
out.write(buf, 0, len);
System.out.write(buf, 0, len);
}
catch (IOException e) {
System.out.println("<EOF>");
return;
}
}
}).start();
PrintWriter stdin = new PrintWriter(pty.getParent().getOutputStream());
script.lines().forEach(stdin::println); // to transform newlines.
stdin.flush();
return new ExecInPy(session, stdin, CompletableFuture.supplyAsync(() -> {
try {
int exitCode = session.waitExited(TIMEOUT_SECONDS, TimeUnit.SECONDS);
Msg.info(this, "Python exited with code " + exitCode);
return new PyResult(false, exitCode, out.toString());
}
catch (TimeoutException e) {
Msg.error(this, "Timed out waiting for GDB");
session.destroyForcibly();
try {
session.waitExited(TIMEOUT_SECONDS, TimeUnit.SECONDS);
}
catch (InterruptedException | TimeoutException e1) {
throw new AssertionError(e1);
}
return new PyResult(true, -1, out.toString());
}
catch (Exception e) {
return ExceptionUtils.rethrow(e);
}
finally {
session.destroyForcibly();
}
}));
}
protected String runThrowError(String script) throws Exception {
CompletableFuture<PyResult> result = execInPy(script).future;
return result.get(TIMEOUT_SECONDS, TimeUnit.SECONDS).handle();
}
protected record PyAndConnection(ExecInPy exec, TraceRmiConnection connection)
implements AutoCloseable {
protected RemoteMethod getMethod(String name) {
return Objects.requireNonNull(connection.getMethods().get(name));
}
@Override
public void close() throws Exception {
Msg.info(this, "Cleaning up python");
try {
exec.stdin.println("exit()");
exec.stdin.close();
PyResult r = exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
r.handle();
waitForPass(() -> assertTrue(connection.isClosed()));
}
finally {
exec.stdin.close();
exec.session.destroyForcibly();
}
}
}
protected PyAndConnection startAndConnectPy(Function<String, String> scriptSupplier)
throws Exception {
TraceRmiAcceptor acceptor = traceRmi.acceptOne(null);
ExecInPy exec =
execInPy(scriptSupplier.apply(sockToStringForPython(acceptor.getAddress())));
acceptor.setTimeout(CONNECT_TIMEOUT_MS);
try {
TraceRmiConnection connection = acceptor.accept();
return new PyAndConnection(exec, connection);
}
catch (SocketTimeoutException e) {
Msg.error(this, "Timed out waiting for client connection");
exec.session.destroyForcibly();
exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS).handle();
throw e;
}
}
protected PyAndConnection startAndConnectPy() throws Exception {
return startAndConnectPy(addr -> """
%s
c = connect(%s)
""".formatted(PREAMBLE, addr));
}
@SuppressWarnings("resource")
protected String runThrowError(Function<String, String> scriptSupplier)
throws Exception {
PyAndConnection conn = startAndConnectPy(scriptSupplier);
PyResult r = conn.exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
String stdout = r.handle();
waitForPass(() -> assertTrue(conn.connection.isClosed()));
return stdout;
}
protected ManagedDomainObject openDomainObject(String path) throws Exception {
DomainFile df = env.getProject().getProjectData().getFile(path);
assertNotNull(df);
return new ManagedDomainObject(df, false, false, monitor);
}
protected ManagedDomainObject waitDomainObject(String path) throws Exception {
DomainFile df;
long start = System.currentTimeMillis();
while (true) {
df = env.getProject().getProjectData().getFile(path);
if (df != null) {
return new ManagedDomainObject(df, false, false, monitor);
}
Thread.sleep(1000);
if (System.currentTimeMillis() - start > 30000) {
throw new TimeoutException("30 seconds expired waiting for domain file");
}
}
}
protected void waitTxDone() {
waitFor(() -> tb.trace.getCurrentTransactionInfo() == null);
}
@Test
public void testConnect() throws Exception {
runThrowError(addr -> """
%s
c = connect(%s)
exit()
""".formatted(PREAMBLE, addr));
}
@Test
public void testClose() throws Exception {
runThrowError(addr -> """
%s
c = connect(%s)
c.close()
exit()
""".formatted(PREAMBLE, addr));
}
@Test
public void testCreateTrace() throws Exception {
runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
print(trace)
exit()
""".formatted(PREAMBLE, addr));
try (ManagedDomainObject obj = openDomainObject("/New Traces/test")) {
switch (obj.get()) {
case Trace trace -> {
}
default -> fail("Wrong type");
}
}
}
@Test
public void testMethodRegistrationAndInvocation() throws Exception {
try (PyAndConnection pac = startAndConnectPy(addr -> """
%s
@registry.method()
def py_eval(expr: str) -> str:
return repr(eval(expr))
c = connect(%s)
""".formatted(PREAMBLE, addr))) {
RemoteMethod pyEval = pac.getMethod("py_eval");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(pyEval.retType()).getType());
assertEquals("expr", Unique.assertOne(pyEval.parameters().keySet()));
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(pyEval.parameters().get("expr").type())
.getType());
String result = (String) pyEval.invoke(Map.ofEntries(
Map.entry("expr", "c")));
assertThat(result, Matchers.startsWith("<ghidratrace.Client <socket.socket"));
}
}
@Test
public void testRegisterAnnotated() throws Exception {
try (PyAndConnection pac = startAndConnectPy(addr -> """
%s
@registry.method()
def py_eval(expr: Annotated[str, ParamDesc(display="Expression")]) -> Annotated[
Any, ParamDesc(schema=sch.STRING)]:
return repr(eval(expr))
c = connect(%s)
""".formatted(PREAMBLE, addr))) {
RemoteMethod pyEval = pac.getMethod("py_eval");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(pyEval.retType()).getType());
assertEquals("expr", Unique.assertOne(pyEval.parameters().keySet()));
RemoteParameter param = pyEval.parameters().get("expr");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(param.type()).getType());
assertEquals("Expression", param.display());
String result = (String) pyEval.invoke(Map.ofEntries(
Map.entry("expr", "c")));
assertThat(result, Matchers.startsWith("<ghidratrace.Client <socket.socket"));
}
}
@Test
public void testRegisterOptional() throws Exception {
try (PyAndConnection pac = startAndConnectPy(addr -> """
%s
@registry.method()
def py_eval(expr: Optional[str]) -> Optional[str]:
return repr(eval(expr))
c = connect(%s)
""".formatted(PREAMBLE, addr))) {
RemoteMethod pyEval = pac.getMethod("py_eval");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(pyEval.retType()).getType());
assertEquals("expr", Unique.assertOne(pyEval.parameters().keySet()));
RemoteParameter param = pyEval.parameters().get("expr");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(param.type()).getType());
String result = (String) pyEval.invoke(Map.ofEntries(
Map.entry("expr", "c")));
assertThat(result, Matchers.startsWith("<ghidratrace.Client <socket.socket"));
}
}
@Test
public void testRegisterObject() throws Exception {
try (PyAndConnection pac = startAndConnectPy(addr -> """
%s
class Session(TraceObject):
pass
@registry.method()
def py_eval(session: Session, expr: str) -> str:
return repr(eval(expr))
c = connect(%s)
""".formatted(PREAMBLE, addr))) {
RemoteMethod pyEval = pac.getMethod("py_eval");
assertEquals(String.class,
MinimalSchemaContext.INSTANCE.getSchema(pyEval.retType()).getType());
assertEquals(Set.of("session", "expr"), pyEval.parameters().keySet());
RemoteParameter param = pyEval.parameters().get("session");
assertEquals(new SchemaName("Session"), param.type());
}
}
@Test
public void testRegisterObjectBad() throws Exception {
String out = runThrowError(addr -> """
%s
c = connect(%s)
class Session(object):
pass
def py_eval(session: Session, expr: str) -> str:
return repr(eval(expr))
try:
registry.method()(py_eval)
except TypeError as e:
print(f"---ERR:{e}---")
exit()
""".formatted(PREAMBLE, addr));
assertThat(out, Matchers.containsString(
"---ERR:Cannot get schema for <class '__main__.Session'>---"));
}
@Test
public void testSnapshotDefaultNoTx() throws Exception {
String out = runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
try:
trace.snapshot("Test")
raise Exception("Expected error")
except TraceRmiError as e:
print(f"---ERR:{e}---")
exit()
""".formatted(PREAMBLE, addr));
assertThat(out,
Matchers.containsString("---ERR:%s".formatted(NoTransactionException.class.getName())));
}
@Test
public void testSnapshotDefault() throws Exception {
runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
with trace.open_tx("Create snapshot") as tx:
trace.snapshot("Test")
exit()
""".formatted(PREAMBLE, addr));
try (ManagedDomainObject obj = openDomainObject("/New Traces/test")) {
Trace trace = (Trace) obj.get();
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(0, false);
assertEquals("Test", snapshot.getDescription());
}
}
@Test
public void testSnapshotSnapOnly() throws Exception {
runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
with trace.open_tx("Create snapshot") as tx:
trace.snapshot("Test", time=Schedule(10, 0))
exit()
""".formatted(PREAMBLE, addr));
try (ManagedDomainObject obj = openDomainObject("/New Traces/test")) {
Trace trace = (Trace) obj.get();
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(10, false);
assertEquals("Test", snapshot.getDescription());
}
}
protected Matcher matchOne(String out, Pattern pat) {
return Unique.assertOne(out.lines().map(pat::matcher).filter(Matcher::find));
}
@Test
public void testSnapshotSchedule() throws Exception {
String out = runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
with trace.open_tx("Create snapshot") as tx:
snap = trace.snapshot("Test", time=Schedule(10, 500))
print(f"---SNAP:{snap}---")
exit()
""".formatted(PREAMBLE, addr));
long snap = Long.parseLong(matchOne(out, Pattern.compile("---SNAP:(-?\\d*)---")).group(1));
assertThat(snap, Matchers.lessThan(0L));
try (ManagedDomainObject obj = openDomainObject("/New Traces/test")) {
Trace trace = (Trace) obj.get();
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap, false);
assertEquals("Test", snapshot.getDescription());
}
}
@Test
public void testSnapshotScheduleInBatch() throws Exception {
String out = runThrowError(addr -> """
%s
c = connect(%s)
trace = c.create_trace("/test", "DATA:LE:64:default", "pointer64", extra=None)
with trace.open_tx("Create snapshot") as tx:
with c.batch() as b:
snap = trace.snapshot("Test", time=Schedule(10, 500))
print(f"---SNAP:{snap}---")
exit()
""".formatted(PREAMBLE, addr));
long snap = Long.parseLong(matchOne(out, Pattern.compile("---SNAP:(-?\\d*)---")).group(1));
assertThat(snap, Matchers.lessThan(0L));
try (ManagedDomainObject obj = openDomainObject("/New Traces/test")) {
Trace trace = (Trace) obj.get();
TraceSnapshot snapshot = trace.getTimeManager().getSnapshot(snap, false);
assertEquals("Test", snapshot.getDescription());
}
}
}

View file

@ -48,8 +48,7 @@ import ghidra.trace.model.breakpoint.TraceBreakpointKind;
import ghidra.trace.model.breakpoint.TraceBreakpointKind.TraceBreakpointKindSet; import ghidra.trace.model.breakpoint.TraceBreakpointKind.TraceBreakpointKindSet;
import ghidra.trace.model.target.TraceObject; import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.TraceObjectValue; import ghidra.trace.model.target.TraceObjectValue;
import ghidra.util.Msg; import ghidra.util.*;
import ghidra.util.NumericUtilities;
public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDebuggerTest { public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDebuggerTest {
/** /**
@ -58,6 +57,7 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
*/ */
public static final String PREAMBLE = """ public static final String PREAMBLE = """
from ghidradbg.commands import * from ghidradbg.commands import *
from ghidratrace.client import Schedule
"""; """;
// Connecting should be the first thing the script does, so use a tight timeout. // Connecting should be the first thing the script does, so use a tight timeout.
protected static final int CONNECT_TIMEOUT_MS = 3000; protected static final int CONNECT_TIMEOUT_MS = 3000;
@ -111,14 +111,18 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
assumeTrue(OperatingSystem.CURRENT_OPERATING_SYSTEM == OperatingSystem.WINDOWS); assumeTrue(OperatingSystem.CURRENT_OPERATING_SYSTEM == OperatingSystem.WINDOWS);
} }
//@BeforeClass @BeforeClass
public static void setupPython() throws Throwable { public static void setupPython() throws Throwable {
if (didSetupPython) { if (didSetupPython) {
// Only do this once when running the full suite. // Only do this once when running the full suite.
return; return;
} }
if (SystemUtilities.isInTestingBatchMode()) {
// Don't run gradle in gradle. It already did this task.
return;
}
String gradle = DummyProc.which("gradle.bat"); String gradle = DummyProc.which("gradle.bat");
new ProcessBuilder(gradle, "Debugger-agent-dbgeng:assemblePyPackage") new ProcessBuilder(gradle, "assemblePyPackage")
.directory(TestApplicationUtils.getInstallationDirectory()) .directory(TestApplicationUtils.getInstallationDirectory())
.inheritIO() .inheritIO()
.start() .start()
@ -137,6 +141,10 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
pb.environment().compute("PYTHONPATH", (k, v) -> v == null ? add : (v + sep + add)); pb.environment().compute("PYTHONPATH", (k, v) -> v == null ? add : (v + sep + add));
} }
protected void setWindbgPath(ProcessBuilder pb) throws IOException {
pb.environment().put("WINDBG_DIR", "C:\\Program Files\\Amazon Corretto\\jdk21.0.3_9\\bin");
}
@Before @Before
public void setupTraceRmi() throws Throwable { public void setupTraceRmi() throws Throwable {
traceRmi = addPlugin(tool, TraceRmiPlugin.class); traceRmi = addPlugin(tool, TraceRmiPlugin.class);
@ -147,6 +155,9 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
catch (RuntimeException e) { catch (RuntimeException e) {
pythonPath = Paths.get(DummyProc.which("python")); pythonPath = Paths.get(DummyProc.which("python"));
} }
pythonPath = new File("/C:/Python313/python.exe").toPath();
assertTrue(pythonPath.toFile().exists());
outFile = Files.createTempFile("pydbgout", null); outFile = Files.createTempFile("pydbgout", null);
errFile = Files.createTempFile("pydbgerr", null); errFile = Files.createTempFile("pydbgerr", null);
} }
@ -194,10 +205,49 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
protected record ExecInPython(Process python, CompletableFuture<PythonResult> future) {} protected record ExecInPython(Process python, CompletableFuture<PythonResult> future) {}
protected void pump(InputStream streamIn, OutputStream streamOut) {
Thread t = new Thread(() -> {
try (PrintStream printOut = new PrintStream(streamOut);
BufferedReader reader = new BufferedReader(new InputStreamReader(streamIn))) {
String line;
while ((line = reader.readLine()) != null) {
printOut.println(line);
printOut.flush();
}
}
catch (IOException e) {
Msg.info(this, "Terminating stdin pump, because " + e);
}
});
t.setDaemon(true);
t.start();
}
protected void pumpTee(InputStream streamIn, File fileOut, PrintStream streamOut) {
Thread t = new Thread(() -> {
try (PrintStream fileStream = new PrintStream(fileOut);
BufferedReader reader = new BufferedReader(new InputStreamReader(streamIn))) {
String line;
while ((line = reader.readLine()) != null) {
streamOut.println(line);
streamOut.flush();
fileStream.println(line);
fileStream.flush();
}
}
catch (IOException e) {
Msg.info(this, "Terminating tee: " + fileOut + ", because " + e);
}
});
t.setDaemon(true);
t.start();
}
@SuppressWarnings("resource") // Do not close stdin @SuppressWarnings("resource") // Do not close stdin
protected ExecInPython execInPython(String script) throws IOException { protected ExecInPython execInPython(String script) throws IOException {
ProcessBuilder pb = new ProcessBuilder(pythonPath.toString(), "-i"); ProcessBuilder pb = new ProcessBuilder(pythonPath.toString(), "-i");
setPythonPath(pb); setPythonPath(pb);
setWindbgPath(pb);
// If commands come from file, Python will quit after EOF. // If commands come from file, Python will quit after EOF.
Msg.info(this, "outFile: " + outFile); Msg.info(this, "outFile: " + outFile);
@ -205,13 +255,29 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
//pb.inheritIO(); //pb.inheritIO();
pb.redirectInput(ProcessBuilder.Redirect.PIPE); pb.redirectInput(ProcessBuilder.Redirect.PIPE);
if (SystemUtilities.isInTestingBatchMode()) {
pb.redirectOutput(outFile.toFile()); pb.redirectOutput(outFile.toFile());
pb.redirectError(errFile.toFile()); pb.redirectError(errFile.toFile());
}
else {
pb.redirectOutput(ProcessBuilder.Redirect.PIPE);
pb.redirectError(ProcessBuilder.Redirect.PIPE);
}
Process pyproc = pb.start(); Process pyproc = pb.start();
if (!SystemUtilities.isInTestingBatchMode()) {
pumpTee(pyproc.getInputStream(), outFile.toFile(), System.out);
pumpTee(pyproc.getErrorStream(), errFile.toFile(), System.err);
}
OutputStream stdin = pyproc.getOutputStream(); OutputStream stdin = pyproc.getOutputStream();
stdin.write(script.getBytes()); stdin.write(script.getBytes());
stdin.flush(); stdin.flush();
//stdin.close();
if (!SystemUtilities.isInTestingBatchMode()) {
pump(System.in, stdin);
}
return new ExecInPython(pyproc, CompletableFuture.supplyAsync(() -> { return new ExecInPython(pyproc, CompletableFuture.supplyAsync(() -> {
try { try {
if (!pyproc.waitFor(TIMEOUT_SECONDS, TimeUnit.SECONDS)) { if (!pyproc.waitFor(TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
@ -286,7 +352,8 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
try { try {
PythonResult r = exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS); PythonResult r = exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
r.handle(); r.handle();
waitForPass(() -> assertTrue(connection.isClosed())); waitForPass(this, () -> assertTrue(connection.isClosed()),
TIMEOUT_SECONDS, TimeUnit.SECONDS);
} }
finally { finally {
exec.python.destroyForcibly(); exec.python.destroyForcibly();
@ -324,18 +391,21 @@ public abstract class AbstractDbgEngTraceRmiTest extends AbstractGhidraHeadedDeb
PythonAndConnection conn = startAndConnectPython(scriptSupplier); PythonAndConnection conn = startAndConnectPython(scriptSupplier);
PythonResult r = conn.exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS); PythonResult r = conn.exec.future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
String stdout = r.handle(); String stdout = r.handle();
waitForPass(() -> assertTrue(conn.connection.isClosed())); waitForPass(this, () -> assertTrue(conn.connection.isClosed()),
TIMEOUT_SECONDS, TimeUnit.SECONDS);
return stdout; return stdout;
} }
protected void waitStopped(String message) { protected void waitStopped(String message) {
TraceObject proc = Objects.requireNonNull(tb.objAny("Processes[]", Lifespan.at(0))); TraceObject proc =
Objects.requireNonNull(tb.objAny("Sessions[].Processes[]", Lifespan.at(0)));
waitForPass(() -> assertEquals(message, "STOPPED", tb.objValue(proc, 0, "_state"))); waitForPass(() -> assertEquals(message, "STOPPED", tb.objValue(proc, 0, "_state")));
waitTxDone(); waitTxDone();
} }
protected void waitRunning(String message) { protected void waitRunning(String message) {
TraceObject proc = Objects.requireNonNull(tb.objAny("Processes[]", Lifespan.at(0))); TraceObject proc =
Objects.requireNonNull(tb.objAny("Sessions[].Processes[]", Lifespan.at(0)));
waitForPass(() -> assertEquals(message, "RUNNING", tb.objValue(proc, 0, "_state"))); waitForPass(() -> assertEquals(message, "RUNNING", tb.objValue(proc, 0, "_state")));
waitTxDone(); waitTxDone();
} }

View file

@ -108,7 +108,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
quit() quit()
""".formatted(PREAMBLE, addr)); """.formatted(PREAMBLE, addr));
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
@ -139,7 +139,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
addr -> """ addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe', start_trace=False) ghidra_trace_create('notepad.exe', start_trace=False, wait=True)
util.set_convenience_variable('ghidra-language','Toy:BE:64:default') util.set_convenience_variable('ghidra-language','Toy:BE:64:default')
util.set_convenience_variable('ghidra-compiler','default') util.set_convenience_variable('ghidra-compiler','default')
ghidra_trace_start('myToy') ghidra_trace_start('myToy')
@ -163,7 +163,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_stop() ghidra_trace_stop()
quit() quit()
""".formatted(PREAMBLE, addr)); """.formatted(PREAMBLE, addr));
@ -188,7 +188,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
print('---Connect---') print('---Connect---')
ghidra_trace_info() ghidra_trace_info()
print('---Create---') print('---Create---')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
print('---Start---') print('---Start---')
ghidra_trace_info() ghidra_trace_info()
ghidra_trace_stop() ghidra_trace_stop()
@ -233,7 +233,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
print('---Import---') print('---Import---')
ghidra_trace_info_lcsp() ghidra_trace_info_lcsp()
print('---Create---') print('---Create---')
ghidra_trace_create('notepad.exe', start_trace=False) ghidra_trace_create('notepad.exe', start_trace=False, wait=True)
print('---File---') print('---File---')
ghidra_trace_info_lcsp() ghidra_trace_info_lcsp()
util.set_convenience_variable('ghidra-language','Toy:BE:64:default') util.set_convenience_variable('ghidra-language','Toy:BE:64:default')
@ -250,6 +250,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
Selected Ghidra compiler: windows""", Selected Ghidra compiler: windows""",
extractOutSection(out, "---File---").replaceAll("\r", "")); extractOutSection(out, "---File---").replaceAll("\r", ""));
assertEquals(""" assertEquals("""
Toy:BE:64:default not found in compiler map
Selected Ghidra language: Toy:BE:64:default Selected Ghidra language: Toy:BE:64:default
Selected Ghidra compiler: default""", Selected Ghidra compiler: default""",
extractOutSection(out, "---Language---").replaceAll("\r", "")); extractOutSection(out, "---Language---").replaceAll("\r", ""));
@ -267,7 +268,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -282,7 +283,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -300,7 +301,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -319,7 +320,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
pc = util.get_pc() pc = util.get_pc()
@ -348,7 +349,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
pc = util.get_pc() pc = util.get_pc()
@ -380,7 +381,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
pc = util.get_pc() pc = util.get_pc()
@ -414,7 +415,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
util.dbg.cmd('r rax=0xdeadbeef') util.dbg.cmd('r rax=0xdeadbeef')
util.dbg.cmd('r st0=1.5') util.dbg.cmd('r st0=1.5')
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
@ -429,7 +430,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
long snap = Unique.assertOne(tb.trace.getTimeManager().getAllSnapshots()).getKey(); long snap = Unique.assertOne(tb.trace.getTimeManager().getAllSnapshots()).getKey();
List<TraceObjectValue> regVals = tb.trace.getObjectManager() List<TraceObjectValue> regVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Threads[].Registers")) PathFilter.parse("Sessions[].Processes[].Threads[].Registers"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.toList(); .toList();
TraceObjectValue tobj = regVals.get(0); TraceObjectValue tobj = regVals.get(0);
@ -440,6 +441,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
RegisterValue rax = regs.getValue(snap, tb.reg("rax")); RegisterValue rax = regs.getValue(snap, tb.reg("rax"));
assertEquals("deadbeef", rax.getUnsignedValue().toString(16)); assertEquals("deadbeef", rax.getUnsignedValue().toString(16));
@SuppressWarnings("unused") // not yet
TraceData st0; TraceData st0;
try (Transaction tx = tb.trace.openTransaction("Float80 unit")) { try (Transaction tx = tb.trace.openTransaction("Float80 unit")) {
TraceCodeSpace code = tb.trace.getCodeManager().getCodeSpace(t1f0, true); TraceCodeSpace code = tb.trace.getCodeManager().getCodeSpace(t1f0, true);
@ -460,7 +462,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
util.dbg.cmd('r rax=0xdeadbeef') util.dbg.cmd('r rax=0xdeadbeef')
ghidra_trace_txstart('Create snapshot') ghidra_trace_txstart('Create snapshot')
ghidra_trace_new_snap('Scripted snapshot') ghidra_trace_new_snap('Scripted snapshot')
@ -476,7 +478,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
long snap = Unique.assertOne(tb.trace.getTimeManager().getAllSnapshots()).getKey(); long snap = Unique.assertOne(tb.trace.getTimeManager().getAllSnapshots()).getKey();
List<TraceObjectValue> regVals = tb.trace.getObjectManager() List<TraceObjectValue> regVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Threads[].Registers")) PathFilter.parse("Sessions[].Processes[].Threads[].Registers"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.toList(); .toList();
TraceObjectValue tobj = regVals.get(0); TraceObjectValue tobj = regVals.get(0);
@ -544,11 +546,11 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
ghidra_trace_set_snap(1) ghidra_trace_new_snap(time=Schedule(1))
ghidra_trace_remove_obj('Test.Objects[1]') ghidra_trace_remove_obj('Test.Objects[1]')
ghidra_trace_txcommit() ghidra_trace_txcommit()
ghidra_trace_kill() ghidra_trace_kill()
@ -570,7 +572,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
@ -721,14 +723,14 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
ghidra_trace_set_value('Test.Objects[1]', '[1]', '"A"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[1]', '"A"', 'STRING')
ghidra_trace_set_value('Test.Objects[1]', '[2]', '"B"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[2]', '"B"', 'STRING')
ghidra_trace_set_value('Test.Objects[1]', '[3]', '"C"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[3]', '"C"', 'STRING')
ghidra_trace_set_snap(10) ghidra_trace_new_snap(time=Schedule(10))
ghidra_trace_retain_values('Test.Objects[1]', '[1] [3]') ghidra_trace_retain_values('Test.Objects[1]', '[1] [3]')
ghidra_trace_txcommit() ghidra_trace_txcommit()
ghidra_trace_kill() ghidra_trace_kill()
@ -770,7 +772,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
TraceObject object = tb.trace.getObjectManager() TraceObject object = tb.trace.getObjectManager()
.getObjectByCanonicalPath(KeyPath.parse("Test.Objects[1]")); .getObjectByCanonicalPath(KeyPath.parse("Test.Objects[1]"));
assertNotNull(object); assertNotNull(object);
assertEquals("1\tTest.Objects[1]", extractOutSection(out, "---GetObject---")); assertEquals("3\tTest.Objects[1]", extractOutSection(out, "---GetObject---"));
} }
} }
@ -779,7 +781,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
@ -840,7 +842,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
@ -866,7 +868,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
#set language c++ #set language c++
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
@ -888,7 +890,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
String out = runThrowError(addr -> """ String out = runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
pc = util.get_pc() pc = util.get_pc()
ghidra_trace_putmem(pc, 16) ghidra_trace_putmem(pc, 16)
@ -950,7 +952,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
Collection<TraceObject> available = tb.trace.getObjectManager() Collection<TraceObject> available = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), PathFilter.parse("Available[]")) .getValuePaths(Lifespan.at(0), PathFilter.parse("Sessions[].Available[]"))
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
.toList(); .toList();
assertThat(available.size(), greaterThan(2)); assertThat(available.size(), greaterThan(2));
@ -962,7 +964,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
pc = util.get_pc() pc = util.get_pc()
util.dbg.bp(expr=pc) util.dbg.bp(expr=pc)
util.dbg.ba(expr=pc+4) util.dbg.ba(expr=pc+4)
@ -976,7 +978,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
List<TraceObjectValue> procBreakLocVals = tb.trace.getObjectManager() List<TraceObjectValue> procBreakLocVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Breakpoints[]")) PathFilter.parse("Sessions[].Processes[].Debug.Breakpoints[]"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.sorted(Comparator.comparing(TraceObjectValue::getEntryKey)) .sorted(Comparator.comparing(TraceObjectValue::getEntryKey))
.toList(); .toList();
@ -999,7 +1001,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
pc = util.get_pc() pc = util.get_pc()
util.dbg.ba(expr=pc, access=DbgEng.DEBUG_BREAK_EXECUTE) util.dbg.ba(expr=pc, access=DbgEng.DEBUG_BREAK_EXECUTE)
@ -1014,7 +1016,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
List<TraceObjectValue> procBreakVals = tb.trace.getObjectManager() List<TraceObjectValue> procBreakVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Breakpoints[]")) PathFilter.parse("Sessions[].Processes[].Debug.Breakpoints[]"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.sorted(Comparator.comparing(TraceObjectValue::getEntryKey)) .sorted(Comparator.comparing(TraceObjectValue::getEntryKey))
.toList(); .toList();
@ -1043,7 +1045,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
ghidra_trace_put_environment() ghidra_trace_put_environment()
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -1054,7 +1056,8 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
// Assumes LLDB on Linux amd64 // Assumes LLDB on Linux amd64
TraceObject env = TraceObject env =
Objects.requireNonNull(tb.objAny("Processes[].Environment", Lifespan.at(0))); Objects.requireNonNull(
tb.objAny("Sessions[].Processes[].Environment", Lifespan.at(0)));
assertEquals("pydbg", env.getValue(0, "_debugger").getValue()); assertEquals("pydbg", env.getValue(0, "_debugger").getValue());
assertEquals("x86_64", env.getValue(0, "_arch").getValue()); assertEquals("x86_64", env.getValue(0, "_arch").getValue());
assertEquals("windows", env.getValue(0, "_os").getValue()); assertEquals("windows", env.getValue(0, "_os").getValue());
@ -1067,7 +1070,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
ghidra_trace_put_regions() ghidra_trace_put_regions()
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -1088,7 +1091,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
ghidra_trace_put_modules() ghidra_trace_put_modules()
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -1110,7 +1113,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
ghidra_trace_put_threads() ghidra_trace_put_threads()
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -1130,7 +1133,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra_trace_connect('%s') ghidra_trace_connect('%s')
ghidra_trace_create('notepad.exe') ghidra_trace_create('notepad.exe', wait=True)
ghidra_trace_txstart('Tx') ghidra_trace_txstart('Tx')
ghidra_trace_put_frames() ghidra_trace_put_frames()
ghidra_trace_txcommit() ghidra_trace_txcommit()
@ -1142,7 +1145,7 @@ public class DbgEngCommandsTest extends AbstractDbgEngTraceRmiTest {
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
List<TraceObject> stack = tb.trace.getObjectManager() List<TraceObject> stack = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[0].Threads[0].Stack[]")) PathFilter.parse("Sessions[0].Processes[0].Threads[0].Stack.Frames[]"))
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
.toList(); .toList();
assertThat(stack.size(), greaterThan(2)); assertThat(stack.size(), greaterThan(2));

View file

@ -86,31 +86,32 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
return conn.conn.connection().getLastSnapshot(tb.trace); return conn.conn.connection().getLastSnapshot(tb.trace);
} }
static final int INIT_NOTEPAD_THREAD_COUNT = 4; // This could be fragile
@Test @Test
public void testOnNewThread() throws Exception { public void testOnNewThread() throws Exception {
final int INIT_NOTEPAD_THREAD_COUNT = 4; // This could be fragile
try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) { try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) {
conn.execute("from ghidradbg.commands import *"); conn.execute("from ghidradbg.commands import *");
txPut(conn, "processes"); txPut(conn, "processes");
waitForPass(() -> { waitForPass(() -> {
TraceObject proc = tb.objAny0("Processes[]"); TraceObject proc = tb.objAny0("Sessions[].Processes[]");
assertNotNull(proc); assertNotNull(proc);
assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state")); assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state"));
}, RUN_TIMEOUT_MS, RETRY_MS); }, RUN_TIMEOUT_MS, RETRY_MS);
txPut(conn, "threads"); txPut(conn, "threads");
waitForPass(() -> assertEquals(INIT_NOTEPAD_THREAD_COUNT, waitForPass(() -> assertEquals(INIT_NOTEPAD_THREAD_COUNT,
tb.objValues(lastSnap(conn), "Processes[].Threads[]").size()), tb.objValues(lastSnap(conn), "Sessions[].Processes[].Threads[]").size()),
RUN_TIMEOUT_MS, RETRY_MS); RUN_TIMEOUT_MS, RETRY_MS);
// Via method, go is asynchronous // Via method, go is asynchronous
RemoteMethod go = conn.conn.getMethod("go"); RemoteMethod go = conn.conn.getMethod("go");
TraceObject proc = tb.objAny0("Processes[]"); TraceObject proc = tb.objAny0("Sessions[].Processes[]");
go.invoke(Map.of("process", proc)); go.invoke(Map.of("process", proc));
waitForPass( waitForPass(() -> assertThat(
() -> assertThat(tb.objValues(lastSnap(conn), "Processes[].Threads[]").size(), tb.objValues(lastSnap(conn), "Sessions[].Processes[].Threads[]").size(),
greaterThan(INIT_NOTEPAD_THREAD_COUNT)), greaterThan(INIT_NOTEPAD_THREAD_COUNT)),
RUN_TIMEOUT_MS, RETRY_MS); RUN_TIMEOUT_MS, RETRY_MS);
} }
@ -122,14 +123,14 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
txPut(conn, "processes"); txPut(conn, "processes");
waitForPass(() -> { waitForPass(() -> {
TraceObject proc = tb.obj("Processes[0]"); TraceObject proc = tb.obj("Sessions[0].Processes[0]");
assertNotNull(proc); assertNotNull(proc);
assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state")); assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state"));
}, RUN_TIMEOUT_MS, RETRY_MS); }, RUN_TIMEOUT_MS, RETRY_MS);
txPut(conn, "threads"); txPut(conn, "threads");
waitForPass(() -> assertEquals(4, waitForPass(() -> assertEquals(4,
tb.objValues(lastSnap(conn), "Processes[0].Threads[]").size()), tb.objValues(lastSnap(conn), "Sessions[0].Processes[0].Threads[]").size()),
RUN_TIMEOUT_MS, RETRY_MS); RUN_TIMEOUT_MS, RETRY_MS);
// Now the real test // Now the real test
@ -138,7 +139,8 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
waitForPass(() -> { waitForPass(() -> {
String tnum = conn.executeCapture("print(util.selected_thread())").strip(); String tnum = conn.executeCapture("print(util.selected_thread())").strip();
assertEquals("1", tnum); assertEquals("1", tnum);
assertEquals(tb.obj("Processes[0].Threads[1]"), traceManager.getCurrentObject()); String threadIndex = threadIndex(traceManager.getCurrentObject());
assertEquals("1", threadIndex);
}, RUN_TIMEOUT_MS, RETRY_MS); }, RUN_TIMEOUT_MS, RETRY_MS);
conn.execute("util.select_thread(2)"); conn.execute("util.select_thread(2)");
@ -182,11 +184,11 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
} }
protected String threadIndex(TraceObject object) { protected String threadIndex(TraceObject object) {
return getIndex(object, "Processes[].Threads[]", 1); return getIndex(object, "Sessions[].Processes[].Threads[]", 2);
} }
protected String frameIndex(TraceObject object) { protected String frameIndex(TraceObject object) {
return getIndex(object, "Processes[].Threads[].Stack[]", 2); return getIndex(object, "Sessions[].Processes[].Threads[].Stack.Frames[]", 3);
} }
@Test @Test
@ -247,7 +249,7 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
conn.execute("ghidra_trace_txcommit()"); conn.execute("ghidra_trace_txcommit()");
conn.execute("util.dbg.cmd('r rax=0x1234')"); conn.execute("util.dbg.cmd('r rax=0x1234')");
String path = "Processes[].Threads[].Registers"; String path = "Sessions[].Processes[].Threads[].Registers";
TraceObject registers = Objects.requireNonNull(tb.objAny(path, Lifespan.at(0))); TraceObject registers = Objects.requireNonNull(tb.objAny(path, Lifespan.at(0)));
AddressSpace space = tb.trace.getBaseAddressFactory() AddressSpace space = tb.trace.getBaseAddressFactory()
.getAddressSpace(registers.getCanonicalPath().toString()); .getAddressSpace(registers.getCanonicalPath().toString());
@ -272,7 +274,7 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
"""); """);
waitRunning("Missed running after go"); waitRunning("Missed running after go");
TraceObject proc = waitForValue(() -> tb.objAny0("Processes[]")); TraceObject proc = waitForValue(() -> tb.objAny0("Sessions[].Processes[]"));
waitForPass(() -> { waitForPass(() -> {
assertEquals("RUNNING", tb.objValue(proc, lastSnap(conn), "_state")); assertEquals("RUNNING", tb.objValue(proc, lastSnap(conn), "_state"));
}, RUN_TIMEOUT_MS, RETRY_MS); }, RUN_TIMEOUT_MS, RETRY_MS);
@ -284,7 +286,7 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) { try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) {
txPut(conn, "processes"); txPut(conn, "processes");
TraceObject proc = waitForValue(() -> tb.objAny0("Processes[]")); TraceObject proc = waitForValue(() -> tb.objAny0("Sessions[].Processes[]"));
waitForPass(() -> { waitForPass(() -> {
assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state")); assertEquals("STOPPED", tb.objValue(proc, lastSnap(conn), "_state"));
}, RUN_TIMEOUT_MS, RETRY_MS); }, RUN_TIMEOUT_MS, RETRY_MS);
@ -306,7 +308,7 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
assertNotNull(snapshot); assertNotNull(snapshot);
assertEquals("Exited with code 0", snapshot.getDescription()); assertEquals("Exited with code 0", snapshot.getDescription());
TraceObject proc = tb.objAny0("Processes[]"); TraceObject proc = tb.objAny0("Sessions[].Processes[]");
assertNotNull(proc); assertNotNull(proc);
Object val = tb.objValue(proc, lastSnap(conn), "_exit_code"); Object val = tb.objValue(proc, lastSnap(conn), "_exit_code");
assertThat(val, instanceOf(Number.class)); assertThat(val, instanceOf(Number.class));
@ -319,13 +321,15 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
public void testOnBreakpointCreated() throws Exception { public void testOnBreakpointCreated() throws Exception {
try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) { try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) {
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
assertEquals(0, tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]").size()); assertEquals(0,
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]").size());
conn.execute("pc = util.get_pc()"); conn.execute("pc = util.get_pc()");
conn.execute("util.dbg.bp(expr=pc)"); conn.execute("util.dbg.bp(expr=pc)");
waitForPass(() -> { waitForPass(() -> {
List<Object> brks = tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]"); List<Object> brks =
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]");
assertEquals(1, brks.size()); assertEquals(1, brks.size());
}); });
} }
@ -335,13 +339,15 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
public void testOnBreakpointModified() throws Exception { public void testOnBreakpointModified() throws Exception {
try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) { try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) {
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
assertEquals(0, tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]").size()); assertEquals(0,
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]").size());
conn.execute("pc = util.get_pc()"); conn.execute("pc = util.get_pc()");
conn.execute("util.dbg.bp(expr=pc)"); conn.execute("util.dbg.bp(expr=pc)");
TraceObject brk = waitForPass(() -> { TraceObject brk = waitForPass(() -> {
List<Object> brks = tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]"); List<Object> brks =
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]");
assertEquals(1, brks.size()); assertEquals(1, brks.size());
return (TraceObject) brks.get(0); return (TraceObject) brks.get(0);
}); });
@ -362,13 +368,15 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
public void testOnBreakpointDeleted() throws Exception { public void testOnBreakpointDeleted() throws Exception {
try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) { try (PythonAndTrace conn = startAndSyncPython("notepad.exe")) {
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
assertEquals(0, tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]").size()); assertEquals(0,
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]").size());
conn.execute("pc = util.get_pc()"); conn.execute("pc = util.get_pc()");
conn.execute("util.dbg.bp(expr=pc)"); conn.execute("util.dbg.bp(expr=pc)");
TraceObject brk = waitForPass(() -> { TraceObject brk = waitForPass(() -> {
List<Object> brks = tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]"); List<Object> brks =
tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]");
assertEquals(1, brks.size()); assertEquals(1, brks.size());
return (TraceObject) brks.get(0); return (TraceObject) brks.get(0);
}); });
@ -380,14 +388,14 @@ public class DbgEngHooksTest extends AbstractDbgEngTraceRmiTest {
conn.execute("util.dbg.cmd('bc %s')".formatted(id)); conn.execute("util.dbg.cmd('bc %s')".formatted(id));
waitForPass(() -> assertEquals(0, waitForPass(() -> assertEquals(0,
tb.objValues(lastSnap(conn), "Processes[].Breakpoints[]").size())); tb.objValues(lastSnap(conn), "Sessions[].Processes[].Debug.Breakpoints[]").size()));
} }
} }
private void start(PythonAndConnection conn, String obj) { private void start(PythonAndConnection conn, String obj) {
conn.execute("from ghidradbg.commands import *"); conn.execute("from ghidradbg.commands import *");
if (obj != null) if (obj != null)
conn.execute("ghidra_trace_create('" + obj + "')"); conn.execute("ghidra_trace_create('" + obj + "', wait=True)");
else else
conn.execute("ghidra_trace_create()"); conn.execute("ghidra_trace_create()");
conn.execute("ghidra_trace_sync_enable()"); conn.execute("ghidra_trace_sync_enable()");

View file

@ -19,11 +19,18 @@ import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import java.io.File;
import java.util.*; import java.util.*;
import org.hamcrest.Matchers;
import org.junit.AssumptionViolatedException;
import org.junit.Test; import org.junit.Test;
import db.Transaction;
import generic.Unique; import generic.Unique;
import ghidra.app.plugin.core.debug.gui.control.DebuggerMethodActionsPlugin;
import ghidra.app.plugin.core.debug.gui.model.DebuggerModelPlugin;
import ghidra.app.plugin.core.debug.gui.time.DebuggerTimePlugin;
import ghidra.app.plugin.core.debug.utils.ManagedDomainObject; import ghidra.app.plugin.core.debug.utils.ManagedDomainObject;
import ghidra.debug.api.tracermi.RemoteMethod; import ghidra.debug.api.tracermi.RemoteMethod;
import ghidra.program.model.address.*; import ghidra.program.model.address.*;
@ -37,18 +44,27 @@ import ghidra.trace.model.memory.TraceMemoryRegion;
import ghidra.trace.model.memory.TraceMemorySpace; import ghidra.trace.model.memory.TraceMemorySpace;
import ghidra.trace.model.modules.TraceModule; import ghidra.trace.model.modules.TraceModule;
import ghidra.trace.model.target.TraceObject; import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.TraceObject.ConflictResolution;
import ghidra.trace.model.target.TraceObjectValue; import ghidra.trace.model.target.TraceObjectValue;
import ghidra.trace.model.target.path.PathFilter; import ghidra.trace.model.target.path.*;
import ghidra.trace.model.target.path.PathPattern; import ghidra.trace.model.time.TraceSnapshot;
import ghidra.trace.model.time.schedule.TraceSchedule;
public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest { public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testEvaluate() throws Exception { public void testEvaluate() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
start(conn, null);
RemoteMethod evaluate = conn.getMethod("evaluate"); RemoteMethod evaluate = conn.getMethod("evaluate");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get());
assertEquals("11", assertEquals("11",
evaluate.invoke(Map.of("expr", "3+4*2"))); evaluate.invoke(Map.ofEntries(
Map.entry("session", tb.obj("Sessions[0]")),
Map.entry("expr", "3+4*2"))));
}
} }
} }
@ -77,18 +93,19 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
public void testRefreshAvailable() throws Exception { public void testRefreshAvailable() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
start(conn, null); start(conn, null);
txCreate(conn, "Available"); // Fake its creation, so it's empty before the refresh
txCreate(conn, "Sessions[0].Available");
RemoteMethod refreshAvailable = conn.getMethod("refresh_available"); RemoteMethod refreshAvailable = conn.getMethod("refresh_available");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject available = Objects.requireNonNull(tb.objAny0("Available")); TraceObject available = Objects.requireNonNull(tb.objAny0("Sessions[].Available"));
refreshAvailable.invoke(Map.of("node", available)); refreshAvailable.invoke(Map.of("node", available));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
List<TraceObject> list = tb.trace.getObjectManager() List<TraceObject> list = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), PathFilter.parse("Available[]")) .getValuePaths(Lifespan.at(0), PathFilter.parse("Sessions[].Available[]"))
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
.toList(); .toList();
assertThat(list.size(), greaterThan(2)); assertThat(list.size(), greaterThan(2));
@ -111,12 +128,12 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
conn.execute("util.dbg.ba(expr=pc+4)"); conn.execute("util.dbg.ba(expr=pc+4)");
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
TraceObject breakpoints = TraceObject breakpoints =
Objects.requireNonNull(tb.objAny0("Processes[].Breakpoints")); Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Debug.Breakpoints"));
refreshBreakpoints.invoke(Map.of("node", breakpoints)); refreshBreakpoints.invoke(Map.of("node", breakpoints));
List<TraceObjectValue> procBreakLocVals = tb.trace.getObjectManager() List<TraceObjectValue> procBreakLocVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Breakpoints[]")) PathFilter.parse("Sessions[].Processes[].Debug.Breakpoints[]"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.sorted(Comparator.comparing(TraceObjectValue::getEntryKey)) .sorted(Comparator.comparing(TraceObjectValue::getEntryKey))
.toList(); .toList();
@ -150,12 +167,12 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
conn.execute("util.dbg.ba(expr=pc+4, access=DbgEng.DEBUG_BREAK_READ)"); conn.execute("util.dbg.ba(expr=pc+4, access=DbgEng.DEBUG_BREAK_READ)");
conn.execute("util.dbg.ba(expr=pc+8, access=DbgEng.DEBUG_BREAK_WRITE)"); conn.execute("util.dbg.ba(expr=pc+8, access=DbgEng.DEBUG_BREAK_WRITE)");
TraceObject locations = TraceObject locations =
Objects.requireNonNull(tb.objAny0("Processes[].Breakpoints")); Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Debug.Breakpoints"));
refreshProcWatchpoints.invoke(Map.of("node", locations)); refreshProcWatchpoints.invoke(Map.of("node", locations));
List<TraceObjectValue> procBreakVals = tb.trace.getObjectManager() List<TraceObjectValue> procBreakVals = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Breakpoints[]")) PathFilter.parse("Sessions[].Processes[].Debug.Breakpoints[]"))
.map(p -> p.getLastEntry()) .map(p -> p.getLastEntry())
.sorted(Comparator.comparing(TraceObjectValue::getEntryKey)) .sorted(Comparator.comparing(TraceObjectValue::getEntryKey))
.toList(); .toList();
@ -186,20 +203,19 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshProcesses() throws Exception { public void testRefreshProcesses() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
start(conn, null); start(conn, "notepad.exe");
txCreate(conn, "Processes"); txCreate(conn, "Sessions[0].Processes");
txCreate(conn, "Processes[1]");
RemoteMethod refreshProcesses = conn.getMethod("refresh_processes"); RemoteMethod refreshProcesses = conn.getMethod("refresh_processes");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject processes = Objects.requireNonNull(tb.objAny0("Processes")); TraceObject processes = Objects.requireNonNull(tb.objAny0("Sessions[].Processes"));
refreshProcesses.invoke(Map.of("node", processes)); refreshProcesses.invoke(Map.of("node", processes));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
List<TraceObject> list = tb.trace.getObjectManager() List<TraceObject> list = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), PathFilter.parse("Processes[]")) .getValuePaths(Lifespan.at(0), PathFilter.parse("Sessions[].Processes[]"))
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
.toList(); .toList();
assertEquals(1, list.size()); assertEquals(1, list.size());
@ -210,14 +226,14 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshEnvironment() throws Exception { public void testRefreshEnvironment() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Environment";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
txPut(conn, "all"); txPut(conn, "all");
RemoteMethod refreshEnvironment = conn.getMethod("refresh_environment"); RemoteMethod refreshEnvironment = conn.getMethod("refresh_environment");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject env = Objects.requireNonNull(tb.objAny0(path)); TraceObject env =
Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Environment"));
refreshEnvironment.invoke(Map.of("node", env)); refreshEnvironment.invoke(Map.of("node", env));
@ -233,15 +249,15 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshThreads() throws Exception { public void testRefreshThreads() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Threads";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
txCreate(conn, path); txPut(conn, "processes");
RemoteMethod refreshThreads = conn.getMethod("refresh_threads"); RemoteMethod refreshThreads = conn.getMethod("refresh_threads");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject threads = Objects.requireNonNull(tb.objAny0(path));
TraceObject proc = tb.objAny0("Sessions[].Processes[]");
TraceObject threads = fakeEmpty(proc, "Threads");
refreshThreads.invoke(Map.of("node", threads)); refreshThreads.invoke(Map.of("node", threads));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
@ -254,7 +270,6 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshStack() throws Exception { public void testRefreshStack() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Threads[].Stack";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
txPut(conn, "processes"); txPut(conn, "processes");
@ -263,13 +278,14 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
txPut(conn, "frames"); txPut(conn, "frames");
TraceObject stack = Objects.requireNonNull(tb.objAny0(path)); TraceObject stack = Objects.requireNonNull(
tb.objAny0("Sessions[].Processes[].Threads[].Stack.Frames"));
refreshStack.invoke(Map.of("node", stack)); refreshStack.invoke(Map.of("node", stack));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
List<TraceObject> list = tb.trace.getObjectManager() List<TraceObject> list = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), .getValuePaths(Lifespan.at(0),
PathFilter.parse("Processes[].Threads[].Stack[]")) PathFilter.parse("Sessions[].Processes[].Threads[].Stack.Frames[]"))
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
.toList(); .toList();
assertTrue(list.size() > 1); assertTrue(list.size() > 1);
@ -280,7 +296,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshRegisters() throws Exception { public void testRefreshRegisters() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Threads[].Registers"; String path = "Sessions[].Processes[].Threads[].Registers";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
conn.execute("ghidra_trace_txstart('Tx')"); conn.execute("ghidra_trace_txstart('Tx')");
conn.execute("ghidra_trace_putreg()"); conn.execute("ghidra_trace_putreg()");
@ -308,15 +324,15 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
@Test @Test
public void testRefreshMappings() throws Exception { public void testRefreshMappings() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Memory";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
txCreate(conn, path); txPut(conn, "processes");
RemoteMethod refreshMappings = conn.getMethod("refresh_mappings"); RemoteMethod refreshMappings = conn.getMethod("refresh_mappings");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject memory = Objects.requireNonNull(tb.objAny0(path));
TraceObject proc = tb.objAny0("Sessions[].Processes[]");
TraceObject memory = fakeEmpty(proc, "Memory");
refreshMappings.invoke(Map.of("node", memory)); refreshMappings.invoke(Map.of("node", memory));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
@ -327,18 +343,28 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
} }
} }
protected TraceObject fakeEmpty(TraceObject parent, String ext) {
KeyPath path = parent.getCanonicalPath().extend(KeyPath.parse(ext));
Trace trace = parent.getTrace();
try (Transaction tx = trace.openTransaction("Fake %s".formatted(path))) {
TraceObject obj = trace.getObjectManager().createObject(path);
obj.insert(parent.getLife().bound(), ConflictResolution.DENY);
return obj;
}
}
@Test @Test
public void testRefreshModules() throws Exception { public void testRefreshModules() throws Exception {
try (PythonAndConnection conn = startAndConnectPython()) { try (PythonAndConnection conn = startAndConnectPython()) {
String path = "Processes[].Modules";
start(conn, "notepad.exe"); start(conn, "notepad.exe");
txCreate(conn, path); txPut(conn, "processes");
RemoteMethod refreshModules = conn.getMethod("refresh_modules"); RemoteMethod refreshModules = conn.getMethod("refresh_modules");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject modules = Objects.requireNonNull(tb.objAny0(path));
TraceObject proc = tb.objAny0("Sessions[].Processes[]");
TraceObject modules = fakeEmpty(proc, "Modules");
refreshModules.invoke(Map.of("node", modules)); refreshModules.invoke(Map.of("node", modules));
// Would be nice to control / validate the specifics // Would be nice to control / validate the specifics
@ -363,7 +389,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
txPut(conn, "threads"); txPut(conn, "threads");
PathPattern pattern = PathPattern pattern =
PathFilter.parse("Processes[].Threads[]").getSingletonPattern(); PathFilter.parse("Sessions[].Processes[].Threads[]").getSingletonPattern();
List<TraceObject> list = tb.trace.getObjectManager() List<TraceObject> list = tb.trace.getObjectManager()
.getValuePaths(Lifespan.at(0), pattern) .getValuePaths(Lifespan.at(0), pattern)
.map(p -> p.getDestination(null)) .map(p -> p.getDestination(null))
@ -374,7 +400,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
activateThread.invoke(Map.of("thread", t)); activateThread.invoke(Map.of("thread", t));
String out = conn.executeCapture("print(util.dbg.get_thread())").strip(); String out = conn.executeCapture("print(util.dbg.get_thread())").strip();
List<String> indices = pattern.matchKeys(t.getCanonicalPath(), true); List<String> indices = pattern.matchKeys(t.getCanonicalPath(), true);
assertEquals("%s".formatted(indices.get(1)), out); assertEquals("%s".formatted(indices.get(2)), out);
} }
} }
} }
@ -390,7 +416,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/netstat.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/netstat.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject proc2 = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc2 = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
removeProcess.invoke(Map.of("process", proc2)); removeProcess.invoke(Map.of("process", proc2));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
@ -409,9 +435,10 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
RemoteMethod attachObj = conn.getMethod("attach_obj"); RemoteMethod attachObj = conn.getMethod("attach_obj");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject target = TraceObject target = Objects.requireNonNull(tb.obj(
Objects.requireNonNull(tb.obj("Available[%d]".formatted(dproc.pid))); "Sessions[0].Available[%d]".formatted(dproc.pid)));
attachObj.invoke(Map.of("target", target)); attachObj.invoke(Map.ofEntries(
Map.entry("target", target)));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
assertThat(out, containsString("%d".formatted(dproc.pid))); assertThat(out, containsString("%d".formatted(dproc.pid)));
@ -430,9 +457,11 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
RemoteMethod attachPid = conn.getMethod("attach_pid"); RemoteMethod attachPid = conn.getMethod("attach_pid");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/noname")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
Objects.requireNonNull( Objects.requireNonNull(tb.obj(
tb.objAny("Available[" + dproc.pid + "]", Lifespan.at(0))); "Sessions[0].Available[%d]".formatted(dproc.pid)));
attachPid.invoke(Map.of("pid", dproc.pid)); attachPid.invoke(Map.ofEntries(
Map.entry("session", tb.obj("Sessions[0]")),
Map.entry("pid", dproc.pid)));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
assertThat(out, containsString("%d".formatted(dproc.pid))); assertThat(out, containsString("%d".formatted(dproc.pid)));
@ -451,7 +480,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/netstat.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/netstat.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
detach.invoke(Map.of("process", proc)); detach.invoke(Map.of("process", proc));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
@ -471,7 +500,9 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
launch.invoke(Map.ofEntries( launch.invoke(Map.ofEntries(
Map.entry("file", "notepad.exe"))); Map.entry("session", tb.obj("Sessions[0]")),
Map.entry("file", "notepad.exe"),
Map.entry("wait", true)));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
assertThat(out, containsString("notepad.exe")); assertThat(out, containsString("notepad.exe"));
@ -490,8 +521,10 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
launch.invoke(Map.ofEntries( launch.invoke(Map.ofEntries(
Map.entry("session", tb.obj("Sessions[0]")),
Map.entry("initial_break", true), Map.entry("initial_break", true),
Map.entry("file", "notepad.exe"))); Map.entry("file", "notepad.exe"),
Map.entry("wait", true)));
txPut(conn, "processes"); txPut(conn, "processes");
@ -512,7 +545,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
kill.invoke(Map.of("process", proc)); kill.invoke(Map.of("process", proc));
String out = conn.executeCapture("print(list(util.process_list()))"); String out = conn.executeCapture("print(list(util.process_list()))");
@ -535,7 +568,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
go.invoke(Map.of("process", proc)); go.invoke(Map.of("process", proc));
@ -561,7 +594,8 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
txPut(conn, "threads"); txPut(conn, "threads");
TraceObject thread = Objects.requireNonNull(tb.objAny0("Processes[].Threads[]")); TraceObject thread =
Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Threads[]"));
while (!getInst(conn).contains("call")) { while (!getInst(conn).contains("call")) {
stepInto.invoke(Map.of("thread", thread)); stepInto.invoke(Map.of("thread", thread));
@ -595,7 +629,8 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
txPut(conn, "threads"); txPut(conn, "threads");
TraceObject thread = Objects.requireNonNull(tb.objAny0("Processes[].Threads[]")); TraceObject thread =
Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Threads[]"));
while (!getInst(conn).contains("call")) { while (!getInst(conn).contains("call")) {
stepOver.invoke(Map.of("thread", thread)); stepOver.invoke(Map.of("thread", thread));
@ -623,7 +658,8 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
txPut(conn, "threads"); txPut(conn, "threads");
TraceObject thread = Objects.requireNonNull(tb.objAny0("Processes[].Threads[]")); TraceObject thread =
Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Threads[]"));
while (!getInst(conn).contains("call")) { while (!getInst(conn).contains("call")) {
stepInto.invoke(Map.of("thread", thread)); stepInto.invoke(Map.of("thread", thread));
} }
@ -635,7 +671,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
} }
long pcNext = getAddressAtOffset(conn, sz); long pcNext = getAddressAtOffset(conn, sz);
stepTo.invoke(Map.of("thread", thread, "address", tb.addr(pcNext), "max", 10)); stepTo.invoke(Map.of("thread", thread, "address", tb.addr(pcNext), "max", 10L));
long pc = getAddressAtOffset(conn, 0); long pc = getAddressAtOffset(conn, 0);
assertEquals(pcNext, pc); assertEquals(pcNext, pc);
@ -656,7 +692,8 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
txPut(conn, "threads"); txPut(conn, "threads");
TraceObject thread = Objects.requireNonNull(tb.objAny0("Processes[].Threads[]")); TraceObject thread =
Objects.requireNonNull(tb.objAny0("Sessions[].Processes[].Threads[]"));
while (!getInst(conn).contains("call")) { while (!getInst(conn).contains("call")) {
stepInto.invoke(Map.of("thread", thread)); stepInto.invoke(Map.of("thread", thread));
@ -683,7 +720,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address))); breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address)));
@ -723,7 +760,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) { try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/notepad.exe")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address))); breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address)));
@ -764,7 +801,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
AddressRange range = tb.range(address, address + 3); // length 4 AddressRange range = tb.range(address, address + 3); // length 4
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
@ -809,7 +846,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
AddressRange range = tb.range(address, address + 3); // length 4 AddressRange range = tb.range(address, address + 3); // length 4
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
@ -854,7 +891,7 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
waitStopped("Missed initial stop"); waitStopped("Missed initial stop");
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
AddressRange range = tb.range(address, address + 3); // length 4 AddressRange range = tb.range(address, address + 3); // length 4
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
@ -900,11 +937,12 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address))); breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address)));
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
TraceObject bpt = Objects.requireNonNull(tb.objAny0("Processes[].Breakpoints[]")); TraceObject bpt = Objects
.requireNonNull(tb.objAny0("Sessions[].Processes[].Debug.Breakpoints[]"));
toggleBreakpoint.invoke(Map.of("breakpoint", bpt, "enabled", false)); toggleBreakpoint.invoke(Map.of("breakpoint", bpt, "enabled", false));
@ -926,11 +964,12 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
tb = new ToyDBTraceBuilder((Trace) mdo.get()); tb = new ToyDBTraceBuilder((Trace) mdo.get());
long address = getAddressAtOffset(conn, 0); long address = getAddressAtOffset(conn, 0);
TraceObject proc = Objects.requireNonNull(tb.objAny0("Processes[]")); TraceObject proc = Objects.requireNonNull(tb.objAny0("Sessions[].Processes[]"));
breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address))); breakAddress.invoke(Map.of("process", proc, "address", tb.addr(address)));
txPut(conn, "breakpoints"); txPut(conn, "breakpoints");
TraceObject bpt = Objects.requireNonNull(tb.objAny0("Processes[].Breakpoints[]")); TraceObject bpt = Objects
.requireNonNull(tb.objAny0("Sessions[].Processes[].Debug.Breakpoints[]"));
deleteBreakpoint.invoke(Map.of("breakpoint", bpt)); deleteBreakpoint.invoke(Map.of("breakpoint", bpt));
@ -940,22 +979,115 @@ public class DbgEngMethodsTest extends AbstractDbgEngTraceRmiTest {
} }
} }
protected static final File TRACE_RUN_FILE = new File("C:\\TTD_Testing\\cmd01.run");
/**
* Tracing with the TTD.exe utility (or WinDbg for that matter) requires Administrative
* privileges, which we cannot assume we have. Likely, we should assume we DO NOT have those.
* Thus, it is up to the person running the tests to ensure the required trace output exists.
* Follow the directions on MSDN to install the TTD.exe command-line utility, then issue the
* following in an Administrator command prompt:
*
* <pre>
* C:
* cd \TTD_Testing
* ttd -launch cmd /c exit
* </pre>
*
* You may need to set ownership and/or permissions on the output to ensure the tests can read
* it. You'll also need to install/copy the dbgeng.dll and related files to support TTD into
* C:\TTD_Testing.
*/
protected void createMsTtdTrace() {
// Can't actually do anything as standard user. Just check and ignore if missing
if (!TRACE_RUN_FILE.exists()) {
throw new AssumptionViolatedException(TRACE_RUN_FILE + " does not exist");
}
assertTrue("Cannot read " + TRACE_RUN_FILE, TRACE_RUN_FILE.canRead());
}
@Test
public void testTtdOpenTrace() throws Exception {
createMsTtdTrace();
try (PythonAndConnection conn = startAndConnectPython()) {
openTtdTrace(conn);
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/cmd01.run")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get());
}
}
}
@Test
public void testTtdActivateFrame() throws Exception {
addPlugin(tool, DebuggerModelPlugin.class);
addPlugin(tool, DebuggerMethodActionsPlugin.class);
addPlugin(tool, DebuggerTimePlugin.class);
createMsTtdTrace();
try (PythonAndConnection conn = startAndConnectPython()) {
openTtdTrace(conn);
txPut(conn, "frames");
txPut(conn, "events");
RemoteMethod activate = conn.getMethod("activate_frame");
try (ManagedDomainObject mdo = openDomainObject("/New Traces/pydbg/cmd01.run")) {
tb = new ToyDBTraceBuilder((Trace) mdo.get());
traceManager.openTrace(tb.trace);
traceManager.activateTrace(tb.trace);
TraceSnapshot init =
tb.trace.getTimeManager().getSnapshot(traceManager.getCurrentSnap(), false);
assertThat(init.getDescription(), Matchers.containsString("ThreadCreated"));
TraceObject frame0 = tb.objAny0("Sessions[].Processes[].Threads[].Stack.Frames[]");
TraceSchedule time = TraceSchedule.snap(init.getKey() + 1);
activate.invoke(Map.ofEntries(
Map.entry("frame", frame0),
Map.entry("time", time.toString())));
assertEquals(time, traceManager.getCurrent().getTime());
}
}
}
private void start(PythonAndConnection conn, String obj) { private void start(PythonAndConnection conn, String obj) {
conn.execute("from ghidradbg.commands import *"); conn.execute("from ghidradbg.commands import *");
if (obj != null) if (obj != null)
conn.execute("ghidra_trace_create('" + obj + "')"); conn.execute("ghidra_trace_create('%s', wait=True)".formatted(obj));
else else
conn.execute("ghidra_trace_create()"); conn.execute("ghidra_trace_create()");
} }
private void openTtdTrace(PythonAndConnection conn) {
/**
* NOTE: dbg.wait() must precede sync_enable() or else the PROC_STATE will have the wrong
* PID, and later events will all get snuffed.
*/
conn.execute("""
import os
from ghidradbg.commands import *
from ghidradbg.util import dbg
os.environ['USE_TTD'] = 'true'
dbg.IS_TRACE = True
os.environ['OPT_USE_DBGMODEL'] = 'true'
dbg.use_generics = True
ghidra_trace_open(r'%s', start_trace=False)
dbg.wait()
ghidra_trace_start(r'%s')
ghidra_trace_sync_enable()
""".formatted(TRACE_RUN_FILE, TRACE_RUN_FILE));
}
private void txPut(PythonAndConnection conn, String obj) { private void txPut(PythonAndConnection conn, String obj) {
conn.execute("ghidra_trace_txstart('Tx')"); conn.execute("ghidra_trace_txstart('Tx-put %s')".formatted(obj));
conn.execute("ghidra_trace_put_" + obj + "()"); conn.execute("ghidra_trace_put_%s()".formatted(obj));
conn.execute("ghidra_trace_txcommit()"); conn.execute("ghidra_trace_txcommit()");
} }
private void txCreate(PythonAndConnection conn, String path) { private void txCreate(PythonAndConnection conn, String path) {
conn.execute("ghidra_trace_txstart('Fake')"); conn.execute("ghidra_trace_txstart('Fake %s')".formatted(path));
conn.execute("ghidra_trace_create_obj('%s')".formatted(path)); conn.execute("ghidra_trace_create_obj('%s')".formatted(path));
conn.execute("ghidra_trace_txcommit()"); conn.execute("ghidra_trace_txcommit()");
} }

View file

@ -31,6 +31,7 @@ import java.util.function.*;
import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass;
import generic.jar.ResourceFile; import generic.jar.ResourceFile;
import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest; import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest;
@ -46,6 +47,7 @@ import ghidra.framework.plugintool.PluginsConfiguration;
import ghidra.framework.plugintool.util.*; import ghidra.framework.plugintool.util.*;
import ghidra.pty.testutil.DummyProc; import ghidra.pty.testutil.DummyProc;
import ghidra.util.Msg; import ghidra.util.Msg;
import ghidra.util.SystemUtilities;
import junit.framework.AssertionFailedError; import junit.framework.AssertionFailedError;
public abstract class AbstractDrgnTraceRmiTest extends AbstractGhidraHeadedDebuggerTest { public abstract class AbstractDrgnTraceRmiTest extends AbstractGhidraHeadedDebuggerTest {
@ -81,14 +83,18 @@ public abstract class AbstractDrgnTraceRmiTest extends AbstractGhidraHeadedDebug
assumeTrue(OperatingSystem.CURRENT_OPERATING_SYSTEM == OperatingSystem.LINUX); assumeTrue(OperatingSystem.CURRENT_OPERATING_SYSTEM == OperatingSystem.LINUX);
} }
//@BeforeClass @BeforeClass
public static void setupPython() throws Throwable { public static void setupPython() throws Throwable {
if (didSetupPython) { if (didSetupPython) {
// Only do this once when running the full suite. // Only do this once when running the full suite.
return; return;
} }
if (SystemUtilities.isInTestingBatchMode()) {
// Don't run gradle in gradle. It already did this task.
return;
}
String gradle = DummyProc.which("gradle"); String gradle = DummyProc.which("gradle");
new ProcessBuilder(gradle, "Debugger-agent-drgn:assemblePyPackage") new ProcessBuilder(gradle, "assemblePyPackage")
.directory(TestApplicationUtils.getInstallationDirectory()) .directory(TestApplicationUtils.getInstallationDirectory())
.inheritIO() .inheritIO()
.start() .start()

View file

@ -407,7 +407,7 @@ public class DrgnCommandsTest extends AbstractDrgnTraceRmiTest {
ghidra_trace_txstart('Create Object') ghidra_trace_txstart('Create Object')
ghidra_trace_create_obj('Test.Objects[1]') ghidra_trace_create_obj('Test.Objects[1]')
ghidra_trace_insert_obj('Test.Objects[1]') ghidra_trace_insert_obj('Test.Objects[1]')
ghidra_trace_set_snap(1) ghidra_trace_new_snap("Snap 1", time=1)
ghidra_trace_remove_obj('Test.Objects[1]') ghidra_trace_remove_obj('Test.Objects[1]')
ghidra_trace_txcommit() ghidra_trace_txcommit()
quit() quit()
@ -585,7 +585,7 @@ public class DrgnCommandsTest extends AbstractDrgnTraceRmiTest {
ghidra_trace_set_value('Test.Objects[1]', '[1]', '"A"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[1]', '"A"', 'STRING')
ghidra_trace_set_value('Test.Objects[1]', '[2]', '"B"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[2]', '"B"', 'STRING')
ghidra_trace_set_value('Test.Objects[1]', '[3]', '"C"', 'STRING') ghidra_trace_set_value('Test.Objects[1]', '[3]', '"C"', 'STRING')
ghidra_trace_set_snap(10) ghidra_trace_new_snap("Snap 10", time=10)
ghidra_trace_retain_values('Test.Objects[1]', '[1] [3]') ghidra_trace_retain_values('Test.Objects[1]', '[1] [3]')
ghidra_trace_txcommit() ghidra_trace_txcommit()
quit() quit()
@ -761,10 +761,7 @@ public class DrgnCommandsTest extends AbstractDrgnTraceRmiTest {
String extract = extractOutSection(out, "---Disassemble---"); String extract = extractOutSection(out, "---Disassemble---");
String[] split = extract.split("\r\n"); String[] split = extract.split("\r\n");
// NB: core.12137 has no memory // NB: core.12137 has no memory
//assertEquals("Disassembled %d bytes".formatted(total), assertEquals("Disassembled %d bytes".formatted(total), split[0]);
// split[0]);
assertEquals(0, total);
assertEquals("", split[0]);
} }
} }

View file

@ -31,6 +31,7 @@ import java.util.stream.Stream;
import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass;
import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest; import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest;
import ghidra.app.plugin.core.debug.service.tracermi.TraceRmiPlugin; import ghidra.app.plugin.core.debug.service.tracermi.TraceRmiPlugin;
@ -52,8 +53,7 @@ import ghidra.trace.model.breakpoint.TraceBreakpointKind.TraceBreakpointKindSet;
import ghidra.trace.model.target.TraceObject; import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.TraceObjectValue; import ghidra.trace.model.target.TraceObjectValue;
import ghidra.trace.model.target.path.KeyPath; import ghidra.trace.model.target.path.KeyPath;
import ghidra.util.Msg; import ghidra.util.*;
import ghidra.util.NumericUtilities;
public abstract class AbstractGdbTraceRmiTest extends AbstractGhidraHeadedDebuggerTest { public abstract class AbstractGdbTraceRmiTest extends AbstractGhidraHeadedDebuggerTest {
/** /**
@ -71,7 +71,7 @@ public abstract class AbstractGdbTraceRmiTest extends AbstractGhidraHeadedDebugg
"""; """;
// Connecting should be the first thing the script does, so use a tight timeout. // Connecting should be the first thing the script does, so use a tight timeout.
protected static final int CONNECT_TIMEOUT_MS = 3000; protected static final int CONNECT_TIMEOUT_MS = 3000;
protected static final int TIMEOUT_SECONDS = 300; protected static final int TIMEOUT_SECONDS = 10;
protected static final int QUIT_TIMEOUT_MS = 1000; protected static final int QUIT_TIMEOUT_MS = 1000;
public static final String INSTRUMENT_STOPPED = """ public static final String INSTRUMENT_STOPPED = """
ghidra trace tx-open "Fake" 'ghidra trace create-obj Inferiors[1]' ghidra trace tx-open "Fake" 'ghidra trace create-obj Inferiors[1]'
@ -95,18 +95,29 @@ public abstract class AbstractGdbTraceRmiTest extends AbstractGhidraHeadedDebugg
/** Some snapshot likely to exceed the latest */ /** Some snapshot likely to exceed the latest */
protected static final long SNAP = 100; protected static final long SNAP = 100;
protected static boolean didSetupPython = false;
protected TraceRmiService traceRmi; protected TraceRmiService traceRmi;
private Path gdbPath; private Path gdbPath;
private Path outFile; private Path outFile;
private Path errFile; private Path errFile;
// @BeforeClass @BeforeClass
public static void setupPython() throws Throwable { public static void setupPython() throws Throwable {
new ProcessBuilder("gradle", "Debugger-agent-gdb:assemblePyPackage") if (didSetupPython) {
// Only do this once when running the full suite.
return;
}
if (SystemUtilities.isInTestingBatchMode()) {
// Don't run gradle in gradle. It already did this task.
return;
}
new ProcessBuilder("gradle", "assemblePyPackage")
.directory(TestApplicationUtils.getInstallationDirectory()) .directory(TestApplicationUtils.getInstallationDirectory())
.inheritIO() .inheritIO()
.start() .start()
.waitFor(); .waitFor();
didSetupPython = true;
} }
protected void setPythonPath(ProcessBuilder pb) throws IOException { protected void setPythonPath(ProcessBuilder pb) throws IOException {

View file

@ -584,6 +584,7 @@ public class GdbCommandsTest extends AbstractGdbTraceRmiTest {
@Test @Test
public void testRemoveObj() throws Exception { public void testRemoveObj() throws Exception {
// Must give 1 for new-snap, since snap 0 was never created
runThrowError(addr -> """ runThrowError(addr -> """
%s %s
ghidra trace connect %s ghidra trace connect %s
@ -591,7 +592,7 @@ public class GdbCommandsTest extends AbstractGdbTraceRmiTest {
ghidra trace tx-start "Create Object" ghidra trace tx-start "Create Object"
ghidra trace create-obj Test.Objects[1] ghidra trace create-obj Test.Objects[1]
ghidra trace insert-obj Test.Objects[1] ghidra trace insert-obj Test.Objects[1]
ghidra trace set-snap 1 ghidra trace new-snap 1 "Snap 1"
ghidra trace remove-obj Test.Objects[1] ghidra trace remove-obj Test.Objects[1]
ghidra trace tx-commit ghidra trace tx-commit
quit quit
@ -779,7 +780,7 @@ public class GdbCommandsTest extends AbstractGdbTraceRmiTest {
ghidra trace set-value Test.Objects[1] [1] '"A"' ghidra trace set-value Test.Objects[1] [1] '"A"'
ghidra trace set-value Test.Objects[1] [2] '"B"' ghidra trace set-value Test.Objects[1] [2] '"B"'
ghidra trace set-value Test.Objects[1] [3] '"C"' ghidra trace set-value Test.Objects[1] [3] '"C"'
ghidra trace set-snap 10 ghidra trace new-snap 10 "Snap 10"
ghidra trace retain-values Test.Objects[1] [1] [3] ghidra trace retain-values Test.Objects[1] [1] [3]
ghidra trace tx-commit ghidra trace tx-commit
kill kill

View file

@ -275,6 +275,9 @@ public class GdbHooksTest extends AbstractGdbTraceRmiTest {
TraceMemorySpace regs = tb.trace.getMemoryManager().getMemorySpace(space, false); TraceMemorySpace regs = tb.trace.getMemoryManager().getMemorySpace(space, false);
waitForPass(() -> assertEquals("1234", waitForPass(() -> assertEquals("1234",
regs.getValue(lastSnap(conn), tb.reg("RAX")).getUnsignedValue().toString(16))); regs.getValue(lastSnap(conn), tb.reg("RAX")).getUnsignedValue().toString(16)));
assertEquals(List.of("0x1234"),
tb.objValues(lastSnap(conn), "Inferiors[1].Threads[1].Stack[0].Registers.rax"));
} }
} }

View file

@ -33,6 +33,7 @@ import org.apache.commons.io.output.TeeOutputStream;
import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.commons.lang3.exception.ExceptionUtils;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass;
import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest; import ghidra.app.plugin.core.debug.gui.AbstractGhidraHeadedDebuggerTest;
import ghidra.app.plugin.core.debug.service.tracermi.TraceRmiPlugin; import ghidra.app.plugin.core.debug.service.tracermi.TraceRmiPlugin;
@ -68,6 +69,8 @@ public abstract class AbstractLldbTraceRmiTest extends AbstractGhidraHeadedDebug
public static final PlatDep PLAT = computePlat(); public static final PlatDep PLAT = computePlat();
protected static boolean didSetupPython = false;
static PlatDep computePlat() { static PlatDep computePlat() {
return switch (System.getProperty("os.arch")) { return switch (System.getProperty("os.arch")) {
case "aarch64" -> PlatDep.ARM64; case "aarch64" -> PlatDep.ARM64;
@ -112,15 +115,22 @@ public abstract class AbstractLldbTraceRmiTest extends AbstractGhidraHeadedDebug
protected TraceRmiService traceRmi; protected TraceRmiService traceRmi;
private Path lldbPath; private Path lldbPath;
// @BeforeClass @BeforeClass
public static void setupPython() throws Throwable { public static void setupPython() throws Throwable {
new ProcessBuilder("gradle", if (didSetupPython) {
"Debugger-rmi-trace:assemblePyPackage", // Only do this once when running the full suite.
"Debugger-agent-lldb:assemblePyPackage") return;
}
if (SystemUtilities.isInTestingBatchMode()) {
// Don't run gradle in gradle. It already did this task.
return;
}
new ProcessBuilder("gradle", "assemblePyPackage")
.directory(TestApplicationUtils.getInstallationDirectory()) .directory(TestApplicationUtils.getInstallationDirectory())
.inheritIO() .inheritIO()
.start() .start()
.waitFor(); .waitFor();
didSetupPython = true;
} }
protected void setPythonPath(Map<String, String> env) throws IOException { protected void setPythonPath(Map<String, String> env) throws IOException {

View file

@ -587,7 +587,7 @@ public class LldbCommandsTest extends AbstractLldbTraceRmiTest {
ghidra trace tx-start "Create Object" ghidra trace tx-start "Create Object"
ghidra trace create-obj Test.Objects[1] ghidra trace create-obj Test.Objects[1]
ghidra trace insert-obj Test.Objects[1] ghidra trace insert-obj Test.Objects[1]
ghidra trace set-snap 1 ghidra trace new-snap 1 "Next"
ghidra trace remove-obj Test.Objects[1] ghidra trace remove-obj Test.Objects[1]
ghidra trace tx-commit ghidra trace tx-commit
kill kill
@ -805,7 +805,7 @@ public class LldbCommandsTest extends AbstractLldbTraceRmiTest {
ghidra trace set-value Test.Objects[1] [1] 10 ghidra trace set-value Test.Objects[1] [1] 10
ghidra trace set-value Test.Objects[1] [2] 20 ghidra trace set-value Test.Objects[1] [2] 20
ghidra trace set-value Test.Objects[1] [3] 30 ghidra trace set-value Test.Objects[1] [3] 30
ghidra trace set-snap 10 ghidra trace new-snap 10 "Snap 10"
ghidra trace retain-values Test.Objects[1] [1] [3] ghidra trace retain-values Test.Objects[1] [1] [3]
ghidra trace tx-commit ghidra trace tx-commit
kill kill

View file

@ -15,8 +15,7 @@
*/ */
package agent.lldb.rmi; package agent.lldb.rmi;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue; import static org.junit.Assume.assumeTrue;
@ -870,7 +869,9 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
assertThat(out, containsString("0x%x".formatted(address))); assertThat(out, anyOf(
containsString("0x%x".formatted(address)),
containsString("0x%08x".formatted(address))));
assertThat(out, containsString("size = 1")); assertThat(out, containsString("size = 1"));
assertThat(out, containsString("type = r")); assertThat(out, containsString("type = r"));
} }
@ -889,7 +890,7 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakExpression.invoke(Map.of( breakExpression.invoke(Map.of(
"expression", "`(void(*)())main`", "expression", "`(void(*)())main`",
"size", 1)); "size", "1"));
long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]); long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]);
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
@ -916,9 +917,13 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
assertThat(out, containsString("0x%x".formatted(address))); assertThat(out, anyOf(
containsString("0x%x".formatted(address)),
containsString("0x%08x".formatted(address))));
assertThat(out, containsString("size = 1")); assertThat(out, containsString("size = 1"));
assertThat(out, containsString("type = w")); assertThat(out, anyOf(
containsString("type = w"),
containsString("type = m")));
} }
} }
} }
@ -935,12 +940,14 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakExpression.invoke(Map.of( breakExpression.invoke(Map.of(
"expression", "`(void(*)())main`", "expression", "`(void(*)())main`",
"size", 1)); "size", "1"));
long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]); long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]);
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
assertThat(out, containsString(Long.toHexString(address))); assertThat(out, containsString(Long.toHexString(address)));
assertThat(out, containsString("type = w")); assertThat(out, anyOf(
containsString("type = w"),
containsString("type = m")));
} }
} }
} }
@ -962,7 +969,9 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakRange.invoke(Map.of("process", proc, "range", range)); breakRange.invoke(Map.of("process", proc, "range", range));
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
assertThat(out, containsString("0x%x".formatted(address))); assertThat(out, anyOf(
containsString("0x%x".formatted(address)),
containsString("0x%08x".formatted(address))));
assertThat(out, containsString("size = 1")); assertThat(out, containsString("size = 1"));
assertThat(out, containsString("type = rw")); assertThat(out, containsString("type = rw"));
} }
@ -981,7 +990,7 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakExpression.invoke(Map.of( breakExpression.invoke(Map.of(
"expression", "`(void(*)())main`", "expression", "`(void(*)())main`",
"size", 1)); "size", "1"));
long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]); long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]);
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");
@ -1094,7 +1103,7 @@ public class LldbMethodsTest extends AbstractLldbTraceRmiTest {
breakExpression.invoke(Map.of( breakExpression.invoke(Map.of(
"expression", "`(void(*)())main`", "expression", "`(void(*)())main`",
"size", 1)); "size", "1"));
long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]); long address = Long.decode(conn.executeCapture("dis -c1 -n main").split("\\s+")[1]);
String out = conn.executeCapture("watchpoint list"); String out = conn.executeCapture("watchpoint list");

View file

@ -56,6 +56,7 @@ public class AbstractGhidraHeadedDebuggerIntegrationTest
public static final SchemaContext SCHEMA_CTX = xmlSchema(""" public static final SchemaContext SCHEMA_CTX = xmlSchema("""
<context> <context>
<schema name='Session' elementResync='NEVER' attributeResync='ONCE'> <schema name='Session' elementResync='NEVER' attributeResync='ONCE'>
<interface name='EventScope' />
<attribute name='Processes' schema='ProcessContainer' /> <attribute name='Processes' schema='ProcessContainer' />
</schema> </schema>
<schema name='ProcessContainer' canonical='yes' elementResync='NEVER' <schema name='ProcessContainer' canonical='yes' elementResync='NEVER'
@ -220,6 +221,66 @@ public class AbstractGhidraHeadedDebuggerIntegrationTest
rmiCx.getMethods().add(rmiMethodActivateFrame); rmiCx.getMethods().add(rmiMethodActivateFrame);
} }
protected void addActivateWithSnapMethods() {
rmiMethodActivateProcess =
new TestRemoteMethod("activate_process", ActionName.ACTIVATE, "Activate Process",
"Activate a process", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("process", new SchemaName("Process"), true, null, "Process",
"The process to activate"),
new TestRemoteParameter("snap", PrimitiveTraceObjectSchema.LONG, false, null,
"Time", "The snapshot to activate"));
rmiMethodActivateThread =
new TestRemoteMethod("activate_thread", ActionName.ACTIVATE, "Activate Thread",
"Activate a thread", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("thread", new SchemaName("Thread"), true, null, "Thread",
"The thread to activate"),
new TestRemoteParameter("snap", PrimitiveTraceObjectSchema.LONG, false, null,
"Time", "The snapshot to activate"));
rmiMethodActivateFrame =
new TestRemoteMethod("activate_frame", ActionName.ACTIVATE, "Activate Frame",
"Activate a frame", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("frame", new SchemaName("Frame"), true, null, "Frame",
"The frame to activate"),
new TestRemoteParameter("snap", PrimitiveTraceObjectSchema.LONG, false, null,
"Time", "The snapshot to activate"));
rmiCx.getMethods().add(rmiMethodActivateProcess);
rmiCx.getMethods().add(rmiMethodActivateThread);
rmiCx.getMethods().add(rmiMethodActivateFrame);
}
protected void addActivateWithTimeMethods() {
rmiMethodActivateProcess =
new TestRemoteMethod("activate_process", ActionName.ACTIVATE, "Activate Process",
"Activate a process", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("process", new SchemaName("Process"), true, null, "Process",
"The process to activate"),
new TestRemoteParameter("time", PrimitiveTraceObjectSchema.STRING, false, null,
"Time", "The schedule to activate"));
rmiMethodActivateThread =
new TestRemoteMethod("activate_thread", ActionName.ACTIVATE, "Activate Thread",
"Activate a thread", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("thread", new SchemaName("Thread"), true, null, "Thread",
"The thread to activate"),
new TestRemoteParameter("time", PrimitiveTraceObjectSchema.STRING, false, null,
"Time", "The schedule to activate"));
rmiMethodActivateFrame =
new TestRemoteMethod("activate_frame", ActionName.ACTIVATE, "Activate Frame",
"Activate a frame", PrimitiveTraceObjectSchema.VOID,
new TestRemoteParameter("frame", new SchemaName("Frame"), true, null, "Frame",
"The frame to activate"),
new TestRemoteParameter("time", PrimitiveTraceObjectSchema.STRING, false, null,
"Time", "The schedule to activate"));
rmiCx.getMethods().add(rmiMethodActivateProcess);
rmiCx.getMethods().add(rmiMethodActivateThread);
rmiCx.getMethods().add(rmiMethodActivateFrame);
}
protected boolean activationMethodsQueuesEmpty() { protected boolean activationMethodsQueuesEmpty() {
return rmiMethodActivateProcess.argQueue().isEmpty() && return rmiMethodActivateProcess.argQueue().isEmpty() &&
rmiMethodActivateThread.argQueue().isEmpty() && rmiMethodActivateThread.argQueue().isEmpty() &&

View file

@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals;
import java.util.*; import java.util.*;
import org.junit.After;
import org.junit.Before; import org.junit.Before;
import db.Transaction; import db.Transaction;
@ -41,6 +42,23 @@ public class DebuggerRmiBreakpointsProviderTest
addBreakpointMethods(); addBreakpointMethods();
} }
@After
public void tearDownBreakpointTest() {
waitForTasks();
runSwing(() -> {
if (traceManager == null) {
return;
}
traceManager.setSaveTracesByDefault(false);
});
if (tb3 != null) {
if (traceManager != null && traceManager.getOpenTraces().contains(tb3.trace)) {
traceManager.closeTraceNoConfirm(tb3.trace);
}
tb3.close();
}
}
@Override @Override
protected TraceRmiTarget createTarget1() throws Throwable { protected TraceRmiTarget createTarget1() throws Throwable {
createTrace(); createTrace();

View file

@ -19,6 +19,7 @@ import static org.junit.Assert.*;
import java.util.*; import java.util.*;
import org.hamcrest.Matchers;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.experimental.categories.Category; import org.junit.experimental.categories.Category;
@ -37,12 +38,15 @@ import ghidra.trace.database.target.DBTraceObjectManagerTest;
import ghidra.trace.model.Lifespan; import ghidra.trace.model.Lifespan;
import ghidra.trace.model.Trace; import ghidra.trace.model.Trace;
import ghidra.trace.model.target.TraceObject; import ghidra.trace.model.target.TraceObject;
import ghidra.trace.model.target.iface.TraceObjectEventScope;
import ghidra.trace.model.target.path.KeyPath; import ghidra.trace.model.target.path.KeyPath;
import ghidra.trace.model.target.schema.SchemaContext; import ghidra.trace.model.target.schema.SchemaContext;
import ghidra.trace.model.target.schema.XmlSchemaContext;
import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName; import ghidra.trace.model.target.schema.TraceObjectSchema.SchemaName;
import ghidra.trace.model.target.schema.XmlSchemaContext;
import ghidra.trace.model.thread.TraceObjectThread; import ghidra.trace.model.thread.TraceObjectThread;
import ghidra.trace.model.thread.TraceThread; import ghidra.trace.model.thread.TraceThread;
import ghidra.trace.model.time.schedule.TraceSchedule;
import ghidra.trace.model.time.schedule.TraceSchedule.ScheduleForm;
@Category(NightlyCategory.class) // this may actually be an @PortSensitive test @Category(NightlyCategory.class) // this may actually be an @PortSensitive test
public class DebuggerTraceManagerServiceTest extends AbstractGhidraHeadedDebuggerIntegrationTest { public class DebuggerTraceManagerServiceTest extends AbstractGhidraHeadedDebuggerIntegrationTest {
@ -492,4 +496,194 @@ public class DebuggerTraceManagerServiceTest extends AbstractGhidraHeadedDebugge
// Focus should never be reflected back to target // Focus should never be reflected back to target
assertTrue(activationMethodsQueuesEmpty()); assertTrue(activationMethodsQueuesEmpty());
} }
@Test
public void testSynchronizeTimeTargetToGui() throws Throwable {
createRmiConnection();
addActivateWithTimeMethods();
createAndOpenTrace();
TraceObjectThread thread;
try (Transaction tx = tb.startTransaction()) {
tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION);
thread = tb.createObjectsProcessAndThreads();
tb.createObjectsFramesAndRegs(thread, Lifespan.nowOn(0), tb.host, 2);
}
rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
assertTrue(activationMethodsQueuesEmpty());
assertNull(traceManager.getCurrentTrace());
try (Transaction tx = tb.startTransaction()) {
rmiCx.setLastSnapshot(tb.trace, Long.MIN_VALUE)
.setSchedule(TraceSchedule.parse("0:10"));
}
rmiCx.synthActivate(tb.obj("Processes[1].Threads[1].Stack[0]"));
waitForSwing();
assertEquals(TraceSchedule.parse("0:10"), traceManager.getCurrent().getTime());
assertTrue(activationMethodsQueuesEmpty());
}
@Test
public void testTimeSupportNoTimeParam() throws Throwable {
createRmiConnection();
addActivateMethods();
createAndOpenTrace();
TraceObjectThread thread;
try (Transaction tx = tb.startTransaction()) {
tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION);
thread = tb.createObjectsProcessAndThreads();
}
Target target = rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
assertNull(target.getSupportedTimeForm(thread.getObject(), 0));
}
@Test
public void testTimeSupportSnapParam() throws Throwable {
createRmiConnection();
addActivateWithSnapMethods();
createAndOpenTrace();
TraceObject thread;
TraceObject root;
try (Transaction tx = tb.startTransaction()) {
root = tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION).getChild();
thread = tb.createObjectsProcessAndThreads().getObject();
}
Target target = rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
assertNull(target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_ONLY.name());
}
assertEquals(ScheduleForm.SNAP_ONLY, target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_ANY_STEPS_OPS.name());
}
// Constrained by method parameter
assertEquals(ScheduleForm.SNAP_ONLY, target.getSupportedTimeForm(thread, 0));
}
@Test
public void testTimeSupportTimeParam() throws Throwable {
createRmiConnection();
addActivateWithTimeMethods();
createAndOpenTrace();
TraceObject thread;
TraceObject root;
try (Transaction tx = tb.startTransaction()) {
root = tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION).getChild();
thread = tb.createObjectsProcessAndThreads().getObject();
}
Target target = rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
assertNull(target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_ONLY.name());
}
assertEquals(ScheduleForm.SNAP_ONLY, target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_EVT_STEPS.name());
}
assertEquals(ScheduleForm.SNAP_EVT_STEPS, target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_ANY_STEPS.name());
}
assertEquals(ScheduleForm.SNAP_ANY_STEPS, target.getSupportedTimeForm(thread, 0));
try (Transaction tx = tb.startTransaction()) {
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_ANY_STEPS_OPS.name());
}
assertEquals(ScheduleForm.SNAP_ANY_STEPS_OPS, target.getSupportedTimeForm(thread, 0));
}
@Test
public void testSynchronizeTimeGuiToTargetFailsWhenNoTimeParam() throws Throwable {
createRmiConnection();
addActivateMethods();
createAndOpenTrace();
TraceObjectThread thread;
try (Transaction tx = tb.startTransaction()) {
tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION);
thread = tb.createObjectsProcessAndThreads();
tb.trace.getTimeManager()
.getSnapshot(0, true)
.setEventThread(thread);
}
rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
var activate1 = rmiMethodActivateThread.expect(args -> {
assertEquals(Map.ofEntries(
Map.entry("thread", thread.getObject())),
args);
return null;
});
traceManager.activate(DebuggerCoordinates.NOWHERE.thread(thread).snap(0));
waitOn(activate1);
var activate2 = rmiMethodActivateThread.expect(args -> {
fail();
return null;
});
traceManager.activateSnap(1);
waitForSwing();
assertThat(tool.getStatusInfo(), Matchers.containsString("Switch to Trace or Emulate"));
assertFalse(activate2.isDone());
}
@Test
public void testSynchronizeTimeGuiToTarget() throws Throwable {
createRmiConnection();
addActivateWithTimeMethods();
createAndOpenTrace();
TraceObjectThread thread;
TraceObject root;
try (Transaction tx = tb.startTransaction()) {
root = tb.trace.getObjectManager().createRootObject(SCHEMA_SESSION).getChild();
thread = tb.createObjectsProcessAndThreads();
root.setAttribute(Lifespan.nowOn(0), TraceObjectEventScope.KEY_TIME_SUPPORT,
ScheduleForm.SNAP_EVT_STEPS.name());
tb.trace.getTimeManager()
.getSnapshot(0, true)
.setEventThread(thread);
}
rmiCx.publishTarget(tool, tb.trace);
waitForSwing();
var activate1 = rmiMethodActivateThread.expect(args -> {
assertEquals(Map.ofEntries(
Map.entry("thread", thread.getObject())),
// time is optional and not changed, so omitted
args);
return null;
});
traceManager.activate(DebuggerCoordinates.NOWHERE.thread(thread).snap(0));
waitOn(activate1);
var activate2 = rmiMethodActivateThread.expect(args -> {
assertEquals(Map.ofEntries(
Map.entry("thread", thread.getObject()),
Map.entry("time", "0:1")),
args);
return null;
});
traceManager.activateTime(TraceSchedule.snap(0).steppedForward(thread, 1));
waitOn(activate2);
}
} }

View file

@ -59,7 +59,22 @@ dependencies {
} }
}*/ }*/
task generateProto { task configureGenerateProto {
dependsOn(configurations.protocArtifact)
doLast {
def exe = configurations.protocArtifact.first()
if (!isCurrentWindows()) {
exe.setExecutable(true)
}
generateProto.commandLine exe, "--java_out=${generateProto.outdir}", "-I${generateProto.srcdir}"
generateProto.args generateProto.src
}
}
// Can't use providers.exec, or else we see no output
task generateProto(type:Exec) {
dependsOn(configureGenerateProto)
ext.srcdir = file("src/main/proto") ext.srcdir = file("src/main/proto")
ext.src = fileTree(srcdir) { ext.src = fileTree(srcdir) {
include "**/*.proto" include "**/*.proto"
@ -67,17 +82,6 @@ task generateProto {
ext.outdir = file("build/generated/source/proto/main/java") ext.outdir = file("build/generated/source/proto/main/java")
outputs.dir(outdir) outputs.dir(outdir)
inputs.files(src) inputs.files(src)
dependsOn(configurations.protocArtifact)
doLast {
def exe = configurations.protocArtifact.first()
if (!isCurrentWindows()) {
exe.setExecutable(true)
}
providers.exec {
commandLine exe, "--java_out=$outdir", "-I$srcdir"
args src
}.result.get()
}
} }
tasks.compileJava.dependsOn(tasks.generateProto) tasks.compileJava.dependsOn(tasks.generateProto)