diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 1ceee975..aace4d19 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -39,10 +39,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b5_nightly_20240919.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b6_nightly_20240923.tar.gz", ], - strip_prefix = "yacl-0.4.5b5_nightly_20240919", - sha256 = "0ef295f6878dce6160fd44e6af59fa369099f736fa8d4a10f9685dda66aefa71", + strip_prefix = "yacl-0.4.5b6_nightly_20240923", + sha256 = "14eaaf7ad4aead7f2244e56453fead4a47973a020e23739ca0fe93873866bb5f", ) def _libpsi(): diff --git a/libspu/compiler/tools/spu-translate.cc b/libspu/compiler/tools/spu-translate.cc index acf2ed87..9375c809 100644 --- a/libspu/compiler/tools/spu-translate.cc +++ b/libspu/compiler/tools/spu-translate.cc @@ -72,7 +72,7 @@ void isEqual(const xt::xarray &lhs, const xt::xarray &rhs) { auto error = lhs - rhs; - for (auto v : error) { + for (T v : error) { if (v != 0) { llvm::report_fatal_error(fmt::format("Diff = {}", v).c_str()); } diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index b74a0427..a797d803 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -20,6 +20,7 @@ #include "absl/types/span.h" #include "fmt/ostream.h" +#include "fmt/ranges.h" #include "yacl/base/buffer.h" #include "libspu/core/bit_utils.h" diff --git a/libspu/core/trace.h b/libspu/core/trace.h index dbe4b460..582257d7 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -22,7 +22,6 @@ #include #include "absl/types/span.h" -#include "fmt/format.h" #include "fmt/ranges.h" #include "spdlog/spdlog.h" #include "yacl/link/context.h" diff --git a/libspu/core/xt_helper.h b/libspu/core/xt_helper.h index 3230eada..44921507 100644 --- a/libspu/core/xt_helper.h +++ b/libspu/core/xt_helper.h @@ -63,3 +63,6 @@ NdArrayRef xt_to_ndarray(const xt::xexpression& e) { } } // namespace spu + +template +struct fmt::is_range, char> : std::false_type {}; diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 6c95aac3..6cfe3039 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -420,8 +420,14 @@ LogicalResult inferDynamicUpdateSliceOp( } // dynamic_update_slice_c1 + TypeTools tools(operand.getContext()); + auto common_vis = + tools.computeCommonVisibility({tools.getTypeVisibility(operandType), + tools.getTypeVisibility(updateType)}); + inferredReturnTypes.emplace_back(RankedTensorType::get( - operandType.getShape(), operandType.getElementType())); + operandType.getShape(), + tools.getType(operandType.getElementType(), common_vis))); return success(); } diff --git a/libspu/mpc/common/prg_tensor.h b/libspu/mpc/common/prg_tensor.h index ced75981..54de9171 100644 --- a/libspu/mpc/common/prg_tensor.h +++ b/libspu/mpc/common/prg_tensor.h @@ -42,4 +42,8 @@ inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) { return ring_rand(desc.field, desc.shape, seed, &counter); } +inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) { + return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); +} + } // namespace spu::mpc diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index d3d09401..402d1c0d 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -107,7 +107,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; - auto adjust = TrustedParty::adjustMul(ops); + auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops)); ring_add_(c, adjust); } @@ -158,7 +158,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size, if (lctx_->Rank() == 0) { ops[1].seeds = seeds_; - auto adjust = TrustedParty::adjustSquare(ops); + auto adjust = TrustedParty::adjustSquare(absl::MakeSpan(ops)); ring_add_(b, adjust); } @@ -223,7 +223,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Dot(FieldType field, int64_t m, if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; - auto adjust = TrustedParty::adjustDot(ops); + auto adjust = TrustedParty::adjustDot(absl::MakeSpan(ops)); ring_add_(c, adjust); } @@ -250,7 +250,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::And(int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustAnd(ops); + auto adjust = TrustedParty::adjustAnd(absl::MakeSpan(ops)); ring_xor_(c, adjust); } @@ -276,7 +276,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Trunc(FieldType field, int64_t size, for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustTrunc(ops, bits); + auto adjust = TrustedParty::adjustTrunc(absl::MakeSpan(ops), bits); ring_add_(b, adjust); } @@ -300,7 +300,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::TruncPr(FieldType field, int64_t size, for (auto& op : ops) { op.seeds = seeds_; } - auto adjusts = TrustedParty::adjustTruncPr(ops, bits); + auto adjusts = TrustedParty::adjustTruncPr(absl::MakeSpan(ops), bits); ring_add_(rc, std::get<0>(adjusts)); ring_add_(rb, std::get<1>(adjusts)); } @@ -322,7 +322,7 @@ BeaverTfpUnsafe::Array BeaverTfpUnsafe::RandBit(FieldType field, int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustRandBit(ops); + auto adjust = TrustedParty::adjustRandBit(absl::MakeSpan(ops)); ring_add_(a, adjust); } @@ -348,10 +348,11 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::PermPair( auto pv_buf = lctx_->Recv(perm_rank, kTag); ring_add_(b, TrustedParty::adjustPerm( - ops, absl::MakeSpan(pv_buf.data(), - pv_buf.size() / sizeof(int64_t)))); + absl::MakeSpan(ops), + absl::MakeSpan(pv_buf.data(), + pv_buf.size() / sizeof(int64_t)))); } else { - ring_add_(b, TrustedParty::adjustPerm(ops, perm_vec)); + ring_add_(b, TrustedParty::adjustPerm(absl::MakeSpan(ops), perm_vec)); } } else if (perm_rank == lctx_->Rank()) { lctx_->SendAsync( @@ -380,7 +381,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Eqz(FieldType field, int64_t size) { for (auto& op : ops) { op.seeds = seeds_; } - auto adjust = TrustedParty::adjustEqz(ops); + auto adjust = TrustedParty::adjustEqz(absl::MakeSpan(ops)); ring_xor_(b, adjust); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index 5f6febb3..c29fa32b 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -98,7 +98,11 @@ class StreamReader : public brpc::StreamInputHandler { kStreamFailed, }; - StreamReader() { + StreamReader(int32_t num_buf, size_t buf_len) { + SPU_ENFORCE(num_buf > 0); + SPU_ENFORCE(buf_len > 0); + buf_vec_.resize(num_buf); + buf_len_ = buf_len; future_finished_ = promise_finished_.get_future(); future_closed_ = promise_closed_.get_future(); } @@ -106,45 +110,36 @@ class StreamReader : public brpc::StreamInputHandler { int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[], size_t size) override { SPDLOG_DEBUG("on_received_messages, stream id: {}", id); - if (status_ != Status::kNotFinished) { - SPDLOG_ERROR("unexpected messages received"); - return -1; - } - for (size_t i = 0; i < size; ++i) { + if (status_ != Status::kNotFinished) { + SPDLOG_ERROR("unexpected messages received"); + return -1; + } + SPDLOG_DEBUG("receive buf size: {}", messages[i]->size()); const auto& message = messages[i]; - if (!buf_lens_.has_value()) { - beaver::ttp_server::BeaverDownStreamMeta meta{}; - message->copy_to(&meta, sizeof(meta)); - message->pop_front(sizeof(meta)); - if (meta.err_code != 0) { - SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}", - meta.err_code, message->to_string()); - status_ = Status::kAbnormalFinished; - promise_finished_.set_value(status_); - return -2; - } - SPU_ENFORCE(meta.total_buf_num > 0); - buf_.emplace_back(); - buf_lens_.emplace(meta.total_buf_num); - size_t meta_bytes = meta.total_buf_num * sizeof(uint64_t); - SPU_ENFORCE(message->length() >= meta_bytes); - message->copy_to(buf_lens_.value().data(), meta_bytes); - message->pop_front(meta_bytes); + beaver::ttp_server::BeaverDownStreamMeta meta; + message->copy_to(&meta, sizeof(meta)); + message->pop_front(sizeof(meta)); + if (meta.err_code != 0) { + SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}", + meta.err_code, message->to_string()); + status_ = Status::kAbnormalFinished; + promise_finished_.set_value(status_); + return -2; } - size_t cur_buf_idx = buf_.size() - 1; - size_t cur_buf_size = buf_lens_.value().at(cur_buf_idx); - buf_.back().append(message->movable()); - SPU_ENFORCE(buf_.back().length() <= cur_buf_size); - if (buf_.back().length() == cur_buf_size) { - if (cur_buf_idx == buf_lens_.value().size() - 1) { - status_ = Status::kNormalFinished; - promise_finished_.set_value(status_); - } else { - buf_.emplace_back(); - } + SPU_ENFORCE(message->length() % buf_vec_.size() == 0); + size_t msg_len = message->length() / buf_vec_.size(); + for (size_t buf_idx = 0; buf_idx < buf_vec_.size(); ++buf_idx) { + message->append_to(&buf_vec_[buf_idx], msg_len, buf_idx * msg_len); + } + + SPU_ENFORCE(buf_vec_[0].length() <= buf_len_, + "unexpected bytes received"); + if (buf_vec_[0].length() == buf_len_) { + status_ = Status::kNormalFinished; + promise_finished_.set_value(status_); } } return 0; @@ -169,7 +164,7 @@ class StreamReader : public brpc::StreamInputHandler { const auto& GetBufVecRef() const { SPU_ENFORCE(status_ == Status::kNormalFinished); - return buf_; + return buf_vec_; } Status WaitFinished() { return future_finished_.get(); }; @@ -177,8 +172,8 @@ class StreamReader : public brpc::StreamInputHandler { void WaitClosed() { future_closed_.wait(); } private: - std::vector buf_; - std::optional> buf_lens_; + std::vector buf_vec_; + size_t buf_len_; Status status_ = Status::kNotFinished; std::promise promise_finished_; std::promise promise_closed_; @@ -186,6 +181,24 @@ class StreamReader : public brpc::StreamInputHandler { std::future future_closed_; }; +// Obtain a tuple containing num_buf and buf_len +template +std::tuple GetBufferLength(const AdjustRequest& req) { + if constexpr (std::is_same_v) { + SPU_ENFORCE_EQ(req.prg_inputs().size(), 3); + return {1, req.prg_inputs()[2].buffer_len()}; + } else if constexpr (std::is_same_v< + AdjustRequest, + beaver::ttp_server::AdjustTruncPrRequest>) { + SPU_ENFORCE_GE(req.prg_inputs().size(), 1); + return {2, req.prg_inputs()[0].buffer_len()}; + } else { + SPU_ENFORCE_GE(req.prg_inputs().size(), 1); + return {1, req.prg_inputs()[0].buffer_len()}; + } +} + template std::vector RpcCall( brpc::Channel& channel, AdjustRequest req, FieldType ret_field, @@ -194,9 +207,10 @@ std::vector RpcCall( beaver::ttp_server::BeaverService::Stub stub(&channel); beaver::ttp_server::AdjustResponse rsp; - StreamReader reader; + auto [num_buf, buf_len] = GetBufferLength(req); + StreamReader reader(num_buf, buf_len); brpc::StreamOptions stream_options; - stream_options.max_buf_size = 0; + stream_options.max_buf_size = 2 * beaver::ttp_server::kUpStreamChunkSize; stream_options.handler = &reader; brpc::StreamId stream_id; SPU_ENFORCE_EQ(brpc::StreamCreate(&stream_id, cntl, &stream_options), 0, @@ -206,14 +220,6 @@ std::vector RpcCall( reader.WaitClosed(); }); - if (upstream_messages != nullptr) { - for (const auto& message : *upstream_messages) { - SPU_ENFORCE_EQ(brpc::StreamWrite(stream_id, message), 0); - SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(), - stream_id); - } - } - if constexpr (std::is_same_v) { stub.AdjustMul(&cntl, &req, &rsp, nullptr); @@ -255,6 +261,19 @@ std::vector RpcCall( "Adjust server failed code={}, error={}", ErrorCode_Name(rsp.code()), rsp.message()); + if (upstream_messages != nullptr) { + for (const auto& message : *upstream_messages) { + int ret = brpc::StreamWrite(stream_id, message); + if (ret == EAGAIN) { + SPU_ENFORCE_EQ(brpc::StreamWait(stream_id, nullptr), 0); + ret = brpc::StreamWrite(stream_id, message); + } + SPU_ENFORCE_EQ(ret, 0, "Write stream failed"); + SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(), + stream_id); + } + } + auto status = reader.WaitFinished(); SPU_ENFORCE(status == StreamReader::Status::kNormalFinished, "Stream reader finished abnormally, status: {}", @@ -590,25 +609,20 @@ BeaverTtp::Pair BeaverTtp::PermPair(FieldType field, int64_t size, if (lctx_->Rank() == perm_rank) { auto req = BuildAdjustRequest( descs, descs_seed); - std::vector buf_vec; - beaver::ttp_server::BeaverPermUpStreamMeta meta{}; - meta.total_buf_size = perm_vec.size() * sizeof(int64_t); + std::vector stream_data; size_t left_buf_size = perm_vec.size() * sizeof(int64_t); size_t chunk_idx = 0; while (left_buf_size > 0) { using beaver::ttp_server::kUpStreamChunkSize; size_t cur_chunk_size = std::min(left_buf_size, kUpStreamChunkSize); - buf_vec.emplace_back(); - if (chunk_idx == 0) { - buf_vec.back().append(&meta, sizeof(meta)); - } - buf_vec.back().append(reinterpret_cast(perm_vec.data()) + - (chunk_idx * kUpStreamChunkSize), - cur_chunk_size); + stream_data.emplace_back(); + stream_data.back().append(reinterpret_cast(perm_vec.data()) + + (chunk_idx * kUpStreamChunkSize), + cur_chunk_size); ++chunk_idx; left_buf_size -= cur_chunk_size; } - auto adjusts = RpcCall(channel_, req, field, &buf_vec); + auto adjusts = RpcCall(channel_, req, field, &stream_data); SPU_ENFORCE_EQ(adjusts.size(), 1U); ring_add_(b, adjusts[0].reshape(b.shape())); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index 80ef18e7..1ff405ad 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -25,15 +25,20 @@ enum class RecOp : uint8_t { XOR = 1, }; -std::vector reconstruct( - RecOp op, absl::Span ops) { +std::vector reconstruct(RecOp op, + absl::Span ops) { std::vector rs(ops.size()); const auto world_size = ops[0].seeds.size(); for (size_t rank = 0; rank < world_size; rank++) { for (size_t idx = 0; idx < ops.size(); idx++) { // FIXME: TTP adjuster server and client MUST have same endianness. - auto t = prgReplayArray(ops[idx].seeds[rank], ops[idx].desc); + NdArrayRef t; + if (rank < world_size - 1) { + t = prgReplayArray(ops[idx].seeds[rank], ops[idx].desc); + } else { + t = prgReplayArrayMutable(ops[idx].seeds[rank], ops[idx].desc); + } if (rank == 0) { rs[idx] = t; @@ -65,7 +70,7 @@ void checkOperands(absl::Span ops, } // namespace -NdArrayRef TrustedParty::adjustMul(absl::Span ops) { +NdArrayRef TrustedParty::adjustMul(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); @@ -74,7 +79,7 @@ NdArrayRef TrustedParty::adjustMul(absl::Span ops) { return ring_sub(ring_mul(rs[0], rs[1]), rs[2]); } -NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { +NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 2U); auto rs = reconstruct(RecOp::ADD, ops); @@ -82,7 +87,7 @@ NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { return ring_sub(ring_mul(rs[0], rs[0]), rs[1]); } -NdArrayRef TrustedParty::adjustDot(absl::Span ops) { +NdArrayRef TrustedParty::adjustDot(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops, true, true); SPU_ENFORCE(ops[2].transpose == false); @@ -99,7 +104,7 @@ NdArrayRef TrustedParty::adjustDot(absl::Span ops) { return ring_sub(ring_mmul(rs[0], rs[1]), rs[2]); } -NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { +NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); @@ -108,8 +113,7 @@ NdArrayRef TrustedParty::adjustAnd(absl::Span ops) { return ring_xor(ring_and(rs[0], rs[1]), rs[2]); } -NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, - size_t bits) { +NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, size_t bits) { SPU_ENFORCE_EQ(ops.size(), 2U); checkOperands(ops); @@ -119,7 +123,7 @@ NdArrayRef TrustedParty::adjustTrunc(absl::Span ops, } std::pair TrustedParty::adjustTruncPr( - absl::Span ops, size_t bits) { + absl::Span ops, size_t bits) { // descs[0] is r, descs[1] adjust to r[k-2, bits], descs[2] adjust to r[k-1] SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); @@ -139,7 +143,7 @@ std::pair TrustedParty::adjustTruncPr( return {adjust1, adjust2}; } -NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { +NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 1U); auto rs = reconstruct(RecOp::ADD, ops); @@ -147,7 +151,7 @@ NdArrayRef TrustedParty::adjustRandBit(absl::Span ops) { return ring_sub(ring_randbit(ops[0].desc.field, ops[0].desc.shape), rs[0]); } -NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { +NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 2U); checkOperands(ops); auto rs_a = reconstruct(RecOp::ADD, ops.subspan(0, 1)); @@ -156,7 +160,7 @@ NdArrayRef TrustedParty::adjustEqz(absl::Span ops) { return ring_xor(rs_a[0], rs_b[0]); } -NdArrayRef TrustedParty::adjustPerm(absl::Span ops, +NdArrayRef TrustedParty::adjustPerm(absl::Span ops, absl::Span perm_vec) { SPU_ENFORCE_EQ(ops.size(), 2U); auto rs = reconstruct(RecOp::ADD, ops); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h index de58b591..55a412e9 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h @@ -31,24 +31,24 @@ class TrustedParty { bool transpose{false}; }; - static NdArrayRef adjustMul(absl::Span); + static NdArrayRef adjustMul(absl::Span); - static NdArrayRef adjustSquare(absl::Span); + static NdArrayRef adjustSquare(absl::Span); - static NdArrayRef adjustDot(absl::Span); + static NdArrayRef adjustDot(absl::Span); - static NdArrayRef adjustAnd(absl::Span); + static NdArrayRef adjustAnd(absl::Span); - static NdArrayRef adjustTrunc(absl::Span, size_t bits); + static NdArrayRef adjustTrunc(absl::Span, size_t bits); - static std::pair adjustTruncPr( - absl::Span, size_t bits); + static std::pair adjustTruncPr(absl::Span, + size_t bits); - static NdArrayRef adjustRandBit(absl::Span); + static NdArrayRef adjustRandBit(absl::Span); - static NdArrayRef adjustEqz(absl::Span); + static NdArrayRef adjustEqz(absl::Span); - static NdArrayRef adjustPerm(absl::Span, + static NdArrayRef adjustPerm(absl::Span, absl::Span perm_vec); }; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc index 865cbb9c..30a72277 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc @@ -184,7 +184,8 @@ class StreamReader : public brpc::StreamInputHandler { kStreamFailed, }; - StreamReader() { + explicit StreamReader(size_t total_buf_len) { + total_buf_len_ = total_buf_len; future_finished_ = promise_finished_.get_future(); future_closed_ = promise_closed_.get_future(); } @@ -192,28 +193,20 @@ class StreamReader : public brpc::StreamInputHandler { int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[], size_t size) override { SPDLOG_DEBUG("on_received_messages, stream id: {}", id); - if (status_ != Status::kNotFinished) { - SPDLOG_WARN("unexpected messages received"); - return -1; - } - for (size_t i = 0; i < size; ++i) { + if (status_ != Status::kNotFinished) { + SPDLOG_WARN("unexpected messages received"); + return -1; + } const auto& message = messages[i]; SPDLOG_DEBUG("receive buf size: {}", message->size()); - if (!total_buf_size_.has_value()) { - beaver::ttp_server::BeaverPermUpStreamMeta meta{}; - message->copy_to(&meta, sizeof(meta)); - message->pop_front(sizeof(meta)); - total_buf_size_.emplace(meta.total_buf_size); - } - buf_.append(message->movable()); - if (buf_.length() == total_buf_size_.value()) { + if (buf_.length() == total_buf_len_) { status_ = Status::kNormalFinished; promise_finished_.set_value(status_); - } else if (buf_.length() > total_buf_size_.value()) { + } else if (buf_.length() > total_buf_len_) { SPDLOG_ERROR("buf length ({}) greater than expected buf size ({})", - buf_.length(), total_buf_size_.value()); + buf_.length(), total_buf_len_); status_ = Status::kAbnormalFinished; promise_finished_.set_value(status_); } @@ -249,7 +242,7 @@ class StreamReader : public brpc::StreamInputHandler { private: butil::IOBuf buf_; - std::optional total_buf_size_; + size_t total_buf_len_; Status status_ = Status::kNotFinished; std::promise promise_finished_; std::promise promise_closed_; @@ -258,18 +251,57 @@ class StreamReader : public brpc::StreamInputHandler { }; template -std::vector AdjustImpl( - const AdjustRequest& req, StreamReader& stream_reader, - const std::unique_ptr& decryptor) { - std::vector ret; - size_t field_size; - if constexpr (std::is_same_v) { - field_size = 128 / 8; - } else { - field_size = req.field_size(); +size_t GetBufferLength(const AdjustRequest& req) { + if constexpr (std::is_same_v) { + if (req.prg_inputs().size() > 0 && req.field_size() > 0) { + return req.prg_inputs()[0].buffer_len() / req.field_size() * + sizeof(int64_t); + } else { + SPDLOG_ERROR("Invalid request, prg_inputs size: {}, field_size: {}", + req.prg_inputs().size(), req.field_size()); + } } - auto [ops, seeds, pad_length] = BuildOperand(req, field_size, decryptor); + return 0; +} +void SendStreamData(brpc::StreamId stream_id, + absl::Span buf_vec) { + SPU_ENFORCE(!buf_vec.empty()); + for (size_t idx = 1; idx < buf_vec.size(); ++idx) { + SPU_ENFORCE_EQ(buf_vec[0].size(), buf_vec[idx].size()); + } + + size_t chunk_size = kDownStreamChunkSize / buf_vec.size(); + // FIXME: TTP adjuster server and client MUST have same endianness. + size_t left_buf_size = buf_vec[0].size(); + int64_t chunk_idx = 0; + while (left_buf_size > 0) { + butil::IOBuf io_buf; + BeaverDownStreamMeta meta; + io_buf.append(&meta, sizeof(meta)); + + size_t cur_chunk_size = std::min(left_buf_size, chunk_size); + for (const auto& buf : buf_vec) { + int ret = io_buf.append(buf.data() + (chunk_idx * chunk_size), + cur_chunk_size); + SPU_ENFORCE_EQ(ret, 0, "Append data to IO buffer failed"); + } + + // StreamWrite result cannot be EAGAIN, given that we have not set + // max_buf_size + SPU_ENFORCE_EQ(brpc::StreamWrite(stream_id, io_buf), 0); + + left_buf_size -= cur_chunk_size; + ++chunk_idx; + } +} + +template +std::vector AdjustImpl(const AdjustRequest& req, + absl::Span ops, + StreamReader& stream_reader) { + std::vector ret; if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustMul(ops); ret.push_back(std::move(adjust)); @@ -311,7 +343,53 @@ std::vector AdjustImpl( "not support AdjustRequest type"); } - return StripNdArray(ret, pad_length); + return ret; +} + +template +void AdjustAndSend( + const AdjustRequest& req, brpc::StreamId stream_id, + StreamReader& stream_reader, + const std::unique_ptr& decryptor) { + size_t field_size; + if constexpr (std::is_same_v) { + field_size = 128 / 8; + } else { + field_size = req.field_size(); + } + auto [ops, seeds, pad_length] = BuildOperand(req, field_size, decryptor); + + if constexpr (std::is_same_v || + std::is_same_v) { + auto adjusts = AdjustImpl(req, absl::MakeSpan(ops), stream_reader); + auto buf_vec = StripNdArray(adjusts, pad_length); + SendStreamData(stream_id, buf_vec); + return; + } + + SPU_ENFORCE_EQ(beaver::ttp_server::kReplayChunkSize % 128, 0U); + SPU_ENFORCE(!ops.empty()); + for (size_t idx = 1; idx < ops.size(); idx++) { + SPU_ENFORCE(ops[0].desc.shape == ops[idx].desc.shape); + } + int64_t left_elements = ops[0].desc.shape.at(0); + int64_t chunk_elements = + beaver::ttp_server::kReplayChunkSize / SizeOf(ops[0].desc.field); + while (left_elements > 0) { + int64_t cur_elements = std::min(left_elements, chunk_elements); + left_elements -= cur_elements; + for (auto& op : ops) { + op.desc.shape[0] = cur_elements; + } + auto adjusts = AdjustImpl(req, absl::MakeSpan(ops), stream_reader); + if (left_elements > 0) { + auto buf_vec = StripNdArray(adjusts, 0); + SendStreamData(stream_id, buf_vec); + } else { + auto buf_vec = StripNdArray(adjusts, pad_length); + SendStreamData(stream_id, buf_vec); + } + } } } // namespace @@ -338,9 +416,9 @@ class ServiceImpl final : public BeaverService { ::google::protobuf::Closure* done) const { auto* cntl = static_cast(controller); std::string client_side(butil::endpoint2str(cntl->remote_side()).c_str()); - StreamReader reader; brpc::StreamId stream_id = brpc::INVALID_STREAM_ID; auto request = *req; + StreamReader reader(GetBufferLength(*req)); // To address the scenario where clients transmit data after an RPC // response, give precedence to setting up absl::MakeCleanup before invoking @@ -353,14 +431,13 @@ class ServiceImpl final : public BeaverService { reader.WaitClosed(); } }); - std::vector adjusts; try { - adjusts = AdjustImpl(request, reader, decryptor_); + AdjustAndSend(request, stream_id, reader, decryptor_); } catch (const DecryptError& e) { auto err = fmt::format("Seed Decrypt error {}", e.what()); SPDLOG_ERROR("{}, client {}", err, client_side); // TODO: catch the function name - BeaverDownStreamMeta meta{}; + BeaverDownStreamMeta meta; meta.err_code = ErrorCode::SeedDecryptError; butil::IOBuf buf; SPU_ENFORCE_EQ(buf.append(&meta, sizeof(meta)), 0); @@ -370,7 +447,7 @@ class ServiceImpl final : public BeaverService { } catch (const std::exception& e) { auto err = fmt::format("adjust error {}", e.what()); SPDLOG_ERROR("{}, client {}", err, client_side); - BeaverDownStreamMeta meta{}; + BeaverDownStreamMeta meta; meta.err_code = ErrorCode::OpAdjustError; butil::IOBuf buf; SPU_ENFORCE_EQ(buf.append(&meta, sizeof(meta)), 0); @@ -378,45 +455,11 @@ class ServiceImpl final : public BeaverService { brpc::StreamWrite(stream_id, buf); return; } - - butil::IOBuf buf; - BeaverDownStreamMeta meta{}; - meta.total_buf_num = adjusts.size(); - SPU_ENFORCE_EQ(buf.append(&meta, sizeof(meta)), 0); - for (const auto& adjust : adjusts) { - uint64_t cur_buf_size = adjust.size(); - buf.append(&cur_buf_size, sizeof(cur_buf_size)); - } - - for (auto& adjust : adjusts) { - // FIXME: TTP adjuster server and client MUST have same endianness. - size_t left_buf_size = adjust.size(); - int64_t chunk_idx = 0; - while (left_buf_size > 0) { - size_t cur_chunk_size = std::min(left_buf_size, kDownStreamChunkSize); - SPU_ENFORCE_EQ(buf.append(adjust.data() + - (chunk_idx * kDownStreamChunkSize), - cur_chunk_size), - - 0); - // ret cannot be EAGAIN, given that we have not set max_buf_size - int ret = brpc::StreamWrite(stream_id, buf); - if (ret != 0) { - SPDLOG_ERROR("brpc::StreamWrite return {}", ret); - return; - } - - buf.clear(); - left_buf_size -= cur_chunk_size; - ++chunk_idx; - } - adjust.reset(); - } }); brpc::ClosureGuard done_guard(done); brpc::StreamOptions stream_options; - stream_options.max_buf_size = 0; + stream_options.max_buf_size = 0; // there is no flow control for downstream stream_options.handler = &reader; if (brpc::StreamAccept(&stream_id, *cntl, &stream_options) != 0) { SPDLOG_ERROR("Failed to accept stream"); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h index 04d0ffa4..04dcc88e 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h @@ -18,16 +18,13 @@ namespace spu::mpc::semi2k::beaver::ttp_server { +constexpr size_t kReplayChunkSize = 50 * 1024 * 1024; // bytes + constexpr size_t kUpStreamChunkSize = 50 * 1024 * 1024; // bytes constexpr size_t kDownStreamChunkSize = 50 * 1024 * 1024; // bytes -struct BeaverPermUpStreamMeta { - uint64_t total_buf_size; -}; - // A list of buffer streams struct BeaverDownStreamMeta { - uint32_t total_buf_num; // total buffer stream num int32_t err_code = 0; }; diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index e741b51d..5d0eca0c 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -223,7 +223,7 @@ NdArrayRef ring_rand(FieldType field, const Shape& shape) { NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, uint64_t* prg_counter) { constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = - yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + yacl::crypto::SymmetricCrypto::CryptoType::AES128_ECB; constexpr uint128_t kAesInitialVector = 0U; NdArrayRef res(makeType(field), shape); diff --git a/spu/tests/data/BUILD.bazel b/spu/tests/data/BUILD.bazel index c5e9d338..647c6e0a 100644 --- a/spu/tests/data/BUILD.bazel +++ b/spu/tests/data/BUILD.bazel @@ -17,12 +17,15 @@ package(default_visibility = ["//visibility:public"]) filegroup( name = "data", data = [ - "100K-1-16.json", "alice.csv", "bob.csv", "carol.csv", "db.csv", "ground_truth.csv", + "pir/100K-1-16.json", + "pir/db.csv", + "pir/ground_truth.csv", + "pir/query.csv", "query.csv", ], ) diff --git a/spu/tests/data/100K-1-16.json b/spu/tests/data/pir/100K-1-16.json similarity index 100% rename from spu/tests/data/100K-1-16.json rename to spu/tests/data/pir/100K-1-16.json diff --git a/spu/tests/data/pir/db.csv b/spu/tests/data/pir/db.csv new file mode 100644 index 00000000..daf120eb --- /dev/null +++ b/spu/tests/data/pir/db.csv @@ -0,0 +1,101 @@ +key,value +aPYaKgvvcESwAtfghnRUIAYYIZeCsGeaWbAAEUCXQzNrxGRPVOcACqMBJmdfiveq,LdQNbKmBMhlpCctB +eXeLfYoSlEntpRaKuddYhtaImMOdEhNTIolxElSrlPYMhgZoWxccpUTOjFciaRcD,wcsaUIrVHmxDxnXr +gOUgSoWKKYaLpmWiuTyuVNostvTJUHxBZjYWJZukOzqICDlmKdavyERZimgOkaHn,FgdZUDbsfINsODTe +BrImQzkUpelejHVFeatNDJTqdTuxCJmHKunkODDPsUtsgEfCkXyLNEAZXniJZgvf,PkxxKRwEEEwrcLbR +RhoATcXotsJDbAICIkIbAYdDUXqreTDbZVPoSyziyQJvBVVaZhYqWmHVQsqNbdIP,OTFagQCudOpRtKqT +bibVQsoLIFjXrSFIjPFOTWclPDhakarsWICsYNvbMtlNwBIKqnaWmvCPcFPesBVy,YysQfuPwTiXayxlt +CQZcsLKmugDFfsYAOihJENLpBwCYwoLuATHrMiIChXkycvhejgjjUCoxKDpTDMuZ,OIGhXryCxSXCZKBn +aizidqjurwmpHHqioJclEQvccvnWQKmZTMBNSGgmfxxqHilZwFPdjHlZElQaayJN,SjCyivgzSTLpXmnz +eNGJWxNcevvcthSXdXLseTORLpDVdVCAdnbQtcovUVAislbievgctCGMidsncXyL,rFJEAlFBMUMZNPrc +OQArDeXFlhxHBPERmlYuwZHQAgCQtRorzCkDenmMdnbFdGKJggqmhApWMrbtqDZw,saeWaHguQALRalnZ +AZKVktpAYSLSZTpjylrifvdItAuhXvzvDhtWnbjMlshYFTooHXtwNfCsmiuWgXLa,sntRGrWUapWvBwAM +xNTfGmPFpdczqsPvlLFMSzXLdqGkyupfwXepkfJHECoAIDZaRwqoIullNzMfWHLK,jVerCzWaKQoYYBeP +hKFRJSHErdbPbPFRZpEkigOemyqRgRDzWHMhiVDlHDtTAjKHZcQCogTvFlbwVVbA,ftLFWLGNrSkavyEX +tGYkoEeaQinZmEsPtFyhyJxXGiJreOCTHTxPfSWFFQwqChLKADXjeWiTsVnLyFTb,vKUHgdGvKkJKSAcJ +KSyQGEXCvzziOdoaIwQYynSjQlyPOkJYroZAxblcGRHOqhjWuByOVIUYfuacruUQ,NnsQMUvXYZOhswVL +WsQkdjbaExqvNkQlbMWqwxtChiWRBsmylFanbONarkbmmWAnKPvCdPLlhQtvFaXP,phEafGnFQIcoYCKG +rESDWnSxXeTvwvTDxSsSmwPswujcWVlrvlgVAuGUYjvyXFjCSGenqEZhfFHgZXIB,sHxPvGwZWDbJgCYn +UKgGYBiPUdYdfuQZrpKEtprGxofPaLVqhdiuSUzKrSDzXuCgYuOQJlyhsNHyLrCx,sdNdnkjuRNLyPhaB +eEXLhpsjwQQRFVXxPgqQnFMCQDAMOMxwoAeeubrJXwxKaiWgnilOwjzoEZuRUJLV,WXVgFQQFBEwRomjA +QVGpdCqJIpZStCSUOEKyEoOJJurZSWmQZDCnIrANHGJYhpbfxAhsPvrVZnVrhQKn,bUIqpQPVvtiPhItz +UOmpkfibQXxQlYJQzZgdfoIHckIWHWiVBcoLaLSQnlpnIBQZcrnCEXCfHTDwsFDX,OdckqFMcdfnVnBoB +TDzDtfzzMugZZxhNGhmwYsMrOCFvCUWmmOUOLNGAYMRMnZVGuOMSXZZgaTufrqXK,DAToiYOldpgNGOqn +zmRKIEQFtIjCYUXaFgyAvZZEDIukHAwYlzUwxbttWndcAGFEoRzGyAUuLsKnbfZi,rfwdiePuXVtvKgat +vmtpIcBkPJyFRqKWIYHWcecKdgCoUShJwkhYvjHZPdhwmcdBGwQDDVynyOwSZcYj,PDNsnMKRZubVpMRT +EjQMHoawzxMREpZaJFKJBNsnKdzQTWeGmAMkhsuSfEzoDpQfdUWUeTFKvKClRNPz,rauheCdowFiOAMFk +mjjTWkjovIsCsMuZfdtXIKVZEcwuspLRUtCVPKpMdkkaGQtUUmFrXaZHaDuPKvsa,wCscOftxAHuBnsSW +MjvTMaePwIVFpEbspToomYGFAOmpGuKlmgJvIOhtVoHNgWaHReuMELUapHWAaZjL,MeizqMvAktGZLkCH +EwJZCBgPDuKRnTTZwuJRKfkznXpHGdbfMOZTnVjixKGciMLkdLSzWBXkBhMGzwSS,RhpjUzFsJtSSXund +HSaBXSGcBxYSUIXYnlFnYrdTclIehDdMhKqIRJuAYebfViJttknfMmCqbyYOJAXE,NIZPwgQebsKBehaN +fLOyQDLsIUaWZUjwzsrxlGHlGTYNWVZyTEWJZenWqZiMqHEpLWAvGojmOQvteOqS,XGTuKgtLshqQUtfr +wvKVPbYksmYXTsRqvJETrjXJethrvgmBLIwMQhJBCMTfLGOFKHwxrrBcGQqdMjZe,ICnEQAovJhrWaIiY +WpbdxcjMqKMkLdSlUBowCTWDGVtRJLiEDQytMenWEIkWFWLKByiEhrvIpCncUQDS,xEGgXHFKHYDlGdYF +AVSBpQelmdheyUZPdmRrhrEqHmKowFAIDNjxzphVCoLgSypBfHNtVuDgoIVqCoLF,ZLFAfNImdiEwcupl +XtndzLknyWTeElGFXjZfbrHGqYzqHcTzEtXquKkpckuwhkQPcCkmXIhfCnLYCrVG,UHbNwRmAJFMalnbt +lvLHBgacmSdZJqpzrezjTYTfWIBFUDIaMcGwErtmnAwgjDXwmIHxMDqYTrJvjUyq,NoWnhJQWDJVCtExB +SKenZwjdFvsAiARRmpBzTAXGWtByjJcIniiAhovlsAHLZXQJCmDyRJxKevZBttDa,mHAfCtCExIyoRuWG +vKgaRAkJVOSESTgmEVXpQIVXDADJiHmFFaAwxtjwUyFVrQouyJcZeDwhMUZPROkA,CiuvQKMCsVGLkoBM +cdxMDaSfvmlcpSfFqvfzgNIyUcmkDEyVswXcCKCJfYyAqrSCGWBGIEQlBxKTWSCj,LTIsYYovMEnpawzQ +bnfGLKRBqwOCQBNWSRAbEWVLqyUAlzrYNiJRWUZnmGnXtjMFBQhLHBhVJBygIrGU,vLYmsxCGzRXOMiqB +bxidiEYIOjbFDbHRqnaYXZuQcZJbNxynsmjNPCpzujEKzATaBeTrUchoylhvqLjx,AXlyMrMWiWGMqoIs +eSuMvTbZgMRDrEIxJwgFYdpNWkQzEzrsyyybeaJlUPEhEZZBWpPwQFqImIGnLFar,BTfGxXWpxvvPzcDS +uONlUmHMFQiPkCfdPrqqUDleaUBHKnxuQbFJAqfMqzoSpqPzawdOIvtVQMSWfCRv,uFDJTeJjZNFjIxKT +IEWyZnggNMulCyYklMdZaMYiIsqQNtbzbcpMHBUfPeOKaoSCMeezBSqcQwVXNJho,zpDlpaygXXZyslHj +ItlLEykbpzIeTErDcbaxxfzntBAYcPHVcLFneOzhNhxYYgwsbKZEmHHuHTnPnhSW,nKPAjZgLwcVTMNGi +uTQQIgVItVwPUSygrxoUwrLuFbAbahqbnixUuRIKRnJConAViRHYsRerKyEieFYI,fhPHVglAfXtHjNae +etMRJJVQSxIMgvzdSoCPsVpGJcKtjpXtMqtzXgaGQDryplTJifNvOFYGWChHLOUo,WJrEYWYGMOqveCgd +BYHllipDRYZVvMQYYhIRzLHabgftPTSnFbUCRmZejFUoeLLQoZtrZPJrTjVqWfNO,EztmadvCnpbgQtBl +ZEkjAbQOpxQrbtEDVlDhKgChCNsxTxSQtUXARrEeVJQrzPuVPkYHuuoMXjVCyeCk,kcgMJGMwDDDiFDSP +WdhxSOyjKNLItJZiXZtkKBdIcjGLuLbZHPmSfJCzlvqBxnrobjDTPxsFRXhEInhh,OPuAWqKpmpZHbvIY +wSGlgcmRasbLYVClIhhCppYezjZrWIhhaiASQcDrCDxdsGJIJjNmTWwbsFuKlbAt,VToAPASiTJHUGIEk +CFzAfFrDFyhndtOJalJiSiufNlCjWcwxwQjnjRgbqFlaAlIzXgrwVJmwISEAKHxx,jpylGyqZUinEIiVo +BwluSThaftzdTOrrzjBfqmHdbXqLUDEPPqmduYZESLtaSQAWLOKeuECRKPDEumJA,vgocckgnbQZCgpSq +OUguorluoGFpgVuXujAFBOIkBsIAaNCXgcywsWjutvEcrJrrDBRAHgcKwnfNLpXr,vwsRUvSJpupXvGta +TXQWeuibBvqDmmTaLAPZNsEHosjhcLBsixvomJaiAPmLmBDemETNOZMwrwREVRir,EIgmXjWGNYpcfdOs +BbQiYWirQnKUzBachIeJZgJWvmeZkpBUphzYEGrGcGpxUwWvSSnSYpQzBRsOFuzC,oUBiVRJNarkuUdRV +dugMhlZPXHEnWTZjaOFaxmmZgIHdVBmzUsIfUUPZkdVzDvCLRZyTBPrHhDVrYAVk,OdxERipabMlyoyEa +aKPvrmNwyclBMcPMEAgGItoshSJVrSonWVrcHMWiBXxqTpdHjuGKrRLTaHhQyunk,TmBLACtwsObAkgoz +LSndhzUuoIbprGCzRDfZryKqdcqwLpLWYfHoOHgBJDSkZRoYMQxmIoVCUdBSxHsZ,qvhvMIqFvXELhWqK +lInwjXRJTuZgrvcbDEvFHgPFGpqlnSuIJWtNRJYizWEfZJbZtfLexmEQMnGLxNlW,MURWsZhFTeLoqeAd +DOlyDBVEZFruhsBwHZgrnWTXckcVJcVzrwniSnJFEYUNiFgkIyukstlbdrluVhag,vUuhogLghMFjNjyK +zgDxGOAAYMGtAOoMviwtSQDLEpHluVuqFsqisVvoKCLfnMdPVTKgKCrchKrAmlmz,YaojDSYxfntenXUp +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp +nJJyhHkDkQRttGsYjkMBBuGeQuPPDHQEQQdnGQmMbOdRFifRDpZVUqdfqeskxngR,nyikwrUYIuWbTawJ +ZcwBhgoFFiWyDeZbeMpliSvhbfsAkccEQQhYreLTGdfVHuNLpmCsduhkIlRKMNkx,nWeXVNOUmsoakAMC +BCJYdYFkkLRWUxhnnwpJGbXJchPvVCtbcaMdkArTcNLdRmopwncgdgOLGhJkZOnC,JWAVSCWqKUcmfoqk +NkxNbLyBjixoCzClTdwshhuZcFRjJJdDdWgCfiQIttZWQWqouBkYyMGpampLdUAr,nkuaUoDVcCUBUoOv +WgwvZoDsjspDAWHEflQMWzlbqnssWiBElmABhLmhgDPqFbNmAHSnzQrbAqSVAmWS,IQIlCVVcjVcTdshJ +yXelVXMUEuAtfNgzPrhjvYOpiAVEMZuqPfsQEUQoshjSIekxxzkFxdftfqFzfzpa,AyTbbGVAXKCYUUln +sAtWqpSPxPSkDtmIJKfNvlKjgStnYMOmrLsQnzmIFAusETPPzLDTjKcBASKWNRAJ,bDQFllamogAjBEPU +RAOMDMMZkezCBWxQDWLjvHLkbpvFyrbUbDDEekWciXejYwKifSVumcsocUmMkmpa,MnwKAECsMVGVLIZM +LMizeIoxMxHCKwikjqOSSPbuiqXWDAmbTLMBOXpyorUmpunjWFTNLVqvHNcCNHrN,zDzyAkfeTYZzaxhG +IzLqciYsaKtWrsrjOQldeIvEqavoIZEYnupJZLizJVeOhoLqtLQFaoNRdvZMWSQH,klLegifptLAxnhha +VkyQsnlXcIGjGhcSJUcZKeQyiDUbIcgIWSbHaEsbfEydSTHqRlxImGdYGEurZczg,jyzyCGlvuBdKwyIX +dkSRcFpHPXWjNIHrCpWlOPaIkVqjtyPRhlJeYMksjieDxYhiUcGhbuvamVlrMDDx,FATEzwNXGerINvHD +KravfnzVbNOhLhstPcaLLVWpbqYzXckQGbuALlEiXqqhFfUyThFZFhzSLhjldPMB,sowcmsLQTFiKpNXy +ITZWxxjqnizlxuRWMlPQLUnBopyDtOxNfcaoFDbRKetIpVxKLSRoJOauSCcDwWUP,NmqaomxmRQGqiKCV +guhzFIXdUEACMMHObJptkrZqJgbclDcdRxCPYSvuAdITaKgHfJaKLNHFzdRmpHni,DiONbLbxHfUHhoTU +AnZjaIBmnsEBplkEmHBstdggPnYmhblyQQttVqYzxxNtOXwlNQetkvCOySSXRUpw,FCdUtLDyvAqszerb +wkCcCXGKcZJEDwTkzOoNRkMbxHdNciQlVruGSKcJrHpokspcZIVfupcTxapISupH,pzEvnKzLQbzNDSQN +ppieExXmNHqBXVgLFhjlHHHhHSAddipMCmPXhXDfHZVTtNhqcMMVauyjKOFGBHPe,tXpiHGkKGTzzMluO +lnVknQNrrYyqFbEKYPxsQWNPKLpsEVmUGtbbWMWDThMuScSByeZRwuusLYzKPbHE,HzxgCCtiIFYvgwWO +oYSnlwjpsWaNzYunBnhNLwiICrmAEFiZRczbdHYpQgwSrrMQCixgtjfCGOptTkmd,IYmbhaQueIKQvcBc +saWvQlIiiYqAPmcEDGsVXNAIJNNGTyZKhrMMKYHXJQnniGVuIClgwvAEXeIPGeFN,epReKmWNANFpINhn +PZIEJMwirPArOGCfJJAfdwGydRDBGGQojUzWFJtVoJZTFAFYwaDOuLFruvRjolHq,yCMGUjViZoTPMTtg +bMCOdGAYjXaDSPyZyGegyuRnnwYSySrRzbLbvtgBjfFXfCMPIVIGFTagRyBpiKLa,zRTXXYnUXmkIMDoP +JQWdUrNElPWswpQvnVqCmboMEjMhebKISRcmznakzemGxBjUughzOVbctPzmVTLW,CdcnGSQhRbdsoQrg +OdAvNbDdCQpTrAbJWrUrVprpgVIXwvvSStooIVwzUfDIThtvdBHldyUFFkvabfyj,ueAqmPurXOjNtvWr +zjIkKNEzyUFiubvlxYWNXdjoIIEwZavalnqwSCgDgcZUldjZOkzhKXuRciwSTNJg,MWwWigLZKqLgZkLp +WTaCVgYrnoyEoShtBDUmrRHeRSYIAjvUpZnVAUTxTyaIGzvQIdwcPafAnkIbplSq,mdIEdIajBbeAyPCk +tNdAhhdqVzLdGfPoctgRkehzEOIRvjEwDpmAQrMjbWtfRQGjeUiVJNafrhVKFieX,IxHvvUzKrlMpWhpR +vhgPWqsRnvDRMFIYHppovDbKlWPzEFwbBXSihpYbwCYpkeXFIXbIYdWSLfcHpnWX,qhwFSRGRKwcPlLJs +ARxEUJokZaGDgXHGxwPiSqqvNSmoowUxRDDkozqbvcUvQuPtdNaeaKOKykMIUkmR,XqmCMQPKPtAPzBZd +RwztiezZCzbSLKzIyYqfEMjDTcLpASCiGWoaseuxBWpvSVutmtdEgdZornGkHrQf,YcRFZNgodJFPNoop +YTcHYrADMhlKAnvdGBdQBXWBqcftxkNpFceODelYVRXwFOZTHdXkVGAfJTzZcyhD,tBGtrQaLFgACGOEE +fHFCvDLRGGhYZWSnxaIqKTgvNbCPLzyvOnpHyAhrKEAsApdPgkxAptCTtgYAnmEq,vxGOPFzvJOVBEblg +zckpuLjSVdhSFnhTqPfDoHdJdjpfZBDdlzGbYgzVbKgDMJQDBGCHZSJBdtzlvHro,TeeGbXAcEbwzglGf +muAQTPuNCQTZurKTDlYzTQgvlWNyRXOlKizgsnGSrKdYWCSBlQtOvIyEWVthaYhO,ZnYBDVQYoJOoTMlS +UQswwuiprHWAbguGNZgOAdFrgEIdsDRImrqXXTmbqppVgnJrjjiOdZaNUpIQGcTR,VwugWpNMzEKHAFqo +GDRPaAUIAymOEEksSqccGOqpUYvGUyvBKjfRqKSTAyNadpaMYnMYboPOrEEfXVWf,noDbJmsjYCgqHsBu +cVjSBnCUnKfKXwETABIPvavwLXMGSLSpoVylUSCRlRCzpDvDVjfNAIrSiRWNHJZS,OszhlCboIvNdCTYH diff --git a/spu/tests/data/pir/ground_truth.csv b/spu/tests/data/pir/ground_truth.csv new file mode 100644 index 00000000..6160abfa --- /dev/null +++ b/spu/tests/data/pir/ground_truth.csv @@ -0,0 +1,2 @@ +key,value +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp diff --git a/spu/tests/data/pir/query.csv b/spu/tests/data/pir/query.csv new file mode 100644 index 00000000..5cb2f046 --- /dev/null +++ b/spu/tests/data/pir/query.csv @@ -0,0 +1,2 @@ +key +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv diff --git a/spu/tests/pir_test.py b/spu/tests/pir_test.py index c75d6bbe..3bf1cd68 100644 --- a/spu/tests/pir_test.py +++ b/spu/tests/pir_test.py @@ -30,8 +30,8 @@ def test_pir(self): # setup stage sender_setup_config_json = f''' {{ - "db_file": "spu/tests/data/db.csv", - "params_file": "spu/tests/data/100K-1-16.json", + "source_file": "spu/tests/data/pir/db.csv", + "params_file": "spu/tests/data/pir/100K-1-16.json", "sdb_out_file": "{temp_dir}/sdb", "save_db_only": true }} @@ -51,9 +51,9 @@ def test_pir(self): receiver_online_config_json = f''' {{ - "query_file": "spu/tests/data/query.csv", + "query_file": "spu/tests/data/pir/query.csv", "output_file": "{temp_dir}/result.csv", - "params_file": "spu/tests/data/100K-1-16.json" + "params_file": "spu/tests/data/pir/100K-1-16.json" }} ''' @@ -92,7 +92,7 @@ def receiver_wrap(rank, link_desc, config): import pandas as pd df1 = pd.read_csv(f'{temp_dir}/result.csv') - df2 = pd.read_csv('spu/tests/data/ground_truth.csv') + df2 = pd.read_csv('spu/tests/data/pir/ground_truth.csv') self.assertTrue(df1.equals(df2))