From f3975dba19132e80b2890e62be1e8b58adf891fc Mon Sep 17 00:00:00 2001 From: Jeremy Bruestle Date: Mon, 30 Dec 2024 16:26:35 -0800 Subject: [PATCH] Fix host read --- zirgen/circuit/rv32im/v2/dsl/arr.zir | 10 +-- zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir | 73 ++++++++++++---- zirgen/circuit/rv32im/v2/dsl/mem.zir | 2 +- zirgen/circuit/rv32im/v2/dsl/top.zir | 4 - zirgen/circuit/rv32im/v2/emu/preflight.cpp | 1 + zirgen/circuit/rv32im/v2/emu/r0vm.h | 22 +++-- zirgen/circuit/rv32im/v2/test/BUILD.bazel | 23 +++++ zirgen/circuit/rv32im/v2/test/test_io.cpp | 51 ++++++++++++ .../circuit/rv32im/v2/test/test_io_kernel.cpp | 83 +++++++++++++++++++ 9 files changed, 236 insertions(+), 33 deletions(-) create mode 100644 zirgen/circuit/rv32im/v2/test/test_io.cpp create mode 100644 zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp diff --git a/zirgen/circuit/rv32im/v2/dsl/arr.zir b/zirgen/circuit/rv32im/v2/dsl/arr.zir index 0e99c723..e8dae002 100644 --- a/zirgen/circuit/rv32im/v2/dsl/arr.zir +++ b/zirgen/circuit/rv32im/v2/dsl/arr.zir @@ -1,4 +1,3 @@ -// This file contains utilities that work with bits and twits. // RUN: zirgen --test %s // Vector / List functions @@ -37,11 +36,10 @@ component EqArr(a: Array, b: Array) { // Tests.... test ShiftAndRotate { - // TODO: Now that these support non-bit values, maybe make new tests // Remember: array entry 0 is the low bit, so there seem backwards - EqArr<8>(ShiftRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 0, 0]); - EqArr<8>(ShiftLeft<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [0, 0, 1, 1, 1, 0, 1, 0]); - EqArr<8>(RotateRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 1, 1]); - EqArr<8>(RotateLeft<8>([1, 1, 1, 0, 1, 0, 0, 1], 2), [0, 1, 1, 1, 1, 0, 1, 0]); + EqArr<8>(ShiftRight<8>([3, 1, 5, 0, 2, 0, 0, 0], 2), [5, 0, 2, 0, 0, 0, 0, 0]); + EqArr<8>(ShiftLeft<8>([1, 4, 2, 0, 6, 0, 0, 0], 2), [0, 0, 1, 4, 2, 0, 6, 0]); + EqArr<8>(RotateRight<8>([7, 6, 1, 0, 2, 0, 0, 0], 2), [1, 0, 2, 0, 0, 0, 7, 6]); + EqArr<8>(RotateLeft<8>([4, 5, 1, 0, 1, 0, 0, 3], 2), [0, 3, 4, 5, 1, 0, 1, 0]); } diff --git a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir index 133dde56..54ce2df7 100644 --- a/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir +++ b/zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir @@ -3,7 +3,6 @@ import inst; import consts; - // Prepare to read a certain length, maybe return a smaller one extern HostReadPrepare(fd: Val, len: Val): Val; @@ -87,8 +86,8 @@ component ECallHostReadSetup(cycle: Reg, input: InstInput) { lenDecomp := DecomposeLow2(newLen); // Check if length is exactly 1, 2, or 3 len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero); - // Check if things are 'uneven' - uneven := Reg(len123 * ptrDecomp.low2Nonzero); + // Check if things are 'uneven' (this is an 'or') + uneven := Reg(len123 + ptrDecomp.low2Nonzero - len123 * ptrDecomp.low2Nonzero); // Now pick the next cycle nextCycle := // If length == 0, go back to decoding @@ -118,11 +117,54 @@ component ECallHostWrite(cycle: Reg, input: InstInput) { ECallOutput(StateDecode(), 0, 0, 0) } -component ECallHostReadBytes(cycle: Reg, input: InstInput) { - // TODO +component ECallHostReadBytes(cycle: Reg, input: InstInput, ptrWord: Val, ptrLow2: Val, len: Val) { input.state = StateHostReadBytes(); - 0 = 1; - ECallOutput(16, 0, 0, 0) + // Decompose next len + lenDecomp := DecomposeLow2(len - 1); + // Check if length is exactly 1, 2, or 3 + len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero); + // Check is next pointer is even (this can only happen if Low2 == 3) + nextPtrEven := IsZero(ptrLow2 - 3); + nextPtrUneven := 1 - nextPtrEven; + nextPtrWord := nextPtrEven * (ptrWord + 1) + nextPtrUneven * ptrWord; + nextPtrLow2 := nextPtrUneven * (ptrLow2 + 1); + // Check if things are 'uneven' (this is an 'or') + uneven := Reg(len123 + nextPtrUneven - len123 * nextPtrUneven); + // Check is length is exactly zero + lenZero := IsZero(len - 1); + // Split low bits into parts + low0 := NondetBitReg(ptrLow2 & 1); + low1 := BitReg((ptrLow2 - low0) / 2); + // Load the original word + origWord := MemoryRead(cycle, ptrWord); + // Write the answer + io := MemoryWriteUnconstrained(cycle, ptrWord).io; + // Make the non-specified half matches + if (low1) { + origWord.low = io.newTxn.dataLow; + } else { + origWord.high = io.newTxn.dataHigh; + }; + // Get the half that changed + oldHalf := low1 * origWord.high + (1 - low1) * origWord.low; + newHalf := low1 * io.newTxn.dataHigh + (1 - low1) * io.newTxn.dataLow; + // Split both into bytes + oldBytes := SplitWord(oldHalf); + newBytes := SplitWord(newHalf); + // Make sure the non-specified bytes matchs + if (low0) { + oldBytes.byte0 = newBytes.byte0; + } else { + oldBytes.byte1 = newBytes.byte1; + }; + nextCycle := + // If length == 0, go back to decoding + lenZero * StateDecode() + + // If length != 0 and uneven, do bytes + (1 - lenZero) * uneven * StateHostReadBytes() + + // If length != 0 and even, more words + (1 - lenZero) * (1 - uneven) * StateHostReadWords(); + ECallOutput(nextCycle, nextPtrWord, nextPtrLow2, len - 1) } component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Val) { @@ -130,9 +172,9 @@ component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Va lenDecomp := DecomposeLow2(len); wordsDecomp := DecomposeLow2(lenDecomp.high); doWord := [ - wordsDecomp.low2Hot[1] * wordsDecomp.highZero, - wordsDecomp.low2Hot[2] * wordsDecomp.highZero, - wordsDecomp.low2Hot[3] * wordsDecomp.highZero, + (wordsDecomp.low2Hot[1] + wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero), + (wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3])* wordsDecomp.highZero + (1 - wordsDecomp.highZero), + (wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero), (1 - wordsDecomp.highZero) ]; count := reduce doWord init 0 with Add; @@ -140,14 +182,15 @@ component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Va addr := Reg(doWord[i] * (ptrWord + i) + (1 - doWord[i]) * SafeWriteWord()); MemoryWriteUnconstrained(cycle, addr); }; - lenZero := IsZero(len - 4 * count); + newLenHighZero := IsZero(lenDecomp.high - count); + lenZero := Reg(newLenHighZero * (1 - lenDecomp.low2Nonzero)); nextCycle := // If length == 0, go back to decoding lenZero * StateDecode() + // If length != 0 and uneven, do bytes - (1 - lenZero) * (lenDecomp.low2Nonzero) * StateHostReadBytes() + - // If lengtj != 0 and even, more words - (1 - lenZero) * (1 - lenDecomp.low2Nonzero) * StateHostReadWords(); + (1 - lenZero) * newLenHighZero * StateHostReadBytes() + + // If length != 0 and even, more words + (1 - lenZero) * (1 - newLenHighZero) * StateHostReadWords(); ECallOutput(nextCycle, ptrWord + count, 0, len - count * 4) } @@ -163,7 +206,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) { ECallTerminate(cycle, inst_input), ECallHostReadSetup(cycle, inst_input), ECallHostWrite(cycle, inst_input), - ECallHostReadBytes(cycle, inst_input), + ECallHostReadBytes(cycle, inst_input, s0@1, s1@1, s2@1), ECallHostReadWords(cycle, inst_input, s0@1, s2@1), IllegalECall(), IllegalECall() diff --git a/zirgen/circuit/rv32im/v2/dsl/mem.zir b/zirgen/circuit/rv32im/v2/dsl/mem.zir index a5b07963..e0923b9e 100644 --- a/zirgen/circuit/rv32im/v2/dsl/mem.zir +++ b/zirgen/circuit/rv32im/v2/dsl/mem.zir @@ -101,7 +101,7 @@ component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) { // Let the host write anythings (used in host read words) component MemoryWriteUnconstrained(cycle: Reg, addr: Val) { - io := MemoryIO(2*cycle + 1, addr); + public io := MemoryIO(2*cycle + 1, addr); IsForward(io); } diff --git a/zirgen/circuit/rv32im/v2/dsl/top.zir b/zirgen/circuit/rv32im/v2/dsl/top.zir index 50df4719..4f91fc68 100644 --- a/zirgen/circuit/rv32im/v2/dsl/top.zir +++ b/zirgen/circuit/rv32im/v2/dsl/top.zir @@ -1,9 +1,5 @@ // RUN: true -// TODO: Now that the v2 circuit uses an extern to compute major/minor it no -// longer makes sense to do rv32im conformance testing here. Make sure -// integration tests are covering this. - import inst_div; import inst_misc; import inst_mul; diff --git a/zirgen/circuit/rv32im/v2/emu/preflight.cpp b/zirgen/circuit/rv32im/v2/emu/preflight.cpp index e2d30b47..2f217ed6 100644 --- a/zirgen/circuit/rv32im/v2/emu/preflight.cpp +++ b/zirgen/circuit/rv32im/v2/emu/preflight.cpp @@ -221,6 +221,7 @@ struct PreflightContext { } size_t rlen = segment.readRecord[curRead].size(); memcpy(data, segment.readRecord[curRead].data(), rlen); + curRead++; return rlen; } diff --git a/zirgen/circuit/rv32im/v2/emu/r0vm.h b/zirgen/circuit/rv32im/v2/emu/r0vm.h index 0fcfd6a8..84bec4b8 100644 --- a/zirgen/circuit/rv32im/v2/emu/r0vm.h +++ b/zirgen/circuit/rv32im/v2/emu/r0vm.h @@ -173,18 +173,22 @@ template struct R0Context { std::vector bytes(len); rlen = context.read(fd, bytes.data(), len); storeReg(REG_A0, rlen); + uint32_t i = 0; if (rlen == 0) { context.pc += 4; } context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); curState = nextState(ptr, rlen); - uint32_t i = 0; while (rlen > 0 && ptr % 4 != 0) { writeByte(ptr, bytes[i]); - // context.hostReadBytes(ptr); ptr++; i++; rlen--; + if (rlen == 0) { + context.pc += 4; + } + context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); + curState = nextState(ptr, rlen); } while (rlen >= 4) { uint32_t words = std::min(rlen / 4, uint32_t(4)); @@ -195,12 +199,12 @@ template struct R0Context { word |= bytes[i + k] << (8 * k); } storeMem(ptr / 4, word); + ptr += 4; + i += 4; + rlen -= 4; } else { storeMem(SAFE_WRITE_WORD, 0); } - ptr += words; - i += words; - rlen -= words; } if (rlen == 0) { context.pc += 4; @@ -208,12 +212,16 @@ template struct R0Context { context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); curState = nextState(ptr, rlen); } - while (rlen > 0 && ptr % 4 != 0) { + while (rlen > 0) { writeByte(ptr, bytes[i]); - // context.hostReadBytes(ptr); ptr++; i++; rlen--; + if (rlen == 0) { + context.pc += 4; + } + context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen); + curState = nextState(ptr, rlen); } return false; } diff --git a/zirgen/circuit/rv32im/v2/test/BUILD.bazel b/zirgen/circuit/rv32im/v2/test/BUILD.bazel index 259f238d..90b7ee36 100644 --- a/zirgen/circuit/rv32im/v2/test/BUILD.bazel +++ b/zirgen/circuit/rv32im/v2/test/BUILD.bazel @@ -52,4 +52,27 @@ cc_binary( ], ) +risc0_cc_kernel_binary( + name = "test_io_kernel", + srcs = [ + "entry.s", + "test_io_kernel.cpp", + ], + deps = ["//zirgen/circuit/rv32im/v2/platform:core"], +) + +cc_test( + name = "test_io", + srcs = [ + "test_io.cpp", + ], + data = [ + ":test_io_kernel", + ], + deps = [ + "//risc0/core", + "//zirgen/circuit/rv32im/v2/run", + ] +) + riscv_test_suite() diff --git a/zirgen/circuit/rv32im/v2/test/test_io.cpp b/zirgen/circuit/rv32im/v2/test/test_io.cpp new file mode 100644 index 00000000..33287ef9 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_io.cpp @@ -0,0 +1,51 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" +#include "zirgen/circuit/rv32im/v2/run/run.h" + +using namespace zirgen::rv32im_v2; + +const std::string kernelName = "zirgen/circuit/rv32im/v2/test/test_io_kernel"; + +// Allows reads of any size, fill with a pattern to check in kernel +struct RandomReadSizeHandler : public HostIoHandler { + uint32_t write(uint32_t fd, const uint8_t* data, uint32_t len) override { return len; } + uint32_t read(uint32_t fd, uint8_t* data, uint32_t len) override { + std::cout << "DOING READ OF SIZE " << len << "\n"; + for(size_t i = 0; i < len; i++) { + data[i] = i; + } + return len; + } +}; + + +int main() { + size_t cycles = 100000; + RandomReadSizeHandler io; + + // Load image + auto image = MemoryImage::fromRawElf(kernelName); + // Do executions + auto segments = execute(image, io, cycles, cycles); + // Do 'run' (preflight + expansion) + for (const auto& segment : segments) { + std::cout << "HEY, doing a segment!\n"; + runSegment(segment, cycles + 1000); + } + std::cout << "What a fine day\n"; +} diff --git a/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp b/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp new file mode 100644 index 00000000..015f3dc1 --- /dev/null +++ b/zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp @@ -0,0 +1,83 @@ +// Copyright 2024 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "zirgen/circuit/rv32im/v2/platform/constants.h" + +using namespace zirgen::rv32im_v2; + +inline void die() { + asm("fence\n"); +} + +// Implement machine mode ECALLS + +inline void terminate(uint32_t val) { + register uintptr_t a0 asm("a0") = val; + register uintptr_t a7 asm("a7") = 0; + asm volatile("ecall\n" + : // no outputs + : "r"(a0), "r"(a7) // inputs + : // no clobbers + ); +} + +inline uint32_t host_read(uint32_t fd, uint32_t buf, uint32_t len) { + register uintptr_t a0 asm("a0") = fd; + register uintptr_t a1 asm("a1") = buf; + register uintptr_t a2 asm("a2") = len; + register uintptr_t a7 asm("a7") = 1; + asm volatile("ecall\n" + : "+r"(a0) // outputs + : "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs + : // no clobbers + ); + return a0; +} + +inline uint32_t host_write(uint32_t fd, uint32_t buf, uint32_t len) { + register uintptr_t a0 asm("a0") = fd; + register uintptr_t a1 asm("a1") = buf; + register uintptr_t a2 asm("a2") = len; + register uintptr_t a7 asm("a7") = 2; + asm volatile("ecall\n" + : "+r"(a0) // outputs + : "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs + : // no clobbers + ); + return a0; +} + +constexpr uint32_t sizes[11] = { 0, 1, 2, 3, 4, 5, 7, 13, 19, 40, 101 }; + +void test_multi_read() { + uint8_t buf[200]; + // Try all 4 alignments + for (size_t i = 0; i < 4; i++) { + // Try a variety of size + for (size_t j = 0; j < 11; j++) { + host_read(0, (uint32_t) (buf + i), sizes[j]); + for (size_t k = 0; k < sizes[j]; k++) { + if (buf[i + k] != k) { die(); } + } + } + } +} + +extern "C" void start() { + test_multi_read(); + terminate(0); +}