Skip to content

Commit

Permalink
repo-sync-2024-09-24T14:07:49+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Sep 24, 2024
1 parent d7cb0cf commit 184395a
Show file tree
Hide file tree
Showing 20 changed files with 359 additions and 179 deletions.
6 changes: 3 additions & 3 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b5_nightly_20240919.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b6_nightly_20240923.tar.gz",
],
strip_prefix = "yacl-0.4.5b5_nightly_20240919",
sha256 = "0ef295f6878dce6160fd44e6af59fa369099f736fa8d4a10f9685dda66aefa71",
strip_prefix = "yacl-0.4.5b6_nightly_20240923",
sha256 = "14eaaf7ad4aead7f2244e56453fead4a47973a020e23739ca0fe93873866bb5f",
)

def _libpsi():
Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/tools/spu-translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void isEqual(const xt::xarray<T> &lhs, const xt::xarray<T> &rhs) {

auto error = lhs - rhs;

for (auto v : error) {
for (T v : error) {
if (v != 0) {
llvm::report_fatal_error(fmt::format("Diff = {}", v).c_str());
}
Expand Down
1 change: 1 addition & 0 deletions libspu/core/ndarray_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/types/span.h"
#include "fmt/ostream.h"
#include "fmt/ranges.h"
#include "yacl/base/buffer.h"

#include "libspu/core/bit_utils.h"
Expand Down
1 change: 0 additions & 1 deletion libspu/core/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "absl/types/span.h"
#include "fmt/format.h"
#include "fmt/ranges.h"
#include "spdlog/spdlog.h"
#include "yacl/link/context.h"
Expand Down
3 changes: 3 additions & 0 deletions libspu/core/xt_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ NdArrayRef xt_to_ndarray(const xt::xexpression<E>& e) {
}

} // namespace spu

template <typename T>
struct fmt::is_range<xt::xarray<T>, char> : std::false_type {};
8 changes: 7 additions & 1 deletion libspu/dialect/pphlo/IR/type_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,14 @@ LogicalResult inferDynamicUpdateSliceOp(
}

// dynamic_update_slice_c1
TypeTools tools(operand.getContext());
auto common_vis =
tools.computeCommonVisibility({tools.getTypeVisibility(operandType),
tools.getTypeVisibility(updateType)});

inferredReturnTypes.emplace_back(RankedTensorType::get(
operandType.getShape(), operandType.getElementType()));
operandType.getShape(),
tools.getType(operandType.getElementType(), common_vis)));
return success();
}

Expand Down
4 changes: 4 additions & 0 deletions libspu/mpc/common/prg_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) {
return ring_rand(desc.field, desc.shape, seed, &counter);
}

inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) {
return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter);
}

} // namespace spu::mpc
23 changes: 12 additions & 11 deletions libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size,

if (lctx_->Rank() == 0) {
ops[2].seeds = seeds_;
auto adjust = TrustedParty::adjustMul(ops);
auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops));
ring_add_(c, adjust);
}

Expand Down Expand Up @@ -158,7 +158,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size,

if (lctx_->Rank() == 0) {
ops[1].seeds = seeds_;
auto adjust = TrustedParty::adjustSquare(ops);
auto adjust = TrustedParty::adjustSquare(absl::MakeSpan(ops));
ring_add_(b, adjust);
}

