Merge remote-tracking branch 'origin/GP-4400_ghintern_mlextension_improvements'

This commit is contained in:
Ryan Kurtz 2025-07-29 15:22:49 -04:00
commit 0d8f57ba2f
10 changed files with 316 additions and 122 deletions

View file

@ -24,8 +24,7 @@ import ghidra.app.cmd.disassemble.DisassembleCommand;
import ghidra.app.cmd.function.CreateFunctionCmd; import ghidra.app.cmd.function.CreateFunctionCmd;
import ghidra.app.script.GhidraScript; import ghidra.app.script.GhidraScript;
import ghidra.machinelearning.functionfinding.*; import ghidra.machinelearning.functionfinding.*;
import ghidra.program.model.address.Address; import ghidra.program.model.address.*;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.block.BasicBlockModel; import ghidra.program.model.block.BasicBlockModel;
//NOTE: This script is referenced by name in the help for the //NOTE: This script is referenced by name in the help for the
@ -131,7 +130,7 @@ public class FindFunctionsRFExampleScript extends GhidraScript {
new GetAddressesToClassifyTask(currentProgram, minUndefinedRange); new GetAddressesToClassifyTask(currentProgram, minUndefinedRange);
getAddressTask.run(monitor); getAddressTask.run(monitor);
AddressSet toClassify = getAddressTask.getAddressesToClassify(); AddressSetView toClassify = getAddressTask.getAddressesToClassify();
Map<Address, Double> potentialStarts = classifier.classify(toClassify, monitor); Map<Address, Double> potentialStarts = classifier.classify(toClassify, monitor);

View file

@ -110,6 +110,11 @@
this to ensure that the value in <A href="#MaximumNumberOfStarts">Maximum Number this to ensure that the value in <A href="#MaximumNumberOfStarts">Maximum Number
of Starts</A> field doesn't cause all starts to be used for training (leaving of Starts</A> field doesn't cause all starts to be used for training (leaving
none for testing).</P> none for testing).</P>
<H4><A name="MinimumUndefRangeSize"></A> Minimum Undefined Range Size </H4>
<P> This value is the minimum size of an undefined address range that will be considered when
applying the model to a program. Defaults to the value stored in the plugin options, see
<A href="#MinLengthUndefinedRange">Minimum Length of Undefined Ranges to Search</A>. </P>
<H4><A name="RestrictSearchToAlignedAddresses"></A> Restrict Search to Aligned Addresses </H4> <H4><A name="RestrictSearchToAlignedAddresses"></A> Restrict Search to Aligned Addresses </H4>
<P> If this is checked, only addresses which are zero modulo the value in the <P> If this is checked, only addresses which are zero modulo the value in the
@ -139,11 +144,15 @@
options). The results are displayed in a options). The results are displayed in a
<A href="#FunctionStartTable"> Function Start Table</A>. </P> <A href="#FunctionStartTable"> Function Start Table</A>. </P>
<H4><A name="ApplyModelTo"></A> Apply Model To... Action </H4> <H4><A name="ApplyModelToOtherProgram"></A> Apply Model To Other Program... Action </H4>
<P> This action will open a dialog to select another program in the current project and <P> This action will open a dialog to select another program in the current project and
then apply the model to it. Note that the only check that the model is compatible with then apply the model to it. Note that the only check that the model is compatible with
the selected program is that any context registers specified when training must be the selected program is that any context registers specified when training must be
present in the selected program. </P> present in the selected program. </P>
<H4><A name="ApplyModelToSelection"></A> Apply Model To Selection Action </H4>
<P> This action will apply the model to the current selection in the program used to train it.
</P>
<H4><A name="DebugModel"></A> Debug Model Action </H4> <H4><A name="DebugModel"></A> Debug Model Action </H4>
<P> This action will display a <A href="#DebugModelTable"> Debug Model Table</A>, which shows <P> This action will display a <A href="#DebugModelTable"> Debug Model Table</A>, which shows

View file

