diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index aace4d19..6ba0fd47 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -39,10 +39,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b6_nightly_20240923.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b7_nightly_20240930.tar.gz", ], - strip_prefix = "yacl-0.4.5b6_nightly_20240923", - sha256 = "14eaaf7ad4aead7f2244e56453fead4a47973a020e23739ca0fe93873866bb5f", + strip_prefix = "yacl-0.4.5b7_nightly_20240930", + sha256 = "cf8dc7cceb9c5d05df00f1c086feec99d554db3e3cbe101253cf2a5a1adb9072", ) def _libpsi(): diff --git a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir index 423528ac..342b7d3e 100644 --- a/libspu/compiler/tests/interpret/dynamic_update_slice.mlir +++ b/libspu/compiler/tests/interpret/dynamic_update_slice.mlir @@ -19,3 +19,22 @@ func.func @dynamic_update_slice() { pphlo.custom_call @expect_eq (%result, %expected) : (tensor<4x4xi64>,tensor<4x4xi64>)->() func.return } + +// ----- + +func.func @dynamic_update_slice() { + %operand = pphlo.constant dense<[[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 2, 2, 2], + [1, 2, 2, 2]]> : tensor<4x4xi64> + %update = pphlo.constant dense<[[1, 1, 1], + [1, 1, 1]]> : tensor<2x3xi64> + %i0 = pphlo.constant dense<4> : tensor + %start_indices0 = pphlo.convert %i0 : (tensor) -> tensor> + %start_indices1 = pphlo.constant dense<4> : tensor + %result = pphlo.dynamic_update_slice %operand, %update, %start_indices0, %start_indices1 : + (tensor<4x4xi64>, tensor<2x3xi64>, tensor>, tensor) -> tensor<4x4x!pphlo.secret> + %expected = pphlo.constant dense<[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]> : tensor<4x4xi64> + pphlo.custom_call @expect_eq (%result, %expected) : (tensor<4x4x!pphlo.secret>, tensor<4x4xi64>)->() + func.return +} diff --git a/libspu/core/xt_helper.h b/libspu/core/xt_helper.h index f76e7d92..44921507 100644 --- a/libspu/core/xt_helper.h +++ b/libspu/core/xt_helper.h @@ -66,4 +66,3 @@ NdArrayRef xt_to_ndarray(const xt::xexpression& e) { template struct fmt::is_range, char> : std::false_type {}; - diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 6cfe3039..1b22d8a7 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -421,13 +421,16 @@ LogicalResult inferDynamicUpdateSliceOp( // dynamic_update_slice_c1 TypeTools tools(operand.getContext()); - auto common_vis = - tools.computeCommonVisibility({tools.getTypeVisibility(operandType), - tools.getTypeVisibility(updateType)}); + auto vis = llvm::map_to_vector(startIndices, [&](mlir::Value v) { + return tools.getTypeVisibility(v.getType()); + }); + vis.emplace_back(tools.getTypeVisibility(operand.getType())); + vis.emplace_back(tools.getTypeVisibility(update.getType())); - inferredReturnTypes.emplace_back(RankedTensorType::get( - operandType.getShape(), - tools.getType(operandType.getElementType(), common_vis))); + inferredReturnTypes.emplace_back( + RankedTensorType::get(operandType.getShape(), + tools.getType(operandType.getElementType(), + tools.computeCommonVisibility(vis)))); return success(); } diff --git a/libspu/mpc/aby3/permute.cc b/libspu/mpc/aby3/permute.cc index fb80bfd0..bd5f95c1 100644 --- a/libspu/mpc/aby3/permute.cc +++ b/libspu/mpc/aby3/permute.cc @@ -23,33 +23,13 @@ namespace spu::mpc::aby3 { -namespace { - -PermVector ring2pv(const NdArrayRef& x) { - SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", - x.eltype()); - const auto field = x.eltype().as()->field(); - PermVector pv(x.numel()); - DISPATCH_ALL_FIELDS(field, [&]() { - NdArrayView _x(x); - pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); - }); - return pv; -} - -} // namespace - NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); - // generate a RandU64 pair as permutation seeds auto* prg_state = ctx->getState(); - const auto [seed_self, seed_next] = - prg_state->genPrssPair(FieldType::FM64, {1}, PrgState::GenPrssCtrl::Both); - NdArrayView _seed_self(seed_self); - NdArrayView _seed_next(seed_next); - const auto pv_self = genRandomPerm(out.numel(), _seed_self[0]); - const auto pv_next = genRandomPerm(out.numel(), _seed_next[0]); + const auto& pvs = prg_state->genPrssPermPair(out.numel()); + const auto& pv_self = pvs.first; + const auto& pv_next = pvs.second; const auto field = out.eltype().as()->field(); auto out1 = getFirstShare(out); @@ -74,8 +54,8 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in.eltype().as()->field(); auto* prg_state = ctx->getState(); - PermVector pv_self = ring2pv(getFirstShare(perm)); - PermVector pv_next = ring2pv(getSecondShare(perm)); + auto pv_self = getFirstShare(perm); + auto pv_next = getSecondShare(perm); NdArrayRef out(in.eltype(), in.shape()); DISPATCH_ALL_FIELDS(field, [&]() { @@ -90,77 +70,87 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, prg_state->fillPrssPair(a0.data(), a1.data(), a0.size(), PrgState::GenPrssCtrl::Both); - if (comm->getRank() == 0) { - std::vector tmp(numel); - std::vector delta(numel); - pforeach(0, numel, [&](int64_t idx) { - tmp[idx] = _in[pv_self[idx]][0] + _in[pv_self[idx]][1] - a0[idx]; - }); - pforeach(0, numel, - [&](int64_t idx) { delta[idx] = tmp[pv_next[idx]] - a1[idx]; }); - comm->sendAsync(2, delta, "delta"); - - // 2to3 re-share - std::vector r0(numel); - std::vector r1(numel); - prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), - PrgState::GenPrssCtrl::Both); - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = r1[idx]; - }); - - } else if (comm->getRank() == 1) { - auto gama = comm->recv(2, "gama"); - std::vector tmp(numel); - std::vector beta(numel); - - pforeach(0, numel, - [&](int64_t idx) { tmp[idx] = gama[pv_self[idx]] + a0[idx]; }); - pforeach(0, numel, [&](int64_t idx) { beta[idx] = tmp[pv_next[idx]]; }); - - // 2to3 re-share - std::vector r0(numel); - prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), - PrgState::GenPrssCtrl::First); - pforeach(0, numel, [&](int64_t idx) { beta[idx] -= r0[idx]; }); - - comm->sendAsync(2, beta, "2to3"); - tmp = comm->recv(2, "2to3"); - - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = beta[idx] + tmp[idx]; - }); - - } else if (comm->getRank() == 2) { - std::vector gama(numel); - std::vector beta(numel); - pforeach(0, numel, [&](int64_t idx) { - gama[idx] = _in[pv_next[idx]][0] + a1[idx]; - }); - comm->sendAsync(1, gama, "gama"); - auto delta = comm->recv(0, "delta"); - pforeach(0, numel, [&](int64_t idx) { beta[idx] = delta[pv_self[idx]]; }); - - // 2to3 re-share - std::vector r1(numel); - prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), - PrgState::GenPrssCtrl::Second); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r1[idx]; - }); - comm->sendAsync(1, beta, "2to3"); - auto tmp = comm->recv(1, "2to3"); - - // rebuild the final result. - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = beta[idx] + tmp[idx]; - _out[idx][1] = r1[idx]; - }); - } else { - SPU_THROW("Party number exceeds 3!"); - } + const auto pv_field = pv_self.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + using pv_t = ring2k_t; + NdArrayView _pv_self(pv_self); + NdArrayView _pv_next(pv_next); + if (comm->getRank() == 0) { + std::vector tmp(numel); + std::vector delta(numel); + pforeach(0, numel, [&](int64_t idx) { + tmp[idx] = _in[_pv_self[idx]][0] + _in[_pv_self[idx]][1] - a0[idx]; + }); + pforeach(0, numel, [&](int64_t idx) { + delta[idx] = tmp[_pv_next[idx]] - a1[idx]; + }); + comm->sendAsync(2, delta, "delta"); + + // 2to3 re-share + std::vector r0(numel); + std::vector r1(numel); + prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), + PrgState::GenPrssCtrl::Both); + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + } else if (comm->getRank() == 1) { + auto gama = comm->recv(2, "gama"); + std::vector tmp(numel); + std::vector beta(numel); + + pforeach(0, numel, [&](int64_t idx) { + tmp[idx] = gama[_pv_self[idx]] + a0[idx]; + }); + pforeach(0, numel, + [&](int64_t idx) { beta[idx] = tmp[_pv_next[idx]]; }); + + // 2to3 re-share + std::vector r0(numel); + prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), + PrgState::GenPrssCtrl::First); + pforeach(0, numel, [&](int64_t idx) { beta[idx] -= r0[idx]; }); + + comm->sendAsync(2, beta, "2to3"); + tmp = comm->recv(2, "2to3"); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = beta[idx] + tmp[idx]; + }); + + } else if (comm->getRank() == 2) { + std::vector gama(numel); + std::vector beta(numel); + pforeach(0, numel, [&](int64_t idx) { + gama[idx] = _in[_pv_next[idx]][0] + a1[idx]; + }); + comm->sendAsync(1, gama, "gama"); + auto delta = comm->recv(0, "delta"); + pforeach(0, numel, + [&](int64_t idx) { beta[idx] = delta[_pv_self[idx]]; }); + + // 2to3 re-share + std::vector r1(numel); + prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), + PrgState::GenPrssCtrl::Second); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r1[idx]; + }); + comm->sendAsync(1, beta, "2to3"); + auto tmp = comm->recv(1, "2to3"); + + // rebuild the final result. + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = beta[idx] + tmp[idx]; + _out[idx][1] = r1[idx]; + }); + } else { + SPU_THROW("Party number exceeds 3!"); + } + }); }); return out; } @@ -170,11 +160,10 @@ NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); if (out.numel() != 0) { - PermVector pv = ring2pv(perm); - const auto& in1 = getFirstShare(in); - const auto& in2 = getSecondShare(in); - auto perm1 = applyPerm(in1, pv); - auto perm2 = applyPerm(in2, pv); + const auto in1 = getFirstShare(in); + const auto in2 = getSecondShare(in); + auto perm1 = applyPerm(in1, perm); + auto perm2 = applyPerm(in2, perm); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); @@ -194,8 +183,8 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in.eltype().as()->field(); auto* prg_state = ctx->getState(); - PermVector pv_self = ring2pv(getFirstShare(perm)); - PermVector pv_next = ring2pv(getSecondShare(perm)); + auto pv_self = getFirstShare(perm); + auto pv_next = getSecondShare(perm); NdArrayRef out(in.eltype(), in.shape()); DISPATCH_ALL_FIELDS(field, [&]() { @@ -210,82 +199,91 @@ NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, prg_state->fillPrssPair(a0.data(), a1.data(), a0.size(), PrgState::GenPrssCtrl::Both); - if (comm->getRank() == 0) { - std::vector beta(numel); - std::vector tmp(numel); - auto gama = comm->recv(2, "gama"); - - pforeach(0, numel, [&](int64_t idx) { - tmp[pv_next[idx]] = gama[idx] + a1[pv_next[idx]]; - }); - pforeach(0, numel, [&](int64_t idx) { beta[pv_self[idx]] = tmp[idx]; }); - - // 2to3 re-share - std::vector r1(numel); - prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), - PrgState::GenPrssCtrl::Second); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r1[idx]; - }); - comm->sendAsync(2, beta, "2to3"); - tmp = comm->recv(2, "2to3"); - - // rebuild the final result. - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = beta[idx] + tmp[idx]; - _out[idx][1] = r1[idx]; - }); - } else if (comm->getRank() == 1) { - std::vector tmp(numel); - std::vector delta(numel); - - pforeach(0, numel, [&](int64_t idx) { - tmp[pv_next[idx]] = _in[idx][0] + _in[idx][1] - a1[pv_next[idx]]; - }); - pforeach(0, numel, [&](int64_t idx) { - delta[pv_self[idx]] = tmp[idx] - a0[pv_self[idx]]; - }); - comm->sendAsync(2, delta, "delta"); - - // 2to3 re-share - std::vector r0(numel); - std::vector r1(numel); - prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), - PrgState::GenPrssCtrl::Both); - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = r1[idx]; - }); - - } else if (comm->getRank() == 2) { - std::vector gama(numel); - std::vector beta(numel); - pforeach(0, numel, [&](int64_t idx) { - gama[pv_self[idx]] = _in[idx][1] + a0[pv_self[idx]]; - }); - comm->sendAsync(0, gama, "gama"); - auto delta = comm->recv(1, "delta"); - pforeach(0, numel, [&](int64_t idx) { beta[pv_next[idx]] = delta[idx]; }); - - // 2to3 re-share - std::vector r0(numel); - prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), - PrgState::GenPrssCtrl::First); - pforeach(0, numel, [&](int64_t idx) { // - beta[idx] -= r0[idx]; - }); - - comm->sendAsync(0, beta, "2to3"); - auto tmp = comm->recv(0, "2to3"); - - pforeach(0, numel, [&](int64_t idx) { - _out[idx][0] = r0[idx]; - _out[idx][1] = beta[idx] + tmp[idx]; - }); - - } else { - SPU_THROW("Party number exceeds 3!"); - } + const auto pv_field = pv_self.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + using pv_t = ring2k_t; + NdArrayView _pv_self(pv_self); + NdArrayView _pv_next(pv_next); + + if (comm->getRank() == 0) { + std::vector beta(numel); + std::vector tmp(numel); + auto gama = comm->recv(2, "gama"); + + pforeach(0, numel, [&](int64_t idx) { + tmp[_pv_next[idx]] = gama[idx] + a1[_pv_next[idx]]; + }); + pforeach(0, numel, + [&](int64_t idx) { beta[_pv_self[idx]] = tmp[idx]; }); + + // 2to3 re-share + std::vector r1(numel); + prg_state->fillPrssPair(nullptr, r1.data(), r1.size(), + PrgState::GenPrssCtrl::Second); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r1[idx]; + }); + comm->sendAsync(2, beta, "2to3"); + tmp = comm->recv(2, "2to3"); + + // rebuild the final result. + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = beta[idx] + tmp[idx]; + _out[idx][1] = r1[idx]; + }); + } else if (comm->getRank() == 1) { + std::vector tmp(numel); + std::vector delta(numel); + + pforeach(0, numel, [&](int64_t idx) { + tmp[_pv_next[idx]] = _in[idx][0] + _in[idx][1] - a1[_pv_next[idx]]; + }); + pforeach(0, numel, [&](int64_t idx) { + delta[_pv_self[idx]] = tmp[idx] - a0[_pv_self[idx]]; + }); + comm->sendAsync(2, delta, "delta"); + + // 2to3 re-share + std::vector r0(numel); + std::vector r1(numel); + prg_state->fillPrssPair(r0.data(), r1.data(), r1.size(), + PrgState::GenPrssCtrl::Both); + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + } else if (comm->getRank() == 2) { + std::vector gama(numel); + std::vector beta(numel); + pforeach(0, numel, [&](int64_t idx) { + gama[_pv_self[idx]] = _in[idx][1] + a0[_pv_self[idx]]; + }); + comm->sendAsync(0, gama, "gama"); + auto delta = comm->recv(1, "delta"); + pforeach(0, numel, + [&](int64_t idx) { beta[_pv_next[idx]] = delta[idx]; }); + + // 2to3 re-share + std::vector r0(numel); + prg_state->fillPrssPair(r0.data(), nullptr, r0.size(), + PrgState::GenPrssCtrl::First); + pforeach(0, numel, [&](int64_t idx) { // + beta[idx] -= r0[idx]; + }); + + comm->sendAsync(0, beta, "2to3"); + auto tmp = comm->recv(0, "2to3"); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = beta[idx] + tmp[idx]; + }); + + } else { + SPU_THROW("Party number exceeds 3!"); + } + }); }); return out; } @@ -295,12 +293,11 @@ NdArrayRef InvPermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); if (out.numel() != 0) { - PermVector pv = ring2pv(perm); - const auto& in1 = getFirstShare(in); - const auto& in2 = getSecondShare(in); + const auto in1 = getFirstShare(in); + const auto in2 = getSecondShare(in); - auto perm1 = applyInvPerm(in1, pv); - auto perm2 = applyInvPerm(in2, pv); + auto perm1 = applyInvPerm(in1, perm); + auto perm2 = applyInvPerm(in2, perm); auto out1 = getFirstShare(out); auto out2 = getSecondShare(out); diff --git a/libspu/mpc/common/BUILD.bazel b/libspu/mpc/common/BUILD.bazel index 017c316a..375db84d 100644 --- a/libspu/mpc/common/BUILD.bazel +++ b/libspu/mpc/common/BUILD.bazel @@ -65,6 +65,7 @@ spu_cc_library( hdrs = ["prg_state.h"], deps = [ "//libspu/core:object", + "//libspu/mpc/utils:permute", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/link:context", diff --git a/libspu/mpc/common/prg_state.cc b/libspu/mpc/common/prg_state.cc index 0f68fcff..374fba19 100644 --- a/libspu/mpc/common/prg_state.cc +++ b/libspu/mpc/common/prg_state.cc @@ -19,6 +19,8 @@ #include "yacl/link/algorithm/allgather.h" #include "yacl/utils/serialize.h" +#include "libspu/mpc/utils/permute.h" + namespace spu::mpc { PrgState::PrgState() { @@ -108,4 +110,15 @@ NdArrayRef PrgState::genPubl(FieldType field, const Shape& shape) { return res; } +Index PrgState::genPrivPerm(size_t numel) { + return genRandomPerm(numel, priv_seed_, &priv_counter_); +} + +std::pair PrgState::genPrssPermPair(size_t numel) { + std::pair res; + res.first = genRandomPerm(numel, self_seed_, &r0_counter_); + res.second = genRandomPerm(numel, next_seed_, &r1_counter_); + return res; +} + } // namespace spu::mpc diff --git a/libspu/mpc/common/prg_state.h b/libspu/mpc/common/prg_state.h index 76ab79b6..93c0461b 100644 --- a/libspu/mpc/common/prg_state.h +++ b/libspu/mpc/common/prg_state.h @@ -15,6 +15,7 @@ #pragma once #include "absl/types/span.h" +#include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" #include "yacl/link/context.h" @@ -60,6 +61,11 @@ class PrgState : public State { NdArrayRef genPubl(FieldType field, const Shape& shape); + Index genPrivPerm(size_t numel); + + // Generate a random permutation pair (p0, p1). + std::pair genPrssPermPair(size_t numel); + // Generate a random pair (r0, r1), where // r1 = next_party.r0 // diff --git a/libspu/mpc/securenn/arithmetic.cc b/libspu/mpc/securenn/arithmetic.cc index fa44545b..2bb7bb6b 100644 --- a/libspu/mpc/securenn/arithmetic.cc +++ b/libspu/mpc/securenn/arithmetic.cc @@ -18,6 +18,8 @@ #include #include +#include "yacl/crypto/rand/rand.h" + #include "libspu/core/type_util.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" @@ -758,10 +760,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // P0 and P1 end execute if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - auto a_0 = comm->recv(0, ty, "a_"); auto a_1 = comm->recv(1, ty, "a_"); a_0 = a_0.reshape(a.shape()); @@ -784,7 +782,8 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, NdArrayView _dp_x_p0(dp_x_p0); NdArrayView _dp_x_p1(dp_x_p1); - NdArrayRef delta_p0(ty, a.shape()); + NdArrayRef delta_p0 = + ring_rand_range(field, a.shape(), 0, L_1 - 1); // (ty, a.shape()); NdArrayRef delta_p1(ty, a.shape()); NdArrayView _delta_p0(delta_p0); NdArrayView _delta_p1(delta_p1); @@ -803,7 +802,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // split delta in Z_(L_1) - _delta_p0[idx] = dis(gen); _delta_p1[idx] = _delta[idx] - _delta_p0[idx]; if (_delta[idx] < _delta_p0[idx]) _delta_p1[idx] -= (U)1; // when overflow @@ -816,7 +814,7 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, comm->sendAsync(1, delta_p1, "delta"); // split eta_ in Z_(L_1) - NdArrayRef eta_p0(ty, a.shape()); + NdArrayRef eta_p0 = ring_rand_range(field, a.shape(), 0, L_1 - 1); NdArrayRef eta_p1(ty, a.shape()); NdArrayView _eta_p0(eta_p0); NdArrayView _eta_p1(eta_p1); @@ -843,7 +841,6 @@ NdArrayRef ShareConvert::proc(KernelEvalContext* ctx, } // split eta_ in Z_(L_1) - _eta_p0[idx] = dis(gen); _eta_p1[idx] = _eta_[idx] - _eta_p0[idx]; if (_eta_[idx] < _eta_p0[idx]) _eta_p1[idx] -= (U)1; // when overflow }); // end pforeach @@ -889,10 +886,6 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto [u_r0, u_r1] = prg_state->genPrssPair(field, {size * k}, PrgState::GenPrssCtrl::Both); if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - // random for beaver // P2 generate a0, a1, b0, b1, c0 by PRF // and calculate c1 @@ -908,12 +901,12 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto c1 = ring_sub(ring_mul(ring_add(a0, a1), ring_add(b0, b1)), c0); // end beaver (c1 will be sent with x to reduce one round latency) - NdArrayRef x(ty, in.shape()); + NdArrayRef x = ring_rand_range(field, in.shape(), 0, L_1 - 1); NdArrayView _x(x); // split x into x_p0 and x_p1 in Z_(L-1), (L=2^k) - NdArrayRef x_p0(ty, in.shape()); + NdArrayRef x_p0 = ring_rand_range(field, in.shape(), 0, L_1 - 1); NdArrayRef x_p1(ty, in.shape()); NdArrayView _x_p0(x_p0); NdArrayView _x_p1(x_p1); @@ -932,11 +925,9 @@ NdArrayRef Msb::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef lsb_x(ty, in.shape()); NdArrayView _lsb_x(lsb_x); pforeach(0, size, [&](int64_t idx) { - _x[idx] = dis(gen); auto dp_x = bitDecompose(_x[idx], k); // vector // split x - _x_p0[idx] = dis(gen); _x_p1[idx] = _x[idx] - _x_p0[idx]; if (_x[idx] < _x_p0[idx]) _x_p1[idx] -= (U)1; // when overflow @@ -1237,10 +1228,6 @@ NdArrayRef Msb_opt::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto [beta_0, beta_1] = prg_state->genPrssPair(field, in.shape(), PrgState::GenPrssCtrl::Both); if (rank == 2) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(0, L_1 - 1); - // random for beaver // P2 generate a0, a1, b0, b1, c0 by PRF // and calculate c1 diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 58dcdc80..78cc71d2 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -870,7 +870,9 @@ TEST_P(BeaverTest, PermPair) { const size_t adjust_rank = std::get<4>(GetParam()); const int64_t kNumel = 10; std::random_device rd; - const auto r_perm = genRandomPerm(kNumel, rd()); + uint128_t seed = rd(); + uint64_t ctr = rd(); + const auto r_perm = genRandomPerm(kNumel, seed, &ctr); for (size_t r = 0; r < kWorldSize; ++r) { std::vector pairs(kWorldSize); diff --git a/libspu/mpc/semi2k/permute.cc b/libspu/mpc/semi2k/permute.cc index 21787a42..71f68ef4 100644 --- a/libspu/mpc/semi2k/permute.cc +++ b/libspu/mpc/semi2k/permute.cc @@ -40,18 +40,35 @@ inline int64_t getOwner(const NdArrayRef& x) { return x.eltype().as()->owner(); } +Index ring2pv(const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", + x.eltype()); + const auto field = x.eltype().as()->field(); + Index pv(x.numel()); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); + }); + return pv; +} + // Secure inverse permutation of x by perm_rank's permutation pv // The idea here is: // Input permutation pv, beaver generates perm pair {, } that // InversePermute(A, pv) = B. So we can get = InversePermute(open( - // ), pv) + that y = InversePermute(x, pv). NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, - size_t perm_rank, absl::Span pv) { + const NdArrayRef& perm, size_t perm_rank) { const auto lctx = ctx->lctx(); const auto field = x.eltype().as()->field(); auto* beaver = ctx->getState()->beaver(); auto numel = x.numel(); + Index pv; + if (perm.eltype().isa() || + (perm.eltype().isa() && isOwner(ctx, perm.eltype()))) { + pv = ring2pv(perm); + } auto [a_buf, b_buf] = beaver->PermPair(field, numel, perm_rank, pv); NdArrayRef a(std::make_shared(std::move(a_buf)), x.eltype(), @@ -75,11 +92,8 @@ NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef RandPermM::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(), shape); - // generate a RandU64 as permutation seed auto* prg_state = ctx->getState(); - const auto seed = prg_state->genPriv(FieldType::FM64, {1}); - NdArrayView _seed(seed); - const auto perm_vector = genRandomPerm(out.numel(), _seed[0]); + const auto perm_vector = prg_state->genPrivPerm(out.numel()); const auto field = out.eltype().as()->field(); DISPATCH_ALL_FIELDS(field, [&]() { @@ -95,51 +109,37 @@ NdArrayRef PermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); - PermVector pv = ring2pv(perm); NdArrayRef out(in); for (size_t i = 0; i < comm->getWorldSize(); ++i) { - out = SecureInvPerm(ctx, out, i, pv); + out = SecureInvPerm(ctx, out, perm, i); } - return out; } NdArrayRef PermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv = ring2pv(perm); - auto out = applyPerm(in, pv); - return out; + return applyPerm(in, perm); } NdArrayRef InvPermAM::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { auto* comm = ctx->getState(); - PermVector pv = ring2pv(perm); NdArrayRef out(in); - auto inv_pv = genInversePerm(pv); + auto inv_perm = genInversePerm(perm); for (int i = comm->getWorldSize() - 1; i >= 0; --i) { - out = SecureInvPerm(ctx, out, i, inv_pv); + out = SecureInvPerm(ctx, out, inv_perm, i); } - return out; } NdArrayRef InvPermAP::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv = ring2pv(perm); - auto out = applyInvPerm(in, pv); - return out; + return applyInvPerm(in, perm); } NdArrayRef InvPermAV::proc(KernelEvalContext* ctx, const NdArrayRef& in, const NdArrayRef& perm) const { - PermVector pv; - const auto lctx = ctx->lctx(); - if (isOwner(ctx, perm.eltype())) { - pv = ring2pv(perm); - } - auto out = SecureInvPerm(ctx, in, getOwner(perm), pv); - return out; + return SecureInvPerm(ctx, in, perm, getOwner(perm)); } } // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index ea70dbb9..00287e26 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -50,6 +50,7 @@ spu_cc_library( hdrs = ["permute.h"], deps = [ "//libspu/core:ndarray_ref", + "@yacl//yacl/crypto/rand", ], ) diff --git a/libspu/mpc/utils/permute.cc b/libspu/mpc/utils/permute.cc index 9ef42a61..62a8e85a 100644 --- a/libspu/mpc/utils/permute.cc +++ b/libspu/mpc/utils/permute.cc @@ -17,16 +17,18 @@ #include #include +#include "yacl/crypto/rand/rand.h" + #include "libspu/core/ndarray_ref.h" #include "libspu/core/type_util.h" namespace spu::mpc { -PermVector ring2pv(const NdArrayRef& x) { +Index ring2pv(const NdArrayRef& x) { SPU_ENFORCE(x.eltype().isa(), "must be ring2k_type, got={}", x.eltype()); const auto field = x.eltype().as()->field(); - PermVector pv(x.numel()); + Index pv(x.numel()); DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { pv[idx] = int64_t(_x[idx]); }); @@ -49,6 +51,26 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { return y; } +NdArrayRef applyInvPerm(const NdArrayRef& x, const NdArrayRef& pv) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + SPU_ENFORCE_EQ(x.shape(), pv.shape(), "x and pv should have same shape"); + + NdArrayRef y(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + NdArrayView _y(y); + const auto pv_field = pv.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + NdArrayView _pv(pv); + for (int64_t i = 0; i < y.numel(); i++) { + _y[_pv[i]] = _x[i]; + } + }); + }); + return y; +} + NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); @@ -64,26 +86,42 @@ NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { return y; } -PermVector genRandomPerm(size_t size, uint64_t seed) { - PermVector perm(size); - std::iota(perm.begin(), perm.end(), 0); - // TODO: change PRNG to CSPRNG - std::mt19937 rng(seed); - std::shuffle(perm.begin(), perm.end(), rng); - return perm; +NdArrayRef applyPerm(const NdArrayRef& x, const NdArrayRef& pv) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + SPU_ENFORCE_EQ(x.shape(), pv.shape(), "x and pv should have same shape"); + + NdArrayRef y(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _x(x); + NdArrayView _y(y); + const auto pv_field = pv.eltype().as()->field(); + DISPATCH_ALL_FIELDS(pv_field, [&]() { + NdArrayView _pv(pv); + for (int64_t i = 0; i < y.numel(); i++) { + _y[i] = _x[_pv[i]]; + } + }); + }); + return y; } -PermVector genInversePerm(absl::Span pv) { - PermVector ret(pv.size()); - for (size_t i = 0; i < pv.size(); ++i) { - ret[pv[i]] = i; - } +NdArrayRef genInversePerm(const NdArrayRef& perm) { + NdArrayRef ret(perm.eltype(), perm.shape()); + auto field = perm.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _perm(perm); + for (int64_t i = 0; i < perm.numel(); ++i) { + _ret[_perm[i]] = ring2k_t(i); + } + }); return ret; } -PermVector genPermBySort(const NdArrayRef& x) { +Index genPermBySort(const NdArrayRef& x) { SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); - PermVector perm(x.shape()[0]); + Index perm(x.shape()[0]); std::iota(perm.begin(), perm.end(), 0); const auto field = x.eltype().as()->field(); DISPATCH_ALL_FIELDS(field, [&]() { @@ -96,4 +134,11 @@ PermVector genPermBySort(const NdArrayRef& x) { return perm; } +Index genRandomPerm(size_t numel, uint128_t seed, uint64_t* ctr) { + Index perm(numel); + std::iota(perm.begin(), perm.end(), 0); + yacl::crypto::ReplayShuffle(perm.begin(), perm.end(), seed, ctr); + return perm; +} + } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/permute.h b/libspu/mpc/utils/permute.h index 9034d79b..5c4eb1d2 100644 --- a/libspu/mpc/utils/permute.h +++ b/libspu/mpc/utils/permute.h @@ -20,24 +20,23 @@ namespace spu::mpc { constexpr char kPermModule[] = "Permute"; -using PermVector = std::vector; - -PermVector genRandomPerm(size_t size, uint64_t seed); - -PermVector genInversePerm(absl::Span pv); +NdArrayRef genInversePerm(const NdArrayRef& perm); // generate permutation vector that can make x ordered -PermVector genPermBySort(const NdArrayRef& x); +Index genPermBySort(const NdArrayRef& x); // reorder 1-d tensor element by applying inverse permutation. // ret = ApplyInvPerm(x, pv) -> ret[pv[i]] = x[i] NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv); +NdArrayRef applyInvPerm(const NdArrayRef& x, const NdArrayRef& pv); // reorder 1-d tensor element by applying permutation. // ret = ApplyPerm(x, pv) -> ret[i] = x[pv[i]] NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv); +NdArrayRef applyPerm(const NdArrayRef& x, const NdArrayRef& pv); // get a permutation vector from a ring -PermVector ring2pv(const NdArrayRef& x); +Index ring2pv(const NdArrayRef& x); +Index genRandomPerm(size_t numel, uint128_t seed, uint64_t* ctr); } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 5d0eca0c..e0e67b2b 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -234,21 +234,27 @@ NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, return res; } -NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, - int32_t max) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution dis(min, max); +NdArrayRef ring_rand_range(FieldType field, const Shape& shape, uint128_t min, + uint128_t max) { + constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = + yacl::crypto::SymmetricCrypto::CryptoType::AES128_ECB; + constexpr uint64_t kAesInitialVector = 0U; + uint64_t cnt = 0; NdArrayRef x(makeType(field), shape); auto numel = x.numel(); DISPATCH_ALL_FIELDS(field, [&]() { + std::vector rand_range(numel); + yacl::crypto::FillPRandWithLtN( + kCryptoType, yacl::crypto::SecureRandSeed(), kAesInitialVector, cnt, + absl::MakeSpan(rand_range), static_cast(max - min + 1)); SPU_ENFORCE(sizeof(ring2k_t) >= sizeof(int32_t)); auto iter = x.begin(); for (auto idx = 0; idx < numel; ++idx, ++iter) { - iter.getScalarValue() = static_cast(dis(gen)); + iter.getScalarValue() = + rand_range[idx] + static_cast(min); } }); @@ -292,17 +298,15 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { } NdArrayRef ring_randbit(FieldType field, const Shape& shape) { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> distrib(0, RAND_MAX); - NdArrayRef ret(makeType(field), shape); auto numel = ret.numel(); + auto rand_bytes = yacl::crypto::RandBytes(numel, false); + return DISPATCH_ALL_FIELDS(field, [&]() { NdArrayView _ret(ret); for (auto idx = 0; idx < numel; ++idx) { - _ret[idx] = distrib(gen) & 0x1; + _ret[idx] = static_cast(rand_bytes[idx]) & 0x1; } return ret; }); diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index 8a960ebc..4fc04a28 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -48,8 +48,8 @@ void ring_print(const NdArrayRef& x, std::string_view name = "_"); NdArrayRef ring_rand(FieldType field, const Shape& shape); NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, uint64_t* prg_counter); -NdArrayRef ring_rand_range(FieldType field, const Shape& shape, int32_t min, - int32_t max); +NdArrayRef ring_rand_range(FieldType field, const Shape& shape, uint128_t min, + uint128_t max); NdArrayRef ring_zeros(FieldType field, const Shape& shape);