Skip to content

Commit

Permalink
Merge pull request #687 from bytedance/memshell-jvm-benchmark
Browse files Browse the repository at this point in the history
Memshell jvm benchmark
  • Loading branch information
yoloyyh authored Sep 26, 2024
2 parents 63aa263 + 13cc787 commit 9a65e84
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 6 deletions.
11 changes: 10 additions & 1 deletion rasp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DEBUG_SYMBOLS ?= debug
LIB_OUTPUT ?= $(OUTPUT)/lib-$(VERSION)
VCPKG_OVERLAY_PORTS ?= $(abspath overlay-ports)

.PHONY: all help install clean set-version agent-plugin nsenter pangolin jattach JVMProbe python-probe python-loader go-probe go-probe-ebpf node-probe php-probe librasp rasp-server NSMount
.PHONY: all help install clean check_benchmark set-version agent-plugin nsenter pangolin jattach JVMProbe python-probe python-loader go-probe go-probe-ebpf node-probe php-probe librasp rasp-server NSMount

all: rasp-linux-default-x86_64-$(VERSION).tar.gz rasp-linux-default-x86_64-$(VERSION)-debug.tar.gz SHA256SUMS

Expand Down Expand Up @@ -41,6 +41,15 @@ rasp-linux-default-x86_64-$(VERSION)-debug.tar.gz: | $(DEBUG_SYMBOLS)
SHA256SUMS: rasp-linux-default-x86_64-$(VERSION).tar.gz
sha256sum $(OUTPUT)/rasp rasp-linux-default-x86_64-$(VERSION).tar.gz > $@

# for benchmark
check_benchmark:
if echo $(VERSION) | grep -q benchmark; then \
echo "Version contains benchmark. Modifying files..."; \
sed -i "s/isBenchMark = false/isBenchMark = true/g" "jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java" \
sed -i "s/isBenchMark = false/isBenchMark = true/g" "jvm/JVMProbe/src/main/java/com/security/smith/asm/SmithMethodVisitor.java" \
else \
echo "Version does not contain benchmark. Skipping modification."; \
fi

set-version:
sed -i "s/1.0.0.1/${VERSION}/g" "librasp/src/settings.rs"
Expand Down
78 changes: 74 additions & 4 deletions rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.lmax.disruptor.EventHandler;

import com.lmax.disruptor.InsufficientCapacityException;
import com.lmax.disruptor.RingBuffer;
import com.lmax.disruptor.dsl.Disruptor;
import com.lmax.disruptor.util.DaemonThreadFactory;
import com.security.smith.asm.SmithClassVisitor;
Expand Down Expand Up @@ -54,9 +55,6 @@
import java.io.File;
import java.io.FileOutputStream;
import java.security.CodeSource;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.jar.JarFile;


Expand All @@ -73,6 +71,10 @@ public class SmithProbe implements ClassFileTransformer, MessageHandler, EventHa

private final Map<String, SmithClass> smithClasses;
private final Map<String, Patcher> patchers;
private final Map<Pair<Integer, Integer>, List<Long>> records;
private final Map<Pair<Integer, Integer>, List<Long>> recordsTotal;
private final Map<Pair<Integer, Integer>, Long> hooktimeRecords;
private final Map<Pair<Integer, Integer>, Long> runtimeRecords;
private final Map<Pair<Integer, Integer>, Filter> filters;
private final Map<Pair<Integer, Integer>, Block> blocks;
private final Map<Pair<Integer, Integer>, Integer> limits;
Expand All @@ -81,6 +83,7 @@ public class SmithProbe implements ClassFileTransformer, MessageHandler, EventHa
private final Rule_Mgr rulemgr;
private final Rule_Config ruleconfig;
private SmithProbeProxy smithProxy;
private boolean isBenchMark;

