Skip to content

Commit

Permalink
Fix host read
Browse files Browse the repository at this point in the history
  • Loading branch information
jbruestle committed Dec 31, 2024
1 parent cc4d252 commit f3975db
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 33 deletions.
10 changes: 4 additions & 6 deletions zirgen/circuit/rv32im/v2/dsl/arr.zir
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// This file contains utilities that work with bits and twits.
// RUN: zirgen --test %s

// Vector / List functions
Expand Down Expand Up @@ -37,11 +36,10 @@ component EqArr<SIZE: Val>(a: Array<Val, SIZE>, b: Array<Val, SIZE>) {
// Tests....

test ShiftAndRotate {
// TODO: Now that these support non-bit values, maybe make new tests
// Remember: array entry 0 is the low bit, so there seem backwards
EqArr<8>(ShiftRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [0, 0, 1, 1, 1, 0, 1, 0]);
EqArr<8>(RotateRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 1, 1]);
EqArr<8>(RotateLeft<8>([1, 1, 1, 0, 1, 0, 0, 1], 2), [0, 1, 1, 1, 1, 0, 1, 0]);
EqArr<8>(ShiftRight<8>([3, 1, 5, 0, 2, 0, 0, 0], 2), [5, 0, 2, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 4, 2, 0, 6, 0, 0, 0], 2), [0, 0, 1, 4, 2, 0, 6, 0]);
EqArr<8>(RotateRight<8>([7, 6, 1, 0, 2, 0, 0, 0], 2), [1, 0, 2, 0, 0, 0, 7, 6]);
EqArr<8>(RotateLeft<8>([4, 5, 1, 0, 1, 0, 0, 3], 2), [0, 3, 4, 5, 1, 0, 1, 0]);
}

73 changes: 58 additions & 15 deletions zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inst;
import consts;


// Prepare to read a certain length, maybe return a smaller one
extern HostReadPrepare(fd: Val, len: Val): Val;

Expand Down Expand Up @@ -87,8 +86,8 @@ component ECallHostReadSetup(cycle: Reg, input: InstInput) {
lenDecomp := DecomposeLow2(newLen);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check if things are 'uneven'
uneven := Reg(len123 * ptrDecomp.low2Nonzero);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + ptrDecomp.low2Nonzero - len123 * ptrDecomp.low2Nonzero);
// Now pick the next cycle
nextCycle :=
// If length == 0, go back to decoding
Expand Down Expand Up @@ -118,36 +117,80 @@ component ECallHostWrite(cycle: Reg, input: InstInput) {
ECallOutput(StateDecode(), 0, 0, 0)
}

component ECallHostReadBytes(cycle: Reg, input: InstInput) {
// TODO
component ECallHostReadBytes(cycle: Reg, input: InstInput, ptrWord: Val, ptrLow2: Val, len: Val) {
input.state = StateHostReadBytes();
0 = 1;
ECallOutput(16, 0, 0, 0)
// Decompose next len
lenDecomp := DecomposeLow2(len - 1);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check is next pointer is even (this can only happen if Low2 == 3)
nextPtrEven := IsZero(ptrLow2 - 3);
nextPtrUneven := 1 - nextPtrEven;
nextPtrWord := nextPtrEven * (ptrWord + 1) + nextPtrUneven * ptrWord;
nextPtrLow2 := nextPtrUneven * (ptrLow2 + 1);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + nextPtrUneven - len123 * nextPtrUneven);
// Check is length is exactly zero
lenZero := IsZero(len - 1);
// Split low bits into parts
low0 := NondetBitReg(ptrLow2 & 1);
low1 := BitReg((ptrLow2 - low0) / 2);
// Load the original word
origWord := MemoryRead(cycle, ptrWord);
// Write the answer
io := MemoryWriteUnconstrained(cycle, ptrWord).io;
// Make the non-specified half matches
if (low1) {
origWord.low = io.newTxn.dataLow;
} else {
origWord.high = io.newTxn.dataHigh;
};
// Get the half that changed
oldHalf := low1 * origWord.high + (1 - low1) * origWord.low;
newHalf := low1 * io.newTxn.dataHigh + (1 - low1) * io.newTxn.dataLow;
// Split both into bytes
oldBytes := SplitWord(oldHalf);
newBytes := SplitWord(newHalf);
// Make sure the non-specified bytes matchs
if (low0) {
oldBytes.byte0 = newBytes.byte0;
} else {
oldBytes.byte1 = newBytes.byte1;
};
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * uneven * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - uneven) * StateHostReadWords();
ECallOutput(nextCycle, nextPtrWord, nextPtrLow2, len - 1)
}

