Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-306: Add SHA-2 accelerator, fix division overconstraint bug, add rv32im compliance tests to C++ #144

Merged
merged 6 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/arr.zir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// This file contains utilities that work with bits and twits.
// RUN: zirgen --test %s

// Vector / List functions

// Shifts + Rotates
component RotateLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { in[SIZE + i - n] }
}
}

component RotateRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { in[i + n - SIZE] }
}
}

component ShiftLeft<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i - n, SIZE)) { in[i - n] } else { 0 }
}
}

component ShiftRight<SIZE: Val>(in: Array<Val, SIZE>, n: Val) {
for i : 0..SIZE {
if (InRange(0, i + n, SIZE)) { in[i + n] } else { 0 }
}
}

component EqArr<SIZE: Val>(a: Array<Val, SIZE>, b: Array<Val, SIZE>) {
for i : 0..SIZE {
a[i] = b[i];
}
}

// Tests....

test ShiftAndRotate {
// TODO: Now that these support non-bit values, maybe make new tests
// Remember: array entry 0 is the low bit, so there seem backwards
EqArr<8>(ShiftRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [0, 0, 1, 1, 1, 0, 1, 0]);
EqArr<8>(RotateRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 1, 1]);
EqArr<8>(RotateLeft<8>([1, 1, 1, 0, 1, 0, 0, 1], 2), [0, 1, 1, 1, 1, 0, 1, 0]);
}

10 changes: 7 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/bits.zir
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ function AssertTwit(val: Val) {
val * (1 - val) * (2 - val) * (3 - val) = 0;
}

// Simple bit ops
component BitAnd(a: Val, b: Val) {
Reg(a * b)
a * b
}

component BitOr(a: Val, b: Val) {
Reg(1 - (1 - a) * (1 - b))
a + b - a * b
}

component BitXor(a: Val, b: Val) {
a + b - 2 * a * b
}

// Set a register nodeterministically, and then verify it is a twit
Expand Down Expand Up @@ -81,4 +86,3 @@ test TwitInRange{
test_fails TwitOutOfRange {
AssertTwit(4);
}

9 changes: 8 additions & 1 deletion zirgen/circuit/rv32im/v2/dsl/consts.zir
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ component StatePoseidonStoreState() { 23 }
component StatePoseidonExtRound() { 24 }
component StatePoseidonIntRounds() { 25 }

component StateDecode() { 32 }
component StateShaEcall() { 32 }
component StateShaLoadState() { 33 }
component StateShaLoadData() { 34 }
component StateShaMix() { 35 }
component StateShaStoreState() { 36 }

component StateDecode() { 40 }

component RegA0() { 10 }
component RegA1() { 11 }
component RegA2() { 12 }
component RegA3() { 13 }
component RegA4() { 14 }

component RegA7() { 17 }
23 changes: 20 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_div.zir
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,36 @@ component DoDiv(numer: ValU32, denom: ValU32, signed: Val, ones_comp: Val) {
settings := MultiplySettings(signed, signed, signed);
// Do the accumulate
mul := MultiplyAccumulate(quot, denom, rem, settings);
// Check the main result (numer = quot * denom + rem
// Check the main result (numer = quot * denom + rem)
AssertEqU32(mul.outLow, numer);
// The top bits should all be 0 or all be 1
topBitType := NondetBitReg(1 - Isz(mul.outHigh.low));
AssertEqU32(mul.outHigh, ValU32(0xffff * topBitType, 0xffff * topBitType));
// Check if denom is zero
isZero := IsZero(denom.low + denom.high);
// Get top bit of numerator
topNum := NondetBitReg((numer.high & 0x8000) / 0x8000);
// Verify we got it right
U16Reg((numer.high - 0x8000 * topNum) * 2);
numNeg := topNum * signed;
// Get the absolute value of the denominator
denomNeg := mul.bNeg;
denomAbs := NormalizeU32(DenormedValU32(
denomNeg * (0x10000 - denom.low) + (1 - denomNeg) * denom.low,
denomNeg * (0xffff - denom.high) + (1 - denomNeg) * denom.high
));
// Flip the sign of the remainder if the numerator is negative
remNormal := NormalizeU32(DenormedValU32(
numNeg * (0x10000 - rem.low) + (1 - numNeg) * rem.low,
numNeg * (0xffff - rem.high) + (1 - numNeg) * rem.high
));
// Decide if we need to swap order of
// If non-zero, make sure 0 <= rem < denom
if (isZero) {
AssertEqU32(rem, numer);
} else {
cmp := CmpLessThanUnsigned(rem, denom);
cmp.is_less_than = 1;
lt := CmpLessThanUnsigned(remNormal, denomAbs);
lt.is_less_than = 1;
};
DivideReturn(quot, rem)
}
Expand Down
8 changes: 5 additions & 3 deletions zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ component MachineECall(cycle: Reg, input: InstInput, pc_addr: Val) {
input.mode = 1;
dispatch_idx := MemoryRead(cycle, MachineRegBase() + RegA7());
dispatch_idx.high = 0;
dispatch := OneHot<4>(dispatch_idx.low);
dispatch := OneHot<5>(dispatch_idx.low);
state := dispatch -> (
StateTerminate(),
StateHostReadSetup(),
StateHostWrite(),
StatePoseidonEcall()
StatePoseidonEcall(),
StateShaEcall()
);
ECallOutput(state, 0, 0, 0)
}
Expand Down Expand Up @@ -172,6 +173,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
s2 := Reg(output.s2);
isDecode := IsZero(output.state - StateDecode());
isP2Entry := IsZero(output.state - StatePoseidonEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry) * 4, 0)));
isShaEcall := IsZero(output.state - StateShaEcall());
addPC := NormalizeU32(AddU32(inst_input.pc_u32, ValU32((isDecode + isP2Entry + isShaEcall) * 4, 0)));
InstOutput(addPC, output.state, 1)
}
2 changes: 2 additions & 0 deletions zirgen/circuit/rv32im/v2/dsl/inst_p2.zir
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ component PoseidonPaging(cycle: Reg, mode: Val, prev: PoseidonState) {

component Poseidon0(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonEntry() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonEntry(cycle, inst_input.pc_u32, inst_input.mode),
Expand All @@ -480,6 +481,7 @@ component Poseidon0(cycle:Reg, inst_input: InstInput) {

component Poseidon1(cycle:Reg, inst_input: InstInput) {
DoCycleTable(cycle);
inst_input.state = StatePoseidonExtRound() + inst_input.minor;
state : PoseidonState;
state := inst_input.minor_onehot -> (
PoseidonExtRound(state@1),
Expand Down
Loading
Loading