diff --git a/.bazelrc b/.bazelrc index 95d2c838..5eeb8ee9 100644 --- a/.bazelrc +++ b/.bazelrc @@ -29,6 +29,12 @@ build:avx --host_copt=-mavx build:avx --copt=-DCHECK_AVX build:avx --host_copt=-DCHECK_AVX +# default off CUDA build +build --@rules_cuda//cuda:enable=false + +# Only on when asked +build:gpu --@rules_cuda//cuda:archs=compute_80:compute_80 +build:gpu --@rules_cuda//cuda:enable=true # Binary safety flags build --copt=-fPIC diff --git a/CHANGELOG.md b/CHANGELOG.md index 20b70d5d..484a5394 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ > > please add your unreleased change here. +- [Bugfix] Fix compatibility with latest Jax +- [Feature] Improve memory efficiency during encode/decode data +- [Feature] Add radix sort support for SEMI2K +- [Feature] Experimental: ABY3 matmul CUDA support +- [Feature] Experimental: Private support under colocated mode + ## 20230906 - [SPU] 0.5.0 release @@ -168,7 +174,7 @@ ## 20220325 - [SPU] 0.0.5.1 release -- [Bugfix] Fix SEMI2K divivsion wrong answer +- [Bugfix] Fix SEMI2K division wrong answer ## 20220324 @@ -198,7 +204,7 @@ - [API] merge (config.proto, executable.proto, types.proto) into single spu.proto. - [API] change RuntimeConfig.enable_protocol_trace to enable_action_trace. - [API] change RuntimeConfig.fxp_recirptocal_goldschmdit_iters to fxp_reciprocal_goldschmdit_iters. -- [API] add RuntimConfig.reveal_secret_condition to allow reveal secret control flow condition. +- [API] add RuntimeConfig.reveal_secret_condition to allow reveal secret control flow condition. - [Bugfix] Fixed SEGV when reconstruct from an ABY3 scalar secret - [Feature] Left/right shift now properly supports non-scalar inputs diff --git a/INSTALLATION.md b/INSTALLATION.md index d871052f..7ed108c2 100644 --- a/INSTALLATION.md +++ b/INSTALLATION.md @@ -41,3 +41,9 @@ pip install spu python setup.py bdist_wheel pip install dist/*.whl --force-reinstall ``` + +- Once GCC/bazel/python/Xcode version or other environment settings have changed, please run the following command to ensure a clean build + +```bash +bazel clean --expunge +``` diff --git a/bazel/nvidia_cutlass.BUILD b/bazel/nvidia_cutlass.BUILD new file mode 100644 index 00000000..ca46fb0e --- /dev/null +++ b/bazel/nvidia_cutlass.BUILD @@ -0,0 +1,19 @@ +load("@spulib//bazel:spu.bzl", "spu_cc_library") + +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "all", + srcs = glob(["**"]), +) + +spu_cc_library( + name = "cutlass", + srcs = [], + hdrs = glob([ + "include/**/*.h", + "include/**/*.hpp", + ]), + strip_include_prefix = "include", + visibility = ["//visibility:public"], +) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index c5c2da36..236829db 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -42,6 +42,7 @@ def spu_deps(): _com_github_microsoft_gsl() _com_github_microsoft_kuku() _com_google_flatbuffers() + _com_github_nvidia_cutlass() maybe( git_repository, @@ -76,9 +77,9 @@ def _rules_proto_grpc(): def _rules_cuda(): http_archive( name = "rules_cuda", - sha256 = "fa1462c4c3104de44489800a1da055f55afa57795789539c835e069818786f71", - strip_prefix = "rules_cuda-cab1fa2dd0e1f8489f566c91a5025856cf5ae572", - urls = ["https://github.com/bazel-contrib/rules_cuda/archive/cab1fa2dd0e1f8489f566c91a5025856cf5ae572.tar.gz"], + sha256 = "2f8c8c8c85f727bec4423efecec12d3b751cb0a98bda99f0f9d351608a23b858", + strip_prefix = "rules_cuda-v0.2.1", + urls = ["https://github.com/bazel-contrib/rules_cuda/releases/download/v0.2.1/rules_cuda-v0.2.1.tar.gz"], ) def _bazel_platform(): @@ -357,3 +358,15 @@ def _com_google_flatbuffers(): "https://github.com/google/flatbuffers/archive/refs/tags/v23.3.3.tar.gz", ], ) + +def _com_github_nvidia_cutlass(): + maybe( + http_archive, + name = "com_github_nvidia_cutlass", + strip_prefix = "cutlass-3.2.0", + urls = [ + "https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.2.0.tar.gz", + ], + sha256 = "9637961560a9d63a6bb3f407faf457c7dbc4246d3afb54ac7dc1e014dd7f172f", + build_file = "@spulib//bazel:nvidia_cutlass.BUILD", + ) diff --git a/examples/cpp/pir/generate_pir_data.cc b/examples/cpp/pir/generate_pir_data.cc index a7ff9189..13b8c273 100644 --- a/examples/cpp/pir/generate_pir_data.cc +++ b/examples/cpp/pir/generate_pir_data.cc @@ -13,11 +13,8 @@ // limitations under the License. // clang-format off -// build generate_pir_data -// > bazel build //examples/cpp/pir:generate_pir_data -c opt -// // To run the example, start two terminals: -// > ./generate_pir_data -data_count 10000 -label_len 32 -server_out_path pir_server.csv -client_out_path pir_client.csv +// > bazel run //examples/cpp/pir:generate_pir_data -c opt -- -data_count 10000 -label_len 32 -server_out_path pir_server.csv -client_out_path pir_client.csv // clang-format on #include @@ -100,4 +97,4 @@ int main(int argc, char **argv) { psi2_out_file.close(); return 0; -} \ No newline at end of file +} diff --git a/examples/cpp/pir/keyword_pir_client.cc b/examples/cpp/pir/keyword_pir_client.cc index 2ed7dd3c..1a3bd547 100644 --- a/examples/cpp/pir/keyword_pir_client.cc +++ b/examples/cpp/pir/keyword_pir_client.cc @@ -13,11 +13,8 @@ // limitations under the License. // clang-format off -// build keyword_pir_client -// > bazel build //examples/cpp/pir:keyword_pir_client -c opt -// // To run the example, start terminals: -// > ./keyword_pir_client -rank 1 -in_path ../../data/psi_client_data.csv.csv +// > bazel run //examples/cpp/pir:keyword_pir_client -c opt -- -rank 1 -in_path ../../data/psi_client_data.csv.csv // > -key_columns id -out_path pir_out.csv // clang-format on diff --git a/examples/cpp/pir/keyword_pir_mem_server.cc b/examples/cpp/pir/keyword_pir_mem_server.cc index 23e2002c..5fc66917 100644 --- a/examples/cpp/pir/keyword_pir_mem_server.cc +++ b/examples/cpp/pir/keyword_pir_mem_server.cc @@ -13,11 +13,8 @@ // limitations under the License. // clang-format off -// build keyword_pir_server -// > bazel build //examples/cpp/pir:keyword_pir_server -c opt -// // To run the example, start terminals: -// > ./keyword_pir_server -rank 0 -setup_path pir_setup_dir +// > bazel run //examples/cpp/pir:keyword_pir_server -c opt -- -rank 0 -setup_path pir_setup_dir // > -oprf_key_path secret_key.bin // clang-format on diff --git a/examples/cpp/pir/keyword_pir_server.cc b/examples/cpp/pir/keyword_pir_server.cc index fcc0e287..901ccced 100644 --- a/examples/cpp/pir/keyword_pir_server.cc +++ b/examples/cpp/pir/keyword_pir_server.cc @@ -13,11 +13,8 @@ // limitations under the License. // clang-format off -// build keyword_pir_server -// > bazel build //examples/cpp/pir:keyword_pir_server -c opt -// // To run the example, start terminals: -// > ./keyword_pir_server -rank 0 -setup_path pir_setup_dir +// > bazel run //examples/cpp/pir:keyword_pir_server -c opt -- -rank 0 -setup_path pir_setup_dir // > -oprf_key_path secret_key.bin // clang-format on diff --git a/examples/cpp/pir/keyword_pir_setup.cc b/examples/cpp/pir/keyword_pir_setup.cc index befed3dc..6c3d3042 100644 --- a/examples/cpp/pir/keyword_pir_setup.cc +++ b/examples/cpp/pir/keyword_pir_setup.cc @@ -13,14 +13,10 @@ // limitations under the License. // clang-format off -// build keyword_pir_setup -// > bazel build //examples/cpp/pir:keyword_pir_setup -c opt -// // To generate ecc oprf secret key, start terminals: // > dd if=/dev/urandom of=secret_key.bin bs=32 count=1 -// // To run the example, start terminals: -// > ./keyword_pir_setup -in_path ../../data/psi_server_data.csv -oprf_key_path secret_key.bin +// > bazel run //examples/cpp/pir:keyword_pir_setup -c opt -- -in_path ../../data/psi_server_data.csv -oprf_key_path secret_key.bin // > -key_columns id -label_columns label -data_per_query 1 -label_max_len 40 // > -setup_path pir_setup_dir // clang-format on diff --git a/examples/cpp/simple_dp_psi.cc b/examples/cpp/simple_dp_psi.cc index 108cf17c..df575639 100644 --- a/examples/cpp/simple_dp_psi.cc +++ b/examples/cpp/simple_dp_psi.cc @@ -13,12 +13,10 @@ // limitations under the License. // clang-format off -// build simple_dp_psi -// > bazel build //examples/cpp:simple_dp_psi -c opt -// // To run the example, start two terminals: -// > ./simple_dp_psi -rank 0 -in_path examples/data/psi_1.csv -field_names id -// > ./simple_dp_psi -rank 1 -in_path examples/data/psi_2.csv -field_names id -out_path /tmp/p2.out +// > bazel run //examples/cpp/simple_dp_psi -c opt -- -rank 0 -in_path examples/data/psi_1.csv -field_names id +// > bazel run //examples/cpp/simple_dp_psi -c opt -- -rank 1 -in_path examples/data/psi_2.csv -field_names id -out_path /tmp/p2.out +// To run with non-default IP config, add -parties IP:port,IP:port to above commands // clang-format on #include diff --git a/examples/cpp/simple_in_memory_psi.cc b/examples/cpp/simple_in_memory_psi.cc index b240db9a..079b1963 100644 --- a/examples/cpp/simple_in_memory_psi.cc +++ b/examples/cpp/simple_in_memory_psi.cc @@ -14,8 +14,9 @@ // clang-format off // To run the example, start two terminals: -// > bazel run //examples/cpp:simple_in_memory_psi -- --rank=0 -// > bazel run //examples/cpp:simple_in_memory_psi -- --rank=1 +// > bazel run //examples/cpp:simple_in_memory_psi -c opt -- -rank=0 +// > bazel run //examples/cpp:simple_in_memory_psi -c opt -- -rank=1 +// To run with non-default IP config, add -parties IP:port,IP:port to above commands // clang-format on #include diff --git a/examples/cpp/simple_lr.cc b/examples/cpp/simple_lr.cc index 77a076d7..ba441b1b 100644 --- a/examples/cpp/simple_lr.cc +++ b/examples/cpp/simple_lr.cc @@ -14,8 +14,9 @@ // clang-format off // To run the example, start two terminals: -// > bazel run //examples/cpp:simple_lr -- --dataset=examples/data/perfect_logit_a.csv --has_label=true -// > bazel run //examples/cpp:simple_lr -- --dataset=examples/data/perfect_logit_b.csv --rank=1 +// > bazel run //examples/cpp:simple_lr -c opt -- -dataset=examples/data/perfect_logit_a.csv -has_label=true +// > bazel run //examples/cpp:simple_lr -c opt -- -dataset=examples/data/perfect_logit_b.csv -rank=1 +// To run with non-default IP config, add -parties IP:port,IP:port to above commands // clang-format on #include diff --git a/examples/cpp/simple_pphlo.cc b/examples/cpp/simple_pphlo.cc index b95a4c95..1cef41c6 100644 --- a/examples/cpp/simple_pphlo.cc +++ b/examples/cpp/simple_pphlo.cc @@ -14,8 +14,9 @@ // clang-format off // To run the example, start two terminals: -// > bazel run //examples/cpp:simple_pphlo -- --rank=0 -// > bazel run //examples/cpp:simple_pphlo -- --rank=1 +// > bazel run //examples/cpp:simple_pphlo -c opt -- -rank=0 +// > bazel run //examples/cpp:simple_pphlo -c opt -- -rank=1 +// To run with non-default IP config, add -parties IP:port,IP:port to above commands // clang-format on #include "examples/cpp/utils.h" diff --git a/examples/cpp/simple_psi.cc b/examples/cpp/simple_psi.cc index ec81c260..d3093394 100644 --- a/examples/cpp/simple_psi.cc +++ b/examples/cpp/simple_psi.cc @@ -14,8 +14,9 @@ // clang-format off // To run the example, start two terminals: -// > bazel run //examples/cpp:simple_psi -- -rank 0 -protocol 1 -in_path examples/data/psi_1.csv -field_names id -out_path /tmp/p1.out -// > bazel run //examples/cpp:simple_psi -- -rank 1 -protocol 1 -in_path examples/data/psi_2.csv -field_names id -out_path /tmp/p2.out +// > bazel run //examples/cpp:simple_psi -c opt -- -rank 0 -protocol 1 -in_path examples/data/psi_1.csv -field_names id -out_path /tmp/p1.out +// > bazel run //examples/cpp:simple_psi -c opt -- -rank 1 -protocol 1 -in_path examples/data/psi_2.csv -field_names id -out_path /tmp/p2.out +// To run with non-default IP config, add -parties IP:port,IP:port to above commands // clang-format on #include "absl/strings/str_split.h" diff --git a/examples/python/conf/3pc_colocated.json b/examples/python/conf/3pc_colocated.json new file mode 100644 index 00000000..1158bdb4 --- /dev/null +++ b/examples/python/conf/3pc_colocated.json @@ -0,0 +1,52 @@ +{ + "id": "colocated.3pc", + "nodes": { + "node:0": "127.0.0.1:9920", + "node:1": "127.0.0.1:9921", + "node:2": "127.0.0.1:9922" + }, + "devices": { + "SPU": { + "kind": "SPU", + "config": { + "node_ids": [ + "node:0", + "node:1", + "node:2" + ], + "spu_internal_addrs": [ + "127.0.0.1:9930", + "127.0.0.1:9931", + "127.0.0.1:9932" + ], + "experimental_data_folder": [ + "/tmp/spu_data_0/", + "/tmp/spu_data_1/", + "/tmp/spu_data_2/" + ], + "runtime_config": { + "protocol": "ABY3", + "field": "FM64", + "enable_pphlo_profile": true, + "enable_hal_profile": true, + "experimental_disable_mmul_split": true, + "enable_pphlo_trace": false, + "fxp_exp_mode": 1, + "experimental_enable_colocated_optimization": true + } + } + }, + "P1": { + "kind": "PYU", + "config": { + "node_id": "node:1" + } + }, + "P2": { + "kind": "PYU", + "config": { + "node_id": "node:2" + } + } + } +} \ No newline at end of file diff --git a/examples/python/conf/BUILD.bazel b/examples/python/conf/BUILD.bazel index 4a6713cf..2fc047bf 100644 --- a/examples/python/conf/BUILD.bazel +++ b/examples/python/conf/BUILD.bazel @@ -19,6 +19,7 @@ filegroup( srcs = [ "2pc.json", "3pc.json", + "3pc_colocated.json", "ds_breast_cancer_basic.json", "ds_mock_regression_basic.json", ], diff --git a/examples/python/ml/flax_llama7b/README.md b/examples/python/ml/flax_llama7b/README.md index 546085ed..3acabedd 100644 --- a/examples/python/ml/flax_llama7b/README.md +++ b/examples/python/ml/flax_llama7b/README.md @@ -3,6 +3,8 @@ This example demonstrates how to use SPU to run secure inference on a pre-trained [Llama-7B](https://research.facebook.com/publications/llama-open-and-efficient-foundation-language-models/) model using [Puma](https://arxiv.org/abs/2307.12533). +> **_NOTE:_** To run LLaMA-7B with ABY3, each node requires at least 1TB of RAM + 1. Install huggingface transformers library ```sh diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index 8a81ae9a..e5be8f8c 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -70,6 +70,8 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::createCSEPass()); + optPM.addPass(mlir::pphlo::createConvertPushDownPass()); + if (!options.disable_reduce_truncation_optimization()) { optPM.addPass(mlir::pphlo::createReduceTruncationPass()); } diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel index 913431bd..2df8f650 100644 --- a/libspu/compiler/passes/BUILD.bazel +++ b/libspu/compiler/passes/BUILD.bazel @@ -258,10 +258,23 @@ spu_cc_library( ], ) +spu_cc_library( + name = "convert_push_down", + srcs = ["convert_push_down.cc"], + hdrs = ["passes.h"], + deps = [ + ":pass_details", + "//libspu/dialect:pphlo_dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", + ], +) + spu_cc_library( name = "all_passes", hdrs = ["register_passes.h"], deps = [ + ":convert_push_down", ":decompose_comparison", ":decompose_minmax", ":expand_secret_gather", diff --git a/libspu/compiler/passes/convert_push_down.cc b/libspu/compiler/passes/convert_push_down.cc new file mode 100644 index 00000000..01734d71 --- /dev/null +++ b/libspu/compiler/passes/convert_push_down.cc @@ -0,0 +1,91 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "libspu/compiler/passes/pass_details.h" +#include "libspu/compiler/passes/passes.h" +#include "libspu/dialect/pphlo_ops.h" +#include "libspu/dialect/pphlo_types.h" + +namespace mlir::pphlo { + +namespace { + +// Idea here: +// %2 = convert(%0) +// %3 = reshape(%2) +// mul(%1, %3) +// Can be rewrite into +// %2 = reshape(%0) +// %3 = convert(%2) +// mul(%1, %3) +// Makes mixed_mul/dot optimization easier +template +struct TypeAgnosticOpConverter : public OpRewritePattern { +public: + explicit TypeAgnosticOpConverter(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + auto operand = op.getOperand(); + auto parentConvert = operand.template getDefiningOp(); + if (parentConvert == nullptr) { + return failure(); + } + + const auto &from_type = parentConvert.getOperand() + .getType() + .template dyn_cast(); + const auto &to_type = + op.getResult().getType().template dyn_cast(); + + OpBuilder builder(op); + + auto new_reshape = builder.create( + op->getLoc(), + RankedTensorType::get(to_type.getShape(), from_type.getElementType()), + parentConvert.getOperand(), op->getAttrs()); + + rewriter.replaceOpWithNewOp(op, op.getType(), new_reshape); + + return success(); + } +}; + +struct ConvertPushDown : public ConvertPushDownBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateOwningPatterns(&patterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + +private: + static void populateOwningPatterns(RewritePatternSet *patterns, + MLIRContext *ctx) { + patterns->insert, + TypeAgnosticOpConverter, + TypeAgnosticOpConverter>(ctx); + } +}; +} // namespace + +std::unique_ptr> createConvertPushDownPass() { + return std::make_unique(); +} + +} // namespace mlir::pphlo diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index aada80e1..6d996e71 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -54,10 +54,21 @@ VisibilityDiscovery(const llvm::ArrayRef input_vis_list, for (const auto &blockargs : entry_func.getBody().getArguments()) { SPU_ENFORCE(blockargs.getArgNumber() < input_vis_list.size(), "Input visibility list does not match actual inputs."); - auto v = - symbolizeEnum(input_vis_list[blockargs.getArgNumber()]); - SPU_ENFORCE(v.has_value(), "Input visibility list has invalid value."); - vis_map.setValueVisibility(blockargs, *v); + Visibility v; + + // There is no compile time private support at this moment. + // Force compiler to treat private as secret for now + if (input_vis_list[blockargs.getArgNumber()] == "VIS_PRIVATE") { + v = Visibility::VIS_SECRET; + } else { + auto v_optional = + symbolizeEnum(input_vis_list[blockargs.getArgNumber()]); + SPU_ENFORCE(v_optional.has_value(), + "Input visibility list has invalid value. value = {}", + input_vis_list[blockargs.getArgNumber()]); + v = *v_optional; + } + vis_map.setValueVisibility(blockargs, v); } VisibilityInference inference(vis_map); diff --git a/libspu/compiler/passes/optimize_select.cc b/libspu/compiler/passes/optimize_select.cc index 44765619..054f8b46 100644 --- a/libspu/compiler/passes/optimize_select.cc +++ b/libspu/compiler/passes/optimize_select.cc @@ -24,6 +24,19 @@ namespace mlir::pphlo { namespace { +/// Returns true if 'val' is a splat of zero, false otherwise. +static bool isSplatZero(DenseElementsAttr val) { + auto type = val.getElementType(); + if (llvm::isa(type)) { + return val && val.isSplat() && val.getSplatValue().isZero(); + } + if (llvm::isa(type)) { + return val && val.isSplat() && val.getSplatValue().isZero(); + } + return false; +} + +// Pattern 1 // Idea here: // select(p, x, y) // into @@ -32,12 +45,34 @@ namespace { // Rational: // If the predicate is used by multiple select, explicit doing a to_a op can // reduce the cost of to_a + +// Pattern 2 +// Idea here: +// select(pred, x, const_0) +// into +// mul(pred, x) +// Rational: +// This is a pattern created by xla alg simplifier struct SelectConversion : public OpRewritePattern { public: explicit SelectConversion(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &) const override { + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewrite) const override { + + // Pattern 2 first: + auto on_false = op.getOnFalse(); + if (auto on_false_const = on_false.getDefiningOp()) { + auto dea = on_false_const.getValue().dyn_cast(); + if (isSplatZero(dea)) { + rewrite.replaceOpWithNewOp(op, op->getResultTypes(), + op.getPred(), op.getOnTrue()); + return success(); + } + } + + // Pattern 1: auto pred = op.getPred(); // Only do this for certain select... if (pred.getDefiningOp() != nullptr) { diff --git a/libspu/compiler/passes/passes.h b/libspu/compiler/passes/passes.h index 6fc54988..fbda4cfb 100644 --- a/libspu/compiler/passes/passes.h +++ b/libspu/compiler/passes/passes.h @@ -75,6 +75,8 @@ std::unique_ptr> createInsertDeallocationOp(); std::unique_ptr> createSortLowering(); +std::unique_ptr> createConvertPushDownPass(); + } // namespace pphlo } // namespace mlir diff --git a/libspu/compiler/passes/passes.td b/libspu/compiler/passes/passes.td index ae045690..281649d4 100644 --- a/libspu/compiler/passes/passes.td +++ b/libspu/compiler/passes/passes.td @@ -104,3 +104,9 @@ def SortLowering: Pass<"sort-lowering", "func::FuncOp"> { let constructor = "createSortLowering()"; let dependentDialects = ["pphlo::PPHloDialect"]; } + +def ConvertPushDown: Pass<"convert-push-down", "func::FuncOp"> { + let summary = "Push convert later"; + let constructor = "createConvertPushDownPass()"; + let dependentDialects = ["pphlo::PPHloDialect"]; +} diff --git a/libspu/compiler/tests/convert_push_down.mlir b/libspu/compiler/tests/convert_push_down.mlir new file mode 100644 index 00000000..d4dae75c --- /dev/null +++ b/libspu/compiler/tests/convert_push_down.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-pphlo-opt --convert-push-down --cse --split-input-file %s | FileCheck %s + +func.func @main(%arg0: tensor<4x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>) { + // CHECK: %0 = "pphlo.reshape"(%arg0) : (tensor<4x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + // CHECK: %1 = "pphlo.convert"(%0) : (tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + %0 = "pphlo.convert"(%arg0) : (tensor<4x!pphlo.pub>) -> tensor<4x!pphlo.pub> + %1 = "pphlo.reshape"(%0) : (tensor<4x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + %2 = "pphlo.multiply"(%1, %arg1) : (tensor<2x2x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + return %2 : tensor<2x2x!pphlo.pub> +} + +// ----- + +func.func @main(%arg0: tensor<2x3x!pphlo.pub>, %arg1: tensor<2x3x!pphlo.pub>) -> (tensor<3x3x!pphlo.pub>) { + // CHECK: %0 = "pphlo.transpose"(%arg0) + // CHECK: %1 = "pphlo.convert"(%0) + %0 = "pphlo.convert"(%arg0) : (tensor<2x3x!pphlo.pub>) -> tensor<2x3x!pphlo.pub> + %1 = "pphlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3x!pphlo.pub>) -> tensor<3x2x!pphlo.pub> + %2 = "pphlo.dot"(%1, %arg1) : (tensor<3x2x!pphlo.pub>, tensor<2x3x!pphlo.pub>) -> tensor<3x3x!pphlo.pub> + return %2 : tensor<3x3x!pphlo.pub> +} + +// ----- diff --git a/libspu/compiler/tests/enum_conversion_test.cc b/libspu/compiler/tests/enum_conversion_test.cc index 3d98fde3..889110a5 100644 --- a/libspu/compiler/tests/enum_conversion_test.cc +++ b/libspu/compiler/tests/enum_conversion_test.cc @@ -42,6 +42,8 @@ TEST(EnumConversion, ProtoKinds) { break; case Visibility::VIS_PUBLIC: break; + case Visibility::VIS_PRIVATE: + break; case Visibility::VIS_INVALID: break; case Visibility::Visibility_INT_MAX_SENTINEL_DO_NOT_USE_: diff --git a/libspu/compiler/tests/optimize_select.mlir b/libspu/compiler/tests/optimize_select.mlir index d8b430a3..4572f295 100644 --- a/libspu/compiler/tests/optimize_select.mlir +++ b/libspu/compiler/tests/optimize_select.mlir @@ -2,7 +2,7 @@ func.func @main() -> (tensor>) { %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> - %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor} : () -> tensor> %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> //CHECK-NOT: pphlo.prefer_a %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> @@ -13,7 +13,7 @@ func.func @main() -> (tensor>) { func.func @main() -> (tensor>, tensor>) { %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> - %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor} : () -> tensor> %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> //CHECK: pphlo.prefer_a %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> @@ -25,7 +25,7 @@ func.func @main() -> (tensor>, tensor>) { func.func @main() -> (tensor>, tensor>) { %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> - %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor} : () -> tensor> %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> //CHECK-NOT: pphlo.prefer_a %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> @@ -37,9 +37,19 @@ func.func @main() -> (tensor>, tensor>) { func.func @main(%arg0: tensor>) -> (tensor>, tensor>) { %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> - %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<1.000000e+00> : tensor} : () -> tensor> //CHECK: pphlo.prefer_a %2 = "pphlo.select"(%arg0, %0, %1): (tensor>, tensor>, tensor>) -> tensor> %3 = "pphlo.select"(%arg0, %1, %0): (tensor>, tensor>, tensor>) -> tensor> return %2, %3: tensor>, tensor> } + + +// ----- + +func.func @main(%arg0: tensor>, %arg1: tensor>) -> (tensor>) { + %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + //CHECK: pphlo.multiply + %2 = "pphlo.select"(%arg0, %arg1, %1): (tensor>, tensor>, tensor>) -> tensor> + return %2: tensor> +} diff --git a/libspu/core/context.h b/libspu/core/context.h index c27d23c4..1f0478e2 100644 --- a/libspu/core/context.h +++ b/libspu/core/context.h @@ -198,10 +198,6 @@ template using OptionalAPI = std::optional; inline constexpr std::nullopt_t NotAvailable = std::nullopt; -// TODO: currently unstable, statically config it. -// When it's stable move it to RuntimeConfig or even enable it by default. -// #define SPU_ENABLE_PRIVATE_TYPE - void setupTrace(spu::SPUContext* sctx, const spu::RuntimeConfig& rt_config); } // namespace spu diff --git a/libspu/core/shape.h b/libspu/core/shape.h index b78eb559..439f25bc 100644 --- a/libspu/core/shape.h +++ b/libspu/core/shape.h @@ -38,6 +38,8 @@ class Shape : public std::vector { /*explicit*/ Shape(llvm::ArrayRef arr) : Base(arr.begin(), arr.end()) {} + Shape(std::initializer_list list) : Base(list) {} + template Shape(std::array arr) : Base(arr.begin(), arr.end()) {} @@ -100,6 +102,8 @@ class Strides : public std::vector { /*explicit*/ Strides(llvm::ArrayRef arr) : Base(arr.begin(), arr.end()) {} + Strides(std::initializer_list list) : Base(list) {} + friend std::ostream &operator<<(std::ostream &out, const Strides &s) { out << fmt::format("{}", fmt::join(s, "x")); return out; diff --git a/libspu/core/type_util.cc b/libspu/core/type_util.cc index 896de751..8261e03f 100644 --- a/libspu/core/type_util.cc +++ b/libspu/core/type_util.cc @@ -27,6 +27,9 @@ std::ostream& operator<<(std::ostream& os, const Visibility& vtype) { case VIS_SECRET: os << "S"; break; + case VIS_PRIVATE: + os << "V"; + break; default: os << "Invalid"; } diff --git a/libspu/core/value.cc b/libspu/core/value.cc index 7c30372f..360ba623 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -30,6 +30,8 @@ Visibility getVisibilityFromType(const Type& ty) { return VIS_SECRET; } else if (ty.isa()) { return VIS_PUBLIC; + } else if (ty.isa()) { + return VIS_PRIVATE; } else { return VIS_INVALID; } @@ -196,8 +198,14 @@ Value Value::clone() const { } std::ostream& operator<<(std::ostream& out, const Value& v) { - out << fmt::format("Value<{}x{}{},s={}>", fmt::join(v.shape(), "x"), - v.vtype(), v.dtype(), fmt::join(v.strides(), ",")); + if (v.isPrivate()) { + out << fmt::format("Value<{}x{}{},s={},o={}>", fmt::join(v.shape(), "x"), + v.vtype(), v.dtype(), fmt::join(v.strides(), ","), + v.storage_type().as()->owner()); + } else { + out << fmt::format("Value<{}x{}{},s={}>", fmt::join(v.shape(), "x"), + v.vtype(), v.dtype(), fmt::join(v.strides(), ",")); + } return out; } diff --git a/libspu/core/value.h b/libspu/core/value.h index 759e393b..90043164 100644 --- a/libspu/core/value.h +++ b/libspu/core/value.h @@ -66,6 +66,7 @@ class Value final { Visibility vtype() const; bool isPublic() const { return vtype() == VIS_PUBLIC; } bool isSecret() const { return vtype() == VIS_SECRET; } + bool isPrivate() const { return vtype() == VIS_PRIVATE; } // Get dtype. DataType dtype() const { return dtype_; } diff --git a/libspu/cuda_support/BUILD.bazel b/libspu/cuda_support/BUILD.bazel new file mode 100644 index 00000000..7e413426 --- /dev/null +++ b/libspu/cuda_support/BUILD.bazel @@ -0,0 +1,86 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cuda//cuda:defs.bzl", "cuda_library") +load("//bazel:spu.bzl", "spu_cc_test") + +package(default_visibility = ["//visibility:public"]) + +cuda_library( + name = "kernels", + srcs = ["kernels.cu"], + hdrs = ["kernels.h"], + tags = [ + "manual", # Exclude this target from :all expansion + ], + deps = [ + "@com_github_nvidia_cutlass//:cutlass", + ], +) + +cuda_library( + name = "utils", + srcs = ["utils.cu"], + hdrs = ["utils.h"], + tags = [ + "manual", # Exclude this target from :all expansion + ], + deps = [ + "//libspu/core:shape", + ], +) + +spu_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + tags = [ + "manual", # Exclude this target from :all expansion + ], + target_compatible_with = select({ + "@rules_cuda//cuda:is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":utils", + ], +) + +spu_cc_test( + name = "kernels_test", + srcs = ["kernels_test.cc"], + tags = [ + "manual", # Exclude this target from :all expansion + ], + target_compatible_with = select({ + "@rules_cuda//cuda:is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":kernels", + ":utils", + ], +) diff --git a/libspu/cuda_support/kernels.cu b/libspu/cuda_support/kernels.cu new file mode 100644 index 00000000..23db9f70 --- /dev/null +++ b/libspu/cuda_support/kernels.cu @@ -0,0 +1,67 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "libspu/cuda_support/kernels.h" + +#define KERNEL_CALL(funcname, n) funcname<<<((n) + 255) / 256, 256>>> +#define GLOBAL_INDEX (blockDim.x * blockIdx.x + threadIdx.x) + +namespace spu::cuda { + +namespace { + +// NN means both A and B is not transposed. That is: C = A * B. +cudaError_t cutlassMatmul_u64(int64_t M, int64_t N, int64_t K, uint64_t alpha, + uint64_t const* A, int64_t lda, uint64_t const* B, + int64_t ldb, uint64_t beta, uint64_t* C, + int64_t ldc) { + using CutlassGemm = cutlass::gemm::device::Gemm< + uint64_t, cutlass::layout::RowMajor, uint64_t, cutlass::layout::RowMajor, + uint64_t, cutlass::layout::RowMajor, uint64_t, cutlass::arch::OpClassSimt, + cutlass::arch::Sm80>; + CutlassGemm gemm_operator; + CutlassGemm::Arguments args({(int)M, (int)N, (int)K}, {A, lda}, {B, ldb}, + {C, ldc}, {C, ldc}, {alpha, beta}); + cutlass::Status status = gemm_operator(args); + if (status != cutlass::Status::kSuccess) { + return cudaErrorUnknown; + } + return cudaSuccess; +} + +__global__ void deviceArrayAddInplaceKernel(uint64_t* A, const uint64_t* B, + int64_t numel) { + size_t idx = GLOBAL_INDEX; + if (idx < numel) { + A[idx] += B[idx]; + } +} + +} // namespace + +// The matrices are in row-major format. +// They are copied to the GPU memory to perform device matrix multiplication. +void matmul(int64_t M, int64_t N, int64_t K, const uint64_t* A, uint64_t* B, + uint64_t* C) { + cutlassMatmul_u64(M, N, K, 1, A, K, B, N, 0, C, N); +} + +void add(uint64_t* A, const uint64_t* B, int64_t numel) { + KERNEL_CALL(deviceArrayAddInplaceKernel, numel)(A, B, numel); +} + +} // namespace spu::cuda diff --git a/libspu/cuda_support/kernels.h b/libspu/cuda_support/kernels.h new file mode 100644 index 00000000..9746d770 --- /dev/null +++ b/libspu/cuda_support/kernels.h @@ -0,0 +1,27 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace spu::cuda { + +// uint64 implementation +void matmul(int64_t M, int64_t N, int64_t K, const uint64_t* A, uint64_t* B, + uint64_t* C); + +void add(uint64_t* A, const uint64_t* B, int64_t numel); + +} // namespace spu::cuda diff --git a/libspu/cuda_support/kernels_test.cc b/libspu/cuda_support/kernels_test.cc new file mode 100644 index 00000000..14234dd6 --- /dev/null +++ b/libspu/cuda_support/kernels_test.cc @@ -0,0 +1,78 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/cuda_support/kernels.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "libspu/cuda_support/utils.h" + +namespace spu::cuda { + +template +std::shared_ptr GenerateGPUData(int64_t numel, T start = 0) { + auto gpu_ptr = allocate(numel * sizeof(T)); + + std::vector data(numel); + std::iota(data.begin(), data.end(), start); + + CopyToCudaDevice((std::byte*)data.data(), numel * sizeof(T), {numel}, {1}, + gpu_ptr.get(), sizeof(T)); + + return gpu_ptr; +} + +template +std::vector getGPUData(std::byte* gpu_ptr, int64_t numel) { + std::vector cpu(numel); + CopyFromCudaDevice(gpu_ptr, (std::byte*)cpu.data(), numel, sizeof(T), 1); + return cpu; +} + +TEST(CudaKernels, BasicAdd) { + auto x = GenerateGPUData(10); + auto y = GenerateGPUData(10, 10); + + add(reinterpret_cast(x.get()), + reinterpret_cast(y.get()), 10); + + auto ret = getGPUData(x.get(), 10); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(ret[idx], idx + 10 + idx); + } +} + +TEST(CudaKernels, Matmul) { + auto x = GenerateGPUData(12, 1); // 4x3 + auto y = GenerateGPUData(6, 1); // 3x2 + auto c = allocate(8 * sizeof(uint64_t)); + + matmul(4, 2, 3, reinterpret_cast(x.get()), + reinterpret_cast(y.get()), + reinterpret_cast(c.get())); + + auto ret = getGPUData(c.get(), 8); + + std::vector expected = {22, 28, 49, 64, 76, 100, 103, 136}; + for (size_t idx = 0; idx < 8; ++idx) { + EXPECT_EQ(ret[idx], expected[idx]); + } +} + +} // namespace spu::cuda diff --git a/libspu/cuda_support/utils.cu b/libspu/cuda_support/utils.cu new file mode 100644 index 00000000..48e5ba73 --- /dev/null +++ b/libspu/cuda_support/utils.cu @@ -0,0 +1,234 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "libspu/cuda_support/utils.h" + +namespace spu::cuda { + +namespace kernels { + +__global__ void printGPUData_(const int* gpu_ptr, size_t numel) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + printf("print idx = %d", idx); + for (size_t idx = 0; idx < numel; ++idx) { + printf("%d ", gpu_ptr[idx]); + } + printf("\n"); +} + +__global__ void compactGPUMemory(std::byte* strided, std::byte* compact, + int64_t col, int64_t s_row, int64_t s_col, + int64_t elsize, int64_t numel) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= numel) { + return; + } + + auto r = idx / col; + auto c = idx % col; + auto s_idx = r * s_row + c * s_col; + s_idx *= elsize; + for (int offset = 0; offset < elsize; ++offset) { + compact[idx * elsize + offset] = strided[s_idx + offset]; + } +} + +__global__ void deinterleaveCopy(std::byte* strided_interleave, + std::byte* compact_p1, std::byte* compact_p2, + int64_t col, int64_t s_row, int64_t s_col, + int64_t elsize, int numel) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= numel) { + return; + } + + auto r = idx / col; + auto c = idx % col; + + auto s_idx = r * s_row + c * s_col; + for (int offset = 0; offset < elsize; ++offset) { + compact_p1[idx * elsize + offset] = strided_interleave[s_idx + offset]; + compact_p2[idx * elsize + offset] = + strided_interleave[s_idx + offset + elsize]; + } +} + +} // namespace kernels + +void checkGPUStatus(cudaError_t status, std::string_view msg) { + if (status != cudaSuccess) { + printf("%s %s", msg.data(), cudaGetErrorString(status)); + } +} + +void printGPUData(const int* ptr, size_t numel) { + dim3 block(1); + dim3 grid(1); + kernels::printGPUData_<<>>(ptr, numel); +} + +std::shared_ptr allocate(size_t bytes) { + std::byte* ptr = nullptr; + cudaError_t error = cudaMalloc(&ptr, bytes); + checkGPUStatus(error, "Failed to allocate GPU memory:"); + return std::shared_ptr(ptr, deallocate); +} + +void deallocate(std::byte* ptr) noexcept { + auto error = cudaFree(ptr); + checkGPUStatus(error, "Failed to free GPU memory:"); +} + +void CopyToCudaDevice(const std::byte* src, size_t buf_size, const Shape& shape, + const Strides& strides, std::byte* dst, size_t elsize) { +#ifdef LOG_GPU_DATA_COPY + cudaEvent_t start; + cudaEvent_t stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaEventRecord(stop); +#endif + + auto numel = shape.numel(); + cudaError_t error; + + if (numel * elsize == buf_size) { + // Straight copy + error = cudaMemcpy(dst, src, buf_size, cudaMemcpyHostToDevice); + } else { + std::byte* tmp = nullptr; + error = cudaMalloc(&tmp, buf_size); + error = cudaMemcpy(tmp, src, buf_size, cudaMemcpyHostToDevice); + + // we allow max 1024 threads per block, and then scale out the copy across + // multiple blocks + dim3 block(std::min(numel, 1024)); + dim3 grid(numel / block.x + (numel % block.x == 0 ? 0 : 1)); + if (shape.ndim() == 1) { + kernels::compactGPUMemory<<>>(tmp, dst, shape[0], 0, + strides[0], elsize, numel); + } else { + kernels::compactGPUMemory<<>>(tmp, dst, shape[1], strides[0], + strides[1], elsize, numel); + } + error = cudaFree(tmp); + } + + checkGPUStatus(error, "Failed to copy to GPU: "); + +#ifdef LOG_GPU_DATA_COPY + cudaEventSynchronize(stop); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + + printf("Copy %ld bytes from %p takes %fms\n", (long)shape.numel() * elsize, + src, milliseconds); +#endif +} + +void DeinterleaveCopyToCudaDevice(const std::byte* src, size_t buf_size, + const Shape& shape, const Strides& strides, + std::byte* dst0, std::byte* dst1, + size_t elsize) { +#ifdef LOG_GPU_DATA_COPY + cudaEvent_t start; + cudaEvent_t stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaEventRecord(stop); +#endif + + auto numel = shape.numel(); + cudaError_t error; + + // First copy to a strided gpu buffer + std::byte* s_gpu_buffer = nullptr; + error = cudaMalloc(&s_gpu_buffer, buf_size); + error = cudaMemcpy(s_gpu_buffer, src, buf_size, cudaMemcpyHostToDevice); + + // Deinterleave + dim3 block(std::min(numel, 1024)); + dim3 grid(numel / block.x + (numel % block.x == 0 ? 0 : 1)); + + if (shape.ndim() == 1) { + kernels::deinterleaveCopy<<>>(s_gpu_buffer, dst0, dst1, + shape[0], 0, strides[0] * elsize, + elsize / 2, numel); + } else { + kernels::deinterleaveCopy<<>>( + s_gpu_buffer, dst0, dst1, shape[1], strides[0] * elsize, + strides[1] * elsize, elsize / 2, numel); + } + + error = cudaFree(s_gpu_buffer); + checkGPUStatus(error, "Failed to copy to GPU and deinterleave data: "); + +#ifdef LOG_GPU_DATA_COPY + cudaEventSynchronize(stop); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + + printf("Copy %ld bytes from %p takes %fms\n", (long)shape.numel() * elsize, + src, milliseconds); +#endif +} + +void CopyFromCudaDevice(const std::byte* src, std::byte* dst, int64_t numel, + int64_t elsize, int64_t stride) { + cudaError_t result; + if (stride == 1) { + result = cudaMemcpy(dst, src, numel * elsize, cudaMemcpyDeviceToHost); + } else { + result = cudaMemcpy2D(dst, stride * elsize, src, elsize, elsize, numel, + cudaMemcpyDeviceToHost); + } + checkGPUStatus(result, "Failed to copy to host: "); +} + +bool initGPUState() { + int nDevices; + cudaGetDeviceCount(&nDevices); + + for (int i = 0; i < nDevices; i++) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, i); + + if (prop.major >= 0) { + printf("Use GPU:\n"); + printf(" Device Number: %d\n", i); + printf(" Device name: %s\n", prop.name); + cudaSetDevice(i); + return true; + } + } + + return false; +} + +static bool hasGPU = false; +static std::once_flag flag; + +bool hasGPUDevice() { + std::call_once(flag, []() { hasGPU = initGPUState(); }); + return hasGPU; +} + +} // namespace spu::cuda diff --git a/libspu/cuda_support/utils.h b/libspu/cuda_support/utils.h new file mode 100644 index 00000000..b10f62bf --- /dev/null +++ b/libspu/cuda_support/utils.h @@ -0,0 +1,43 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "libspu/core/shape.h" + +namespace spu::cuda { + +std::shared_ptr allocate(size_t bytes); + +void deallocate(std::byte* data) noexcept; + +void CopyToCudaDevice(const std::byte* src, size_t buf_size, const Shape& shape, + const Strides& strides, std::byte* dst, size_t elsize); + +void DeinterleaveCopyToCudaDevice(const std::byte* src, size_t buf_size, + const Shape& shape, const Strides& strides, + std::byte* dst0, std::byte* dst1, + size_t elsize); + +void CopyFromCudaDevice(const std::byte* src, std::byte* dst, int64_t numel, + int64_t elsize, int64_t stride); + +void printGPUData(const int* ptr, size_t numel); + +bool hasGPUDevice(); + +} // namespace spu::cuda diff --git a/libspu/cuda_support/utils_test.cc b/libspu/cuda_support/utils_test.cc new file mode 100644 index 00000000..1c420cbf --- /dev/null +++ b/libspu/cuda_support/utils_test.cc @@ -0,0 +1,235 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/cuda_support/utils.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace spu::cuda { + +TEST(Memory, Allocate) { + auto gpu_ptr = allocate(10); + + EXPECT_NE(gpu_ptr, nullptr); +} + +TEST(Memory, TestCompactCopy1D) { + auto gpu_ptr = allocate(10 * sizeof(int)); + + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), 10 * sizeof(int), + {10}, {1}, gpu_ptr.get(), sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 10, + sizeof(int), 1); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(base[idx], from_gpu[idx]); + } +} + +TEST(Memory, TestCompactDeinterleaveCopy1D) { + auto gpu_ptr1 = allocate(10 * sizeof(int)); + auto gpu_ptr2 = allocate(10 * sizeof(int)); + + std::array base; + std::iota(base.begin(), base.end(), 0); + + DeinterleaveCopyToCudaDevice(reinterpret_cast(base.data()), + 20 * sizeof(int), {10}, {1}, gpu_ptr1.get(), + gpu_ptr2.get(), 2 * sizeof(int)); + + std::array from_gpu1; + std::array from_gpu2; + CopyFromCudaDevice(gpu_ptr1.get(), + reinterpret_cast(from_gpu1.data()), 10, + sizeof(int), 1); + CopyFromCudaDevice(gpu_ptr2.get(), + reinterpret_cast(from_gpu2.data()), 10, + sizeof(int), 1); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(base[2 * idx], from_gpu1[idx]); + EXPECT_EQ(base[2 * idx + 1], from_gpu2[idx]); + } +} + +TEST(Memory, TestCompactCopy2D) { + auto gpu_ptr = allocate(7 * 2304 * sizeof(int)); + + Shape s{7, 2304}; + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), + 7 * 2304 * sizeof(int), s, {2304, 1}, gpu_ptr.get(), + sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 7 * 2304, + sizeof(int), 1); + + for (size_t idx = 0; idx < 7 * 2304; ++idx) { + EXPECT_EQ(base[idx], from_gpu[idx]); + } +} + +TEST(Memory, TestCompactDeinterleaveCopy2D) { + auto gpu_ptr1 = allocate(25 * sizeof(int)); + auto gpu_ptr2 = allocate(25 * sizeof(int)); + + Shape s{5, 5}; + std::array base; + std::iota(base.begin(), base.end(), 0); + + DeinterleaveCopyToCudaDevice(reinterpret_cast(base.data()), + 50 * sizeof(int), s, {5, 1}, gpu_ptr1.get(), + gpu_ptr2.get(), 2 * sizeof(int)); + + std::array from_gpu1; + std::array from_gpu2; + CopyFromCudaDevice(gpu_ptr1.get(), + reinterpret_cast(from_gpu1.data()), 25, + sizeof(int), 1); + CopyFromCudaDevice(gpu_ptr2.get(), + reinterpret_cast(from_gpu2.data()), 25, + sizeof(int), 1); + + for (size_t idx = 0; idx < 25; ++idx) { + EXPECT_EQ(base[2 * idx], from_gpu1[idx]); + EXPECT_EQ(base[2 * idx + 1], from_gpu2[idx]); + } +} + +TEST(Memory, TestStridedCopy1D) { + auto gpu_ptr = allocate(10 * sizeof(int)); + + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), 20 * sizeof(int), + {10}, {2}, gpu_ptr.get(), sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 10, + sizeof(int), 1); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(base[2 * idx], from_gpu[idx]); + } +} + +TEST(Memory, TestStridedDeinterleaveCopy1D) { + auto gpu_ptr1 = allocate(10 * sizeof(int)); + auto gpu_ptr2 = allocate(10 * sizeof(int)); + + std::array base; + std::iota(base.begin(), base.end(), 0); + + DeinterleaveCopyToCudaDevice(reinterpret_cast(base.data()), + 40 * sizeof(int), {10}, {2}, gpu_ptr1.get(), + gpu_ptr2.get(), 2 * sizeof(int)); + + std::array from_gpu1; + std::array from_gpu2; + + CopyFromCudaDevice(gpu_ptr1.get(), + reinterpret_cast(from_gpu1.data()), 10, + sizeof(int), 1); + CopyFromCudaDevice(gpu_ptr2.get(), + reinterpret_cast(from_gpu2.data()), 10, + sizeof(int), 1); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(base[4 * idx], from_gpu1[idx]); + EXPECT_EQ(base[4 * idx + 1], from_gpu2[idx]); + } +} + +TEST(Memory, TestStridedCopy2D) { + auto gpu_ptr = allocate(12 * sizeof(int)); + + Shape shape = {3, 4}; + Strides strides = {8, 2}; + + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), 24 * sizeof(int), + shape, strides, gpu_ptr.get(), sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 12, + sizeof(int), 1); + + for (size_t idx = 0; idx < 12; ++idx) { + EXPECT_EQ(base[2 * idx], from_gpu[idx]); + } +} + +TEST(Memory, TestStridedCopy2D2) { + auto gpu_ptr = allocate(12 * sizeof(int)); + + Shape shape = {3, 4}; + Strides strides = {16, 2}; + + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), 48 * sizeof(int), + shape, strides, gpu_ptr.get(), sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 12, + sizeof(int), 1); + + for (size_t row = 0; row < 3; ++row) { + for (size_t col = 0; col < 4; ++col) { + EXPECT_EQ(base[row * strides[0] + col * strides[1]], + from_gpu[row * shape[1] + col]); + } + } +} + +TEST(Memory, TestStidedCopyFrom) { + auto gpu_ptr = allocate(10 * sizeof(int)); + + std::array base; + std::iota(base.begin(), base.end(), 0); + + CopyToCudaDevice(reinterpret_cast(base.data()), 10 * sizeof(int), + {10}, {1}, gpu_ptr.get(), sizeof(int)); + + std::array from_gpu; + CopyFromCudaDevice(gpu_ptr.get(), + reinterpret_cast(from_gpu.data()), 10, + sizeof(int), 2); + + for (size_t idx = 0; idx < 10; ++idx) { + EXPECT_EQ(base[idx], from_gpu[2 * idx]); + } +} + +} // namespace spu::cuda diff --git a/libspu/device/io.cc b/libspu/device/io.cc index 117fb333..4a40667e 100644 --- a/libspu/device/io.cc +++ b/libspu/device/io.cc @@ -89,7 +89,11 @@ std::vector IoClient::makeShares(const PtBufferView &bv, NdArrayRef encoded = encodeToRing(bv, config_.field(), fxp_bits, &dtype); // make shares. - std::vector shares = base_io_->toShares(encoded, vtype); + if (!config_.experimental_enable_colocated_optimization()) { + owner_rank = -1; + } + std::vector shares = + base_io_->toShares(encoded, vtype, owner_rank); // build value. std::vector result; diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc index 9f210160..dd724b97 100644 --- a/libspu/device/pphlo/pphlo_executor_test.cc +++ b/libspu/device/pphlo/pphlo_executor_test.cc @@ -2234,12 +2234,12 @@ TEST_P(ExecutorTest, CasePrivate) { } TEST_P(ExecutorTest, MixedPayload) { - xt::xarray op = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, - 99, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - xt::xarray expected_ret0 = {1, 2, 3, 4, 5, 6, 7, 7, 8, 8, - 9, 9, 10, 10, 11, 12, 13, 14, 15, 99}; - xt::xarray expected_ret1 = {9, 8, 7, 6, 5, 4, 3, 11, 12, 2, - 1, 13, 14, 0, 15, 16, 17, 18, 19, 10}; + xt::xarray op = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + 99, 97, 98, 96, 91, 11, 12, 13, 14, 15}; + xt::xarray expected_ret0 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 91, 96, 97, 98, 99}; + xt::xarray expected_ret1 = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 15, 16, 17, 18, 19, 14, 13, 11, 12, 10}; Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index a1e3feda..8b45c125 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -70,7 +70,7 @@ Value highestOneBit(SPUContext* ctx, const Value& x) { } // FIXME: -// Use range propatation instead of directly set. +// Use range propagation instead of directly set. // or expose bit_decompose as mpc level api. void hintNumberOfBits(const Value& a, size_t nbits) { if (a.storage_type().isa()) { diff --git a/libspu/kernel/hal/prot_wrapper.cc b/libspu/kernel/hal/prot_wrapper.cc index 90bf27dc..e5204fa2 100644 --- a/libspu/kernel/hal/prot_wrapper.cc +++ b/libspu/kernel/hal/prot_wrapper.cc @@ -65,6 +65,11 @@ Type _common_type_s(SPUContext* ctx, const Type& a, const Type& b) { return mpc::common_type_s(ctx, a, b); } +Type _common_type_v(SPUContext* ctx, const Type& a, const Type& b) { + SPU_TRACE_HAL_DISP(ctx, a, b); + return mpc::common_type_v(ctx, a, b); +} + Value _cast_type_s(SPUContext* ctx, const Value& in, const Type& to) { SPU_TRACE_HAL_DISP(ctx, in, to); auto ret = mpc::cast_type_s(ctx, in, to); @@ -107,6 +112,10 @@ Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign) { return mpc::trunc_s(ctx, in, bits, sign); } +Value _trunc_v(SPUContext* ctx, const Value& in, size_t bits, SignType sign) { + SPU_TRACE_HAL_DISP(ctx, in, bits, sign); + return mpc::trunc_v(ctx, in, bits, sign); +} std::vector _sort_s(SPUContext* ctx, absl::Span x) { SPU_TRACE_HAL_DISP(ctx, x.size()); // FIXME(jimi): formalize mpc sort api @@ -114,38 +123,93 @@ std::vector _sort_s(SPUContext* ctx, absl::Span x) { // As pass absl::Span in dynDispatch is dangerous, we initialize a new vector // here. And the copy of value is cheap, so it's ok. std::vector x_val(x.begin(), x.end()); - return dynDispatch>(ctx, "sort_a", x_val); + auto ret = dynDispatch>(ctx, "sort_a", x_val); + SPU_ENFORCE_EQ(x_val.size(), ret.size(), + "sorted results and inputs sizes should match"); + + for (size_t i = 0; i < x_val.size(); ++i) { + ret[i].setDtype(x_val[i].dtype()); + } + return ret; } +// p<->s MAP_UNARY_OP(p2s) MAP_UNARY_OP(s2p) + +// p<->v +MAP_UNARY_OP(v2p) +Value _p2v(SPUContext* ctx, const Value& in, int owner) { + SPU_TRACE_HAL_DISP(ctx, in, owner); + return mpc::p2v(ctx, in, owner); +} + +// s<->v +MAP_UNARY_OP(v2s) +Value _s2v(SPUContext* ctx, const Value& in, int owner) { + SPU_TRACE_HAL_DISP(ctx, in, owner); + return mpc::s2v(ctx, in, owner); +} + +// Not family MAP_UNARY_OP(not_p) MAP_UNARY_OP(not_s) +MAP_UNARY_OP(not_v) +// Msb family MAP_UNARY_OP(msb_p) MAP_UNARY_OP(msb_s) +MAP_UNARY_OP(msb_v) +// lshift family MAP_SHIFT_OP(lshift_p) MAP_SHIFT_OP(lshift_s) +MAP_SHIFT_OP(lshift_v) +// rshift family MAP_SHIFT_OP(rshift_p) MAP_SHIFT_OP(rshift_s) +MAP_SHIFT_OP(rshift_v) +// arshift family MAP_SHIFT_OP(arshift_p) MAP_SHIFT_OP(arshift_s) +MAP_SHIFT_OP(arshift_v) +// bitrev family MAP_BITREV_OP(bitrev_p) MAP_BITREV_OP(bitrev_s) +MAP_BITREV_OP(bitrev_v) +// Add family MAP_BINARY_OP(add_pp) MAP_BINARY_OP(add_sp) MAP_BINARY_OP(add_ss) +MAP_BINARY_OP(add_sv) +MAP_BINARY_OP(add_vp) +MAP_BINARY_OP(add_vv) +// Mul family MAP_BINARY_OP(mul_pp) MAP_BINARY_OP(mul_sp) MAP_BINARY_OP(mul_ss) +MAP_BINARY_OP(mul_sv) +MAP_BINARY_OP(mul_vp) +MAP_BINARY_OP(mul_vv) +// And family MAP_BINARY_OP(and_pp) MAP_BINARY_OP(and_sp) MAP_BINARY_OP(and_ss) +MAP_BINARY_OP(and_sv) +MAP_BINARY_OP(and_vp) +MAP_BINARY_OP(and_vv) +// Xor family MAP_BINARY_OP(xor_pp) MAP_BINARY_OP(xor_sp) MAP_BINARY_OP(xor_ss) +MAP_BINARY_OP(xor_sv) +MAP_BINARY_OP(xor_vp) +MAP_BINARY_OP(xor_vv) +// mmul family MAP_MMUL_OP(mmul_pp) MAP_MMUL_OP(mmul_sp) MAP_MMUL_OP(mmul_ss) +MAP_MMUL_OP(mmul_sv) +MAP_MMUL_OP(mmul_vp) +MAP_MMUL_OP(mmul_vv) #define MAP_OPTIONAL_BINARY_OP(NAME) \ std::optional _##NAME(SPUContext* ctx, const Value& x, \ diff --git a/libspu/kernel/hal/prot_wrapper.h b/libspu/kernel/hal/prot_wrapper.h index 0d769ce7..4c6f7d99 100644 --- a/libspu/kernel/hal/prot_wrapper.h +++ b/libspu/kernel/hal/prot_wrapper.h @@ -28,16 +28,25 @@ namespace spu::kernel::hal { // !!please read [README.md] for api naming conventions. Type _common_type_s(SPUContext* ctx, const Type& a, const Type& b); +Type _common_type_v(SPUContext* ctx, const Type& a, const Type& b); Value _cast_type_s(SPUContext* ctx, const Value& in, const Type& to); Value _p2s(SPUContext* ctx, const Value& in); Value _s2p(SPUContext* ctx, const Value& in); +Value _p2v(SPUContext* ctx, const Value& in, int owner); +Value _v2p(SPUContext* ctx, const Value& in); + +Value _s2v(SPUContext* ctx, const Value& in, int owner); +Value _v2s(SPUContext* ctx, const Value& in); + Value _not_p(SPUContext* ctx, const Value& in); Value _not_s(SPUContext* ctx, const Value& in); +Value _not_v(SPUContext* ctx, const Value& in); Value _msb_p(SPUContext* ctx, const Value& in); Value _msb_s(SPUContext* ctx, const Value& in); +Value _msb_v(SPUContext* ctx, const Value& in); Value _equal_pp(SPUContext* ctx, const Value& x, const Value& y); std::optional _equal_sp(SPUContext* ctx, const Value& x, const Value& y); @@ -45,26 +54,40 @@ std::optional _equal_ss(SPUContext* ctx, const Value& x, const Value& y); Value _lshift_p(SPUContext* ctx, const Value& in, size_t bits); Value _lshift_s(SPUContext* ctx, const Value& in, size_t bits); +Value _lshift_v(SPUContext* ctx, const Value& in, size_t bits); Value _rshift_p(SPUContext* ctx, const Value& in, size_t bits); Value _rshift_s(SPUContext* ctx, const Value& in, size_t bits); +Value _rshift_v(SPUContext* ctx, const Value& in, size_t bits); Value _arshift_p(SPUContext* ctx, const Value& in, size_t bits); Value _arshift_s(SPUContext* ctx, const Value& in, size_t bits); +Value _arshift_v(SPUContext* ctx, const Value& in, size_t bits); + Value _trunc_p(SPUContext* ctx, const Value& in, size_t bits, SignType sign); Value _trunc_s(SPUContext* ctx, const Value& in, size_t bits, SignType sign); +Value _trunc_v(SPUContext* ctx, const Value& in, size_t bits, SignType sign); Value _add_pp(SPUContext* ctx, const Value& x, const Value& y); Value _add_sp(SPUContext* ctx, const Value& x, const Value& y); Value _add_ss(SPUContext* ctx, const Value& x, const Value& y); +Value _add_vv(SPUContext* ctx, const Value& x, const Value& y); +Value _add_vp(SPUContext* ctx, const Value& x, const Value& y); +Value _add_sv(SPUContext* ctx, const Value& x, const Value& y); Value _mul_pp(SPUContext* ctx, const Value& x, const Value& y); Value _mul_sp(SPUContext* ctx, const Value& x, const Value& y); Value _mul_ss(SPUContext* ctx, const Value& x, const Value& y); +Value _mul_vv(SPUContext* ctx, const Value& x, const Value& y); +Value _mul_vp(SPUContext* ctx, const Value& x, const Value& y); +Value _mul_sv(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_pp(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_sp(SPUContext* ctx, const Value& x, const Value& y); Value _mmul_ss(SPUContext* ctx, const Value& x, const Value& y); +Value _mmul_vv(SPUContext* ctx, const Value& x, const Value& y); +Value _mmul_vp(SPUContext* ctx, const Value& x, const Value& y); +Value _mmul_sv(SPUContext* ctx, const Value& x, const Value& y); Value _conv2d_ss(SPUContext* ctx, const Value& input, const Value& kernel, const Strides& strides); @@ -72,13 +95,20 @@ Value _conv2d_ss(SPUContext* ctx, const Value& input, const Value& kernel, Value _and_pp(SPUContext* ctx, const Value& x, const Value& y); Value _and_sp(SPUContext* ctx, const Value& x, const Value& y); Value _and_ss(SPUContext* ctx, const Value& x, const Value& y); +Value _and_vv(SPUContext* ctx, const Value& x, const Value& y); +Value _and_vp(SPUContext* ctx, const Value& x, const Value& y); +Value _and_sv(SPUContext* ctx, const Value& x, const Value& y); Value _xor_pp(SPUContext* ctx, const Value& x, const Value& y); Value _xor_sp(SPUContext* ctx, const Value& x, const Value& y); Value _xor_ss(SPUContext* ctx, const Value& x, const Value& y); +Value _xor_vv(SPUContext* ctx, const Value& x, const Value& y); +Value _xor_vp(SPUContext* ctx, const Value& x, const Value& y); +Value _xor_sv(SPUContext* ctx, const Value& x, const Value& y); Value _bitrev_p(SPUContext* ctx, const Value& in, size_t start, size_t end); Value _bitrev_s(SPUContext* ctx, const Value& in, size_t start, size_t end); +Value _bitrev_v(SPUContext* ctx, const Value& in, size_t start, size_t end); Value _make_p(SPUContext* ctx, uint128_t init, const Shape& shape); diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index d3cc3cf2..c6eaa9c7 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -25,53 +25,6 @@ #include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel::hal { -namespace { - -std::tuple deduceMmulArgs( - const std::vector& lhs, const std::vector& rhs) { - SPU_ENFORCE(!lhs.empty() && lhs.size() <= 2); - SPU_ENFORCE(!rhs.empty() && rhs.size() <= 2); - - if (lhs.size() == 1 && rhs.size() == 1) { - SPU_ENFORCE(lhs[0] == rhs[0]); - return std::make_tuple(1, 1, rhs[0]); - } - if (lhs.size() == 1 && rhs.size() == 2) { - SPU_ENFORCE(lhs[0] == rhs[0]); - return std::make_tuple(1, rhs[1], rhs[0]); - } - if (lhs.size() == 2 && rhs.size() == 1) { - SPU_ENFORCE(lhs[1] == rhs[0]); - return std::make_tuple(lhs[0], 1, rhs[0]); - } - SPU_ENFORCE(lhs[1] == rhs[0]); - return std::make_tuple(lhs[0], rhs[1], rhs[0]); -} - -std::tuple calcMmulTilingSize(int64_t m, int64_t n, - int64_t k, - size_t elsize, - size_t mem_limit) { - if (m == 0 || n == 0 || k == 0) { - return {m, n, k}; - } - const auto elnum_limit = static_cast(mem_limit / elsize); - const int64_t expected_step = std::ceil(std::sqrt(elnum_limit)); - - const int64_t expected_mn_step = std::min((m + n), expected_step); - const int64_t k_step = std::max(std::min(k, elnum_limit / expected_mn_step), - static_cast(1)); - - // split expected_mn_step into m/n by radio - const int64_t m_step = - std::max(expected_mn_step * m / (m + n), static_cast(1)); - const int64_t n_step = - std::max(expected_mn_step * n / (m + n), static_cast(1)); - - return {m_step, n_step, k_step}; -} - -} // namespace Type _common_type(SPUContext* ctx, const Type& a, const Type& b) { if (a.isa() && b.isa()) { @@ -99,79 +52,96 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { } } -#define IMPL_UNARY_OP(Name, FnP, FnS) \ +#define IMPL_UNARY_OP(Name) \ Value Name(SPUContext* ctx, const Value& in) { \ SPU_TRACE_HAL_LEAF(ctx, in); \ if (in.isPublic()) { \ - return FnP(ctx, in); \ + return Name##_p(ctx, in); \ } else if (in.isSecret()) { \ - return FnS(ctx, in); \ + return Name##_s(ctx, in); \ + } else if (in.isPrivate()) { \ + return Name##_v(ctx, in); \ } else { \ SPU_THROW("unsupport unary op={} for {}", #Name, in); \ } \ } -#define IMPL_SHIFT_OP(Name, FnP, FnS) \ +IMPL_UNARY_OP(_not) +IMPL_UNARY_OP(_msb) + +#undef IMPL_UNARY_OP + +#define IMPL_SHIFT_OP(Name) \ Value Name(SPUContext* ctx, const Value& in, size_t bits) { \ SPU_TRACE_HAL_LEAF(ctx, in, bits); \ if (in.isPublic()) { \ - return FnP(ctx, in, bits); \ + return Name##_p(ctx, in, bits); \ } else if (in.isSecret()) { \ - return FnS(ctx, in, bits); \ + return Name##_s(ctx, in, bits); \ + } else if (in.isPrivate()) { \ + return Name##_v(ctx, in, bits); \ } else { \ SPU_THROW("unsupport unary op={} for {}", #Name, in); \ } \ } -#define IMPL_COMMUTATIVE_BINARY_OP(Name, FnPP, FnSP, FnSS) \ +IMPL_SHIFT_OP(_lshift) +IMPL_SHIFT_OP(_rshift) +IMPL_SHIFT_OP(_arshift) + +#undef IMPL_SHIFT_OP + +#define IMPL_COMMUTATIVE_BINARY_OP(Name) \ Value Name(SPUContext* ctx, const Value& x, const Value& y) { \ SPU_TRACE_HAL_LEAF(ctx, x, y); \ - if (x.isPublic() && y.isPublic()) { \ - return FnPP(ctx, x, y); \ - } else if (x.isSecret() && y.isPublic()) { \ - return FnSP(ctx, x, y); \ - } else if (x.isPublic() && y.isSecret()) { \ + if (x.isPublic() && y.isPublic()) { /*PP*/ \ + return Name##_pp(ctx, x, y); \ + } else if (x.isPrivate() && y.isPrivate()) { /*VV*/ \ + return Name##_vv(ctx, x, y); \ + } else if (x.isSecret() && y.isSecret()) { /*SS*/ \ + return Name##_ss(ctx, y, x); \ + } else if (x.isSecret() && y.isPublic()) { /*SP*/ \ + return Name##_sp(ctx, x, y); \ + } else if (x.isPublic() && y.isSecret()) { /*PS*/ \ /* commutative, swap args */ \ - return FnSP(ctx, y, x); \ - } else if (x.isSecret() && y.isSecret()) { \ - return FnSS(ctx, y, x); \ + return Name##_sp(ctx, y, x); \ + } else if (x.isPrivate() && y.isPublic()) { /*VP*/ \ + return Name##_vp(ctx, x, y); \ + } else if (x.isPublic() && y.isPrivate()) { /*PV*/ \ + /* commutative, swap args */ \ + return Name##_vp(ctx, y, x); \ + } else if (x.isPrivate() && y.isSecret()) { /*VS*/ \ + return Name##_sv(ctx, y, x); \ + } else if (x.isSecret() && y.isPrivate()) { /*SV*/ \ + /* commutative, swap args */ \ + return Name##_sv(ctx, x, y); \ } else { \ SPU_THROW("unsupported op {} for x={}, y={}", #Name, x, y); \ } \ } -IMPL_UNARY_OP(_not, _not_p, _not_s) -IMPL_UNARY_OP(_msb, _msb_p, _msb_s) - -IMPL_SHIFT_OP(_lshift, _lshift_p, _lshift_s) -IMPL_SHIFT_OP(_rshift, _rshift_p, _rshift_s) -IMPL_SHIFT_OP(_arshift, _arshift_p, _arshift_s) +IMPL_COMMUTATIVE_BINARY_OP(_add) +IMPL_COMMUTATIVE_BINARY_OP(_mul) +IMPL_COMMUTATIVE_BINARY_OP(_and) +IMPL_COMMUTATIVE_BINARY_OP(_xor) -IMPL_COMMUTATIVE_BINARY_OP(_add, _add_pp, _add_sp, _add_ss) -IMPL_COMMUTATIVE_BINARY_OP(_mul, _mul_pp, _mul_sp, _mul_ss) -IMPL_COMMUTATIVE_BINARY_OP(_and, _and_pp, _and_sp, _and_ss) -IMPL_COMMUTATIVE_BINARY_OP(_xor, _xor_pp, _xor_sp, _xor_ss) +#undef IMPL_COMMUTATIVE_BINARY_OP -Value _sub(SPUContext* ctx, const Value& x, const Value& y) { +static OptionalAPI _equal_impl(SPUContext* ctx, const Value& x, + const Value& y) { SPU_TRACE_HAL_LEAF(ctx, x, y); - return _add(ctx, x, _negate(ctx, y)); -} -// TODO: remove this kernel, the algorithm could be used for boolean equal test. -[[maybe_unused]] Value _eqz(SPUContext* ctx, const Value& x) { - SPU_TRACE_HAL_LEAF(ctx, x); - - // eqz(x) = not(lsb(pre_or(x))) - // all equal to zero means lsb equals to zero - auto _k1 = _constant(ctx, 1U, x.shape()); - auto res = _xor(ctx, _and(ctx, _prefix_or(ctx, x), _k1), _k1); - - // FIXME(jint): see hintNumberOfBits - if (res.storage_type().isa()) { - const_cast(res.storage_type()).as()->setNbits(1); + if (x.isPublic() && y.isPublic()) { + return _equal_pp(ctx, x, y); + } else if (x.isSecret() && y.isPublic()) { + return _equal_sp(ctx, x, y); + } else if (x.isPublic() && y.isSecret()) { /* commutative, swap args */ + return _equal_sp(ctx, y, x); + } else if (x.isSecret() && y.isSecret()) { + return _equal_ss(ctx, y, x); } - return res; + return NotAvailable; } Value _conv2d(SPUContext* ctx, const Value& input, const Value& kernel, @@ -184,19 +154,129 @@ Value _conv2d(SPUContext* ctx, const Value& input, const Value& kernel, } static Value _mmul_impl(SPUContext* ctx, const Value& x, const Value& y) { - if (x.isPublic() && y.isPublic()) { + if (x.isPublic() && y.isPublic()) { // PP return _mmul_pp(ctx, x, y); - } else if (x.isSecret() && y.isPublic()) { + } else if (x.isSecret() && y.isSecret()) { // SS + return _mmul_ss(ctx, x, y); + } else if (x.isPrivate() && y.isPrivate()) { // VV + return _mmul_vv(ctx, x, y); + } else if (x.isSecret() && y.isPublic()) { // SP return _mmul_sp(ctx, x, y); - } else if (x.isPublic() && y.isSecret()) { + } else if (x.isPublic() && y.isSecret()) { // PS return transpose(ctx, _mmul_sp(ctx, transpose(ctx, y), transpose(ctx, x))); - } else if (x.isSecret() && y.isSecret()) { - return _mmul_ss(ctx, x, y); + } else if (x.isPrivate() && y.isPublic()) { // VP + return _mmul_vp(ctx, x, y); + } else if (x.isPublic() && y.isPrivate()) { // PV + return transpose(ctx, _mmul_vp(ctx, transpose(ctx, y), transpose(ctx, x))); + } else if (x.isSecret() && y.isPrivate()) { // SV + return _mmul_sv(ctx, x, y); + } else if (x.isPrivate() && y.isSecret()) { // VS + return transpose(ctx, _mmul_sv(ctx, transpose(ctx, y), transpose(ctx, x))); } else { SPU_THROW("unsupported op {} for x={}, y={}", "_matmul", x, y); } }; +Value _trunc(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { + SPU_TRACE_HAL_LEAF(ctx, x, bits); + bits = (bits == 0) ? ctx->getFxpBits() : bits; + + if (x.isPublic()) { + return _trunc_p(ctx, x, bits, sign); + } else if (x.isSecret()) { + return _trunc_s(ctx, x, bits, sign); + } else if (x.isPrivate()) { + return _trunc_v(ctx, x, bits, sign); + } else { + SPU_THROW("unsupport unary op={} for {}", __func__, x); + } +} + +// swap bits of [start, end) +Value _bitrev(SPUContext* ctx, const Value& x, size_t start, size_t end) { + SPU_TRACE_HAL_LEAF(ctx, x, start, end); + + if (x.isPublic()) { + return _bitrev_p(ctx, x, start, end); + } else if (x.isSecret()) { + return _bitrev_s(ctx, x, start, end); + } else if (x.isPrivate()) { + return _bitrev_v(ctx, x, start, end); + } + + SPU_THROW("unsupport op={} for {}", "_bitrev", x); +} + +namespace { + +std::tuple deduceMmulArgs(const Shape& lhs, + const Shape& rhs) { + SPU_ENFORCE(lhs.ndim() > 0 && lhs.ndim() <= 2); + SPU_ENFORCE(rhs.ndim() > 0 && rhs.ndim() <= 2); + + if (lhs.size() == 1 && rhs.size() == 1) { + SPU_ENFORCE(lhs[0] == rhs[0]); + return std::make_tuple(1, 1, rhs[0]); + } + if (lhs.size() == 1 && rhs.size() == 2) { + SPU_ENFORCE(lhs[0] == rhs[0]); + return std::make_tuple(1, rhs[1], rhs[0]); + } + if (lhs.size() == 2 && rhs.size() == 1) { + SPU_ENFORCE(lhs[1] == rhs[0]); + return std::make_tuple(lhs[0], 1, rhs[0]); + } + SPU_ENFORCE(lhs[1] == rhs[0]); + return std::make_tuple(lhs[0], rhs[1], rhs[0]); +} + +std::tuple calcMmulTilingSize(int64_t m, int64_t n, + int64_t k, + size_t elsize, + size_t mem_limit) { + if (m == 0 || n == 0 || k == 0) { + return {m, n, k}; + } + const auto elnum_limit = static_cast(mem_limit / elsize); + const int64_t expected_step = std::ceil(std::sqrt(elnum_limit)); + + const int64_t expected_mn_step = std::min((m + n), expected_step); + const int64_t k_step = std::max(std::min(k, elnum_limit / expected_mn_step), + static_cast(1)); + + // split expected_mn_step into m/n by radio + const int64_t m_step = + std::max(expected_mn_step * m / (m + n), static_cast(1)); + const int64_t n_step = + std::max(expected_mn_step * n / (m + n), static_cast(1)); + + return {m_step, n_step, k_step}; +} + +} // namespace + +Value _sub(SPUContext* ctx, const Value& x, const Value& y) { + SPU_TRACE_HAL_LEAF(ctx, x, y); + return _add(ctx, x, _negate(ctx, y)); +} + +// TODO: remove this kernel, the algorithm could be used for boolean equal test. +[[maybe_unused]] Value _eqz(SPUContext* ctx, const Value& x) { + SPU_TRACE_HAL_LEAF(ctx, x); + + // eqz(x) = not(lsb(pre_or(x))) + // all equal to zero means lsb equals to zero + auto _k1 = _constant(ctx, 1U, x.shape()); + auto res = _xor(ctx, _and(ctx, _prefix_or(ctx, x), _k1), _k1); + + // FIXME(jint): see hintNumberOfBits + if (res.storage_type().isa()) { + const_cast(res.storage_type()).as()->setNbits(1); + } + + return res; +} + Value _mmul(SPUContext* ctx, const Value& x, const Value& y) { auto [m, n, k] = deduceMmulArgs(x.shape(), y.shape()); @@ -295,23 +375,6 @@ Value _or(SPUContext* ctx, const Value& x, const Value& y) { return _xor(ctx, x, _xor(ctx, y, _and(ctx, x, y))); } -static std::optional _equal_impl(SPUContext* ctx, const Value& x, - const Value& y) { - SPU_TRACE_HAL_LEAF(ctx, x, y); - - if (x.isPublic() && y.isPublic()) { - return _equal_pp(ctx, x, y); - } else if (x.isSecret() && y.isPublic()) { - return _equal_sp(ctx, x, y); - } else if (x.isPublic() && y.isSecret()) { /* commutative, swap args */ - return _equal_sp(ctx, y, x); - } else if (x.isSecret() && y.isSecret()) { - return _equal_ss(ctx, y, x); - } - - return std::nullopt; -} - Value _equal(SPUContext* ctx, const Value& x, const Value& y) { // First try use equal kernel, i.e. for 2PC , equal can be done with the same // cost of half MSB. @@ -331,20 +394,6 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y) { _xor(ctx, _less(ctx, y, x), _k1)); } -// TODO: -Value _trunc(SPUContext* ctx, const Value& x, size_t bits, SignType sign) { - SPU_TRACE_HAL_LEAF(ctx, x, bits); - bits = (bits == 0) ? ctx->getFxpBits() : bits; - - if (x.isPublic()) { - return _trunc_p(ctx, x, bits, sign); - } else if (x.isSecret()) { - return _trunc_s(ctx, x, bits, sign); - } else { - SPU_THROW("unsupport unary op={} for {}", __func__, x); - } -} - Value _negate(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); @@ -376,19 +425,6 @@ Value _less(SPUContext* ctx, const Value& x, const Value& y) { return _msb(ctx, _sub(ctx, x, y)); } -// swap bits of [start, end) -Value _bitrev(SPUContext* ctx, const Value& x, size_t start, size_t end) { - SPU_TRACE_HAL_LEAF(ctx, x, start, end); - - if (x.isPublic()) { - return _bitrev_p(ctx, x, start, end); - } else if (x.isSecret()) { - return _bitrev_s(ctx, x, start, end); - } - - SPU_THROW("unsupport op={} for {}", "_bitrev", x); -} - Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) { SPU_TRACE_HAL_LEAF(ctx, pred, a, b); diff --git a/libspu/kernel/hal/shape_ops.cc b/libspu/kernel/hal/shape_ops.cc index 7bf96065..733ce226 100644 --- a/libspu/kernel/hal/shape_ops.cc +++ b/libspu/kernel/hal/shape_ops.cc @@ -29,6 +29,8 @@ namespace { Type _common_type(SPUContext* ctx, const Type& a, const Type& b) { if (a.isa() && b.isa()) { return _common_type_s(ctx, a, b); + } else if (a.isa() && b.isa()) { + return _common_type_v(ctx, a, b); } else if (a.isa()) { return a; } else if (b.isa()) { @@ -48,10 +50,16 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) { } else if (x.isPublic() && to.isa()) { // FIXME: casting to BShare semantic is wrong. return _p2s(ctx, x); + } else if (x.isPublic() && to.isa()) { + return _p2v(ctx, x, to.as()->owner()); + } else if (x.isSecret() && to.isa()) { + return _s2v(ctx, x, to.as()->owner()); + } else if (x.isPrivate() && to.isa()) { + return _v2s(ctx, x); } else if (x.isSecret() && to.isa()) { return _cast_type_s(ctx, x, to); } else { - SPU_THROW("show not be here x={}, to={}", x, to); + SPU_THROW("should not be here x={}, to={}", x, to); } } diff --git a/libspu/kernel/hal/type_cast.cc b/libspu/kernel/hal/type_cast.cc index 8ff37ac7..53b65742 100644 --- a/libspu/kernel/hal/type_cast.cc +++ b/libspu/kernel/hal/type_cast.cc @@ -63,11 +63,17 @@ Value fxp2int(SPUContext* ctx, const Value& x, DataType to_type) { // TODO: move seal/reveal into a new header file. Value seal(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); + if (x.isPrivate()) { + return _v2s(ctx, x).setDtype(x.dtype()); + } return _p2s(ctx, x).setDtype(x.dtype()); } Value reveal(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); + if (x.isPrivate()) { + return _v2p(ctx, x).setDtype(x.dtype()); + } return _s2p(ctx, x).setDtype(x.dtype()); } diff --git a/libspu/kernel/hlo/indexing.cc b/libspu/kernel/hlo/indexing.cc index d3ef4388..66b65525 100644 --- a/libspu/kernel/hlo/indexing.cc +++ b/libspu/kernel/hlo/indexing.cc @@ -747,9 +747,8 @@ spu::Value DynamicSlice(SPUContext *ctx, const spu::Value &operand, SPU_ENFORCE(!start_indices.empty()); SPU_ENFORCE(!operand.isComplex()); - if (start_indices[0].isSecret()) { - return SecretDynamicSlice(ctx, operand, slice_size, start_indices); - } else { + if (std::all_of(start_indices.begin(), start_indices.end(), + [](const spu::Value &v) { return v.isPublic(); })) { // Start indices Index start_indices_i64(start_indices.size()); for (const auto &idx : llvm::enumerate(start_indices)) { @@ -774,6 +773,8 @@ spu::Value DynamicSlice(SPUContext *ctx, const spu::Value &operand, return hal::slice(ctx, operand, start_indices_i64, limit, strides); } + + return SecretDynamicSlice(ctx, operand, slice_size, start_indices); } spu::Value FilterByMask(SPUContext *, const spu::Value &operand, diff --git a/libspu/mpc/aby3/BUILD.bazel b/libspu/mpc/aby3/BUILD.bazel index 6d9fd664..e4c78a5a 100644 --- a/libspu/mpc/aby3/BUILD.bazel +++ b/libspu/mpc/aby3/BUILD.bazel @@ -50,6 +50,10 @@ spu_cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], + defines = select({ + "@rules_cuda//cuda:is_enabled": ["CUDA_ENABLED"], + "//conditions:default": ["CUDA_DISABLED"], + }), deps = [ ":ot", ":type", @@ -58,6 +62,28 @@ spu_cc_library( "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_state", "//libspu/mpc/utils:circuits", + ] + select({ + "@rules_cuda//cuda:is_enabled": [":arithmetic_gpu_ext"], + "//conditions:default": [], + }), +) + +spu_cc_library( + name = "arithmetic_gpu_ext", + srcs = ["arithmetic_gpu_ext.cc"], + hdrs = ["arithmetic_gpu_ext.h"], + tags = [ + "manual", # Exclude this target from :all expansion + ], + target_compatible_with = select({ + "@rules_cuda//cuda:is_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":type", + "//libspu/core:ndarray_ref", + "//libspu/cuda_support:kernels", + "//libspu/cuda_support:utils", ], ) diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index 90c85bbc..c9461eb9 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -14,7 +14,6 @@ #include "libspu/mpc/aby3/arithmetic.h" -#include #include #include "libspu/mpc/aby3/ot.h" @@ -25,6 +24,11 @@ #include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/utils/ring_ops.h" +#ifdef CUDA_ENABLED +#include "libspu/cuda_support/utils.h" +#include "libspu/mpc/aby3/arithmetic_gpu_ext.h" +#endif + namespace spu::mpc::aby3 { namespace { @@ -628,32 +632,48 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, auto* comm = ctx->getState(); auto* prg_state = ctx->getState(); + auto M = x.shape()[0]; + auto N = y.shape()[1]; + auto r = std::async([&] { - auto [r0, r1] = prg_state->genPrssPair(field, {x.shape()[0], y.shape()[1]}, - PrgState::GenPrssCtrl::Both); + auto [r0, r1] = + prg_state->genPrssPair(field, {M, N}, PrgState::GenPrssCtrl::Both); return ring_sub(r0, r1); }); - auto x1 = getFirstShare(x); - auto x2 = getSecondShare(x); - - auto y1 = getFirstShare(y); - auto y2 = getSecondShare(y); - - // z1 := x1*y1 + x1*y2 + x2*y1 + k1 - // z2 := x2*y2 + x2*y3 + x3*y2 + k2 - // z3 := x3*y3 + x3*y1 + x1*y3 + k3 - NdArrayRef out(makeType(field), {x.shape()[0], y.shape()[1]}); + NdArrayRef out(makeType(field), {M, N}); auto o1 = getFirstShare(out); auto o2 = getSecondShare(out); - auto t2 = std::async(ring_mmul, x2, y1); - auto t0 = ring_mmul(x1, ring_add(y1, y2)); // - auto z1 = ring_sum({t0, t2.get(), r.get()}); +#ifdef CUDA_ENABLED + // FIXME: better heuristic? + if (spu::cuda::hasGPUDevice() && M * N <= 20000 && field == FM64) { +#endif + auto x1 = getFirstShare(x); + auto x2 = getSecondShare(x); + + auto y1 = getFirstShare(y); + auto y2 = getSecondShare(y); + // z1 := x1*y1 + x1*y2 + x2*y1 + k1 + // z2 := x2*y2 + x2*y3 + x3*y2 + k2 + // z3 := x3*y3 + x3*y1 + x1*y3 + k3 + + // x1*(y1+y2) + x2*y1 + k1 + auto t2 = std::async(ring_mmul, x2, y1); + auto t0 = ring_mmul(x1, ring_add(y1, y2)); // + auto z1 = ring_sum({t0, t2.get(), r.get()}); + + auto f = std::async([&] { ring_assign(o1, z1); }); + ring_assign(o2, comm->rotate(z1, kBindName)); // comm => 1, k + f.get(); +#ifdef CUDA_ENABLED + } else { + matmul_aa_gpu(x, y, o1); + ring_add_(o1, r.get()); + ring_assign(o2, comm->rotate(o1, kBindName)); // comm => 1, k + } +#endif - auto f = std::async([&] { ring_assign(o1, z1); }); - ring_assign(o2, comm->rotate(z1, kBindName)); // comm => 1, k - f.get(); return out; } diff --git a/libspu/mpc/aby3/arithmetic_gpu_ext.cc b/libspu/mpc/aby3/arithmetic_gpu_ext.cc new file mode 100644 index 00000000..64c10333 --- /dev/null +++ b/libspu/mpc/aby3/arithmetic_gpu_ext.cc @@ -0,0 +1,75 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/aby3/arithmetic_gpu_ext.h" + +#include "spdlog/spdlog.h" + +#include "libspu/cuda_support/kernels.h" +#include "libspu/cuda_support/utils.h" + +namespace spu::mpc::aby3 { + +void matmul_aa_gpu(const NdArrayRef& x, const NdArrayRef& y, NdArrayRef& ret) { + auto M = x.shape()[0]; + auto K = x.shape()[1]; + auto N = y.shape()[1]; + + auto x_share_bytes = x.numel() * x.elsize(); + auto y_share_bytes = y.numel() * y.elsize(); + auto result_bytes = M * N * x.elsize(); + + auto g_x = cuda::allocate(x_share_bytes); + auto g_y = cuda::allocate(y_share_bytes); + auto g_ret = cuda::allocate(result_bytes); + + auto* g_x2 = g_x.get() + x_share_bytes / 2; + auto* g_y2 = g_y.get() + y_share_bytes / 2; + auto* result_2 = g_ret.get() + result_bytes / 2; + + // x1 + cuda::DeinterleaveCopyToCudaDevice(x.data(), + x.buf()->size() - x.offset(), x.shape(), + x.strides(), g_x.get(), g_x2, x.elsize()); + + // y1 + cuda::DeinterleaveCopyToCudaDevice(y.data(), + y.buf()->size() - y.offset(), y.shape(), + y.strides(), g_y.get(), g_y2, y.elsize()); + + // x1*(y1+y2) + x2*y1 + k1 + // y1 + y2 - > y2 + cuda::add(reinterpret_cast(g_y2), + reinterpret_cast(g_y.get()), y.shape().numel()); + + // x1*y2 + cuda::matmul(M, N, K, reinterpret_cast(g_x.get()), + reinterpret_cast(g_y2), + reinterpret_cast(g_ret.get())); + + // x2*y1 + cuda::matmul(M, N, K, reinterpret_cast(g_x2), + reinterpret_cast(g_y.get()), + reinterpret_cast(result_2)); + + // result1 + result2 + cuda::add(reinterpret_cast(g_ret.get()), + reinterpret_cast(result_2), M * N); + + // Copy back + cuda::CopyFromCudaDevice(g_ret.get(), ret.data(), ret.numel(), + ret.elsize(), ret.strides()[1]); +} + +} // namespace spu::mpc::aby3 \ No newline at end of file diff --git a/libspu/mpc/aby3/arithmetic_gpu_ext.h b/libspu/mpc/aby3/arithmetic_gpu_ext.h new file mode 100644 index 00000000..4ab5243d --- /dev/null +++ b/libspu/mpc/aby3/arithmetic_gpu_ext.h @@ -0,0 +1,22 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/ndarray_ref.h" + +namespace spu::mpc::aby3 { + +void matmul_aa_gpu(const NdArrayRef& x, const NdArrayRef& y, NdArrayRef& ret); +} diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index ea83f619..2f030970 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -20,6 +20,7 @@ #include "libspu/core/parallel_utils.h" #include "libspu/core/prelude.h" +#include "libspu/core/trace.h" #include "libspu/mpc/ab_api.h" #include "libspu/mpc/aby3/type.h" #include "libspu/mpc/aby3/value.h" @@ -622,4 +623,16 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } } +void CommonTypeV::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + + const auto* lhs_v = lhs.as(); + const auto* rhs_v = rhs.as(); + + ctx->setOutput(makeType(std::max(lhs_v->field(), rhs_v->field()))); +} + } // namespace spu::mpc::aby3 diff --git a/libspu/mpc/aby3/conversion.h b/libspu/mpc/aby3/conversion.h index 59248867..76fe752a 100644 --- a/libspu/mpc/aby3/conversion.h +++ b/libspu/mpc/aby3/conversion.h @@ -120,4 +120,13 @@ class MsbA2B : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; +class CommonTypeV : public Kernel { + public: + static constexpr char kBindName[] = "common_type_v"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; +}; + } // namespace spu::mpc::aby3 diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index 6cc46c91..6fdf4af6 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -51,10 +51,6 @@ std::vector Aby3Io::toShares(const NdArrayRef& raw, Visibility vis, const auto share = raw.as(makeType(field)); return std::vector(world_size_, share); } else if (vis == VIS_SECRET) { -#if !defined(SPU_ENABLE_PRIVATE_TYPE) - owner_rank = -1; -#endif - if (owner_rank >= 0 && owner_rank <= 2) { // indicates private std::vector shares; diff --git a/libspu/mpc/aby3/protocol.cc b/libspu/mpc/aby3/protocol.cc index 61060772..9d01a953 100644 --- a/libspu/mpc/aby3/protocol.cc +++ b/libspu/mpc/aby3/protocol.cc @@ -64,6 +64,8 @@ void regAby3Protocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/api.cc b/libspu/mpc/api.cc index d82c58af..7503042b 100644 --- a/libspu/mpc/api.cc +++ b/libspu/mpc/api.cc @@ -223,6 +223,14 @@ Type common_type_s(SPUContext* ctx, const Type& a, const Type& b) { } } +Type common_type_v(SPUContext* ctx, const Type& a, const Type& b) { + SPU_TRACE_MPC_DISP(ctx, a, b); + if (a == b) { + return a; + } + return dynDispatch(ctx, __func__, a, b); +} + Value cast_type_s(SPUContext* ctx, const Value& frm, const Type& to_type) { SPU_TRACE_MPC_DISP(ctx, frm, to_type); diff --git a/libspu/mpc/api.h b/libspu/mpc/api.h index 104c2ef2..971ecf1a 100644 --- a/libspu/mpc/api.h +++ b/libspu/mpc/api.h @@ -79,6 +79,7 @@ Value export_s(SPUContext* ctx, const Value& x, const Type& t); // // This api calculate the common type. Type common_type_s(SPUContext* ctx, const Type& a, const Type& b); +Type common_type_v(SPUContext* ctx, const Type& a, const Type& b); Value cast_type_s(SPUContext* ctx, const Value& frm, const Type& to_type); // Make a public variable with given plaintext input. diff --git a/libspu/mpc/cheetah/conversion.cc b/libspu/mpc/cheetah/conversion.cc index f51cf6bd..e82b760f 100644 --- a/libspu/mpc/cheetah/conversion.cc +++ b/libspu/mpc/cheetah/conversion.cc @@ -16,6 +16,7 @@ #include "yacl/utils/parallel.h" +#include "libspu/core/trace.h" #include "libspu/mpc/ab_api.h" #include "libspu/mpc/cheetah/state.h" #include "libspu/mpc/cheetah/type.h" @@ -82,4 +83,16 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { return out.as(makeType(field)); } +void CommonTypeV::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + + const auto* lhs_v = lhs.as(); + const auto* rhs_v = rhs.as(); + + ctx->setOutput(makeType(std::max(lhs_v->field(), rhs_v->field()))); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/conversion.h b/libspu/mpc/cheetah/conversion.h index a49424e0..1272824a 100644 --- a/libspu/mpc/cheetah/conversion.h +++ b/libspu/mpc/cheetah/conversion.h @@ -40,4 +40,13 @@ class B2A : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; }; +class CommonTypeV : public Kernel { + public: + static constexpr char kBindName[] = "common_type_v"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; +}; + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/io.cc b/libspu/mpc/cheetah/io.cc index f70beadb..1fe02d47 100644 --- a/libspu/mpc/cheetah/io.cc +++ b/libspu/mpc/cheetah/io.cc @@ -47,10 +47,6 @@ std::vector CheetahIo::toShares(const NdArrayRef& raw, const auto share = raw.as(makeType(field)); return std::vector(world_size_, share); } else if (vis == VIS_SECRET) { -#if !defined(SPU_ENABLE_PRIVATE_TYPE) - owner_rank = -1; -#endif - if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { // indicates private std::vector shares; diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index 149811cd..79edf22b 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -69,6 +69,7 @@ void regCheetahProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index 71641d78..568a1c18 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -373,11 +373,9 @@ class MatMulVVV : public MatmulKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { SPU_ENFORCE(lhs.eltype() == rhs.eltype()); - if (isOwner(ctx, lhs.eltype())) { - return ring_mmul(lhs, rhs).as(lhs.eltype()); - } else { - return lhs; - } + // For parties other than owner, also do a matmul to make result shape + // correct. + return ring_mmul(lhs, rhs).as(lhs.eltype()); } }; @@ -391,12 +389,9 @@ class MatMulVP : public MatmulKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override { - SPU_ENFORCE(lhs.eltype() == rhs.eltype()); - if (isOwner(ctx, lhs.eltype())) { - return ring_mmul(lhs, rhs).as(lhs.eltype()); - } else { - return lhs; - } + // For parties other than owner, also do a matmul to make result shape + // correct. + return ring_mmul(lhs, rhs).as(lhs.eltype()); } }; diff --git a/libspu/mpc/kernel.cc b/libspu/mpc/kernel.cc index cb737422..723a6de3 100644 --- a/libspu/mpc/kernel.cc +++ b/libspu/mpc/kernel.cc @@ -121,4 +121,22 @@ void CastTypeKernel::evaluate(KernelEvalContext* ctx) const { ctx->setOutput(WrapValue(res)); } +void SimpleSortKernel::evaluate(KernelEvalContext* ctx) const { + auto values = ctx->getParam>(0); + std::vector inputs; + + for (const auto& val : values) { + inputs.emplace_back(UnwrapValue(val)); + } + + auto res = proc(ctx, inputs); + + std::vector res_values; + + for (size_t i = 0; i < res.size(); ++i) { + res_values.emplace_back(WrapValue(std::move(res[i]))); + } + ctx->setOutput(res_values); +} + } // namespace spu::mpc diff --git a/libspu/mpc/kernel.h b/libspu/mpc/kernel.h index a1ce4795..ae5a3de7 100644 --- a/libspu/mpc/kernel.h +++ b/libspu/mpc/kernel.h @@ -120,4 +120,11 @@ class CastTypeKernel : public Kernel { const Type& to_type) const = 0; }; +class SimpleSortKernel : public Kernel { + void evaluate(KernelEvalContext* ctx) const override; + + virtual std::vector proc( + KernelEvalContext* ctx, absl::Span in) const = 0; +}; + } // namespace spu::mpc diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 7a600a40..efcd9afc 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -61,6 +61,28 @@ class Ref2kCommonTypeS : public Kernel { } }; +class Ref2kCommonTypeV : public Kernel { + public: + static constexpr char kBindName[] = "common_type_v"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + SPU_ENFORCE(lhs.isa(), "invalid type, got={}", lhs); + SPU_ENFORCE(rhs.isa(), "invalid type, got={}", rhs); + + const auto* lhs_v = lhs.as(); + const auto* rhs_v = rhs.as(); + + ctx->setOutput( + makeType(std::max(lhs_v->field(), rhs_v->field()))); + } +}; + class Ref2kCastTypeS : public CastTypeKernel { public: static constexpr char kBindName[] = "cast_type_s"; @@ -459,6 +481,8 @@ void regRef2kProtocol(SPUContext* ctx, // register compute kernels ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/securenn/conversion.cc b/libspu/mpc/securenn/conversion.cc index 288f66f9..804ff637 100644 --- a/libspu/mpc/securenn/conversion.cc +++ b/libspu/mpc/securenn/conversion.cc @@ -14,10 +14,12 @@ #include "libspu/mpc/securenn/conversion.h" +#include "libspu/core/trace.h" #include "libspu/core/vectorize.h" #include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/securenn/arithmetic.h" #include "libspu/mpc/securenn/type.h" #include "libspu/mpc/utils/ring_ops.h" @@ -190,4 +192,16 @@ NdArrayRef Msb_a2b::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return res; } +void CommonTypeV::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + + const auto* lhs_v = lhs.as(); + const auto* rhs_v = rhs.as(); + + ctx->setOutput(makeType(std::max(lhs_v->field(), rhs_v->field()))); +} + } // namespace spu::mpc::securenn diff --git a/libspu/mpc/securenn/conversion.h b/libspu/mpc/securenn/conversion.h index 60eacf93..4f4dbd79 100644 --- a/libspu/mpc/securenn/conversion.h +++ b/libspu/mpc/securenn/conversion.h @@ -95,4 +95,13 @@ class Msb_a2b : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; +class CommonTypeV : public Kernel { + public: + static constexpr char kBindName[] = "common_type_v"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; +}; + } // namespace spu::mpc::securenn diff --git a/libspu/mpc/securenn/io.cc b/libspu/mpc/securenn/io.cc index f558e0f5..28a8c3c2 100644 --- a/libspu/mpc/securenn/io.cc +++ b/libspu/mpc/securenn/io.cc @@ -46,10 +46,6 @@ std::vector SecurennIo::toShares(const NdArrayRef& raw, const auto share = raw.as(makeType(field)); return std::vector(world_size_, share); } else if (vis == VIS_SECRET) { -#if !defined(SPU_ENABLE_PRIVATE_TYPE) - owner_rank = -1; -#endif - if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { // indicates private std::vector shares; diff --git a/libspu/mpc/securenn/protocol.cc b/libspu/mpc/securenn/protocol.cc index 416a99c9..41b85be7 100644 --- a/libspu/mpc/securenn/protocol.cc +++ b/libspu/mpc/securenn/protocol.cc @@ -60,6 +60,8 @@ void regSecurennProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index 9c2f4163..5c273a1a 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -82,6 +82,7 @@ spu_cc_library( ":arithmetic", ":boolean", ":conversion", + ":sort", ":state", "//libspu/mpc/common:prg_state", ], @@ -92,6 +93,7 @@ spu_cc_test( srcs = ["protocol_test.cc"], deps = [ ":protocol", + ":sort_test", "//libspu/mpc:ab_api_test", "//libspu/mpc:api_test", "//libspu/mpc/semi2k/beaver/ttp_server:beaver_server", @@ -135,3 +137,35 @@ spu_cc_test( ":type", ], ) + +spu_cc_library( + name = "sort", + srcs = ["sort.cc"], + hdrs = ["sort.h"], + deps = [ + ":state", + ":type", + "//libspu/core:vectorize", + "//libspu/mpc:ab_api", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "sort_test", + testonly = 1, + srcs = ["sort_test.cc"], + hdrs = ["sort_test.h"], + deps = [ + "//libspu/mpc:ab_api", + "//libspu/mpc:api", + "//libspu/mpc:api_test_params", + "//libspu/mpc/utils:permute", + "//libspu/mpc/utils:ring_ops", + "//libspu/mpc/utils:simulate", + "@com_google_googletest//:gtest", + ], + alwayslink = True, +) diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index af775e98..f87e06d2 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -14,10 +14,12 @@ #include "libspu/mpc/semi2k/conversion.h" +#include "libspu/core/trace.h" #include "libspu/core/vectorize.h" #include "libspu/mpc/ab_api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" @@ -183,4 +185,16 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { } } +void CommonTypeV::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + + const auto* lhs_v = lhs.as(); + const auto* rhs_v = rhs.as(); + + ctx->setOutput(makeType(std::max(lhs_v->field(), rhs_v->field()))); +} + } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/conversion.h b/libspu/mpc/semi2k/conversion.h index 76596407..dad01a57 100644 --- a/libspu/mpc/semi2k/conversion.h +++ b/libspu/mpc/semi2k/conversion.h @@ -93,4 +93,13 @@ class MsbA2B : public UnaryKernel { NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; +class CommonTypeV : public Kernel { + public: + static constexpr char kBindName[] = "common_type_v"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; +}; + } // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/io.cc b/libspu/mpc/semi2k/io.cc index a01bbb82..ef166fee 100644 --- a/libspu/mpc/semi2k/io.cc +++ b/libspu/mpc/semi2k/io.cc @@ -47,10 +47,6 @@ std::vector Semi2kIo::toShares(const NdArrayRef& raw, const auto share = raw.as(makeType(field)); return std::vector(world_size_, share); } else if (vis == VIS_SECRET) { -#if !defined(SPU_ENABLE_PRIVATE_TYPE) - owner_rank = -1; -#endif - if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { // indicates private std::vector shares; diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index d8d8f43c..fb5178eb 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -20,6 +20,7 @@ #include "libspu/mpc/semi2k/arithmetic.h" #include "libspu/mpc/semi2k/boolean.h" #include "libspu/mpc/semi2k/conversion.h" +#include "libspu/mpc/semi2k/sort.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" @@ -62,6 +63,8 @@ void regSemi2kProtocol(SPUContext* ctx, } ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); @@ -83,6 +86,7 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); + ctx->prot()->regKernel(); } std::unique_ptr makeSemi2kProtocol( diff --git a/libspu/mpc/semi2k/protocol_test.cc b/libspu/mpc/semi2k/protocol_test.cc index 439b1636..0d8accbd 100644 --- a/libspu/mpc/semi2k/protocol_test.cc +++ b/libspu/mpc/semi2k/protocol_test.cc @@ -19,6 +19,7 @@ #include "libspu/mpc/ab_api_test.h" #include "libspu/mpc/api_test.h" #include "libspu/mpc/semi2k/beaver/ttp_server/beaver_server.h" +#include "libspu/mpc/semi2k/sort_test.h" namespace spu::mpc::test { namespace { @@ -115,4 +116,19 @@ INSTANTIATE_TEST_SUITE_P( ; }); +INSTANTIATE_TEST_SUITE_P( + Semi2k, PermuteTest, + testing::Combine(testing::Values(CreateObjectFn(makeSemi2kProtocol, "tfp"), + CreateObjectFn(makeTTPSemi2kProtocol, + "ttp")), // + testing::Values(makeConfig(FieldType::FM32), // + makeConfig(FieldType::FM64), // + makeConfig(FieldType::FM128)), // + testing::Values(2, 3, 5)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), + std::get<1>(p.param).field(), std::get<2>(p.param)); + ; + }); + } // namespace spu::mpc::test diff --git a/libspu/mpc/semi2k/sort.cc b/libspu/mpc/semi2k/sort.cc new file mode 100644 index 00000000..fa5671c2 --- /dev/null +++ b/libspu/mpc/semi2k/sort.cc @@ -0,0 +1,588 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/semi2k/sort.h" + +#include "libspu/mpc/ab_api.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/utils/permute.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +namespace { + +NdArrayRef wrap_a2b(SPUContext* ctx, const NdArrayRef& x) { + return UnwrapValue(a2b(ctx, WrapValue(x))); +} + +NdArrayRef wrap_a2v(SPUContext* ctx, const NdArrayRef& x, size_t rank) { + return UnwrapValue(a2v(ctx, WrapValue(x), rank)); +} + +NdArrayRef wrap_a2p(SPUContext* ctx, const NdArrayRef& x) { + return UnwrapValue(a2p(ctx, WrapValue(x))); +} + +NdArrayRef wrap_mul_aa(SPUContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + return UnwrapValue(mul_aa(ctx, WrapValue(x), WrapValue(y))); +} + +NdArrayRef wrap_add_aa(SPUContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + return UnwrapValue(add_aa(ctx, WrapValue(x), WrapValue(y))); +} + +NdArrayRef wrap_sub_pa(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + + auto* comm = ctx->getState(); + + if (comm->getRank() == 0) { + return ring_sub(x, y).as(y.eltype()); + } else { + return ring_neg(y).as(y.eltype()); + } +} + +NdArrayRef wrap_sub_ap(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + + auto* comm = ctx->getState(); + + if (comm->getRank() == 0) { + return ring_sub(x, y).as(x.eltype()); + } else { + return x; + } +} + +NdArrayRef wrap_sub_aa(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + SPU_ENFORCE(x.eltype() == y.eltype()); + + return ring_sub(x, y).as(x.eltype()); +} + +// Reference: +// III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) +// +// Analysis: +// Online Latency: 1 (x_xor_r reveal) +// Communication: one element bits for one element +// Vectorization: yes +// +// HighLevel Intuition: +// Since: X = sum: Xi * 2^i +// If we have A, then we can construct A = sum: A * 2^i. +// +// The problem is that we only have B in hand. Details for how to +// construct A from B: +// - trusted third party choose a random bit r, where r == 0 or r == 1. +// - trusted third party send A to parties +// - parties compute B from A +// - parties xor_open c = Xi ^ r = open(B ^ B), Xi is still safe due +// to protection from r. +// - parties compute: = c + (1-2c)* +// A = 1 - A if c == 1, i.e. Xi != r +// A = A if c == 0, i.e. Xi == r +// i.e. A = c + (1-2c) * A +// +// Online Communication: +// = 1 (xor open) + +// Unassemble BShr to AShr bit-by-bit +// Input: BShr +// Return: a vector of k AShr, k is the valid bits of BShr +std::vector B2AUnassemble(KernelEvalContext* ctx, + const NdArrayRef& x) { + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + + const int64_t nbits = x.eltype().as()->nbits(); + SPU_ENFORCE((size_t)nbits > 0 && (size_t)nbits <= SizeOf(field) * 8, + "invalid nbits={}", nbits); + + auto numel = x.numel(); + + auto randbits = beaver->RandBit(field, {numel * static_cast(nbits)}); + + std::vector res; + for (int64_t idx = 0; idx < nbits; ++idx) { + res.emplace_back(makeType(field), x.shape()); + } + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = ring2k_t; + + NdArrayView _randbits(randbits); + NdArrayView _x(x); + + std::vector x_xor_r(numel); + + pforeach(0, numel, [&](int64_t idx) { + // use _r[i*nbits, (i+1)*nbits) to construct rb[i] + U mask = 0; + for (int64_t bit = 0; bit < nbits; ++bit) { + mask += (_randbits[idx * nbits + bit] & 0x1) << bit; + } + x_xor_r[idx] = _x[idx] ^ mask; + }); + + // open c = x ^ r + x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + + pforeach(0, numel, [&](int64_t idx) { + pforeach(0, nbits, [&](int64_t bit) { + NdArrayView _res(res[bit]); + auto c_i = (x_xor_r[idx] >> bit) & 0x1; + if (comm->getRank() == 0) { + _res[idx] = (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]); + } else { + _res[idx] = ((1 - c_i * 2) * _randbits[idx * nbits + bit]); + } + }); + }); + }); + + return res; +} + +// Input: AShare of x +// Output: a vector of AShare of each bit of x +std::vector BitDecompose(KernelEvalContext* ctx, + const NdArrayRef& x) { + auto b = wrap_a2b(ctx->sctx(), x); + return B2AUnassemble(ctx, b); +} + +// TODO(jimi): maybe support multiple keys in future +// Generate vector of bit decomposition +std::vector GenBvVector(KernelEvalContext* ctx, + const NdArrayRef& key) { + std::vector ret; + const auto& t = BitDecompose(ctx, key); + SPU_ENFORCE(t.size() > 0); + ret.insert(ret.end(), t.begin(), t.end() - 1); + const auto field = key.eltype().as()->field(); + ret.emplace_back(wrap_sub_pa(ctx, ring_ones(field, key.shape()), t.back())); + + return ret; +} + +// Secure inverse permutation of x by perm_rank's permutation pv +// The idea here is: +// Input permutation pv, beaver generates perm pair {, } that +// InversePermute(A, pv) = B. So we can get = InversePermute(open( - +// ), pv) + that y = InversePermute(x, pv). +NdArrayRef SecureInvPerm(KernelEvalContext* ctx, const NdArrayRef& x, + size_t perm_rank, absl::Span pv) { + const auto lctx = ctx->lctx(); + const auto field = x.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + + auto perm_pair = beaver->PermPair(field, x.shape(), perm_rank, pv); + + auto t = wrap_a2v(ctx->sctx(), ring_sub(x, perm_pair.first).as(x.eltype()), + perm_rank); + + if (lctx->Rank() == perm_rank) { + SPU_ENFORCE(pv.size()); + ring_add_(perm_pair.second, applyInvPerm(t, pv)); + return perm_pair.second.as(x.eltype()); + } else { + return perm_pair.second.as(x.eltype()); + } +} + +// Secure inverse permutation of a vector x by perm_rank's permutation pv +std::vector SecureInvPerm(KernelEvalContext* ctx, + absl::Span x, + size_t perm_rank, + absl::Span pv) { + std::vector v; + for (size_t i = 0; i < x.size(); ++i) { + auto t = SecureInvPerm(ctx, x[i], perm_rank, pv); + v.emplace_back(std::move(t)); + } + return v; +} + +// Secure shuffle a vector x by each party's local permutation pv. +// The shuffle involves multiple rounds of share inverse permutation. Each round +// the permutation is a local permutation pv generated by a perm_rank. +std::vector Shuffle(KernelEvalContext* ctx, + absl::Span x, + absl::Span pv) { + std::vector v; + const auto lctx = ctx->lctx(); + SPU_ENFORCE(!x.empty(), "inputs should not be empty"); + + for (size_t i = 0; i < lctx->WorldSize(); ++i) { + if (i == 0) { + v = SecureInvPerm(ctx, x, i, pv); + } else { + v = SecureInvPerm(ctx, v, i, pv); + } + } + + return v; +} + +// Secure shuffle x by each party's local permutation pv. +NdArrayRef Shuffle(KernelEvalContext* ctx, const NdArrayRef& x, + absl::Span pv) { + auto vec = Shuffle(ctx, std::vector{x}, pv); + SPU_ENFORCE(vec.size() > 0); + return vec[0]; +} + +// Inverse a securely shuffled perm on shuffled x. +// x is a list of shared bit vectors, is a shared permutation, pv is +// a local generated random permutation for secure shuffle, and random_perm is +// revealed permutation of shuffled . +// +// The steps are as follows: +// 1) secure shuffle as +// 2) secure shuffle as +// 3) reveal securely shuffled as random_perm +// 4) inverse permute by random_perm and return +std::vector InvShuffledPerm(KernelEvalContext* ctx, + absl::Span x, + const NdArrayRef& perm, + PermVector* random_perm, + absl::Span pv) { + // 1. = secure shuffle + auto sp = Shuffle(ctx, perm, pv); + + // 2. = secure shuffle + auto sx = Shuffle(ctx, x, pv); + + // 3. M = reveal() + auto m = wrap_a2p(ctx->sctx(), sp); + SPU_ENFORCE_EQ(m.shape().ndim(), 1U, "perm should be 1-d tensor"); + auto size = m.shape()[0]; + const auto field = m.eltype().as()->field(); + PermVector perm_vector(size); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _m(m); + pforeach(0, size, + [&](int64_t idx) { perm_vector[idx] = (int64_t)_m[idx]; }); + }); + SPU_ENFORCE(random_perm != nullptr); + *random_perm = perm_vector; + + // 4. = SP() + std::vector v; + for (size_t i = 0; i < sx.size(); ++i) { + auto t = applyInvPerm(sx[i], perm_vector); + v.emplace_back(std::move(t)); + } + + return v; +} + +// Inverse a securely shuffled perm on shuffled x. +NdArrayRef InvShuffledPerm(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& perm, PermVector* random_perm, + absl::Span pv) { + std::vector v{x}; + auto vec = InvShuffledPerm(ctx, v, perm, random_perm, pv); + SPU_ENFORCE(vec.size() > 0); + return vec[0]; +} + +// Process two bit vectors in one loop +// Reference: https://eprint.iacr.org/2019/695.pdf (5.2 Optimizations) +// +// perm = GenInvPermByTwoBitVectors(x, y) +// input: bit vector x, bit vector y +// bit vector y is more significant than x +// output: shared inverse permutation +// +// We can generate inverse permutation by two bit vectors in one loop. +// It needs one extra mul op and 2 times memory to store intermediate data than +// GenInvPermByBitVector. But the number of invocations of permutation-related +// protocols such as SecureInvPerm or Compose will be reduced to half. +// +// If we process three bit vectors in one loop, it needs at least four extra +// mul ops and 2^2 times data to store intermediate data. The number of +// invocations of permutation-related protocols such as SecureInvPerm or +// Compose will be reduced to 1/3. It's latency friendly but not bandwidth +// friendly. +// +// Example: +// 1) x = [0, 1], y = [1, 0] +// 2) rev_x = [1, 0], rev_y = [0, 1] +// 3) f0 = rev_x * rev_y = [0, 0] +// f1 = x * rev_y = [0, 1] +// f2 = rev_x * y = [1, 0] +// f3 = x * y = [0, 0] +// f = [f0, f1, f2, f3] = [0, 0, 0, 1, 1, 0, 0, 0] +// 4) s[i] = s[i - 1] + f[i], s[0] = f[0] +// s = [0, 0, 0, 1, 2, 2, 2, 2] +// 5) fs = f * s +// fs = [0, 0, 0, 1, 2, 0, 0, 0] +// 6) split fs to four vector +// fsv[0] = [0, 0] +// fsv[1] = [0, 1] +// fsv[2] = [2, 0] +// fsv[3] = [0, 0] +// 7) r = fsv[0] + fsv[1] + fsv[2] + fsv[3] +// r = [2, 1] +// 8) get res by sub r by one +// res = [1, 0] +NdArrayRef GenInvPermByTwoBitVectors(KernelEvalContext* ctx, + const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE(x.shape() == y.shape(), "x and y should has the same shape"); + SPU_ENFORCE(x.shape().ndim() == 1, "x and y should be 1-d"); + + const auto field = x.eltype().as()->field(); + const int64_t numel = x.numel(); + auto ones = ring_ones(field, x.shape()); + auto rev_x = wrap_sub_pa(ctx, ones, x); + auto rev_y = wrap_sub_pa(ctx, ones, y); + auto f0 = wrap_mul_aa(ctx->sctx(), rev_x, rev_y); + auto f1 = wrap_sub_aa(ctx, rev_y, f0); + auto f2 = wrap_sub_aa(ctx, rev_x, f0); + auto f3 = wrap_sub_aa(ctx, y, f2); + + Shape new_shape = {1, numel}; + auto f = f0.reshape(new_shape).concatenate( + {f1.reshape(new_shape), f2.reshape(new_shape), f3.reshape(new_shape)}, 1); + auto s = f.clone(); + + // calculate prefix sum + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _s(s); + for (int64_t i = 1; i < s.numel(); ++i) { + _s[i] += _s[i - 1]; + } + }); + + // mul f and s + auto fs = wrap_mul_aa(ctx->sctx(), f, s); + + auto fs0 = fs.slice({0, 0}, {1, numel}, {}); + auto fs1 = fs.slice({0, numel}, {1, 2 * numel}, {}); + auto fs2 = fs.slice({0, 2 * numel}, {1, 3 * numel}, {}); + auto fs3 = fs.slice({0, 3 * numel}, {1, 4 * numel}, {}); + + // calculate result + auto s01 = wrap_add_aa(ctx->sctx(), fs0, fs1); + auto s23 = wrap_add_aa(ctx->sctx(), fs2, fs3); + auto r = wrap_add_aa(ctx->sctx(), s01, s23); + auto res = wrap_sub_ap(ctx, r.reshape(x.shape()), ones); + + return res; +} + +// Generate perm by bit vector +// input: bit vector generated by bit decomposition +// output: shared inverse permutation +// +// Example: +// 1) x = [1, 0, 1, 0, 0] +// 2) rev_x = [0, 1, 0, 1, 1] +// 3) f = [rev_x, x] +// f = [0, 1, 0, 1, 1, 1, 0, 1, 0, 0] +// 4) s[i] = s[i - 1] + f[i], s[0] = f[0] +// s = [0, 1, 1, 2, 3, 4, 4, 5, 5, 5] +// 5) fs = f * s +// fs = [0, 1, 0, 2, 3, 4, 0, 5, 0, 0] +// 6) split fs to two vector +// fsv[0] = [0, 1, 0, 2, 3] +// fsv[1] = [4, 0, 5, 0, 0] +// 7) r = fsv[0] + fsv[1] +// r = [4, 1, 5, 2, 3] +// 8) get res by sub r by one +// res = [3, 0, 4, 1, 2] +NdArrayRef GenInvPermByBitVector(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.shape().ndim() == 1, "x should be 1-d"); + + const auto field = x.eltype().as()->field(); + const int64_t numel = x.numel(); + auto ones = ring_ones(field, x.shape()); + auto rev_x = wrap_sub_pa(ctx, ones, x); + + Shape new_shape = {1, numel}; + auto f = rev_x.reshape(new_shape).concatenate({x.reshape(new_shape)}, 1); + auto s = f.clone(); + + // calculate prefix sum + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _s(s); + for (int64_t i = 1; i < s.numel(); ++i) { + _s[i] += _s[i - 1]; + } + }); + + // mul f and s + auto fs = wrap_mul_aa(ctx->sctx(), f, s); + + auto fs0 = fs.slice({0, 0}, {1, numel}, {}); + auto fs1 = fs.slice({0, numel}, {1, 2 * numel}, {}); + + // calculate result + auto r = wrap_add_aa(ctx->sctx(), fs0, fs1); + auto res = wrap_sub_ap(ctx, r.reshape(x.shape()), ones); + return res; +} + +// The inverse of secure shuffle +NdArrayRef Unshuffle(KernelEvalContext* ctx, const NdArrayRef& x, + absl::Span pv) { + const auto lctx = ctx->lctx(); + NdArrayRef ret(x); + + auto inv_pv = genInversePerm(pv); + for (int i = lctx->WorldSize() - 1; i >= 0; --i) { + ret = SecureInvPerm(ctx, ret, i, inv_pv); + } + + return ret; +} + +// This is the inverse of InvShuffledPerm. +// The input is a shared inverse permutation , a permutation public_pv +// known to every parties, a locally generated permutation private_pv for secure +// unshuffle. +// +// The steps are as follows: +// 1) permute by public_pv as +// 2) secure unshuffle and return results +// +// By doing InvShuffledPerm and UnshufflePerm, we get the shared inverse +// permutation of initial shared bit vectors. +NdArrayRef UnshufflePerm(KernelEvalContext* ctx, const NdArrayRef& perm, + absl::Span public_pv, + absl::Span private_pv) { + auto sm = applyPerm(perm, public_pv); + auto res = Unshuffle(ctx, sm, private_pv); + return res; +} + +// Generate shared inverse permutation by key +NdArrayRef GenInvPerm(KernelEvalContext* ctx, const NdArrayRef& key) { + // key should be a 1-D tensor + SPU_ENFORCE(key.shape().ndim() == 1, "key should be 1-d"); + const auto field = key.eltype().as()->field(); + auto numel = key.numel(); + + // 1. generate bit decomposition vector of key + std::vector v = GenBvVector(ctx, key); + SPU_ENFORCE_GT(v.size(), 0U); + + // 2. generate natural permutation + NdArrayRef s(key.eltype(), key.shape()); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _s(s); + pforeach(0, numel, [&](int64_t idx) { + _s[idx] = (ctx->lctx()->Rank() == 0) ? idx : 0; + }); + }); + + // 3. generate shared inverse permutation by bit vector and process + PermVector random_perm; + size_t v_size = v.size(); + size_t v_idx = 0; + for (; v_idx < v_size - 1; v_idx += 2) { + auto pv = genRandomPerm(s.shape()[0]); + auto t = + InvShuffledPerm(ctx, std::vector{v[v_idx], v[v_idx + 1]}, s, + &random_perm, pv); + auto perm = GenInvPermByTwoBitVectors(ctx, t[0], t[1]); + s = UnshufflePerm(ctx, perm, random_perm, pv); + } + + if (v_idx == v_size - 1) { + auto pv = genRandomPerm(s.shape()[0]); + auto t = InvShuffledPerm(ctx, v[v_idx], s, &random_perm, pv); + auto perm = GenInvPermByBitVector(ctx, t); + s = UnshufflePerm(ctx, perm, random_perm, pv); + } + + return s; +} + +// Apply inverse permutation on each tensor of x by a shared inverse permutation +// +std::vector ApplyInvPerm(KernelEvalContext* ctx, + absl::Span x, + const NdArrayRef& perm) { + // sanity check. + SPU_ENFORCE(!x.empty(), "inputs should not be empty"); + SPU_ENFORCE(x[0].shape().ndim() == 1, + "inputs should be 1-d but actually have {} dimensions", + x[0].shape().ndim()); + SPU_ENFORCE(std::all_of(x.begin(), x.end(), + [&x](const NdArrayRef& input) { + return input.shape() == x[0].shape(); + }), + "inputs shape mismatched"); + + // 1. = secure shuffle + auto pv = genRandomPerm(x[0].shape()[0]); + auto sp = Shuffle(ctx, perm, pv); + + // 2. = secure shuffle + auto sx = Shuffle(ctx, x, pv); + + // 3. M = reveal() + auto m = wrap_a2p(ctx->sctx(), sp); + SPU_ENFORCE_EQ(m.shape().ndim(), 1U, "perm should be 1-d tensor"); + auto size = m.shape()[0]; + const auto field = m.eltype().as()->field(); + PermVector perm_vector(size); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _m(m); + pforeach(0, size, + [&](int64_t idx) { perm_vector[idx] = (int64_t)_m[idx]; }); + }); + + // 4. = SP() + std::vector v; + + for (size_t i = 0; i < sx.size(); ++i) { + auto t = applyInvPerm(sx[i], perm_vector); + v.emplace_back(std::move(t)); + } + + return v; +} + +} // namespace + +// Radix sort +// Ref: +// https://eprint.iacr.org/2019/695.pdf +// in[0] is the key, each tensor of in is a 1-d tensor +std::vector SimpleSortA::proc( + KernelEvalContext* ctx, absl::Span in) const { + SPU_ENFORCE(!in.empty(), "inputs should not be empty"); + + auto perm = GenInvPerm(ctx, in[0]); + auto res = ApplyInvPerm(ctx, in, perm); + return res; +} + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/sort.h b/libspu/mpc/semi2k/sort.h new file mode 100644 index 00000000..d8dca534 --- /dev/null +++ b/libspu/mpc/semi2k/sort.h @@ -0,0 +1,29 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { + +class SimpleSortA : public SimpleSortKernel { + public: + static constexpr char kBindName[] = "sort_a"; + + std::vector proc(KernelEvalContext* ctx, + absl::Span in) const override; +}; + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/sort_test.cc b/libspu/mpc/semi2k/sort_test.cc new file mode 100644 index 00000000..76c39a94 --- /dev/null +++ b/libspu/mpc/semi2k/sort_test.cc @@ -0,0 +1,70 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/semi2k/sort_test.h" + +#include "libspu/core/ndarray_ref.h" +#include "libspu/core/prelude.h" +#include "libspu/mpc/ab_api.h" +#include "libspu/mpc/api.h" +#include "libspu/mpc/utils/permute.h" +#include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" + +namespace spu::mpc::test { + +namespace { + +#define EXPECT_VALUE_EQ(X, Y) \ + { \ + EXPECT_EQ((X).shape(), (Y).shape()); \ + EXPECT_TRUE(ring_all_equal((X).data(), (Y).data())); \ + } + +Shape kShape = {20}; +const int64_t kInputsSize = 10; + +} // namespace + +TEST_P(PermuteTest, RadixSort) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + std::vector in_p(kInputsSize); + std::vector in_s(kInputsSize); + for (size_t i = 0; i < kInputsSize; ++i) { + in_p[i] = rand_p(obj.get(), kShape); + in_s[i] = p2a(obj.get(), in_p[i]); + } + + /* WHEN */ + auto sorted_s = dynDispatch>(obj.get(), "sort_a", in_s); + + /* THEN */ + const auto perm = genInversePerm(genPermBySort(UnwrapValue(in_p[0]))); + + for (size_t i = 0; i < kInputsSize; ++i) { + auto expected_sorted = applyInvPerm(UnwrapValue(in_p[i]), perm); + auto actual_sorted = a2p(obj.get(), sorted_s[i]); + EXPECT_VALUE_EQ(WrapValue(expected_sorted), actual_sorted); + } + }); +} + +} // namespace spu::mpc::test \ No newline at end of file diff --git a/libspu/mpc/semi2k/sort_test.h b/libspu/mpc/semi2k/sort_test.h new file mode 100644 index 00000000..748ae01e --- /dev/null +++ b/libspu/mpc/semi2k/sort_test.h @@ -0,0 +1,24 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "yacl/link/link.h" + +#include "libspu/mpc/api_test_params.h" + +namespace spu::mpc::test { + +class PermuteTest : public ::testing::TestWithParam {}; + +} // namespace spu::mpc::test diff --git a/libspu/mpc/utils/linalg.cc b/libspu/mpc/utils/linalg.cc index 478c44cc..4c9d956e 100644 --- a/libspu/mpc/utils/linalg.cc +++ b/libspu/mpc/utils/linalg.cc @@ -23,4 +23,4 @@ void setEigenParallelLevel(int64_t expected_threads) { Eigen::setNbThreads(nproc); } -} // namespace spu::mpc::linalg::detail \ No newline at end of file +} // namespace spu::mpc::linalg::detail diff --git a/libspu/mpc/utils/linalg.h b/libspu/mpc/utils/linalg.h index ea50b101..9fb88578 100644 --- a/libspu/mpc/utils/linalg.h +++ b/libspu/mpc/utils/linalg.h @@ -96,4 +96,4 @@ void matmul(int64_t M, int64_t N, int64_t K, const T* A, int64_t LDA, c.noalias() = a * b; } -} // namespace spu::mpc::linalg +} // namespace spu::mpc::linalg \ No newline at end of file diff --git a/libspu/mpc/utils/permute.cc b/libspu/mpc/utils/permute.cc index 396888ca..85bf06f6 100644 --- a/libspu/mpc/utils/permute.cc +++ b/libspu/mpc/utils/permute.cc @@ -37,6 +37,21 @@ NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv) { return y; } +NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + + NdArrayRef y(x.eltype(), x.shape()); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + NdArrayView _x(x); + NdArrayView _y(y); + for (int64_t i = 0; i < y.numel(); i++) { + _y[i] = _x[pv[i]]; + } + }); + return y; +} + PermVector genRandomPerm(size_t size) { PermVector perm(size); std::iota(perm.begin(), perm.end(), 0); @@ -47,4 +62,27 @@ PermVector genRandomPerm(size_t size) { return perm; } +PermVector genInversePerm(absl::Span pv) { + PermVector ret(pv.size()); + for (size_t i = 0; i < pv.size(); ++i) { + ret[pv[i]] = i; + } + return ret; +} + +PermVector genPermBySort(const NdArrayRef& x) { + SPU_ENFORCE_EQ(x.shape().ndim(), 1U, "x should be 1-d tensor"); + PermVector perm(x.shape()[0]); + std::iota(perm.begin(), perm.end(), 0); + const auto field = x.eltype().as()->field(); + DISPATCH_ALL_FIELDS(field, kPermModule, [&]() { + using T = std::make_signed_t; + + NdArrayView _x(x); + auto cmp = [&_x](int64_t a, int64_t b) { return _x[a] < _x[b]; }; + std::stable_sort(perm.begin(), perm.end(), cmp); + }); + return perm; +} + } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/permute.h b/libspu/mpc/utils/permute.h index a75e983b..d076f9f7 100644 --- a/libspu/mpc/utils/permute.h +++ b/libspu/mpc/utils/permute.h @@ -24,8 +24,17 @@ using PermVector = std::vector; PermVector genRandomPerm(size_t size); +PermVector genInversePerm(absl::Span pv); + +// generate permutation vector that can make x ordered +PermVector genPermBySort(const NdArrayRef& x); + // reorder 1-d tensor element by applying inverse permutation. // ret = ApplyInvPerm(x, pv) -> ret[pv[i]] = x[i] NdArrayRef applyInvPerm(const NdArrayRef& x, absl::Span pv); +// reorder 1-d tensor element by applying permutation. +// ret = ApplyPerm(x, pv) -> ret[i] = x[pv[i]] +NdArrayRef applyPerm(const NdArrayRef& x, absl::Span pv); + } // namespace spu::mpc \ No newline at end of file diff --git a/libspu/spu.proto b/libspu/spu.proto index 1800fd55..e35ea089 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -51,6 +51,7 @@ enum Visibility { VIS_INVALID = 0; VIS_SECRET = 1; // Invisible(unknown) for all or some of the parties. VIS_PUBLIC = 2; // Visible(public) for all parties. + VIS_PRIVATE = 3; // Visible for only one party } // @exclude @@ -318,6 +319,8 @@ message RuntimeConfig { bool experimental_disable_vectorization = 103; // inter op concurrency. uint64 experimental_inter_op_concurrency = 104; + // Enable use of private type + bool experimental_enable_colocated_optimization = 105; } message TTPBeaverConfig { diff --git a/spu/libspu.cc b/spu/libspu.cc index 2e68f184..69114dbc 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -743,8 +743,10 @@ PYBIND11_MODULE(libspu, m) { // bind spu io suite. py::class_(m, "IoWrapper", "SPU VM IO") .def(py::init()) - .def("MakeShares", &IoWrapper::MakeShares) - .def("GetShareChunkCount", &IoWrapper::GetShareChunkCount) + .def("MakeShares", &IoWrapper::MakeShares, "Create secret shares", + py::arg("arr"), py::arg("visibility"), py::arg("owner_rank") = -1) + .def("GetShareChunkCount", &IoWrapper::GetShareChunkCount, py::arg("arr"), + py::arg("visibility"), py::arg("owner_rank") = -1) .def("Reconstruct", &IoWrapper::Reconstruct); // bind compiler. diff --git a/spu/tests/spu_io_test.py b/spu/tests/spu_io_test.py index 9953e4e0..d88eefd4 100644 --- a/spu/tests/spu_io_test.py +++ b/spu/tests/spu_io_test.py @@ -286,6 +286,30 @@ def test_io_double_complex(self, wsize, prot, field, chunk_size): npt.assert_almost_equal(x, y, decimal=5) + def test_colocated_io(self, wsize, prot, field, chunk_size): + if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + return + + if prot == spu_pb2.ProtocolKind.REF2K: + return + + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + share_max_chunk_size=chunk_size, + experimental_enable_colocated_optimization=True, + ) + io = ppapi.Io(wsize, config) + + # PrivINT + x = np.random.randint(10, size=()) + + xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET, owner_rank=1) + self.assertIn('Priv2k', _bytes_to_pb(xs[0].meta).storage_type) + y = io.reconstruct(xs) + + npt.assert_equal(x, y) + if __name__ == '__main__': unittest.main()