component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Val) {
input.state = StateHostReadWords();
lenDecomp := DecomposeLow2(len);
wordsDecomp := DecomposeLow2(lenDecomp.high);
doWord := [
wordsDecomp.low2Hot[1] * wordsDecomp.highZero,
wordsDecomp.low2Hot[2] * wordsDecomp.highZero,
wordsDecomp.low2Hot[3] * wordsDecomp.highZero,
(wordsDecomp.low2Hot[1] + wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3])* wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(1 - wordsDecomp.highZero)
];
count := reduce doWord init 0 with Add;
for i : 0..4 {
addr := Reg(doWord[i] * (ptrWord + i) + (1 - doWord[i]) * SafeWriteWord());
MemoryWriteUnconstrained(cycle, addr);
};
lenZero := IsZero(len - 4 * count);
newLenHighZero := IsZero(lenDecomp.high - count);
lenZero := Reg(newLenHighZero * (1 - lenDecomp.low2Nonzero));
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * (lenDecomp.low2Nonzero) * StateHostReadBytes() +
// If lengtj != 0 and even, more words
(1 - lenZero) * (1 - lenDecomp.low2Nonzero) * StateHostReadWords();
(1 - lenZero) * newLenHighZero * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - newLenHighZero) * StateHostReadWords();
ECallOutput(nextCycle, ptrWord + count, 0, len - count * 4)
}

Expand All @@ -163,7 +206,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
ECallTerminate(cycle, inst_input),
ECallHostReadSetup(cycle, inst_input),
ECallHostWrite(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input, s0@1, s1@1, s2@1),
ECallHostReadWords(cycle, inst_input, s0@1, s2@1),
IllegalECall(),
IllegalECall()
Expand Down
2 changes: 1 addition & 1 deletion zirgen/circuit/rv32im/v2/dsl/mem.zir
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) {

// Let the host write anythings (used in host read words)
component MemoryWriteUnconstrained(cycle: Reg, addr: Val) {
io := MemoryIO(2*cycle + 1, addr);
public io := MemoryIO(2*cycle + 1, addr);
IsForward(io);
}

Expand Down
4 changes: 0 additions & 4 deletions zirgen/circuit/rv32im/v2/dsl/top.zir
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
// RUN: true

// TODO: Now that the v2 circuit uses an extern to compute major/minor it no
// longer makes sense to do rv32im conformance testing here. Make sure
// integration tests are covering this.

import inst_div;
import inst_misc;
import inst_mul;
Expand Down
1 change: 1 addition & 0 deletions zirgen/circuit/rv32im/v2/emu/preflight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct PreflightContext {
}
size_t rlen = segment.readRecord[curRead].size();
memcpy(data, segment.readRecord[curRead].data(), rlen);
curRead++;
return rlen;
}