enum Action {
STOP,
Expand All @@ -94,12 +97,17 @@ public static SmithProbe getInstance() {
public SmithProbe() {
disable = false;
scanswitch = true;
isBenchMark = false;

smithClasses = new ConcurrentHashMap<>();
patchers = new ConcurrentHashMap<>();
filters = new ConcurrentHashMap<>();
blocks = new ConcurrentHashMap<>();
limits = new ConcurrentHashMap<>();
records = new HashMap<>();
recordsTotal = new HashMap<>();
hooktimeRecords = new HashMap<>();
runtimeRecords = new HashMap<>();

heartbeat = new Heartbeat();
client = new Client(this);
Expand Down Expand Up @@ -185,6 +193,18 @@ public void run() {
0,
TimeUnit.MINUTES.toMillis(1)
);
if (isBenchMark) {
new Timer(true).schedule(
new TimerTask() {
@Override
public void run() {
show();
}
},
TimeUnit.SECONDS.toMillis(5),
TimeUnit.SECONDS.toMillis(10)
);
}
inst.addTransformer(this, true);
reloadClasses();
}
Expand Down Expand Up @@ -259,6 +279,56 @@ private void reloadClasses(Collection<String> classes) {
}
}

private Long tp(List<Long> times, double percent) {
return times.get((int)(percent / 100 * times.size() - 1));
}

private void show() {
synchronized (records) {
SmithLogger.logger.info("=================== statistics ===================");
records.forEach((k, v) -> {
Collections.sort(v);
List<Long> tv = recordsTotal.get(new ImmutablePair<>(k.getLeft(), k.getRight()));
Collections.sort(tv);
Long hooktime = hooktimeRecords.get(new ImmutablePair<>(k.getLeft(), k.getRight()));
Long runtime = runtimeRecords.get(new ImmutablePair<>(k.getLeft(), k.getRight()));
SmithLogger.logger.info(
String.format(
"class: %d method: %d count: %d tp50: %d tp90: %d tp95: %d tp99: %d tp99.99: %d max: %d total-max:%d hooktime:%d runtime:%d",
k.getLeft(),
k.getRight(),
v.size(),
tp(v, 50),
tp(v, 90),
tp(v, 95),
tp(v, 99),
tp(v, 99.99),
v.get(v.size() - 1),
tv.get(tv.size() - 1),
hooktime,
runtime
)
);
});
}
}
public void record(int classID, int methodID, long time,long totaltime) {
synchronized (records) {
records.computeIfAbsent(new ImmutablePair<>(classID, methodID), k -> new ArrayList<>()).add(time);
}
synchronized (recordsTotal) {
recordsTotal.computeIfAbsent(new ImmutablePair<>(classID, methodID), k -> new ArrayList<>()).add(totaltime);
}
synchronized (hooktimeRecords) {
hooktimeRecords.computeIfAbsent(new ImmutablePair<>(classID, methodID), k -> time);
hooktimeRecords.computeIfPresent(new ImmutablePair<>(classID, methodID),(k,v) -> v+time);
}
synchronized (runtimeRecords) {
runtimeRecords.computeIfAbsent(new ImmutablePair<>(classID, methodID), k -> totaltime);
runtimeRecords.computeIfPresent(new ImmutablePair<>(classID, methodID),(k,v) -> v+totaltime);
}
}

@Override
public void onEvent(Trace trace, long sequence, boolean endOfBatch) {
Filter filter = filters.get(new ImmutablePair<>(trace.getClassID(), trace.getMethodID()));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.security.smith.asm;

import com.security.smith.SmithProbe;
import com.security.smith.SmithProbeProxy;
import com.security.smith.processor.*;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
Expand All @@ -21,6 +23,8 @@ public class SmithMethodVisitor extends AdviceAdapter {
private final boolean canBlock;
private final boolean isStatic;
private final boolean isConstructor;
private int stopWatchVariable;
private int stopWatchTotalVariable;
private final int returnVariable;
private final int argumentsVariable;
private final Label start;
Expand All @@ -29,6 +33,7 @@ public class SmithMethodVisitor extends AdviceAdapter {
private String preHook;
private String postHook;
private String exceptionHook;
private final boolean isBenchMark;

private static final Map<String, Class<?>> smithProcessors = new HashMap<String, Class<?>>() {{
put("byte[]", ByteArrayProcessor.class);
Expand All @@ -55,11 +60,17 @@ protected SmithMethodVisitor(int api, Type classType, int classID, int methodID,
this.preHook = pre_hook;
this.postHook = post_hook;
this.exceptionHook = exception_hook;
this.isBenchMark = false;

start = new Label();
end = new Label();
handler = new Label();

if (isBenchMark) {
stopWatchTotalVariable = newLocal(Type.getType(StopWatch.class));
stopWatchVariable = newLocal(Type.getType(StopWatch.class));
}

argumentsVariable = newLocal(Type.getType(Object[].class));
returnVariable = newLocal(Type.getType(Object.class));

Expand Down Expand Up @@ -99,6 +110,30 @@ protected void onMethodEnter() {

visitTryCatchBlock(start, end, handler, Type.getInternalName(Exception.class));

if (isBenchMark) {
invokeStatic(
Type.getType(StopWatch.class),
new Method(
"createStarted",
Type.getType(StopWatch.class),
new Type[]{}
)
);

storeLocal(stopWatchTotalVariable);

invokeStatic(
Type.getType(StopWatch.class),
new Method(
"createStarted",
Type.getType(StopWatch.class),
new Type[]{}
)
);

storeLocal(stopWatchVariable);
}

loadArgArray();
storeLocal(argumentsVariable);

Expand All @@ -109,6 +144,19 @@ protected void onMethodEnter() {

if (preHook == null || preHook == "") {
if (!canBlock) {
if (isBenchMark) {
loadLocal(stopWatchVariable);

invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"suspend",
Type.VOID_TYPE,
new Type[]{}
)
);
}

return;
} else {
preHook = "detect";
Expand Down Expand Up @@ -140,6 +188,18 @@ protected void onMethodEnter() {
}
)
);
if (isBenchMark) {
loadLocal(stopWatchVariable);

invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"suspend",
Type.VOID_TYPE,
new Type[]{}
)
);
}
}

@Override
Expand All @@ -148,6 +208,19 @@ protected void onMethodExit(int opcode) {

if (opcode == ATHROW)
return;

if (isBenchMark) {
loadLocal(stopWatchVariable);

invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"resume",
Type.VOID_TYPE,
new Type[]{}
)
);
}

Type returnType = Type.getReturnType(methodDesc);

Expand Down Expand Up @@ -192,7 +265,6 @@ protected void onMethodExit(int opcode) {
loadLocal(returnVariable);
push(false);



invokeVirtual(
Type.getType(SmithProbeProxy.class),
Expand All @@ -208,6 +280,68 @@ protected void onMethodExit(int opcode) {
}
)
);

if (isBenchMark) {
loadLocal(stopWatchVariable);
invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"stop",
Type.VOID_TYPE,
new Type[]{}
)
);
loadLocal(stopWatchTotalVariable);
invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"stop",
Type.VOID_TYPE,
new Type[]{}
)
);
invokeStatic(
Type.getType(SmithProbe.class),
new Method(
"getInstance",
Type.getType(SmithProbe.class),
new Type[]{}
)
);
push(classID);
push(methodID);
loadLocal(stopWatchVariable);
invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"getNanoTime",
Type.LONG_TYPE,
new Type[]{}
)
);
loadLocal(stopWatchTotalVariable);
invokeVirtual(
Type.getType(StopWatch.class),
new Method(
"getNanoTime",
Type.LONG_TYPE,
new Type[]{}
)
);
invokeVirtual(
Type.getType(SmithProbe.class),
new Method(
"record",
Type.VOID_TYPE,
new Type[]{
Type.INT_TYPE,
Type.INT_TYPE,
Type.LONG_TYPE,
Type.LONG_TYPE
}
)
);
}
}

@Override
Expand Down

0 comments on commit 9a65e84

Please sign in to comment.