@ -36,7 +36,7 @@ import ghidra.app.services.ProgramManager;
import ghidra.framework.main.ProgramFileChooser; import ghidra.framework.main.ProgramFileChooser;
import ghidra.framework.model.DomainFile; import ghidra.framework.model.DomainFile;
import ghidra.framework.preferences.Preferences; import ghidra.framework.preferences.Preferences;
import ghidra.program.model.address.AddressSet; import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.*; import ghidra.program.model.listing.*;
import ghidra.util.HelpLocation; import ghidra.util.HelpLocation;
import ghidra.util.Msg; import ghidra.util.Msg;
@ -97,6 +97,10 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
private static final String ALIGNMENT_MODULUS_TIP = private static final String ALIGNMENT_MODULUS_TIP =
"Use to define the alignment for restricted search"; "Use to define the alignment for restricted search";
private static final String MIN_UNDEFINED_RANGE_SIZE_TEXT = "Minimum Undefined Range Size";
private static final String MIN_UNDEFINED_RANGE_SIZE_TIP =
"Minimum size of an undefined range of addresses to search over for function starts";
private static final String DEFAULT_INITIAL_BYTES = "8,16"; private static final String DEFAULT_INITIAL_BYTES = "8,16";
private static final String INITIAL_BYTES_PROPERTY = "functionStartRFParams_initialBytes"; private static final String INITIAL_BYTES_PROPERTY = "functionStartRFParams_initialBytes";
private static final String DEFAULT_PRE_BYTES = "2,8"; private static final String DEFAULT_PRE_BYTES = "2,8";
@ -123,16 +127,23 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
private static final String APPLY_MODEL_ACTION_NAME = "ApplyModel"; private static final String APPLY_MODEL_ACTION_NAME = "ApplyModel";
private static final String APPLY_MODEL_MENU_TEXT = "Apply Model"; private static final String APPLY_MODEL_MENU_TEXT = "Apply Model";
private static final String APPLY_MODEL_TO_ACTION_NAME = "ApplyModelTo"; private static final String APPLY_MODEL_TO_ACTION_NAME = "ApplyModelToOtherProgram";
private static final String APPLY_MODEL_TO_MENU_TEXT = "Apply Model To..."; private static final String APPLY_MODEL_TO_MENU_TEXT = "Apply Model To Other Program...";
private static final String APPLY_MODEL_SELECTION_ACTION_NAME = "ApplyModelToSelection";
private static final String APPLY_MODEL_SELECTION_MENU_TEXT = "Apply Model To Selection";
private static final String DEBUG_MODEL_ACTION_NAME = "DebugModel"; private static final String DEBUG_MODEL_ACTION_NAME = "DebugModel";
private static final String DEBUG_MODEL_MENU_TEXT = "DEBUG - Show test set errors"; private static final String DEBUG_MODEL_MENU_TEXT = "DEBUG - Show Test Set Errors";
private static final String ACTION_GROUP_APPLY_LOCAL = "A0_ApplyLocal";
private static final String ACTION_GROUP_APPLY_OTHER = "A1_ApplyOther";
private static final String ACTION_GROUP_DEBUG = "A2_Debug";
private JTextField initialBytesField; private JTextField initialBytesField;
private JTextField preBytesField; private JTextField preBytesField;
private JTextField factorField; private JTextField factorField;
private IntegerTextField minimumSizeField; private IntegerTextField minimumSizeField;
private IntegerTextField maxStartsField; private IntegerTextField maxStartsField;
private IntegerTextField minUndefRangeField;
private JTextField contextRegistersField; private JTextField contextRegistersField;
private JLabel numFuncsField; private JLabel numFuncsField;
private JScrollPane tableScrollPane; private JScrollPane tableScrollPane;
@ -284,19 +295,35 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
.popupWhen(c -> trainingSource != null) .popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1) .enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(APPLY_MODEL_MENU_TEXT) .popupMenuPath(APPLY_MODEL_MENU_TEXT)
.popupMenuGroup(ACTION_GROUP_APPLY_LOCAL)
.inWindow(ActionBuilder.When.ALWAYS) .inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> { .onAction(c -> {
searchTrainingProgram(tableModel.getLastSelectedObjects().get(0)); searchTrainingProgram(tableModel.getLastSelectedObjects().get(0), false);
}) })
.build(); .build();
addAction(applyAction); addAction(applyAction);
DockingAction applySelectionAction =
new ActionBuilder(APPLY_MODEL_SELECTION_ACTION_NAME, plugin.getName())
.description("Apply Model to Current Program Selection")
.popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(APPLY_MODEL_SELECTION_MENU_TEXT)
.popupMenuGroup(ACTION_GROUP_APPLY_LOCAL)
.inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> {
searchTrainingProgram(tableModel.getLastSelectedObjects().get(0), true);
})
.build();
addAction(applySelectionAction);
DockingAction applyToAction = DockingAction applyToAction =
new ActionBuilder(APPLY_MODEL_TO_ACTION_NAME, plugin.getName()) new ActionBuilder(APPLY_MODEL_TO_ACTION_NAME, plugin.getName())
.description("Choose Program and Apply Model to it") .description("Choose Program and Apply Model to it")
.popupWhen(c -> trainingSource != null) .popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1) .enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(APPLY_MODEL_TO_MENU_TEXT) .popupMenuPath(APPLY_MODEL_TO_MENU_TEXT)
.popupMenuGroup(ACTION_GROUP_APPLY_OTHER)
.inWindow(ActionBuilder.When.ALWAYS) .inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> { .onAction(c -> {
searchOtherProgram(tableModel.getLastSelectedObjects().get(0)); searchOtherProgram(tableModel.getLastSelectedObjects().get(0));
@ -309,6 +336,7 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
.popupWhen(c -> trainingSource != null) .popupWhen(c -> trainingSource != null)
.enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1) .enabledWhen(c -> tableModel.getLastSelectedObjects().size() == 1)
.popupMenuPath(DEBUG_MODEL_MENU_TEXT) .popupMenuPath(DEBUG_MODEL_MENU_TEXT)
.popupMenuGroup(ACTION_GROUP_DEBUG)
.inWindow(ActionBuilder.When.ALWAYS) .inWindow(ActionBuilder.When.ALWAYS)
.onAction(c -> { .onAction(c -> {
showTestErrors(tableModel.getLastSelectedObjects().get(0)); showTestErrors(tableModel.getLastSelectedObjects().get(0));
@ -406,6 +434,14 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
updateNumFuncsField(); updateNumFuncsField();
funcDataPanel.add(numFuncsField); funcDataPanel.add(numFuncsField);
JLabel minUndefRangeLabel = new GDLabel(MIN_UNDEFINED_RANGE_SIZE_TEXT);
minUndefRangeLabel.setToolTipText(MIN_UNDEFINED_RANGE_SIZE_TIP);
funcDataPanel.add(minUndefRangeLabel);
minUndefRangeField = new IntegerTextField();
minUndefRangeField.setAllowNegativeValues(false);
minUndefRangeField.setValue(plugin.getMinUndefinedRangeSize());
funcDataPanel.add(minUndefRangeField.getComponent());
JLabel restrictLabel = new GDLabel(RESTRICT_SEARCH_TEXT); JLabel restrictLabel = new GDLabel(RESTRICT_SEARCH_TEXT);
restrictLabel.setToolTipText(RESTRICT_SEARCH_TIP); restrictLabel.setToolTipText(RESTRICT_SEARCH_TIP);
funcDataPanel.add(restrictLabel); funcDataPanel.add(restrictLabel);
@ -478,8 +514,8 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
numFuncsField.setText(Integer.toString(numFuncs)); numFuncsField.setText(Integer.toString(numFuncs));
} }
private void searchTrainingProgram(RandomForestRowObject modelRow) { private void searchTrainingProgram(RandomForestRowObject modelRow, boolean useSelection) {
searchProgram(trainingSource, modelRow); searchProgram(trainingSource, modelRow, useSelection);
} }
private void searchOtherProgram(RandomForestRowObject modelRow) { private void searchOtherProgram(RandomForestRowObject modelRow) {
@ -499,7 +535,7 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
" is not compatible with training source program " + trainingSource.getName()); " is not compatible with training source program " + trainingSource.getName());
return; return;
} }
searchProgram(p, modelRow); searchProgram(p, modelRow, false);
} }
private void showTestErrors(RandomForestRowObject modelRow) { private void showTestErrors(RandomForestRowObject modelRow) {
@ -508,21 +544,32 @@ public class FunctionStartRFParamsDialog extends ReusableDialogComponentProvider
addGeneralActions(provider, trainingSource); addGeneralActions(provider, trainingSource);
} }
private void searchProgram(Program targetProgram, RandomForestRowObject modelRow) { private void searchProgram(Program targetProgram, RandomForestRowObject modelRow,
GetAddressesToClassifyTask getTask = boolean useSelection) {
new GetAddressesToClassifyTask(targetProgram, plugin.getMinUndefinedRangeSize());
GetAddressesToClassifyTask getTask = null;
if (useSelection) {
getTask =
new GetAddressesToClassifyTask(targetProgram, 1, plugin.getProgramSelection());
}
else {
getTask =
new GetAddressesToClassifyTask(targetProgram, minUndefRangeField.getLongValue());
}
//don't want to use the dialog's progress bar //don't want to use the dialog's progress bar
TaskLauncher.launchModal("Gathering Addresses To Classify", getTask); TaskLauncher.launchModal("Gathering Addresses To Classify", getTask);
if (getTask.isCancelled()) { if (getTask.isCancelled()) {
return; return;
} }
AddressSet execNonFunc = null; AddressSetView execNonFunc = null;
if (restrictBox.isSelected()) { if (restrictBox.isSelected()) {
execNonFunc = getTask.getAddressesToClassify((long) modBox.getSelectedItem()); execNonFunc = getTask.getAddressesToClassify((long) modBox.getSelectedItem());
} }
else { else {
execNonFunc = getTask.getAddressesToClassify(); execNonFunc = getTask.getAddressesToClassify();
} }
FunctionStartTableProvider provider = FunctionStartTableProvider provider =
new FunctionStartTableProvider(plugin, targetProgram, execNonFunc, modelRow, false); new FunctionStartTableProvider(plugin, targetProgram, execNonFunc, modelRow, false);
addGeneralActions(provider, targetProgram); addGeneralActions(provider, targetProgram);

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -24,8 +24,9 @@ import ghidra.docking.settings.Settings;
import ghidra.framework.plugintool.PluginTool; import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.ServiceProvider; import ghidra.framework.plugintool.ServiceProvider;
import ghidra.program.model.address.Address; import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet; import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.block.BasicBlockModel; import ghidra.program.model.block.BasicBlockModel;
import ghidra.program.model.listing.Instruction;
import ghidra.program.model.listing.Program; import ghidra.program.model.listing.Program;
import ghidra.util.datastruct.Accumulator; import ghidra.util.datastruct.Accumulator;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
@ -38,7 +39,7 @@ import ghidra.util.task.TaskMonitor;
*/ */
public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStartRowObject> { public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStartRowObject> {
private RandomForestRowObject modelRow; private RandomForestRowObject modelRow;
private AddressSet addressesToClassify; private AddressSetView addressesToClassify;
private boolean debug; private boolean debug;
private BasicBlockModel blockModel; private BasicBlockModel blockModel;
private Map<Address, Double> addressToProbability; private Map<Address, Double> addressToProbability;
@ -54,7 +55,7 @@ public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStar
* @param modelRow trained model info * @param modelRow trained model info
* @param debug is table displaying debug data * @param debug is table displaying debug data
*/ */
public FunctionStartTableModel(PluginTool plugin, Program program, AddressSet toClassify, public FunctionStartTableModel(PluginTool plugin, Program program, AddressSetView toClassify,
RandomForestRowObject modelRow, boolean debug) { RandomForestRowObject modelRow, boolean debug) {
super(program.getName(), plugin, program, null, false); super(program.getName(), plugin, program, null, false);
this.modelRow = modelRow; this.modelRow = modelRow;
@ -100,6 +101,7 @@ public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStar
descriptor.addVisibleColumn(new AddressTableColumn()); descriptor.addVisibleColumn(new AddressTableColumn());
descriptor.addVisibleColumn(new ProbabilityTableColumn(), 0, false); descriptor.addVisibleColumn(new ProbabilityTableColumn(), 0, false);
descriptor.addVisibleColumn(new InterpretationTableColumn()); descriptor.addVisibleColumn(new InterpretationTableColumn());
descriptor.addVisibleColumn(new FallthroughTableColumn());
descriptor.addVisibleColumn(new DataReferencesTableColumn()); descriptor.addVisibleColumn(new DataReferencesTableColumn());
descriptor.addVisibleColumn(new UnconditionalFlowReferencesTableColumn()); descriptor.addVisibleColumn(new UnconditionalFlowReferencesTableColumn());
descriptor.addVisibleColumn(new ConditionalFlowReferencesTableColumn()); descriptor.addVisibleColumn(new ConditionalFlowReferencesTableColumn());
@ -173,6 +175,32 @@ public class FunctionStartTableModel extends AddressBasedTableModel<FunctionStar
} }
} }
private class FallthroughTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Boolean, Object> {
@Override
public String getColumnName() {
return "Is Fallthrough Target";
}
@Override
public Boolean getValue(FunctionStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
Instruction preInstr =
program.getListing().getInstructionBefore(rowObject.getAddress());
if (preInstr == null) {
return false;
}
if (!preInstr.hasFallthrough()) {
return false;
}
return preInstr.getFallThrough().equals(rowObject.getAddress());
}
}
private class DataReferencesTableColumn private class DataReferencesTableColumn
extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> { extends AbstractDynamicTableColumn<FunctionStartRowObject, Integer, Object> {

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -25,7 +25,7 @@ import javax.swing.*;
import ghidra.framework.model.DomainObjectChangedEvent; import ghidra.framework.model.DomainObjectChangedEvent;
import ghidra.framework.model.DomainObjectListener; import ghidra.framework.model.DomainObjectListener;
import ghidra.framework.plugintool.ComponentProviderAdapter; import ghidra.framework.plugintool.ComponentProviderAdapter;
import ghidra.program.model.address.AddressSet; import ghidra.program.model.address.AddressSetView;
import ghidra.program.model.listing.Program; import ghidra.program.model.listing.Program;
import ghidra.util.HelpLocation; import ghidra.util.HelpLocation;
import ghidra.util.table.*; import ghidra.util.table.*;
@ -42,7 +42,7 @@ public class FunctionStartTableProvider extends ProgramAssociatedComponentProvid
private RandomForestFunctionFinderPlugin plugin; private RandomForestFunctionFinderPlugin plugin;
private Program program; private Program program;
private RandomForestRowObject modelRow; private RandomForestRowObject modelRow;
private AddressSet toClassify; private AddressSetView toClassify;
private boolean debug; private boolean debug;
private String subTitle; private String subTitle;
private GhidraTable startTable; private GhidraTable startTable;
@ -59,7 +59,7 @@ public class FunctionStartTableProvider extends ProgramAssociatedComponentProvid
* @param debug whether to display debug version of table * @param debug whether to display debug version of table
*/ */
public FunctionStartTableProvider(RandomForestFunctionFinderPlugin plugin, Program program, public FunctionStartTableProvider(RandomForestFunctionFinderPlugin plugin, Program program,
AddressSet toClassify, RandomForestRowObject modelRow, boolean debug) { AddressSetView toClassify, RandomForestRowObject modelRow, boolean debug) {
super( super(
debug ? "Debug: Test Set Errors in " + program.getDomainFile().getPathname() debug ? "Debug: Test Set Errors in " + program.getDomainFile().getPathname()
: "Potential Functions in " + program.getDomainFile().getPathname(), : "Potential Functions in " + program.getDomainFile().getPathname(),

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -17,6 +17,7 @@ package ghidra.machinelearning.functionfinding;
import ghidra.program.model.address.*; import ghidra.program.model.address.*;
import ghidra.program.model.listing.*; import ghidra.program.model.listing.*;
import ghidra.program.util.ProgramSelection;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.Task; import ghidra.util.task.Task;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -27,8 +28,25 @@ import ghidra.util.task.TaskMonitor;
public class GetAddressesToClassifyTask extends Task { public class GetAddressesToClassifyTask extends Task {
private Program prog; private Program prog;
private AddressSet execNonFunc; private AddressSetView execNonFunc;
private long minUndefinedRangeSize; private long minUndefinedRangeSize;
private ProgramSelection userSelection;
/**
* Creates a {@link Task} that creates a set of addresses to check for function starts. The
* {code minUndefinedRangeSize} parameter determines how large a run of undefined bytes must be
* to be checked for function starts
* @param prog source program
* @param minUndefinedRangeSize minimum size of undefined range
* @param selection program selection used to filter addresses
*/
public GetAddressesToClassifyTask(Program prog, long minUndefinedRangeSize,
ProgramSelection selection) {
super("Gathering Addresses to Classify", true, true, false, false);
this.prog = prog;
this.minUndefinedRangeSize = minUndefinedRangeSize;
this.userSelection = selection;
}
/** /**
* Creates a {@link Task} that creates a set of addresses to check for function starts. The * Creates a {@link Task} that creates a set of addresses to check for function starts. The
@ -38,25 +56,29 @@ public class GetAddressesToClassifyTask extends Task {
* @param minUndefinedRangeSize minimum size of undefined range * @param minUndefinedRangeSize minimum size of undefined range
*/ */
public GetAddressesToClassifyTask(Program prog, long minUndefinedRangeSize) { public GetAddressesToClassifyTask(Program prog, long minUndefinedRangeSize) {
super("Gathering Addresses to Classify", true, true, false, false); this(prog, minUndefinedRangeSize, null);
this.prog = prog;
this.minUndefinedRangeSize = minUndefinedRangeSize;
} }
@Override @Override
public void run(TaskMonitor monitor) throws CancelledException { public void run(TaskMonitor monitor) throws CancelledException {
execNonFunc = new AddressSet(); execNonFunc = new AddressSet();
AddressSetView executable = prog.getMemory().getExecuteSet(); if (userSelection != null) {
AddressSetView initialized = prog.getMemory().getLoadedAndInitializedAddressSet(); execNonFunc = userSelection;
execNonFunc = executable.intersect(initialized);
monitor.initialize(prog.getFunctionManager().getFunctionCount());
FunctionIterator fIter = prog.getFunctionManager().getFunctions(true);
while (fIter.hasNext()) {
monitor.checkCancelled();
monitor.incrementProgress(1);
Function func = fIter.next();
execNonFunc = execNonFunc.subtract(func.getBody());
} }
else {
AddressSetView executable = prog.getMemory().getExecuteSet();
AddressSetView initialized = prog.getMemory().getLoadedAndInitializedAddressSet();
execNonFunc = executable.intersect(initialized);
monitor.initialize(prog.getFunctionManager().getFunctionCount());
FunctionIterator fIter = prog.getFunctionManager().getFunctions(true);
while (fIter.hasNext()) {
monitor.checkCancelled();
monitor.incrementProgress(1);
Function func = fIter.next();
execNonFunc = execNonFunc.subtract(func.getBody());
}
}
//remove small undefined ranges to avoid (for example) searching for //remove small undefined ranges to avoid (for example) searching for
//function starts in an address range of length 3 between two known //function starts in an address range of length 3 between two known
//functions. "small" is controlled by a plugin option. //functions. "small" is controlled by a plugin option.
@ -77,7 +99,7 @@ public class GetAddressesToClassifyTask extends Task {
* Returns the set of addresses to classify * Returns the set of addresses to classify
* @return addresses * @return addresses
*/ */
public AddressSet getAddressesToClassify() { public AddressSetView getAddressesToClassify() {
return execNonFunc; return execNonFunc;
} }
@ -87,7 +109,7 @@ public class GetAddressesToClassifyTask extends Task {
* @param modulus alignment modulus * @param modulus alignment modulus
* @return aligned addresses * @return aligned addresses
*/ */
public AddressSet getAddressesToClassify(long modulus) { public AddressSetView getAddressesToClassify(long modulus) {
AddressSet aligned = new AddressSet(); AddressSet aligned = new AddressSet();
for (Address a : execNonFunc.getAddresses(true)) { for (Address a : execNonFunc.getAddresses(true)) {
if (a.getOffset() % modulus == 0) { if (a.getOffset() % modulus == 0) {

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -205,16 +205,16 @@ public class RandomForestFunctionFinderPlugin extends ProgramPlugin
private void createActions() { private void createActions() {
new ActionBuilder(ACTION_NAME, getName()) new ActionBuilder(ACTION_NAME, getName())
.menuPath(ToolConstants.MENU_SEARCH, MENU_PATH_ENTRY) .menuPath(ToolConstants.MENU_SEARCH, MENU_PATH_ENTRY)
.menuGroup("search for", null) .menuGroup("search for", null)
.description("Train models to search for function starts") .description("Train models to search for function starts")
.helpLocation(new HelpLocation(getName(), getName())) .helpLocation(new HelpLocation(getName(), getName()))
.withContext(NavigatableActionContext.class, true) .withContext(NavigatableActionContext.class, true)
.validContextWhen(c -> !(c instanceof RestrictedAddressSetContext)) .validWhen(c -> !(c instanceof RestrictedAddressSetContext))
.onAction(c -> { .onAction(c -> {
displayDialog(c); displayDialog(c);
}) })
.buildAndInstall(tool); .buildAndInstall(tool);
} }
private void displayDialog(NavigatableActionContext c) { private void displayDialog(NavigatableActionContext c) {

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -18,6 +18,7 @@ package ghidra.machinelearning.functionfinding;
import java.util.*; import java.util.*;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import ghidra.pcodeCPort.utils.MutableLong;
import ghidra.program.model.address.*; import ghidra.program.model.address.*;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -43,25 +44,32 @@ public class RandomSubsetUtils {
* @return random subset of size k * @return random subset of size k
* @throws CancelledException if monitor is canceled * @throws CancelledException if monitor is canceled
*/ */
public static AddressSet randomSubset(AddressSetView addresses, long k, TaskMonitor monitor) public static AddressSet randomSubset(AddressSetView addresses, long k,
throws CancelledException { TaskMonitor monitor) throws CancelledException {
List<Long> sortedRandom = generateRandomIntegerSubset(addresses.getNumAddresses(), k); long[] sortedRandom = generateRandomIntegerSubset(addresses.getNumAddresses(), k);
Collections.sort(sortedRandom); Arrays.sort(sortedRandom);
AddressSet randomAddresses = new AddressSet(); AddressSet randomAddresses = new AddressSet();
AddressIterator iter = addresses.getAddresses(true);
int addressesAdded = 0; long addressesVisited = 0;
int addressesVisited = 0;
int listIndex = 0; int listIndex = 0;
while (iter.hasNext() && addressesAdded < k) { for (AddressRange range : addresses) {
monitor.checkCancelled(); long rangeEnd = addressesVisited + range.getLength();
Address addr = iter.next(); for (; listIndex < k; listIndex++) {
if (sortedRandom.get(listIndex) == addressesVisited) { monitor.checkCancelled();
long next = sortedRandom[listIndex];
if (next >= rangeEnd) {
// Next address is outside of this range
break;
}
Address addr = range.getMinAddress().add(next - addressesVisited);
randomAddresses.add(addr); randomAddresses.add(addr);
addressesAdded += 1;
listIndex += 1;
} }
addressesVisited += 1; if (listIndex == k) {
break;
}
addressesVisited += range.getLength();
} }
return randomAddresses; return randomAddresses;
} }
@ -72,24 +80,34 @@ public class RandomSubsetUtils {
* @param k size of random subset (must be >= 0) * @param k size of random subset (must be >= 0)
* @return list of indices of elements in random subset * @return list of indices of elements in random subset
*/ */
public static List<Long> generateRandomIntegerSubset(long n, long k) { public static long[] generateRandomIntegerSubset(long n, long k) {
if (n < 0) { if (n < 0) {
throw new IllegalArgumentException("n cannot be negative"); throw new IllegalArgumentException("n cannot be negative");
} }
if (k < 0) { if (k < 0) {
throw new IllegalArgumentException("k cannot be negative"); throw new IllegalArgumentException("k cannot be negative");
} }
if (k > Integer.MAX_VALUE) {
// Could probably just switch k to an int. Since we were using ArrayList before
// that was already going to blow up if k > Integer.MAX_VALUE
throw new IllegalArgumentException("k cannot exceed bounds of integer");
}
if (n < k) { if (n < k) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"size of subset (" + k + ") cannot be larger than size of set (" + n + ")"); "size of subset (" + k + ") cannot be larger than size of set (" + n + ")");
} }
Map<Long, Long> permutation = new HashMap<>();
for (long i = 0; i < k; ++i) { Map<Long, MutableLong> permutation = new HashMap<>();
for (long i = 0; i < k; i++) {
swap(permutation, i, ThreadLocalRandom.current().nextLong(i, n)); swap(permutation, i, ThreadLocalRandom.current().nextLong(i, n));
} }
List<Long> random = new ArrayList<>();
for (long i = 0; i < k; i++) { long[] random = new long[(int) k];
random.add(permutation.getOrDefault(i, i));
for (int i = 0; i < k; i++) {
random[i] =
permutation.computeIfAbsent(Long.valueOf(i), key -> new MutableLong(key)).get();
} }
return random; return random;
} }
@ -102,14 +120,15 @@ public class RandomSubsetUtils {
* @param i index * @param i index
* @param j index * @param j index
*/ */
public static void swap(Map<Long, Long> permutation, long i, long j) { public static void swap(Map<Long, MutableLong> permutation, long i, long j) {
if (i == j) { if (i == j) {
return; return;
} }
long ith = permutation.getOrDefault(i, i); MutableLong ith = permutation.computeIfAbsent(i, key -> new MutableLong(key));
long jth = permutation.getOrDefault(j, j); MutableLong jth = permutation.computeIfAbsent(j, key -> new MutableLong(key));
permutation.put(i, jth); long temp = ith.get();
permutation.put(j, ith); ith.set(jth.get());
jth.set(temp);
} }
} }

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -20,6 +20,8 @@ import java.util.List;
import docking.widgets.table.AbstractDynamicTableColumn; import docking.widgets.table.AbstractDynamicTableColumn;
import docking.widgets.table.TableColumnDescriptor; import docking.widgets.table.TableColumnDescriptor;
import ghidra.app.util.PseudoDisassembler;
import ghidra.app.util.PseudoInstruction;
import ghidra.docking.settings.Settings; import ghidra.docking.settings.Settings;
import ghidra.framework.plugintool.PluginTool; import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.ServiceProvider; import ghidra.framework.plugintool.ServiceProvider;
@ -75,6 +77,7 @@ public class SimilarStartsTableModel extends AddressBasedTableModel<SimilarStart
descriptor.addVisibleColumn(new AddressTableColumn()); descriptor.addVisibleColumn(new AddressTableColumn());
descriptor.addVisibleColumn(new SimilarityTableColumn(), 1, false); descriptor.addVisibleColumn(new SimilarityTableColumn(), 1, false);
descriptor.addVisibleColumn(new ByteStringTableColumn()); descriptor.addVisibleColumn(new ByteStringTableColumn());
descriptor.addVisibleColumn(new DisassemblyTableColumn());
return descriptor; return descriptor;
} }
@ -161,4 +164,47 @@ public class SimilarStartsTableModel extends AddressBasedTableModel<SimilarStart
return sb.toString(); return sb.toString();
} }
} }
private class DisassemblyTableColumn
extends AbstractDynamicTableColumn<SimilarStartRowObject, String, Object> {
@Override
public String getColumnName() {
return "Disassembly";
}
@Override
public String getColumnDescription() {
return "Disassembly (ignoring pre-bytes)";
}
@Override
public String getValue(SimilarStartRowObject rowObject, Settings settings, Object data,
ServiceProvider services) throws IllegalArgumentException {
PseudoDisassembler disasm = new PseudoDisassembler(program);
StringBuilder sb = new StringBuilder();
try {
Address addr = rowObject.funcStart();
while (addr.compareTo(
rowObject.funcStart().add(randomForestRow.getNumInitialBytes())) < 0) {
PseudoInstruction instr = disasm.disassemble(addr);
if (instr.isValid()) {
sb.append(instr.toString());
sb.append(" ");
}
else {
sb.append("? ");
}
addr = instr.getMaxAddress().add(1);
}
}
catch (Exception e) {
sb = new StringBuilder("??");
}
return sb.toString();
}
}
} }

View file

@ -4,9 +4,9 @@
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -22,8 +22,8 @@ import java.util.*;
import org.junit.Test; import org.junit.Test;
import generic.test.AbstractGenericTest; import generic.test.AbstractGenericTest;
import ghidra.program.model.address.AddressSet; import ghidra.pcodeCPort.utils.MutableLong;
import ghidra.program.model.address.TestAddress; import ghidra.program.model.address.*;
import ghidra.util.exception.CancelledException; import ghidra.util.exception.CancelledException;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.TaskMonitor;
@ -31,64 +31,88 @@ public class RandomSubsetTest extends AbstractGenericTest {
@Test @Test
public void testGenerateTrivialSubsets() { public void testGenerateTrivialSubsets() {
List<Long> empty = RandomSubsetUtils.generateRandomIntegerSubset(10, 0); long[] empty = RandomSubsetUtils.generateRandomIntegerSubset(10, 0);
assertEquals(0, empty.size()); assertEquals(0, empty.length);
empty = RandomSubsetUtils.generateRandomIntegerSubset(0, 0); empty = RandomSubsetUtils.generateRandomIntegerSubset(0, 0);
assertEquals(0, empty.size()); assertEquals(0, empty.length);
List<Long> complete = RandomSubsetUtils.generateRandomIntegerSubset(1000000, 1000000); long[] complete = RandomSubsetUtils.generateRandomIntegerSubset(1000000, 1000000);
Collections.sort(complete); Arrays.sort(complete);
Iterator<Long> iter = complete.iterator(); for (int current = 0; current < complete.length; current++) {
long current = 0; long elem = complete[current];
while (iter.hasNext()) { assertEquals(current, elem);
long elem = iter.next();
assertEquals(current++, elem);
} }
} }
@Test @Test
public void testBasicRandomSubsetOfAddresses() throws CancelledException { public void testBasicRandomSubsetOfAddresses() throws CancelledException {
// Check we are drawing unique addresses from the set
AddressSet addrs = new AddressSet(); AddressSet addrs = new AddressSet();
for (long i = 0; i < 10000; ++i) { for (long i = 0; i < 10000; ++i) {
addrs.add(new TestAddress(i)); addrs.add(new TestAddress(i));
} }
AddressSet rand = RandomSubsetUtils.randomSubset(addrs, 9998, TaskMonitor.DUMMY); AddressSet rand = RandomSubsetUtils.randomSubset(addrs, 9998, TaskMonitor.DUMMY);
assertEquals(9998, rand.getNumAddresses()); assertEquals(9998, rand.getNumAddresses());
addrs.clear();
// Check we correctly draw from multiple non-contiguous ranges
for (long i = 0; i < 10000; i += 1000) {
if (i % 2000 != 0)
continue;
for (long j = i; j < i + 1000; j++) {
addrs.add(new TestAddress(j));
}
}
rand = RandomSubsetUtils.randomSubset(addrs, 4998, TaskMonitor.DUMMY);
assertEquals(4998, rand.getNumAddresses());
for (Address addr : rand.getAddresses(true)) {
assertTrue(addrs.contains(addr));
}
} }
@Test @Test
public void testSwap() { public void testSwap() {
Map<Long, Long> permuted = new HashMap<>(); Map<Long, MutableLong> permuted = new HashMap<>();
assertTrue(permuted.isEmpty()); assertTrue(permuted.isEmpty());
//should do nothing //should do nothing
RandomSubsetUtils.swap(permuted, 1, 1); RandomSubsetUtils.swap(permuted, 1, 1);
assertTrue(permuted.isEmpty()); assertTrue(permuted.isEmpty());
permuted.put(0l, 5l); permuted.put(0l, new MutableLong(5l));
permuted.put(1l, 10l); permuted.put(1l, new MutableLong(10l));
RandomSubsetUtils.swap(permuted, 0, 1); RandomSubsetUtils.swap(permuted, 0, 1);
assertEquals(2, permuted.size()); assertEquals(2, permuted.size());
assertEquals(Long.valueOf(5), permuted.get(1l)); assertEquals(5l, permuted.get(1l).get());
assertEquals(Long.valueOf(10), permuted.get(0l)); assertEquals(10l, permuted.get(0l).get());
RandomSubsetUtils.swap(permuted, 100l, 200L); RandomSubsetUtils.swap(permuted, 100l, 200L);
assertEquals(4, permuted.size()); assertEquals(4, permuted.size());
assertEquals(Long.valueOf(100), permuted.get(200l)); assertEquals(100l, permuted.get(200l).get());
assertEquals(Long.valueOf(200), permuted.get(100l)); assertEquals(200l, permuted.get(100l).get());
} }
/** // @Test
@Test // public void timingTest() throws CancelledException, InterruptedException {
public void timingTest() throws CancelledException { // AddressSet big = new AddressSet(new TestAddress(0), new TestAddress(9999999));
AddressSet big = new AddressSet(new TestAddress(0), new TestAddress(999999)); //
// Thread.sleep(10000);
long start = System.nanoTime(); // long start;
List<Long> complete = RandomSubset.generateRandomIntegerSubset(1000000, 500000); // long end;
long end = System.nanoTime(); //
Msg.info(this, "choosing random subset of integers: " + // for (int i = 0; i < 10; i++) {
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND); // start = System.nanoTime();
start = System.nanoTime(); // long[] complete = RandomSubsetUtils.generateRandomIntegerSubset(10000000, 5000000);
AddressSet random = RandomSubset.randomSubset(big, 500000, TaskMonitor.DUMMY); // end = System.nanoTime();
end = System.nanoTime(); // Msg.info(this, "choosing random subset of integers: " +
Msg.info(this, "choosing random subset of addresses: " + // (end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
(end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND); // }
}*/ //
// for (int i = 0; i < 10; i++) {
// start = System.nanoTime();
// AddressSet random1 = RandomSubsetUtils.randomSubset(big, 5000000, TaskMonitor.DUMMY);
// end = System.nanoTime();
// Msg.info(this, "choosing random subset of addresses: " +
// (end - start) / RandomForestTrainingTask.NANOSECONDS_PER_SECOND);
// }
// }
} }