Skip to content

Commit

Permalink
repo-sync-2024-10-15T16:59:25+0800 (#888)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
Jimmy MA authored Oct 16, 2024
1 parent a07303c commit 1c5d8e5
Show file tree
Hide file tree
Showing 45 changed files with 1,962 additions and 93 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
>
> please add your unreleased change here.
- [Improvement] Optimize exponential computation for semi2k (**experimental**)
- [Feature] Add more send/recv actions profiling

## 20240716
Expand Down
6 changes: 3 additions & 3 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _libpsi():
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.3.dev240919.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.5.0.dev241016.tar.gz",
],
strip_prefix = "psi-0.4.3.dev240919",
sha256 = "1ee34fbbd9a8f36dea8f7c45588a858e8c31f3a38e60e1fc67cb428ea79334e3",
strip_prefix = "psi-0.5.0.dev241016",
sha256 = "1672e4284f819c40e34c65b0d5b1dfe4cc959b81d6f63daef7b39f7eb8d742e2",
)

def _rules_proto_grpc():
Expand Down
2 changes: 2 additions & 0 deletions libspu/compiler/front_end/hlo_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class CompilationContext;

class HloImporter final {
public:
// clang-format off
explicit HloImporter(CompilationContext *context) : context_(context) {};
// clang-format on

/// Load a xla module and returns a mlir-hlo module
mlir::OwningOpRef<mlir::ModuleOp>
Expand Down
11 changes: 11 additions & 0 deletions libspu/core/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ void populateRuntimeConfig(RuntimeConfig& cfg) {
if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) {
cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR);
}
if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) {
// 0 offset is not supported
if (cfg.experimental_exp_prime_offset() == 0) {
// For FM128 default offset is 13
if (cfg.field() == FieldType::FM128) {
cfg.set_experimental_exp_prime_offset(13);
}
// TODO: set defaults for other fields, currently only FM128 is
// supported
}
}

if (cfg.fxp_exp_iters() == 0) {
cfg.set_fxp_exp_iters(8);
Expand Down
61 changes: 60 additions & 1 deletion libspu/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ class Ring2k {
FieldType field() const { return field_; }
};

// This trait means the data is maintained in Galois prime field.
class Gfp {
protected:
uint128_t prime_{0};

public:
virtual ~Gfp() = default;
uint128_t p() const { return prime_; }
};

// The public interface.
//
// The value of this type is public visible for parties.
Expand Down Expand Up @@ -384,6 +394,54 @@ class RingTy : public TypeImpl<RingTy, TypeObject, Ring2k> {
}
};

// Galois field type of Mersenne primes, e.g., 2^127-1
class GfmpTy : public TypeImpl<GfmpTy, TypeObject, Gfp, Ring2k> {
using Base = TypeImpl<GfmpTy, TypeObject, Gfp, Ring2k>;

protected:
size_t mersenne_prime_exp_;

public:
using Base::Base;
explicit GfmpTy(FieldType field) {
field_ = field;
mersenne_prime_exp_ = GetMersennePrimeExp(field);
prime_ = (static_cast<uint128_t>(1) << mersenne_prime_exp_) - 1;
}

static std::string_view getStaticId() { return "Gfmp"; }

size_t size() const override {
if (field_ == FT_INVALID) {
return 0;
}
return SizeOf(GetStorageType(field_));
}

size_t mp_exp() const { return mersenne_prime_exp_; }

void fromString(std::string_view detail) override {
auto comma = detail.find_first_of(',');
auto field_str = detail.substr(0, comma);
auto mp_exp_str = detail.substr(comma + 1);
SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_),
"parse failed from={}", detail);
mersenne_prime_exp_ = std::stoul(std::string(mp_exp_str));
prime_ = (static_cast<uint128_t>(1) << mersenne_prime_exp_) - 1;
}

std::string toString() const override {
return fmt::format("{},{}", FieldType_Name(field()), mersenne_prime_exp_);
}

bool equals(TypeObject const* other) const override {
auto const* derived_other = dynamic_cast<GfmpTy const*>(other);
SPU_ENFORCE(derived_other);
return field() == derived_other->field() &&
mp_exp() == derived_other->mp_exp() && p() == derived_other->p();
}
};

