diff --git a/build.gradle.kts b/build.gradle.kts index 34c96e2..81cde81 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -29,6 +29,9 @@ dependencies { // Needed for AssemblyScript source mapping: //implementation("com.atlassian.sourcemap:sourcemap:2.0.0") implementation("com.google.code.gson:gson:2.11.0") + + // Logging + implementation("ch.qos.logback:logback-classic:1.5.32") } tasks.test { diff --git a/src/main/java/be/ugent/topl/mio/GdbStub.java b/src/main/java/be/ugent/topl/mio/GdbStub.java new file mode 100644 index 0000000..ed484fd --- /dev/null +++ b/src/main/java/be/ugent/topl/mio/GdbStub.java @@ -0,0 +1,408 @@ +package be.ugent.topl.mio; + +import be.ugent.topl.mio.debugger.Debugger; +import be.ugent.topl.mio.woodstate.Frame; +import be.ugent.topl.mio.woodstate.WOODDumpResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +public class GdbStub { + private final Debugger debugger; + private final String binaryLocation; + private OutputStream out; + private final Logger logger = LoggerFactory.getLogger(GdbStub.class); + private boolean stepForward = true; + + public GdbStub(Debugger debugger, String binaryLocation) { + this.debugger = debugger; + this.binaryLocation = binaryLocation; + + debugger.getBreakpointsListeners().add((pc) -> { + try { + logger.info("Stopped at breakpoint {}", pc); + sendStopPacket(out, "05"); + } catch (IOException e) { + throw new RuntimeException(e); + } + return null; + }); + } + + private String toHex(long data, int maxLen, boolean bigEndian) { + StringBuilder result = new StringBuilder(); + for (int i = 0; i < maxLen; i++) { + long b = (data >> i * 8) & 0xff; + if (!bigEndian) { + b = (data >> (maxLen - i - 1) * 8) & 0xff; + } + result.append(String.format("%02x", b)); + } + return result.toString(); + } + + private String toHex(long data) { + return toHex(data, 8, true); + } + + private long toWasmAddr(long addr) { + // We use the upper bits of the address to indicate the type, being part of the object (0) or part of the wasm memory (1). + return addr | 0x1L << 63; + } + + private long getAddrType(long addr) { + return addr >>> 63; + } + + private long stripAddrType(long addr) { + return addr & Long.MAX_VALUE; + } + + private String getTriple(String s) { + return s.chars() + .mapToObj(c -> String.format("%02x", c)) + .reduce("", String::concat); + } + + public WOODDumpResponse getCurrentState() { + return debugger.getCheckpoints().getLast().getSnapshot(); + } + + public void start(int port) throws IOException { + start(port, false); + } + + public void start(int port, boolean closeOnDetach) throws IOException { + debugger.pause(); + + byte[] wasmData = Files.readAllBytes(Path.of(binaryLocation)); + + ServerSocket server = new ServerSocket(port); + Socket sock; + InputStream in = null; + boolean attached = false; + + do { + if (!attached) { + logger.info("Waiting for LLDB connection on port {}...", port); + sock = server.accept(); + logger.info("LLDB connected! {}", sock.getInetAddress()); + + in = sock.getInputStream(); + out = sock.getOutputStream(); + attached = true; + } + + String pkt = recvPacket(in, out); + if (pkt == null) { + logger.info("Connection closed"); + break; + } + logger.trace("<- {}", pkt); + + if (pkt.startsWith("m")) { + pkt = pkt.substring(1); + String[] memArgs = pkt.split(","); + long pos = Long.parseUnsignedLong(memArgs[0], 16); + long addrType = getAddrType(pos); + pos = stripAddrType(pos); + long len = Long.parseUnsignedLong(memArgs[1], 16); + logger.info("Read memory {} bytes from {} with addr type = {}", len, pos, addrType); + + // TODO: Hack to temporarily prevent lldb from using breakpoints when stepping + // https://github.com/llvm/llvm-project/issues/189960 + /*if (addrType == 0) { + sendPacket(out, "E01"); + continue; + }*/ + + byte[] memory = wasmData; + if (addrType == 1) { + logger.info("Reading from wasm linear memory"); + memory = getCurrentState().getMemory().getBytes(); + } + + // The reply may contain fewer addressable memory units than requested if the server was reading from a trace frame memory and was able to read only part of the region of memory. + if (pos >= memory.length || pos < 0) { + logger.warn("Pos {} out of bounds for length {}", pos, memory.length); + sendPacket(out, "E01"); + continue; + } + + StringBuilder result = new StringBuilder(); + for (int i = 0; i < len; i++) { + // TODO: read from pos + i + // also think about little vs big endian + int index = (int) pos + i; + if (index >= memory.length) { + logger.info("Stop reading, out of bounds"); + break; + } + int b = memory[index]; + result.append(String.format("%02x", b & 0xFF)); + } + sendPacket(out, result.toString()); + continue; + } + else if (pkt.startsWith("qSupported")) { + //ReverseContinue+; + sendPacket(out, "qXfer:libraries:read+;vContSupported-;wasm+;ReverseStep+;"); + continue; + } + else if (pkt.startsWith("qXfer:libraries:read")) { + // https://github.com/v8/v8/blob/main/src/debug/wasm/gdb-server/gdb-remote-util.h#L51 + // For LLDB debugging, an address in a Wasm module code space is represented + // with 64 bits, where the first 32 bits identify the module id: + // +--------------------+--------------------+ + // | module_id | offset | + // +--------------------+--------------------+ + // <----- 32 bit -----> <----- 32 bit -----> + // Offset 0, module id 0 + sendPacket(out, String.format("l
", new File(binaryLocation).getAbsolutePath())); + continue; + } + // TODO: We can probably remove this: + else if (pkt.startsWith("Hc")) { + sendPacket(out, "OK"); + continue; + } + else if (pkt.startsWith("qWasmLocal:")) { + String[] args = pkt.substring("qWasmLocal:".length()).split(";"); + int frameIdx = Integer.parseInt(args[0]); + int localIdx = Integer.parseInt(args[1]); + logger.info("Reading local {} from frame {}", localIdx, frameIdx); + Frame frame = getCallStack(getCurrentState()).get(frameIdx); + logger.trace("{}", getCallStack(getCurrentState())); + logger.trace("{}", frame); + logger.trace("{}", getCurrentState().getStack()); + int fp = frame.getFp(); + long value = getCurrentState().getStack().get(fp + localIdx + 1).getValue(); + sendPacket(out, toHex(toWasmAddr(value))); // If a pointer on the stack is an address it will be clear it's a wasm memory pointer. + continue; + } + else if (pkt.startsWith("qRegisterInfo0")) { + sendPacket(out, "name:pc;alt-name:pc;bitsize:64;offset:0;encoding:uint;format:hex;set:General Purpose Registers;gcc:16;dwarf:16;generic:pc;"); + continue; + } + else if (pkt.startsWith("p0")) { + sendPacket(out, toHex(getCurrentState().getPc())); + continue; + } + else if (pkt.startsWith("Z")) { + String[] args = pkt.substring(1).split(","); + int type = Integer.parseInt(args[0]); + long addr = Long.parseUnsignedLong(args[1], 16); + int kind = Integer.parseInt(args[2]); + logger.info("Add breakpoint on {}", addr); + debugger.addBreakpoint((int) addr); + + // A remote target shall return an empty string for an unrecognized breakpoint or watchpoint packet type. + sendPacket(out, "OK"); + continue; + } + else if (pkt.startsWith("z")) { + // TODO: Fix code duplication + String[] args = pkt.substring(1).split(","); + int type = Integer.parseInt(args[0]); + long addr = Long.parseUnsignedLong(args[1], 16); + int kind = Integer.parseInt(args[2]); + logger.info("Remove breakpoint on {}", addr); + debugger.removeBreakpoint((int) addr); + + // A remote target shall return an empty string for an unrecognized breakpoint or watchpoint packet type. + sendPacket(out, "OK"); + continue; + } + else if (pkt.startsWith("stackInfo")) { + //debugger.stepBack(1, wasmData); + sendPacket(out, getCurrentState().getStack().toString()); + continue; + } else if (pkt.startsWith("bs")) { + if (pkt.equals("bs")) { + stepBack(1); + } else { + int n = Integer.parseInt(pkt.split(" ")[1]); + logger.info("Step back {} instruction(s)", n); + stepBack(n); + } + continue; + } else if (pkt.startsWith("QSetStepDir:")) { + stepForward = Integer.parseInt(pkt.substring("QSetStepDir:".length())) == 0; + logger.info("Change step direction to {}", stepForward ? "forward" : "backward"); + sendPacket(out, "OK"); + continue; + } + + switch (pkt) { + /*case "QStartNoAckMode": + sendPacket(out, "OK"); + break;*/ + case "qHostInfo": + sendPacket(out, "vendor:wamr;ostype:wasi;arch:wasm32;endian:little;ptrsize:4;"); + break; + case "qProcessInfo": + sendPacket(out, "pid:1;parent-pid:1;vendor:wamr;ostype:wasi;arch:wasm32;triple:" + getTriple("wasm32-unknown-unknown-wasm") + ";endian:little;ptrsize:4;"); + break; + case "qGetWorkingDir": + sendPacket(out, "/tmp"); + break; + case "qQueryGDBServer": + sendPacket(out, "PacketSize=4000"); + break; + case "qWasmCallStack:1": // Get the callstack for thread 1. + WOODDumpResponse state = getCurrentState(); + StringBuilder result = new StringBuilder(toHex(state.getPc())); + for (int i = state.getCallstack().size() - 1; i >= 0; i--) { + // Only functions are real callstack elements: + Frame f = state.getCallstack().get(i); + if (f.getType() == 0) { + result.append(toHex(f.getRa())); + } + } + sendPacket(out, result.toString()); + + break; + case "qC": // Get thread id + sendPacket(out, "QC 1"); + break; + case "qfThreadInfo": + sendPacket(out, "m 1"); // Active threads start list + break; + case "qsThreadInfo": + sendPacket(out, "l"); // End of list + break; + case "?": + sendPacket(out, "S05"); // SIGTRAP + break; + case "g": + sendPacket(out, String.format("%08x", getCurrentState().getPc())); + break; + case "s": + logger.info("Received step command from lldb"); + if (stepForward) { + logger.info("Step forward"); + debugger.stepInto(); + sendStopPacket(out, "05"); + } + else { + logger.info("Step backward"); + stepBack(1); + } + break; + case "c": + logger.info("Continue execution"); + debugger.run(); + break; + /*case "bc": + // TODO: Actual backwards continue in MIO + stepBack(1); + break;*/ + case "pause": + debugger.pause(); + sendStopPacket(out, "02"); + break; + case "D": + attached = false; + logger.info("Detach from target"); + sendPacket(out, "OK"); + break; + default: + logger.warn("Unknown packet: {}", pkt); + sendPacket(out, ""); + break; + } + } while(!closeOnDetach || attached); + debugger.close(); + } + + private void stepBack(int n) throws IOException { + if (!debugger.canStepBack(n)) { + sendPacket(out, "T" + "05" + "thread:1;name:warduino;thread-pcs:" + toHex(getCurrentState().getPc()) + ";00:" + toHex(getCurrentState().getPc()) + ";replaylog:begin;"); + return; + } + + debugger.stepBack(n, () -> null); + sendStopPacket(out, "05"); + } + + private void sendStopPacket(OutputStream out, String signal) throws IOException { + sendPacket(out, "T" + signal + "thread:1;name:warduino;thread-pcs:" + toHex(getCurrentState().getPc()) + ";00:" + toHex(getCurrentState().getPc()) + ";reason:trace"); + } + + private String recvPacket(InputStream in, OutputStream out) throws IOException { + int c; + + // Wait for '$' + do { + c = in.read(); + if (c == -1) return null; + } while (c != '$' && c != 0x03); // 0x03 == pause request + + if (c == 0x03) { + return "pause"; + } + + ByteArrayOutputStream payload = new ByteArrayOutputStream(); + + // Read until '#' + while ((c = in.read()) != '#') { + if (c == -1) return null; + payload.write(c); + } + + // Read checksum + int c1 = in.read(); + int c2 = in.read(); + if (c1 == -1 || c2 == -1) return null; + + int received = Integer.parseInt("" + (char)c1 + (char)c2, 16); + byte[] data = payload.toByteArray(); + int computed = checksum(data); + + if (received == computed) { + out.write('+'); // ACK + out.flush(); + return new String(data); + } else { + out.write('-'); // NAK + out.flush(); + return null; + } + } + + private void sendPacket(OutputStream out, String payload) throws IOException { + byte[] data = payload.getBytes(); + int sum = checksum(data); + + String pkt = "$" + payload + "#" + String.format("%02x", sum); + out.write(pkt.getBytes()); + out.flush(); + logger.trace("-> {}", pkt); + } + + private int checksum(byte[] data) { + int sum = 0; + for (byte b : data) { + sum = (sum + (b & 0xFF)) & 0xFF; + } + return sum; + } + + private List getCallStack(WOODDumpResponse state) { + List callStack = new ArrayList(); + for (int i = state.getCallstack().size() - 1; i >= 0; i--) { + Frame f = state.getCallstack().get(i); + if (f.getType() == 0) { + callStack.add(state.getCallstack().get(i)); + } + } + return callStack; + } +} diff --git a/src/main/kotlin/be/ugent/topl/mio/Main.kt b/src/main/kotlin/be/ugent/topl/mio/Main.kt index eb4b9d5..cfb7e4d 100644 --- a/src/main/kotlin/be/ugent/topl/mio/Main.kt +++ b/src/main/kotlin/be/ugent/topl/mio/Main.kt @@ -3,14 +3,14 @@ package be.ugent.topl.mio import be.ugent.topl.mio.connections.ProcessConnection import be.ugent.topl.mio.connections.SerialConnection import be.ugent.topl.mio.debugger.Debugger -import com.formdev.flatlaf.FlatDarkLaf -import com.formdev.flatlaf.FlatIntelliJLaf import be.ugent.topl.mio.sourcemap.AsSourceMapping import be.ugent.topl.mio.sourcemap.compileAndFlash import be.ugent.topl.mio.sourcemap.compileWat import be.ugent.topl.mio.sourcemap.getDwarfSourcemap import be.ugent.topl.mio.ui.InteractiveDebugger import be.ugent.topl.mio.ui.StartScreen +import com.formdev.flatlaf.FlatDarkLaf +import com.formdev.flatlaf.FlatIntelliJLaf import com.formdev.flatlaf.util.SystemInfo import java.io.File import java.io.FileNotFoundException @@ -52,6 +52,20 @@ fun main(args: Array) { } expectNArguments(args, 1) val config = DebuggerConfig() + + if (args[0].startsWith("-p=") || args[0].startsWith("--port=")) { + expectNArguments(args, 2) + val wasmFilename = args[1] + val port = args[0].split("=")[1].toInt() + val connection = if(config.useEmulator) ProcessConnection(config.wdcliPath, wasmFilename, "--no-socket", "--paused") else SerialConnection(config.port!!) + val debugger = Debugger(connection) + debugger.pause() + debugger.setSnapshotPolicy(Debugger.SnapshotPolicy.Checkpointing()) + val stub = GdbStub(debugger, wasmFilename) + stub.start(port) + return + } + when (args[0]) { "debug" -> { expectNArguments(args, 2) diff --git a/src/main/kotlin/be/ugent/topl/mio/debugger/Debugger.kt b/src/main/kotlin/be/ugent/topl/mio/debugger/Debugger.kt index 1983adb..f0647fb 100644 --- a/src/main/kotlin/be/ugent/topl/mio/debugger/Debugger.kt +++ b/src/main/kotlin/be/ugent/topl/mio/debugger/Debugger.kt @@ -14,8 +14,8 @@ import java.util.* import kotlin.concurrent.thread import kotlin.streams.toList -open class Debugger(private val connection: Connection, start: Boolean = true, private val onHitBreakpoint: (Int) -> Unit = {}) : Closeable, AutoCloseable { - private val requestQueue: Queue = LinkedList() +open class Debugger(private val connection: Connection, start: Boolean = true, onHitBreakpoint: (Int) -> Unit = {}) : Closeable, AutoCloseable { + var breakpointsListeners: MutableList<(Int) -> Unit> = mutableListOf(onHitBreakpoint) var printListener: ((String) -> Unit)? = null private val messageQueue = MessageQueue { for (msg in it) { @@ -39,6 +39,7 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p connection.read(readBuffer) messageQueue.push(String(readBuffer), true) + //println(String(readBuffer)) while (true) { val checkpointMessage = messageQueue.search { val match = Regex("CHECKPOINT (.*)").matchEntire(it.trimEnd('\r')) ?: throw Exception() @@ -97,7 +98,9 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p * incoming messages, so the request would never be completed. */ thread { - onHitBreakpoint(searchAtResult.second) + for (breakpointListener in breakpointsListeners) { + breakpointListener(searchAtResult.second) + } } } } @@ -175,7 +178,6 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p private fun send(code: Int, payload: String = "") { val str = String.format("%02d$payload\n", code) - requestQueue.add(code) print("Sending $str") val write = str.toByteArray() connection.write(write) @@ -217,18 +219,18 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p } } - private fun canStepBack(): Boolean { - return checkpoints.size > 1 + fun canStepBack(n: Int = 1): Boolean { + return checkpoints.size > n } - fun stepBackUntil(binaryInfo: WasmInfo, cond: (WOODDumpResponse) -> Boolean) { - stepBack(1, binaryInfo) {} + fun stepBackUntil(cond: (WOODDumpResponse) -> Boolean) { + stepBack() while (!cond(checkpoints.last()!!.snapshot)) { if (!canStepBack()) { System.err.println("WARNING: Can't go back further!") return } - stepBack(1, binaryInfo) + stepBack() } } @@ -261,7 +263,7 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p println("count = ${checkpoints.size}") } - open fun stepBack(n: Int, binaryInfo: WasmInfo, stepDone: () -> Unit = {}) { + open fun stepBack(n: Int = 1, stepDone: () -> Unit = {}) { if (n == 0) { return } @@ -269,7 +271,7 @@ open class Debugger(private val connection: Connection, start: Boolean = true, p val currentState = checkpoints.removeLast() // Remove current state, we don't need to restore this, we are already in this state. val nSnapshots = checkpoints.subList(checkpoints.size - n, checkpoints.size).toList() for (checkpoint in nSnapshots.reversed()) { - if (checkpoint != null && (checkpoint.snapshot.pc in binaryInfo.after_primitive_calls || nSnapshots.first() == checkpoint)) { + if (checkpoint != null && (checkpoint.fidx_called != null || nSnapshots.first() == checkpoint)) { //if (snapshot != null) { println("Snapshot to ${checkpoint.snapshot.pc}") val s = checkpoint.snapshot diff --git a/src/main/kotlin/be/ugent/topl/mio/debugger/MultiverseDebugger.kt b/src/main/kotlin/be/ugent/topl/mio/debugger/MultiverseDebugger.kt index edff384..062f80b 100644 --- a/src/main/kotlin/be/ugent/topl/mio/debugger/MultiverseDebugger.kt +++ b/src/main/kotlin/be/ugent/topl/mio/debugger/MultiverseDebugger.kt @@ -1,7 +1,6 @@ package be.ugent.topl.mio.debugger import WasmBinary -import WasmInfo import be.ugent.topl.mio.concolic.analyse import be.ugent.topl.mio.concolic.processPaths import be.ugent.topl.mio.connections.Connection @@ -202,12 +201,12 @@ class MultiverseDebugger( }*/ } - override fun stepBack(n: Int, binaryInfo: WasmInfo, stepDone: () -> Unit) { + override fun stepBack(n: Int, stepDone: () -> Unit) { var destinationNode = graph.currentNode for (i in 0 ..< n) { destinationNode = destinationNode.parent!! } - super.stepBack(n, binaryInfo, stepDone) + super.stepBack(n, stepDone) graph.currentNode = destinationNode graphUpdated() diff --git a/src/main/kotlin/be/ugent/topl/mio/ui/InteractiveDebugger.kt b/src/main/kotlin/be/ugent/topl/mio/ui/InteractiveDebugger.kt index 652a51c..b4fc55b 100644 --- a/src/main/kotlin/be/ugent/topl/mio/ui/InteractiveDebugger.kt +++ b/src/main/kotlin/be/ugent/topl/mio/ui/InteractiveDebugger.kt @@ -136,7 +136,7 @@ class InteractiveDebugger( stepBackButton.addActionListener { println("Step back") //debugger.stepBack() - debugger.stepBack(1, binaryInfo) {} + debugger.stepBack() updateStepBackButton() updatePcLabel() } @@ -193,7 +193,7 @@ class InteractiveDebugger( } catch(re: RuntimeException) { System.err.println("WARNING: " + re.message) } - debugger.stepBackUntil(binaryInfo) { + debugger.stepBackUntil { try { sourceMapping.getLineForPc(it.pc!!) != startLine } catch(re: RuntimeException) { @@ -541,7 +541,7 @@ class MultiversePanel(private val multiverseDebugger: MultiverseDebugger, config val totalLength = backwardsLength + forwardsLength val backwardPath = graphPanel.selectedPath!!.first.toMutableList() var finishedSteps = 0 - multiverseDebugger.stepBack(backwardPath.size, multiverseDebugger.wasmBinary.metadata) { + multiverseDebugger.stepBack(backwardPath.size) { graphPanel.completedPath.add(backwardPath.removeFirst()) graphPanel.repaint() val remaining = forwardsLength + backwardPath.size diff --git a/src/main/kotlin/be/ugent/topl/mio/woodstate/WOODState.kt b/src/main/kotlin/be/ugent/topl/mio/woodstate/WOODState.kt index aa90b6f..e815c64 100644 --- a/src/main/kotlin/be/ugent/topl/mio/woodstate/WOODState.kt +++ b/src/main/kotlin/be/ugent/topl/mio/woodstate/WOODState.kt @@ -156,6 +156,7 @@ data class Checkpoint( val instructions_executed: Int, val fidx_called: Int?, val args: List?, + val returns: List?, val snapshot: WOODDumpResponse ) diff --git a/src/main/resources/logback.xml b/src/main/resources/logback.xml new file mode 100644 index 0000000..c023bd2 --- /dev/null +++ b/src/main/resources/logback.xml @@ -0,0 +1,23 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %highlight(%-5level) %logger{36} - %msg%n + + + DEBUG + + + + + mio.log + false + + %-4relative [%thread] %-5level %logger{35} - %msg%n + + + + + + + +