Expand Down Expand Up @@ -223,7 +223,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Dot(FieldType field, int64_t m,

if (lctx_->Rank() == 0) {
ops[2].seeds = seeds_;
auto adjust = TrustedParty::adjustDot(ops);
auto adjust = TrustedParty::adjustDot(absl::MakeSpan(ops));
ring_add_(c, adjust);
}

Expand All @@ -250,7 +250,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::And(int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustAnd(ops);
auto adjust = TrustedParty::adjustAnd(absl::MakeSpan(ops));
ring_xor_(c, adjust);
}

Expand All @@ -276,7 +276,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Trunc(FieldType field, int64_t size,
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustTrunc(ops, bits);
auto adjust = TrustedParty::adjustTrunc(absl::MakeSpan(ops), bits);
ring_add_(b, adjust);
}

Expand All @@ -300,7 +300,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::TruncPr(FieldType field, int64_t size,
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjusts = TrustedParty::adjustTruncPr(ops, bits);
auto adjusts = TrustedParty::adjustTruncPr(absl::MakeSpan(ops), bits);
ring_add_(rc, std::get<0>(adjusts));
ring_add_(rb, std::get<1>(adjusts));
}
Expand All @@ -322,7 +322,7 @@ BeaverTfpUnsafe::Array BeaverTfpUnsafe::RandBit(FieldType field, int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustRandBit(ops);
auto adjust = TrustedParty::adjustRandBit(absl::MakeSpan(ops));
ring_add_(a, adjust);
}

Expand All @@ -348,10 +348,11 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::PermPair(
auto pv_buf = lctx_->Recv(perm_rank, kTag);

ring_add_(b, TrustedParty::adjustPerm(
ops, absl::MakeSpan(pv_buf.data<int64_t>(),
pv_buf.size() / sizeof(int64_t))));
absl::MakeSpan(ops),
absl::MakeSpan(pv_buf.data<int64_t>(),
pv_buf.size() / sizeof(int64_t))));
} else {
ring_add_(b, TrustedParty::adjustPerm(ops, perm_vec));
ring_add_(b, TrustedParty::adjustPerm(absl::MakeSpan(ops), perm_vec));
}
} else if (perm_rank == lctx_->Rank()) {
lctx_->SendAsync(
Expand Down Expand Up @@ -380,7 +381,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Eqz(FieldType field, int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustEqz(ops);
auto adjust = TrustedParty::adjustEqz(absl::MakeSpan(ops));
ring_xor_(b, adjust);
}

Expand Down
132 changes: 73 additions & 59 deletions libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,53 +98,48 @@ class StreamReader : public brpc::StreamInputHandler {
kStreamFailed,
};

StreamReader() {
StreamReader(int32_t num_buf, size_t buf_len) {
SPU_ENFORCE(num_buf > 0);
SPU_ENFORCE(buf_len > 0);
buf_vec_.resize(num_buf);
buf_len_ = buf_len;
future_finished_ = promise_finished_.get_future();
future_closed_ = promise_closed_.get_future();
}

int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[],
size_t size) override {
SPDLOG_DEBUG("on_received_messages, stream id: {}", id);
if (status_ != Status::kNotFinished) {
SPDLOG_ERROR("unexpected messages received");
return -1;
}

for (size_t i = 0; i < size; ++i) {
if (status_ != Status::kNotFinished) {
SPDLOG_ERROR("unexpected messages received");
return -1;
}

SPDLOG_DEBUG("receive buf size: {}", messages[i]->size());
const auto& message = messages[i];
if (!buf_lens_.has_value()) {
beaver::ttp_server::BeaverDownStreamMeta meta{};
message->copy_to(&meta, sizeof(meta));
message->pop_front(sizeof(meta));
if (meta.err_code != 0) {
SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}",
meta.err_code, message->to_string());
status_ = Status::kAbnormalFinished;
promise_finished_.set_value(status_);
return -2;
}
SPU_ENFORCE(meta.total_buf_num > 0);
buf_.emplace_back();
buf_lens_.emplace(meta.total_buf_num);
size_t meta_bytes = meta.total_buf_num * sizeof(uint64_t);
SPU_ENFORCE(message->length() >= meta_bytes);
message->copy_to(buf_lens_.value().data(), meta_bytes);
message->pop_front(meta_bytes);
beaver::ttp_server::BeaverDownStreamMeta meta;
message->copy_to(&meta, sizeof(meta));
message->pop_front(sizeof(meta));
if (meta.err_code != 0) {
SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}",
meta.err_code, message->to_string());
status_ = Status::kAbnormalFinished;
promise_finished_.set_value(status_);
return -2;
}

size_t cur_buf_idx = buf_.size() - 1;
size_t cur_buf_size = buf_lens_.value().at(cur_buf_idx);
buf_.back().append(message->movable());
SPU_ENFORCE(buf_.back().length() <= cur_buf_size);
if (buf_.back().length() == cur_buf_size) {
if (cur_buf_idx == buf_lens_.value().size() - 1) {
status_ = Status::kNormalFinished;
promise_finished_.set_value(status_);
} else {
buf_.emplace_back();
}
SPU_ENFORCE(message->length() % buf_vec_.size() == 0);
size_t msg_len = message->length() / buf_vec_.size();
for (size_t buf_idx = 0; buf_idx < buf_vec_.size(); ++buf_idx) {
message->append_to(&buf_vec_[buf_idx], msg_len, buf_idx * msg_len);
}

SPU_ENFORCE(buf_vec_[0].length() <= buf_len_,
"unexpected bytes received");
if (buf_vec_[0].length() == buf_len_) {
status_ = Status::kNormalFinished;
promise_finished_.set_value(status_);
}
}
return 0;
Expand All @@ -169,23 +164,41 @@ class StreamReader : public brpc::StreamInputHandler {

const auto& GetBufVecRef() const {
SPU_ENFORCE(status_ == Status::kNormalFinished);
return buf_;
return buf_vec_;
}

Status WaitFinished() { return future_finished_.get(); };

void WaitClosed() { future_closed_.wait(); }

private:
std::vector<butil::IOBuf> buf_;
std::optional<std::vector<uint64_t>> buf_lens_;
std::vector<butil::IOBuf> buf_vec_;
size_t buf_len_;
Status status_ = Status::kNotFinished;
std::promise<Status> promise_finished_;
std::promise<void> promise_closed_;
std::future<Status> future_finished_;
std::future<void> future_closed_;
};

// Obtain a tuple containing num_buf and buf_len
template <class AdjustRequest>
std::tuple<int32_t, int64_t> GetBufferLength(const AdjustRequest& req) {
if constexpr (std::is_same_v<AdjustRequest,
beaver::ttp_server::AdjustDotRequest>) {
SPU_ENFORCE_EQ(req.prg_inputs().size(), 3);
return {1, req.prg_inputs()[2].buffer_len()};
} else if constexpr (std::is_same_v<
AdjustRequest,
beaver::ttp_server::AdjustTruncPrRequest>) {
SPU_ENFORCE_GE(req.prg_inputs().size(), 1);
return {2, req.prg_inputs()[0].buffer_len()};
} else {
SPU_ENFORCE_GE(req.prg_inputs().size(), 1);
return {1, req.prg_inputs()[0].buffer_len()};
}
}

template <class AdjustRequest>
std::vector<NdArrayRef> RpcCall(
brpc::Channel& channel, AdjustRequest req, FieldType ret_field,
Expand All @@ -194,9 +207,10 @@ std::vector<NdArrayRef> RpcCall(
beaver::ttp_server::BeaverService::Stub stub(&channel);
beaver::ttp_server::AdjustResponse rsp;

StreamReader reader;
auto [num_buf, buf_len] = GetBufferLength(req);
StreamReader reader(num_buf, buf_len);
brpc::StreamOptions stream_options;
stream_options.max_buf_size = 0;
stream_options.max_buf_size = 2 * beaver::ttp_server::kUpStreamChunkSize;
stream_options.handler = &reader;
brpc::StreamId stream_id;
SPU_ENFORCE_EQ(brpc::StreamCreate(&stream_id, cntl, &stream_options), 0,
Expand All @@ -206,14 +220,6 @@ std::vector<NdArrayRef> RpcCall(
reader.WaitClosed();
});

if (upstream_messages != nullptr) {
for (const auto& message : *upstream_messages) {
SPU_ENFORCE_EQ(brpc::StreamWrite(stream_id, message), 0);
SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(),
stream_id);
}
}

if constexpr (std::is_same_v<AdjustRequest,
beaver::ttp_server::AdjustMulRequest>) {
stub.AdjustMul(&cntl, &req, &rsp, nullptr);
Expand Down Expand Up @@ -255,6 +261,19 @@ std::vector<NdArrayRef> RpcCall(
"Adjust server failed code={}, error={}",
ErrorCode_Name(rsp.code()), rsp.message());

if (upstream_messages != nullptr) {
for (const auto& message : *upstream_messages) {
int ret = brpc::StreamWrite(stream_id, message);
if (ret == EAGAIN) {
SPU_ENFORCE_EQ(brpc::StreamWait(stream_id, nullptr), 0);
ret = brpc::StreamWrite(stream_id, message);
}
SPU_ENFORCE_EQ(ret, 0, "Write stream failed");
SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(),
stream_id);
}
}

auto status = reader.WaitFinished();
SPU_ENFORCE(status == StreamReader::Status::kNormalFinished,
"Stream reader finished abnormally, status: {}",
Expand Down Expand Up @@ -590,25 +609,20 @@ BeaverTtp::Pair BeaverTtp::PermPair(FieldType field, int64_t size,
if (lctx_->Rank() == perm_rank) {
auto req = BuildAdjustRequest<beaver::ttp_server::AdjustPermRequest>(
descs, descs_seed);
std::vector<butil::IOBuf> buf_vec;
beaver::ttp_server::BeaverPermUpStreamMeta meta{};
meta.total_buf_size = perm_vec.size() * sizeof(int64_t);
std::vector<butil::IOBuf> stream_data;
size_t left_buf_size = perm_vec.size() * sizeof(int64_t);
size_t chunk_idx = 0;
while (left_buf_size > 0) {
using beaver::ttp_server::kUpStreamChunkSize;
size_t cur_chunk_size = std::min(left_buf_size, kUpStreamChunkSize);
buf_vec.emplace_back();
if (chunk_idx == 0) {
buf_vec.back().append(&meta, sizeof(meta));
}
buf_vec.back().append(reinterpret_cast<const char*>(perm_vec.data()) +
(chunk_idx * kUpStreamChunkSize),
cur_chunk_size);
stream_data.emplace_back();
stream_data.back().append(reinterpret_cast<const char*>(perm_vec.data()) +
(chunk_idx * kUpStreamChunkSize),
cur_chunk_size);
++chunk_idx;
left_buf_size -= cur_chunk_size;
}
auto adjusts = RpcCall(channel_, req, field, &buf_vec);
auto adjusts = RpcCall(channel_, req, field, &stream_data);
SPU_ENFORCE_EQ(adjusts.size(), 1U);
ring_add_(b, adjusts[0].reshape(b.shape()));
}
Expand Down
Loading

0 comments on commit 184395a

Please sign in to comment.