class TypeContext final {
public:
using TypeCreateFn =
Expand All @@ -395,7 +453,8 @@ class TypeContext final {

public:
TypeContext() {
addTypes<VoidTy, PtTy, RingTy>(); // Base types that we need to register
addTypes<VoidTy, PtTy, RingTy,
GfmpTy>(); // Base types that we need to register
}

template <typename T>
Expand Down
20 changes: 20 additions & 0 deletions libspu/core/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,24 @@ TEST(TypeTest, RingTy) {
EXPECT_EQ(Type::fromString(fm128.toString()), fm128);
}

TEST(TypeTest, GfmpTy) {
Type gfmp31 = makeType<GfmpTy>(FM32);
EXPECT_EQ(gfmp31.size(), 4);
EXPECT_TRUE(gfmp31.isa<GfmpTy>());
EXPECT_EQ(gfmp31.toString(), "Gfmp<FM32,31>");
EXPECT_EQ(Type::fromString(gfmp31.toString()), gfmp31);

Type gfmp61 = makeType<GfmpTy>(FM64);
EXPECT_EQ(gfmp61.size(), 8);
EXPECT_TRUE(gfmp61.isa<GfmpTy>());
EXPECT_EQ(gfmp61.toString(), "Gfmp<FM64,61>");
EXPECT_EQ(Type::fromString(gfmp61.toString()), gfmp61);

Type gfmp127 = makeType<GfmpTy>(FM128);
EXPECT_EQ(gfmp127.size(), 16);
EXPECT_TRUE(gfmp127.isa<GfmpTy>());
EXPECT_EQ(gfmp127.toString(), "Gfmp<FM128,127>");
EXPECT_EQ(Type::fromString(gfmp127.toString()), gfmp127);
}

} // namespace spu
16 changes: 16 additions & 0 deletions libspu/core/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ std::ostream& operator<<(std::ostream& os, ProtocolKind protocol) {
return os;
}

//////////////////////////////////////////////////////////////
// Field GFP mappings, currently only support Mersenne primes
//////////////////////////////////////////////////////////////
size_t GetMersennePrimeExp(FieldType field) {
#define CASE(Name, ScalarT, MersennePrimeExp) \
case FieldType::Name: \
return MersennePrimeExp; \
break;
switch (field) {
FIELD_TO_MERSENNE_PRIME_EXP_MAP(CASE)
default:
SPU_THROW("unknown supported field {}", field);
}
#undef CASE
}

//////////////////////////////////////////////////////////////
// Field 2k types, TODO(jint) support Zq
//////////////////////////////////////////////////////////////
Expand Down
27 changes: 26 additions & 1 deletion libspu/core/type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,17 @@ FOREACH_PT_TYPES(CASE)
std::ostream& operator<<(std::ostream& os, ProtocolKind protocol);

//////////////////////////////////////////////////////////////
// Field 2k types, TODO(jint) support Zq
// Field GFP mappings, currently only support Mersenne primes
//////////////////////////////////////////////////////////////
#define FIELD_TO_MERSENNE_PRIME_EXP_MAP(FN) \
FN(FM32, uint32_t, 31) \
FN(FM64, uint64_t, 61) \
FN(FM128, uint128_t, 127)

size_t GetMersennePrimeExp(FieldType field);

//////////////////////////////////////////////////////////////
// Field 2k types
//////////////////////////////////////////////////////////////
#define FIELD_TO_STORAGE_MAP(FN) \
FN(FM32, PT_U32) \
Expand Down Expand Up @@ -259,6 +269,21 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); }
} \
}()

//////////////////////////////////////////////////////////////
// Field Prime types
//////////////////////////////////////////////////////////////
template <typename T>
struct ScalarTypeToPrime {};

#define DEF_TRAITS(Field, ScalarT, Exp) \
template <> \
struct ScalarTypeToPrime<ScalarT> { \
static constexpr size_t exp = Exp; \
static constexpr ScalarT prime = (static_cast<ScalarT>(1) << Exp) - 1; \
};
FIELD_TO_MERSENNE_PRIME_EXP_MAP(DEF_TRAITS)
#undef DEF_TRAITS

//////////////////////////////////////////////////////////////
// Value range information, should it be here, at top level(jint)?
//////////////////////////////////////////////////////////////
Expand Down
1 change: 1 addition & 0 deletions libspu/kernel/hal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ spu_cc_test(
deps = [
":fxp_approx",
"//libspu/kernel:test_util",
"//libspu/mpc/utils:simulate",
],
)

Expand Down
36 changes: 35 additions & 1 deletion libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ Value exp_taylor(SPUContext* ctx, const Value& x) {
return res;
}

Value exp_prime(SPUContext* ctx, const Value& x) {
auto clamped_x = x;
auto offset = ctx->config().experimental_exp_prime_offset();
auto fxp = ctx->getFxpBits();
if (!ctx->config().experimental_exp_prime_disable_lower_bound()) {
// currently the bound is tied to FM128
SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128);
auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E;
clamped_x = _clamp_lower(ctx, clamped_x,
constant(ctx, lower_bound, x.dtype(), x.shape()))
.setDtype(x.dtype());
}
if (ctx->config().experimental_exp_prime_enable_upper_bound()) {
// currently the bound is tied to FM128
SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128);
auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E;
clamped_x = _clamp_upper(ctx, clamped_x,
constant(ctx, upper_bound, x.dtype(), x.shape()))
.setDtype(x.dtype());
}

auto ret = dynDispatch<spu::Value>(ctx, "exp_a", clamped_x);
return ret.setDtype(x.dtype());
}