Expand Down
22 changes: 15 additions & 7 deletions zirgen/circuit/rv32im/v2/emu/r0vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,22 @@ template <typename Context> struct R0Context {
std::vector<uint8_t> bytes(len);
rlen = context.read(fd, bytes.data(), len);
storeReg(REG_A0, rlen);
uint32_t i = 0;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
uint32_t i = 0;
while (rlen > 0 && ptr % 4 != 0) {
writeByte(ptr, bytes[i]);
// context.hostReadBytes(ptr);
ptr++;
i++;
rlen--;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
while (rlen >= 4) {
uint32_t words = std::min(rlen / 4, uint32_t(4));
Expand All @@ -195,25 +199,29 @@ template <typename Context> struct R0Context {
word |= bytes[i + k] << (8 * k);
}
storeMem(ptr / 4, word);
ptr += 4;
i += 4;
rlen -= 4;
} else {
storeMem(SAFE_WRITE_WORD, 0);
}
ptr += words;
i += words;
rlen -= words;
}
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
while (rlen > 0 && ptr % 4 != 0) {
while (rlen > 0) {
writeByte(ptr, bytes[i]);
// context.hostReadBytes(ptr);
ptr++;
i++;
rlen--;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
return false;
}
Expand Down
23 changes: 23 additions & 0 deletions zirgen/circuit/rv32im/v2/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,27 @@ cc_binary(
],
)

risc0_cc_kernel_binary(
name = "test_io_kernel",
srcs = [
"entry.s",
"test_io_kernel.cpp",
],
deps = ["//zirgen/circuit/rv32im/v2/platform:core"],
)

cc_test(
name = "test_io",
srcs = [
"test_io.cpp",
],
data = [
":test_io_kernel",
],
deps = [
"//risc0/core",
"//zirgen/circuit/rv32im/v2/run",
]
)

riscv_test_suite()
51 changes: 51 additions & 0 deletions zirgen/circuit/rv32im/v2/test/test_io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>

#include "zirgen/circuit/rv32im/v2/platform/constants.h"
#include "zirgen/circuit/rv32im/v2/run/run.h"

using namespace zirgen::rv32im_v2;

const std::string kernelName = "zirgen/circuit/rv32im/v2/test/test_io_kernel";

// Allows reads of any size, fill with a pattern to check in kernel
struct RandomReadSizeHandler : public HostIoHandler {
uint32_t write(uint32_t fd, const uint8_t* data, uint32_t len) override { return len; }
uint32_t read(uint32_t fd, uint8_t* data, uint32_t len) override {
std::cout << "DOING READ OF SIZE " << len << "\n";
for(size_t i = 0; i < len; i++) {
data[i] = i;
}
return len;
}
};


int main() {
size_t cycles = 100000;
RandomReadSizeHandler io;

// Load image
auto image = MemoryImage::fromRawElf(kernelName);
// Do executions
auto segments = execute(image, io, cycles, cycles);
// Do 'run' (preflight + expansion)
for (const auto& segment : segments) {
std::cout << "HEY, doing a segment!\n";
runSegment(segment, cycles + 1000);
}
std::cout << "What a fine day\n";
}
83 changes: 83 additions & 0 deletions zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdint.h>
#include <sys/errno.h>

#include "zirgen/circuit/rv32im/v2/platform/constants.h"

using namespace zirgen::rv32im_v2;

inline void die() {
asm("fence\n");
}

// Implement machine mode ECALLS

inline void terminate(uint32_t val) {
register uintptr_t a0 asm("a0") = val;
register uintptr_t a7 asm("a7") = 0;
asm volatile("ecall\n"
: // no outputs
: "r"(a0), "r"(a7) // inputs
: // no clobbers
);
}

inline uint32_t host_read(uint32_t fd, uint32_t buf, uint32_t len) {
register uintptr_t a0 asm("a0") = fd;
register uintptr_t a1 asm("a1") = buf;
register uintptr_t a2 asm("a2") = len;
register uintptr_t a7 asm("a7") = 1;
asm volatile("ecall\n"
: "+r"(a0) // outputs
: "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs
: // no clobbers
);
return a0;
}

inline uint32_t host_write(uint32_t fd, uint32_t buf, uint32_t len) {
register uintptr_t a0 asm("a0") = fd;
register uintptr_t a1 asm("a1") = buf;
register uintptr_t a2 asm("a2") = len;
register uintptr_t a7 asm("a7") = 2;
asm volatile("ecall\n"
: "+r"(a0) // outputs
: "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs
: // no clobbers
);
return a0;
}

constexpr uint32_t sizes[11] = { 0, 1, 2, 3, 4, 5, 7, 13, 19, 40, 101 };

void test_multi_read() {
uint8_t buf[200];
// Try all 4 alignments
for (size_t i = 0; i < 4; i++) {
// Try a variety of size
for (size_t j = 0; j < 11; j++) {
host_read(0, (uint32_t) (buf + i), sizes[j]);
for (size_t k = 0; k < sizes[j]; k++) {
if (buf[i + k] != k) { die(); }
}
}
}
}

extern "C" void start() {
test_multi_read();
terminate(0);
}

0 comments on commit f3975db

Please sign in to comment.