ghidra/Ghidra/Extensions/MachineLearning/ghidra_scripts/FindFunctionsRFExampleScript.java

201 lines
7.9 KiB
Java

/* ###
* IP: GHIDRA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//Example script for training random forests to find function starts
//@category Training
import java.util.*;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import ghidra.app.cmd.disassemble.DisassembleCommand;
import ghidra.app.cmd.function.CreateFunctionCmd;
import ghidra.app.script.GhidraScript;
import ghidra.machinelearning.functionfinding.*;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressSet;
import ghidra.program.model.block.BasicBlockModel;
//NOTE: This script is referenced by name in the help for the
//RandomForestFunctionFinderPlugin. If you change the name be
//sure to update the help.
public class FindFunctionsRFExampleScript extends GhidraScript {
@Override
protected void run() throws Exception {
//get the parameters controlling how many models are trained and
//what data is used to train/test them
FunctionStartRFParams params = new FunctionStartRFParams(currentProgram);
//maximum number of function starts to use in the training set
//a warning will be issued if there are no functions left over for the test set
params.setMaxStarts(1000);
//minimum size of a function to be included in the training/test sets
params.setMinFuncSize(16);
//number of bytes before a function start
params.setPreBytes(Arrays.asList(new Integer[] { 2, 8 }));
//number of bytes after (and including) a function start
params.setInitialBytes(Arrays.asList(new Integer[] { 8, 16 }));
//number of non-starts to sample for each start in the training set
params.setFactors(Arrays.asList(new Integer[] { 10, 50 }));
//for every function start in the (test,training) set, add the code
//units immediately before and immediately after to the (test,training) set
//as non-starts.
params.setIncludePrecedingAndFollowing(true);
//uncomment to include features for each bit rather than just byte-level features
//params.setIncludeBitFeatures(true);
//bound for reducing the size of the test sets
long testSetMax = 1000000l;
//this is where the trained models will go
List<RandomForestRowObject> trainedModels = new ArrayList<>();
RandomForestTrainingTask trainingTask = new RandomForestTrainingTask(currentProgram, params,
r -> trainedModels.add(r), testSetMax);
//launch the task to train the models in parallel
trainingTask.run(monitor);
//sort the models by the number of false positives (ascending)
//if you actually need *unsigned* comparison it's likely that something has gone
//horribly wrong
Collections.sort(trainedModels,
(x, y) -> Integer.compareUnsigned(x.getNumFalsePositives(), y.getNumFalsePositives()));
//grab the model with the fewest false positives
//note: there could be ties; could sort the winners by recall
RandomForestRowObject best = trainedModels.get(0);
printf(
"Best model: pre-bytes: %d, initialBytes: %d, sampling factor: %d, false positives: %d," +
" precision: %g, recall: %g\n",
best.getNumPreBytes(), best.getNumInitialBytes(), best.getSamplingFactor(),
best.getNumFalsePositives(), best.getPrecision(), best.getRecall());
//to get more information about the test set errors, apply the model to the error set
FunctionStartClassifier classifier = new FunctionStartClassifier(currentProgram, best,
RandomForestFunctionFinderPlugin.FUNC_START);
//uncomment to see false negatives as well
//classifier.setProbabilityThreshold(0.0);
Map<Address, Double> errors = classifier.classify(best.getTestErrors(), monitor);
List<Entry<Address, Double>> falsePositives = errors.entrySet()
.stream()
.filter(x -> currentProgram.getFunctionManager().getFunctionAt(x.getKey()) == null)
.sorted((x, y) -> Double.compare(y.getValue(), x.getValue()))
.toList();
//print out addresses of false positives
printf("False positives:\n");
falsePositives.forEach(x -> printf(" %s %g\n", x.getKey().toString(), x.getValue()));
//show the true function starts most similar to one of the false positives
if (!falsePositives.isEmpty()) {
SimilarStartsFinder finder =
new SimilarStartsFinder(currentProgram, currentProgram, best);
List<SimilarStartRowObject> neighbors =
finder.getSimilarFunctionStarts(falsePositives.get(0).getKey(), 10);
printf("\nClosest function starts to false positive at %s :\n",
falsePositives.get(0).getKey());
neighbors.forEach(n -> printf(" %s %d\n", n.funcStart(), n.numAgreements()));
}
//grab the set of bytes in executable memory which are undefined (i.e., initialized but
//not yet assigned to be code or data) or instructions which are not assigned to a function
//body
//don't bother looking in small undefined ranges
long minUndefinedRange = 16;
GetAddressesToClassifyTask getAddressTask =
new GetAddressesToClassifyTask(currentProgram, minUndefinedRange);
getAddressTask.run(monitor);
AddressSet toClassify = getAddressTask.getAddressesToClassify();
Map<Address, Double> potentialStarts = classifier.classify(toClassify, monitor);
//grab all of the addresses with probability of being a function start >= .7
//and disassemble any that are currently undefined
//block model is needed to get the interpretation of an address
//e.g., undefined, block start, within block,...
BasicBlockModel blockModel = new BasicBlockModel(currentProgram);
//@formatter:off
List<Address> addresses = potentialStarts.entrySet()
.stream()
.filter(x -> {return x.getValue() >= 0.7d;})
.map(x -> x.getKey())
.collect(Collectors.toList());
//@formatter:on
AddressSet toDisassemble = new AddressSet();
for (Address addr : addresses) {
Interpretation inter =
Interpretation.getInterpretation(currentProgram, addr, blockModel, monitor);
if (inter.equals(Interpretation.UNDEFINED)) {
toDisassemble.add(addr);
}
}
//see DisassemblyAndApplyContextAction#actionPerfomed if you need to
//apply context register values
printf("Found %d addresses to disassemble\n", toDisassemble.getNumAddresses());
DisassembleCommand cmd = new DisassembleCommand(toDisassemble, null, true);
cmd.applyTo(currentProgram);
//create functions at any address with probability of being a function start > .8,
//where an instruction exists,
//which is defined as a BLOCK_START (so not already a function start)
//and with no conditional flow references to it
//FunctionStartRowObject represents a row in a table in the gui, but you don't
//actually need the gui and it's a convenient container for information about an address
//@formatter:off
List<FunctionStartRowObject> funcRows = potentialStarts.entrySet()
.stream()
.filter(x -> {return x.getValue() >= 0.8d;})
.map(x -> new FunctionStartRowObject(x.getKey(),x.getValue()))
.collect(Collectors.toList());
//@formatter:on
for (FunctionStartRowObject funcRow : funcRows) {
funcRow.setCurrentInterpretation(Interpretation.getInterpretation(currentProgram,
funcRow.getAddress(), blockModel, monitor));
FunctionStartRowObject.setReferenceData(funcRow, currentProgram);
}
AddressSet entries = new AddressSet();
funcRows.stream()
.filter(x -> x.getCurrentInterpretation().equals(Interpretation.BLOCK_START))
.filter(x -> x.getNumConditionalFlowRefs() == 0)
.forEach(x -> entries.add(x.getAddress()));
printf("Found %d addresses to create functions\n", entries.getNumAddresses());
CreateFunctionCmd createCmd = new CreateFunctionCmd(entries);
createCmd.applyTo(currentProgram);
}
}