namespace {

// Pade approximation of exp2(x), x is in [0, 1].
Expand Down Expand Up @@ -439,13 +464,22 @@ Value f_exp(SPUContext* ctx, const Value& x) {
case RuntimeConfig::EXP_PADE: {
// The valid input for exp_pade is [-kInputLimit, kInputLimit].
// TODO(junfeng): should merge clamp into exp_pade to save msb ops.
const float kInputLimit = 32 / std::log2(std::exp(1));
const float kInputLimit = 32.0 / std::log2(std::exp(1));
const auto clamped_x =
_clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()),
constant(ctx, kInputLimit, x.dtype(), x.shape()))
.setDtype(x.dtype());
return detail::exp_pade(ctx, clamped_x);
}
case RuntimeConfig::EXP_PRIME:
if (ctx->hasKernel("exp_a")) {
return detail::exp_prime(ctx, x);
} else {
SPU_THROW(
"exp_a is not implemented for this protocol, currently only "
"2pc "
"semi2k is supported.");
}
default:
SPU_THROW("unexpected exp approximation method {}",
ctx->config().fxp_exp_mode());
Expand Down
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Value exp2_pade(SPUContext* ctx, const Value& x);
// Works for range [-12.0, 18.0]
Value exp_pade(SPUContext* ctx, const Value& x);

Value exp_prime(SPUContext* ctx, const Value& x);

Value tanh_chebyshev(SPUContext* ctx, const Value& x);

} // namespace detail
Expand Down
28 changes: 27 additions & 1 deletion libspu/kernel/hal/fxp_approx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/type_cast.h"
#include "libspu/kernel/test_util.h"
#include "libspu/mpc/utils/simulate.h"

namespace spu::kernel::hal {

Expand Down Expand Up @@ -78,10 +79,35 @@ TEST(FxpTest, ExponentialPade) {
<< y;
}

TEST(FxpTest, ExponentialPrime) {
std::cout << "test exp_prime" << std::endl;
spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
RuntimeConfig conf;
conf.set_protocol(ProtocolKind::SEMI2K);
conf.set_field(FieldType::FM128);
conf.set_fxp_fraction_bits(40);
conf.set_experimental_enable_exp_prime(true);
SPUContext ctx = test::makeSPUContext(conf, lctx);

auto offset = ctx.config().experimental_exp_prime_offset();
auto fxp = ctx.getFxpBits();
auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E;
auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E;

xt::xarray<float> x = xt::linspace<float>(lower_bound, upper_bound, 4000);

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value c = detail::exp_prime(&ctx, a);
auto y = dump_public_as<float>(&ctx, reveal(&ctx, c));
EXPECT_TRUE(xt::allclose(xt::exp(x), y, 0.01, 0.001))
<< xt::exp(x) << std::endl
<< y;
});
}

TEST(FxpTest, Log) {
// GIVEN
SPUContext ctx = test::makeSPUContext();

xt::xarray<float> x = {{0.05, 0.5}, {5, 50}};
// public log
{
Expand Down
15 changes: 13 additions & 2 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,25 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) {
Value _clamp(SPUContext* ctx, const Value& x, const Value& minv,
const Value& maxv) {
SPU_TRACE_HAL_LEAF(ctx, x, minv, maxv);

// clamp lower bound, res = x < minv ? minv : x
auto res = _mux(ctx, _less(ctx, x, minv), minv, x);

// clamp upper bound, res = res < maxv ? res, maxv
return _mux(ctx, _less(ctx, res, maxv), res, maxv);
}

// TODO: refactor polymorphic, and may use select functions in polymorphic
Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv) {
SPU_TRACE_HAL_LEAF(ctx, x, minv);
// clamp lower bound, res = x < minv ? minv : x
return _mux(ctx, _less(ctx, x, minv), minv, x);
}

Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv) {
SPU_TRACE_HAL_LEAF(ctx, x, maxv);
// clamp upper bound, x = x < maxv ? x, maxv
return _mux(ctx, _less(ctx, x, maxv), x, maxv);
}

Value _constant(SPUContext* ctx, uint128_t init, const Shape& shape) {
return _make_p(ctx, init, shape);
}
Expand Down
5 changes: 5 additions & 0 deletions libspu/kernel/hal/ring.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b);
// TODO: test me
Value _clamp(SPUContext* ctx, const Value& x, const Value& minv,
const Value& maxv);

Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv);

Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv);

// Make a public value from uint128_t init value.
//
// If the current working field has less than 128bit, the lower sizeof(field)
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/aby3/oram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,14 @@ void OramContext<T>::onehotB2A(KernelEvalContext *ctx, DpfGenCtrl ctrl) {
const std::vector<T> v = convert_help_v[dpf_idx];
std::for_each(e.begin(), e.end(), [&](T ele) { pm += ele; });
std::for_each(v.begin(), v.end(), [&](T ele) { F -= ele; });
auto blinded_pm = pm + r[0];
T blinded_pm = pm + r[0];

// open blinded_pm
comm->sendAsync<T>(dst_rank, {blinded_pm}, "open(blinded_pm)");
blinded_pm += comm->recv<T>(dst_rank, "open(blinded_pm)")[0];

auto pm_mul_F = mul2pc<T>(ctx, {pm}, {F}, static_cast<size_t>(ctrl));
auto blinded_F = pm_mul_F[0] + r[0];
T blinded_F = pm_mul_F[0] + r[0];

// open blinded_F
comm->sendAsync<T>(dst_rank, {blinded_F}, "open(blinded_F)");
Expand Down
Loading

0 comments on commit 1c5d8e5

Please sign in to comment.