diff --git a/.gitignore b/.gitignore index 0d770b1ec..e5e993922 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ callgrind.out.* Programs/Bytecode/* Programs/Schedules/* Programs/Public-Input/* +Programs/Functions *.com *.class *.dll diff --git a/.gitmodules b/.gitmodules index 9307e3292..84fbcef5f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "deps/SimplestOT_C"] path = deps/SimplestOT_C url = https://github.com/mkskeller/SimplestOT_C +[submodule "deps/sse2neon"] + path = deps/sse2neon + url = https://github.com/DLTcollab/sse2neon diff --git a/CHANGELOG.md b/CHANGELOG.md index adfd96802..19f864f50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.4.0 (November 21, 2024) + +- Functionality to call high-level code from C++ +- Matrix triples from file for all appropriate protocols +- Exit with message on errors instead of uncaught exceptions +- Reduce memory usage for binary memory +- Optimized cint-regint conversion in Dealer protocol +- Fixed security bug: missing MAC check in probabilistic truncation + ## 0.3.9 (July 9, 2024) - Inference with non-sequential PyTorch networks diff --git a/CONFIG b/CONFIG index 90da7e8f0..0c92c9330 100644 --- a/CONFIG +++ b/CONFIG @@ -71,6 +71,8 @@ CXX = clang++ # use CONFIG.mine to overwrite DIR settings -include CONFIG.mine +AVX_SIMPLEOT := $(AVX_OT) + ifeq ($(USE_GF2N_LONG),1) GF2N_LONG = -DUSE_GF2N_LONG endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 52b4db4cb..c69289fb3 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -540,8 +540,8 @@ class split(base.Instruction): :param: number of arguments to follow (number of bits times number of additive shares plus one) :param: source (sint) - :param: first share of least significant bit - :param: second share of least significant bit + :param: first share of least significant bit (sbit) + :param: second share of least significant bit (sbit) :param: (remaining share of least significant bit)... :param: (repeat from first share for bit one step higher)... """ diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 8c2050199..37626c112 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -737,6 +737,7 @@ def get_type(cls, n): :py:obj:`v` and the columns by calling :py:obj:`elements`. """ class sbitvecn(cls, _structure): + n_bits = n @staticmethod def get_type(n): return cls.get_type(n) @@ -757,17 +758,19 @@ def get_input_from(cls, player, size=1, f=0): :param: player (int) """ - v = [0] * n sbits._check_input_player(player) instructions_base.check_vector_size(size) - for i in range(size): - vv = [sbit() for i in range(n)] - inst.inputbvec(n + 3, f, player, *vv) - for j in range(n): - tmp = vv[j] << i - v[j] = tmp ^ v[j] - sbits._check_input_player(player) - return cls.from_vec(v) + if size == 1: + res = cls.from_vec(sbit() for i in range(n)) + inst.inputbvec(n + 3, f, player, *res.v) + return res + else: + elements = [] + for i in range(size): + v = sbits.get_type(n)() + inst.inputb(player, n, f, v) + elements.append(v) + return cls(elements) get_raw_input_from = get_input_from @classmethod def from_vec(cls, vector): diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 2b2c0ea1d..520c8d075 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -178,6 +178,8 @@ def alloc_reg(self, reg, free): dup = dup.vectorbase self.alloc[dup] = self.alloc[base] dup.i = self.alloc[base] + if not dup.dup_count: + dup.dup_count = len(base.duplicates) def dealloc_reg(self, reg, inst, free): if reg.vector: @@ -275,8 +277,9 @@ def finalize(self, options): for reg in self.alloc: for x in reg.get_all(): if x not in self.dealloc and reg not in self.dealloc \ - and len(x.duplicates) == 0: - print('Warning: read before write at register', x) + and len(x.duplicates) == x.dup_count: + print('Warning: read before write at register %s/%x' % + (x, id(x))) print('\tregister trace: %s' % format_trace(x.caller, '\t\t')) if options.stop: @@ -750,6 +753,8 @@ def eliminate(i): G.remove_node(i) merge_nodes.discard(i) stats[type(instructions[i]).__name__] += 1 + for reg in instructions[i].get_def(): + self.block.parent.program.base_addresses.pop(reg) instructions[i] = None if unused_result: eliminate(i) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 8a918f4ec..8c4f41b8d 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -13,8 +13,18 @@ class Compiler: + singleton = None + def __init__(self, custom_args=None, usage=None, execute=False, split_args=False): + if Compiler.singleton: + raise CompilerError( + "Cannot have more than one compiler instance. " + "It's not possible to run direct compilation programs with " + "compile.py or compile-run.py.") + else: + Compiler.singleton = self + if usage: self.usage = usage else: diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 286202466..c71f45ad4 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -1,3 +1,8 @@ +""" This module implements `Dijkstra's algorithm +`_ based on +oblivious RAM. """ + + from Compiler.oram import * from Compiler.program import Program @@ -222,7 +227,21 @@ def dump(self, msg=''): print_ln() print_ln() -def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None): +def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None, + debug=False): + """ Securely compute Dijstra's algorithm on a secret graph. See + :download:`../Programs/Source/dijkstra_example.mpc` for an + explanation of the required inputs. + + :param source: source node (secret or clear-text integer) + :param edges: ORAM representation of edges + :param e_index: ORAM representation of vertices + :param oram_type: ORAM type to use internally (default: + :py:func:`~Compiler.oram.OptimalORAM`) + :param n_loops: when to stop (default: number of edges) + :param int_type: secret integer type (default: sint) + + """ vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ @@ -267,27 +286,46 @@ def f(i): dist.access(v, (basic_type(alt), u), is_shorter) #previous.access(v, u, is_shorter) Q.update(v, basic_type(alt), is_shorter) - print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \ - u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \ - not_visited.reveal()) + if debug: + print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, ' + 'shorter: %s, running: %s, queue size: %s, last edge: %s', + u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), + not_visited.reveal(), is_shorter.reveal(), + running.reveal(), Q.size.reveal(), last_edge.reveal()) return dist def convert_graph(G): + """ Convert a `NetworkX directed graph + `_ + to the cleartext representation of what :py:func:`dijkstra` expects. """ + G = G.copy() + for u in G: + for v in G[u]: + G[u][v].setdefault('weight', 1) edges = [None] * (2 * G.size()) e_index = [None] * (len(G)) i = 0 - for v in G: + for v in sorted(G): e_index[v] = i - for u in G[v]: + for u in sorted(G[v]): edges[i] = [u, G[v][u]['weight'], 0] i += 1 + if not G[v]: + edges[i] = [v, 0, 0] + i += 1 edges[i-1][-1] = 1 - return edges, e_index + return list(filter(lambda x: x, edges)), e_index -def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint): - for u in G: - for v in G[u]: - G[u][v].setdefault('weight', 1) +def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, + int_type=sint): + """ Securely compute Dijstra's algorithm on a cleartext graph. + + :param G: directed graph with NetworkX interface + :param source: source node (secret or clear-text integer) + :param n_loops: when to stop (default: number of edges) + :param int_type: secret integer type (default: sint) + + """ edges_list, e_index_list = convert_graph(G) edges = oram_type(len(edges_list), \ entry_size=(log2(len(G)), log2(len(G)), 1), \ diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 30c5aea99..93f212666 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -399,7 +399,7 @@ class stop(base.Instruction): arg_format = ['i'] class use(base.Instruction): - """ Offline data usage. Necessary to avoid reusage while using + r""" Offline data usage. Necessary to avoid reusage while using preprocessing from files. Also used to multithreading for expensive preprocessing. @@ -419,7 +419,7 @@ def get_usage(cls, args): args[2].i} class use_inp(base.Instruction): - """ Input usage. Necessary to avoid reusage while using + r""" Input usage. Necessary to avoid reusage while using preprocessing from files. :param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit) @@ -1738,7 +1738,7 @@ class print_reg_plains(base.IOInstruction): arg_format = ['s'] class cond_print_plain(base.IOInstruction): - """ Conditionally output clear register (with precision). + r""" Conditionally output clear register (with precision). Outputs :math:`x \cdot 2^p` where :math:`p` is the precision. :param: condition (cint, no output if zero) @@ -1989,7 +1989,7 @@ class closeclientconnection(base.IOInstruction): code = base.opcodes['CLOSECLIENTCONNECTION'] arg_format = ['ci'] -class writesharestofile(base.IOInstruction): +class writesharestofile(base.VectorInstruction, base.IOInstruction): """ Write shares to ``Persistence/Transactions-P.data`` (appending at the end). @@ -2002,11 +2002,12 @@ class writesharestofile(base.IOInstruction): __slots__ = [] code = base.opcodes['WRITEFILESHARE'] arg_format = tools.chain(['ci'], itertools.repeat('s')) + vector_index = 1 def has_var_args(self): return True -class readsharesfromfile(base.IOInstruction): +class readsharesfromfile(base.VectorInstruction, base.IOInstruction): """ Read shares from ``Persistence/Transactions-P.data``. :param: number of arguments to follow / number of shares plus two (int) @@ -2018,6 +2019,7 @@ class readsharesfromfile(base.IOInstruction): __slots__ = [] code = base.opcodes['READFILESHARE'] arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw')) + vector_index = 2 def has_var_args(self): return True @@ -2341,7 +2343,7 @@ class convint(base.Instruction): @base.vectorize class convmodp(base.Instruction): - """ Convert clear integer register (vector) to clear register + r""" Convert clear integer register (vector) to clear register (vector). If the bit length is zero, the unsigned conversion is used, otherwise signed conversion is used. This makes a difference when computing modulo a prime :math:`p`. Signed conversion of @@ -2814,13 +2816,11 @@ class check(base.Instruction): @base.gf2n @base.vectorize class sqrs(base.CISC): - """ Secret squaring $s_i = s_j \cdot s_j$. """ + r""" Secret squaring $s_i = s_j \cdot s_j$. """ __slots__ = [] arg_format = ['sw', 's'] def expand(self): - if program.options.ring: - return muls(self.args[0], self.args[1], self.args[1]) s = [program.curr_block.new_reg('s') for i in range(6)] c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 6d36a480a..5abbf97bf 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -1200,9 +1200,11 @@ def has_var_args(self): class VectorInstruction(Instruction): __slots__ = [] is_vec = lambda self: True + vector_index = 0 def get_code(self): - return super(VectorInstruction, self).get_code(len(self.args[0])) + return super(VectorInstruction, self).get_code( + len(self.args[self.vector_index])) class Ciscable(Instruction): def copy(self, size, subs): diff --git a/Compiler/library.py b/Compiler/library.py index 7a52ce3f7..46113c8b6 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -402,7 +402,8 @@ class FunctionCallTape(FunctionTape): def __init__(self, *args, **kwargs): super(FunctionTape, self).__init__(*args, **kwargs) self.instances = {} - def __call__(self, *args, **kwargs): + @staticmethod + def get_key(args, kwargs): key = (get_program(),) def process_for_key(arg): nonlocal key @@ -419,6 +420,9 @@ def process_for_key(arg): for name, arg in sorted(kwargs.items()): key += (name, 'kw') process_for_key(arg) + return key + def __call__(self, *args, **kwargs): + key = self.get_key(args, kwargs) if key not in self.instances: my_args = [] def wrapped_function(): @@ -502,6 +506,60 @@ def on_call(self, tape_handle, result, inside_args, args): break_point('call-%s' % self.name) return untuplify(tuple(out_result)) +class ExportFunction(FunctionCallTape): + def __init__(self, function): + super(ExportFunction, self).__init__(function) + self.done = set() + def __call__(self, *args, **kwargs): + if kwargs: + raise CompilerError('keyword arguments not supported') + def arg_signature(arg): + if isinstance(arg, types._structure): + return '%s:%d' % (arg.arg_type(), arg.size) + elif isinstance(arg, types._vectorizable): + from .GC.types import sbitvec + if issubclass(arg.value_type, sbitvec): + return 'sbv:[%dx%d]' % (arg.total_size(), + arg.value_type.n_bits) + else: + return '%s:[%d]' % (arg.value_type.arg_type(), + arg.total_size()) + else: + raise CompilerError('argument not supported: %s' % arg) + signature = [] + for arg in args: + signature.append(arg_signature(arg)) + signature = tuple(signature) + key = self.get_key(args, kwargs) + if key in self.instances and signature not in self.done: + raise CompilerError('signature conflict') + super(ExportFunction, self).__call__(*args, **kwargs) + if signature not in self.done: + filename = '%s/%s/%s-%s' % (get_program().programs_dir, 'Functions', + self.name, '-'.join(signature)) + print('Writing to', filename) + out = open(filename, 'w') + print(get_program().name, file=out) + print(self.instances[key][0], file=out) + result = self.instances[key][1] + try: + if result is not None: + result = untuplify(result) + print(arg_signature(result), result.i, file=out) + else: + print('- 0', file=out) + except CompilerError: + raise CompilerError('return type not supported: %s' % result) + for arg in self.instances[key][2]: + if isinstance(arg, types._structure): + print(arg.i, end=' ', file=out) + elif isinstance(arg, types._vectorizable): + print(arg.address, end=' ', file=out) + else: + CompilerError('argument not supported: %s', arg) + print(file=out) + self.done.add(signature) + def function_tape(function): return FunctionTape(function) @@ -589,6 +647,9 @@ def f(x, y, z): """ return FunctionCallTape(function) +def export(function): + return ExportFunction(function) + def memorize(x, write=True): if isinstance(x, (tuple, list)): return tuple(memorize(i, write=write) for i in x) @@ -898,8 +959,10 @@ def _(i): """ def decorator(loop_body): + get_tape().unused_decorators.pop(decorator) range_loop(loop_body, start, stop, step) return loop_body + get_tape().unused_decorators[decorator] = 'for_range' return decorator def for_range_parallel(n_parallel, n_loops): @@ -948,7 +1011,7 @@ def for_range_opt(start, stop=None, step=None, budget=None): :param start/stop/step: int/regint/cint (used as in :py:func:`range`) or :py:obj:`start` only as list/tuple of int (see below) :param budget: number of instructions after which to start optimization - (default is 100,000) + (default is 1000 or as given with ``--budget``) Example: @@ -1487,7 +1550,7 @@ def decorator(loop_body): return loop_body return decorator -def _run_and_link(function, g=None, lock_lists=True): +def _run_and_link(function, g=None, lock_lists=True, allow_return=False): if g is None: g = function.__globals__ if lock_lists: @@ -1504,6 +1567,9 @@ def __setitem__(*args): g[x] = A(g[x]) pre = copy.copy(g) res = function() + if res is not None and not allow_return: + raise CompilerError('Conditional blocks cannot return values. ' + 'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else') _link(pre, g) return res @@ -1536,7 +1602,7 @@ def _(): name='begin-loop') get_tape().loop_breaks.append([]) loop_block = instructions.program.curr_block - condition = _run_and_link(loop_fn, g) + condition = _run_and_link(loop_fn, g, allow_return=True) if callable(condition): condition = condition() branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False) @@ -1558,7 +1624,9 @@ class State: pass condition = condition() try: if not condition.is_clear: - raise CompilerError('cannot branch on secret values') + raise CompilerError( + 'cannot branch on secret values, use if_else instead: ' + 'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.if_else') except AttributeError: pass state.condition = regint.conv(condition) diff --git a/Compiler/ml.py b/Compiler/ml.py index 86dc2d9bb..01fb86ae5 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -849,7 +849,7 @@ def compute_f_input(self, batch): prod = MultiArray([N, self.d, self.d_out], sfix) else: prod = self.f_input - max_size = get_program().budget // self.d_out + max_size = get_program().budget @multithread(self.n_threads, N, max_size) def _(base, size): X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) @@ -1316,12 +1316,12 @@ def __init__(self, inputs): self.inputs = inputs def _forward(self, batch=[0]): - assert len(batch) == 1 @multithread(self.n_threads, self.Y[0].total_size()) def _(base, size): - tmp = sum(inp.Y[batch[0]].get_vector(base, size) - for inp in self.inputs) - self.Y[batch[0]].assign_vector(tmp, base) + for bb in batch: + tmp = sum(inp.Y[bb].get_vector(base, size) + for inp in self.inputs) + self.Y[bb].assign_vector(tmp, base) class FusedBatchNorm(Layer): """ Fixed-point fused batch normalization layer (inference only). @@ -1400,11 +1400,11 @@ def reset(self): def _output(self, batch, mu, var): factor = sfix.Array(len(mu)) - factor[:] = self.InvertSqrt(var[:] + self.epsilon) + factor[:] = self.InvertSqrt(var[:] + self.epsilon) * self.weights[:] @for_range_opt_multithread(self.n_threads, [len(batch), self.X.sizes[1]]) def _(i, j): - tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:] + tmp = (self.X[i][j][:] - mu[:]) * factor[:] self.my_Y[i][j][:] = self.bias[:] + tmp @_layer_method_call_tape @@ -2233,7 +2233,7 @@ def from_args(program, layers): res.output_stats = 'output_stats' in program.args return res - def __init__(self, layers=[], report_loss=None): + def __init__(self, layers=[], report_loss=None, time_layers=False): if get_program().options.binary: raise CompilerError( 'machine learning code not compatible with binary circuits') @@ -2248,6 +2248,10 @@ def __init__(self, layers=[], report_loss=None): self.stopped_on_loss = MemValue(0) self.stopped_on_low_loss = MemValue(0) self.layers = layers + self.time_layers = time_layers + if time_layers: + for i, layer in enumerate(layers): + print('Timer %d: %s' % (100 + i, repr(layer))) @property def layers(self): @@ -2667,8 +2671,12 @@ class A: if model_input: for layer in self.layers: layer.input_from(0) - elif reset: + elif reset and not 'no_reset' in program.args: self.reset() + else: + for layer in self.layers: + for theta in layer.thetas(): + theta.alloc() if 'one_iter' in program.args: print_float_prec(16) self.output_weights() @@ -2689,6 +2697,8 @@ class A: if 'bench10' in program.args or 'bench1' in program.args: n = 1 if 'bench1' in program.args else 10 print('benchmarking %s iterations' % n) + # force allocatoin + self.layers[0].X, self.layers[-1].Y @for_range(n) def _(i): batch = Array.create_from(regint.inc(batch_size)) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 2208fc212..d4f0d951d 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -422,7 +422,7 @@ def mux_exp(x, y, block_size=8): @types.vectorize @instructions_base.sfix_cisc def log2_fx(x, use_division=True): - """ + r""" Returns the result of :math:`\log_2(x)` for any unbounded number. This is achieved by changing :py:obj:`x` into :math:`f \cdot 2^n` where f is bounded by :math:`[0.5, 1]`. Then the @@ -463,7 +463,7 @@ def log2_fx(x, use_division=True): def pow_fx(x, y, zero_output=False): - """ + r""" Returns the value of the expression :math:`x^y` where both inputs are secret shared. It uses :py:func:`log2_fx` together with :py:func:`exp2_fx` to calculate the expression :math:`2^{y \log_2(x)}`. @@ -487,7 +487,7 @@ def pow_fx(x, y, zero_output=False): def log_fx(x, b): - """ + r""" Returns the value of the expression :math:`\log_b(x)` where :py:obj:`x` is secret shared. It uses :py:func:`log2_fx` to calculate the expression :math:`\log_b(2) \cdot \log_2(x)`. @@ -859,7 +859,7 @@ def atan(x): def asin(x): - """ + r""" Returns the arcsine (sfix) of any given fractional value. :param x: fractional input (sfix). valid interval is :math:`-1 \le x \le 1` @@ -875,7 +875,7 @@ def asin(x): def acos(x): - """ + r""" Returns the arccosine (sfix) of any given fractional value. :param x: fractional input (sfix). :math:`-1 \le x \le 1` @@ -887,7 +887,7 @@ def acos(x): def tanh(x): - """ + r""" Hyperbolic tangent. For efficiency, accuracy is diminished around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and :math:`f` denote the fixed-point parameters. diff --git a/Compiler/oram.py b/Compiler/oram.py index d4862e669..20f189844 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -9,6 +9,11 @@ i = sint.get_input_from(0) a[i] = sint.get_input_from(1) +`The introductory book by Evans et +al. `_ contains `a chapter dedicated to +oblivious RAM +`_. + """ import random @@ -41,6 +46,7 @@ crash_on_overflow = False use_insecure_randomness = False debug_ram_size = False +single_thread = False def maybe_start_timer(n): if detailed_timing: @@ -844,7 +850,7 @@ def __init__(self, size, value_type=None, value_length=1, index_size=None, \ start_timer() def get_n_threads(n_loops): - if n_threads is None: + if n_threads is None and not single_thread: if n_loops > 2048: return 8 else: @@ -1038,7 +1044,7 @@ def output(self): __getitem__ = lambda self,index: List.__getitem__(self, index)[0] def get_n_threads_for_tree(size): - if n_threads_for_tree is None: + if n_threads_for_tree is None and not single_thread: if size >= 2**13: return 8 else: diff --git a/Compiler/path_oblivious_heap.py b/Compiler/path_oblivious_heap.py index 54e0542c9..ccc4444ae 100644 --- a/Compiler/path_oblivious_heap.py +++ b/Compiler/path_oblivious_heap.py @@ -1,4 +1,4 @@ -"""This module contains an implementation of the "Path Oblivious Heap" +r"""This module contains an implementation of the "Path Oblivious Heap" oblivious priority queue as proposed by `Shi `_. diff --git a/Compiler/program.py b/Compiler/program.py index a837fee68..5d68948ef 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -260,7 +260,7 @@ def init_names(self, args): os.mkdir(dirname) # create extra directories if needed - for dirname in ["Public-Input", "Bytecode", "Schedules"]: + for dirname in ["Public-Input", "Bytecode", "Schedules", "Functions"]: if not os.path.exists(self.programs_dir + "/" + dirname): os.mkdir(self.programs_dir + "/" + dirname) @@ -859,6 +859,7 @@ def __init__(self, name, program, thread_pool=None): self.warned_about_mem = False self.return_values = [] self.ran_threads = False + self.unused_decorators = {} class BasicBlock(object): def __init__(self, parent, name, scope, exit_condition=None, @@ -1054,6 +1055,10 @@ def optimize(self, options): print() raise CompilerError("Unclosed if/else blocks, see tracebacks above") + if self.unused_decorators: + raise CompilerError("Unused branching decorators, make sure to write " + ",".join( + "'@%s' instead of '%s'" % (x, x) for x in set(self.unused_decorators.values()))) + if self.program.verbose: print( "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) @@ -1561,10 +1566,7 @@ class _no_truth(object): def __bool__(self): raise CompilerError( "Cannot derive truth value from register. " - "This is a catch-all error appearing if you try to use a " - "run-time value where the compiler expects a compile-time " - "value, most likely a Python integer. " - "In some cases, you can fix this by using 'compile.py -l'." + "See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-derive-truth-value-from-register" ) def __int__(self): @@ -1599,6 +1601,7 @@ class Register(_no_truth): "caller", "can_eliminate", "duplicates", + "dup_count", "block", ] maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 @@ -1631,6 +1634,7 @@ def __init__(self, reg_type, program, size=None, i=None): self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) + self.dup_count = None if Program.prog.DEBUG: self.caller = [frame[1:] for frame in inspect.stack()[1:]] else: diff --git a/Compiler/sorting.py b/Compiler/sorting.py index f4f38caba..7779c7489 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -11,7 +11,7 @@ def dest_comp(B): return sum(Tt) - 1 def reveal_sort(k, D, reverse=False): - """ Sort in place according to "perfect" key. The name hints at the fact + r""" Sort in place according to "perfect" key. The name hints at the fact that a random order of the keys is revealed. :param k: vector or Array of sint containing exactly :math:`0,\dots,n-1` diff --git a/Compiler/types.py b/Compiler/types.py index 263aac90a..1e1113410 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -428,7 +428,7 @@ def bit_or(self, other): :return: type depends on inputs (secret if any of them is) """ if util.is_constant(other): if other: - return self + return 1 else: return 0 return self + other - self * other @@ -502,14 +502,14 @@ def cond_swap(self, a, b): return a ^ prod, b ^ prod class _gf2n(_bit): - """ :math:`\mathrm{GF}(2^n)` functionality. """ + r""" :math:`\mathrm{GF}(2^n)` functionality. """ def if_else(self, a, b): - """ MUX in :math:`\mathrm{GF}(2^n)` circuits. Similar to :py:meth:`_int.if_else`. """ + r""" MUX in :math:`\mathrm{GF}(2^n)` circuits. Similar to :py:meth:`_int.if_else`. """ return b ^ self * self.hard_conv(a ^ b) def cond_swap(self, a, b, t=None): - """ Swapping in :math:`\mathrm{GF}(2^n)`. Similar to :py:meth:`_int.if_else`. """ + r""" Swapping in :math:`\mathrm{GF}(2^n)`. Similar to :py:meth:`_int.if_else`. """ prod = self * self.hard_conv(a ^ b) res = a ^ prod, b ^ prod if t is None: @@ -518,7 +518,7 @@ def cond_swap(self, a, b, t=None): return tuple(t.conv(r) for r in res) def bit_xor(self, other): - """ XOR in :math:`\mathrm{GF}(2^n)` circuits. + r""" XOR in :math:`\mathrm{GF}(2^n)` circuits. :param self/other: 0 or 1 (any compatible type) :rtype: depending on inputs (secret if any of them is) """ @@ -583,6 +583,14 @@ def mem_size(): def size_for_mem(self): return self.size + @classmethod + def arg_type(cls): + if issubclass(cls, _register): + return cls.reg_type + if issubclass(cls, (cfix, sfix)): + return cls.int_type.reg_type + raise CompilerError('type not supported as argument: %s' % cls) + class _secret_structure(_structure): @classmethod def input_tensor_from(cls, player, shape): @@ -1362,7 +1370,7 @@ def output_if(self, cond): class cgf2n(_clear, _gf2n): - """ + r""" Clear :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A number operators are supported (``+, -, *, /, **, ^, &, |, ~, ==, !=, <<, >>``), returning either :py:class:`cgf2n` if the other @@ -1381,7 +1389,7 @@ class cgf2n(_clear, _gf2n): @classmethod def bit_compose(cls, bits, step=None): - """ Clear :math:`\mathrm{GF}(2^n)` bit composition. + r""" Clear :math:`\mathrm{GF}(2^n)` bit composition. :param bits: list of cgf2n :param step: set every :py:obj:`step`-th bit in output (defaults to 1) """ @@ -1478,7 +1486,7 @@ def __ne__(self, other): @vectorize def bit_decompose(self, bit_length=None, step=None): - """ Clear bit decomposition. + r""" Clear bit decomposition. :param bit_length: number of bits (defaults to global :math:`\mathrm{GF}(2^n)` bit length) :param step: extract every :py:obj:`step`-th bit (defaults to 1) """ @@ -2685,19 +2693,23 @@ def write_shares_to_socket(cls, client_id, values, writesocketshare(client_id, message_type, values[0].size, *values) @classmethod - def read_from_file(cls, start, n_items): + def read_from_file(cls, start, n_items=1, crash_if_missing=True, size=1): """ Read shares from ``Persistence/Transactions-P.data``. See :ref:`this section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) :param n_items: number of items (int) + :param crash_if_missing: crash if file not found (default) + :param size: vector size (int) :returns: destination for final position, -1 for eof reached, or -2 for file not found (regint) :returns: list of shares """ - shares = [cls(size=1) for i in range(n_items)] + shares = [cls(size=size) for i in range(n_items)] stop = regint() readsharesfromfile(regint.conv(start), stop, *shares) + if crash_if_missing: + library.runtime_error_if(stop == -2, 'Persistence not found') return stop, shares @staticmethod @@ -2710,9 +2722,11 @@ def write_to_file(shares, position=None): :param position: start position (int/regint/cint), defaults to end of file """ + if isinstance(shares, sint): + shares = [shares] for share in shares: assert isinstance(share, sint) - assert share.size == 1 + assert share.size == shares[0].size if position is None: position = -1 writesharestofile(regint.conv(position), *shares) @@ -3246,7 +3260,7 @@ def __rsub__(self, other): __ror__ = __or__ class sgf2n(_secret, _gf2n): - """ + r""" Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A number operators are supported (``+, -, *, /, **, ^, ~, ==, !=, <<``), :py:class:`sgf2n`. Operators generally work with @@ -3271,7 +3285,7 @@ def get_raw_input_from(cls, player): return res def add(self, other): - """ Secret :math:`\mathrm{GF}(2^n)` addition (XOR). + r""" Secret :math:`\mathrm{GF}(2^n)` addition (XOR). :param other: sg2fn/cgf2n/regint/int """ if isinstance(other, sgf2nint): @@ -3280,7 +3294,7 @@ def add(self, other): return super(sgf2n, self).add(other) def mul(self, other): - """ Secret :math:`\mathrm{GF}(2^n)` multiplication. + r""" Secret :math:`\mathrm{GF}(2^n)` multiplication. :param other: sg2fn/cgf2n/regint/int """ if isinstance(other, (sgf2nint)): @@ -3349,7 +3363,7 @@ def __lshift__(self, other): @vectorize def right_shift(self, other, bit_length=None): - """ Secret right shift by public value: + r""" Secret right shift by public value: :param other: compile-time (int) :param bit_length: number of bits of :py:obj:`self` (defaults to :math:`\mathrm{GF}(2^n)` bit length) """ @@ -4458,6 +4472,7 @@ def read_from_file(cls, *args, **kwargs): :param start: starting position in number of shares from beginning (int/regint/cint) :param n_items: number of items (int) + :param crash_if_missing: crash if file not found (default) :returns: destination for final position, -1 for eof reached, or -2 for file not found (regint) :returns: list of shares @@ -4980,11 +4995,21 @@ def multipliable(v, k, f, size): return cfix._new(cint.conv(v, size=size), k, f) def dot(self, other): - """ Dot product with :py:class:`sint`. """ + """ Dot product with any vector or iterable. """ if isinstance(other, sint): return self._new(sint.dot_product(self.v, other), k=self.k, f=self.f) + elif isinstance(other, sfix): + assert self.k == other.k + assert self.f == other.f + return self._new(sint.dot_product(self.v, other.v).round( + self.k + other.f, self.f, nearest=self.round_nearest, + signed=True), k=self.k, f=self.f) + elif isinstance(other, (_int, cfix)): + return (self * other).sum() else: - raise NotImplementedError() + other = list(other) + assert len(self) == len(other) + return sum(a * b for a, b in zip(self, other)) def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -5280,7 +5305,7 @@ def reduce(self, unreduced): return squant._new(shifted, params=self) class sfloat(_number, _secret_structure): - """ + r""" Secret floating-point number. Represents :math:`(1 - 2s) \cdot (1 - z)\cdot v \cdot 2^p`. @@ -5769,6 +5794,23 @@ def _get_type(t): return t class _vectorizable: + @classmethod + def check(cls, index, length, sizes): + if isinstance(index, _clear): + index = regint.conv(index) + if length is not None: + from .GC.types import cbits + if isinstance(index, int): + index += length * (index < 0) + if index >= length or index < 0: + raise IndexError('index %s, length %s' % \ + (str(index), str(length))) + elif cls.check_indices and not isinstance(index, cbits): + library.runtime_error_if( + (index >= length).bit_or(index < 0), + 'overflow: %s/%s', index, sizes) + return index + def reveal_to_clients(self, clients): """ Reveal contents to list of clients. @@ -5807,11 +5849,11 @@ class Array(_vectorizable): @classmethod def create_from(cls, l): - """ Convert Python iterator or vector to array. Basic type will be taken - from first element, further elements must to be convertible to - that. + """ Convert Python iterator or vector to array or copy another array. + Basic type will be taken from first element, further elements + must to be convertible to that. - :param l: Python iterable or register vector + :param l: Python iterable, register vector, or array :returns: :py:class:`Array` of appropriate type containing the contents of :py:obj:`l` @@ -5871,19 +5913,7 @@ def get_address(self, index, size=None): if isinstance(index, (_secret, _single)): raise CompilerError('need cleartext index') key = str(index), size or 1 - if isinstance(index, _clear): - index = regint.conv(index) - if self.length is not None: - from .GC.types import cbits - if isinstance(index, int): - index += self.length * (index < 0) - if index >= self.length or index < 0: - raise IndexError('index %s, length %s' % \ - (str(index), str(self.length))) - elif self.check_indices and not isinstance(index, cbits): - library.runtime_error_if( - (index >= self.length).bit_or(index < 0), - 'overflow: %s/%s', index, self.length) + index = self.check(index, self.length, self.length) if (program.curr_block, key) not in self.address_cache: n = self.value_type.n_elements() length = self.length @@ -6178,13 +6208,14 @@ def _(base, size): def _(i): self[i] = input_from(player, **kwargs) - def read_from_file(self, start): + def read_from_file(self, start, *args, **kwargs): """ Read content from ``Persistence/Transactions-P.data``. Precision must be the same as when storing if applicable. See :ref:`this section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) + :param crash_if_missing: crash if file not found (default) :returns: destination for final position, -1 for eof reached, or -2 for file not found (regint) """ @@ -6192,8 +6223,9 @@ def read_from_file(self, start): res = MemValue(0) @library.multithread(None, len(self), max_size=program.budget) def _(base, size): - stop, shares = self.value_type.read_from_file(start, size) - self.assign(shares, base=base) + stop, shares = self.value_type.read_from_file( + start, *args, size=size, **kwargs) + self.assign(shares[0], base=base) start.iadd(size) res.write(stop) return res @@ -6402,7 +6434,7 @@ def reveal_to(self, player): return personal(player, self.create_from(self[:].reveal_to(player)._v)) def sort(self, n_threads=None, batcher=False, n_bits=None): - """ + r""" Sort in place using `radix sort `_ with complexity :math:`O(n \log n)` for :py:class:`sint` and :py:class:`sfix`, @@ -6495,13 +6527,7 @@ def __getitem__(self, index): key = program.curr_tape, tuple( (x, x.has_else) for x in program.curr_tape.if_states), str(index) if key not in self.sub_cache: - if util.is_constant(index) and \ - (index >= self.sizes[0] or index < 0): - raise CompilerError('index out of range') - elif self.check_indices: - library.runtime_error_if(index >= self.sizes[0], - 'overflow: %s/%s', - index, self.sizes) + index = self.check(index, self.sizes[0], self.sizes) if len(self.sizes) == 2: self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ @@ -6769,20 +6795,21 @@ def _(i): my_pos = position + i * self[i].total_size() self[i].write_to_file(my_pos) - def read_from_file(self, start): + def read_from_file(self, start, *args, **kwargs): """ Read content from ``Persistence/Transactions-P.data``. Precision must be the same as when storing if applicable. See :ref:`this section ` for details on the data format. :param start: starting position in number of shares from beginning (int/regint/cint) + :param crash_if_missing: crash if file not found (default) :returns: destination for final position, -1 for eof reached, or -2 for file not found (regint) """ start = MemValue(start) @library.for_range(len(self)) def _(i): - start.write(self[i].read_from_file(start)) + start.write(self[i].read_from_file(start, *args, **kwargs)) return start def write_to_socket(self, socket, debug=False): diff --git a/Compiler/util.py b/Compiler/util.py index 6e9f43554..aa83f063b 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -292,6 +292,9 @@ def keys(self): def __iter__(self): return self.keys() + def pop(self, key): + return self.content.pop(id(key), None) + class defaultdict_by_id(dict_by_id): def __init__(self, default): dict_by_id.__init__(self) diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 1ce82ff35..ff7da2ed3 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -10,10 +10,12 @@ #include "Math/gfp.h" #include "ECDSA/P256Element.h" #include "GC/VectorInput.h" +#include "Protocols/SPDZ.h" #include "ECDSA/preprocessing.hpp" #include "ECDSA/sign.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/Hemi.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/Share.hpp" #include "Protocols/MAC_Check.hpp" diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 8bd7acb28..cb8ee445a 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -32,7 +32,7 @@ #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include "Machines/MalRep.hpp" #include "Machines/Rep.hpp" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 550c0ac8a..1a6aa0cf1 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -8,6 +8,7 @@ #include "Math/gfp.h" #include "ECDSA/P256Element.h" #include "Protocols/SemiShare.h" +#include "Protocols/SPDZ.h" #include "Processor/BaseMachine.h" #include "ECDSA/preprocessing.hpp" @@ -15,6 +16,7 @@ #include "Protocols/Beaver.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/MascotPrep.hpp" +#include "Protocols/Hemi.hpp" #include "Processor/Processor.hpp" #include "Processor/Data_Files.hpp" #include "Processor/Input.hpp" diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp index d7a4d8836..5498b4c43 100644 --- a/ECDSA/semi-ecdsa-party.cpp +++ b/ECDSA/semi-ecdsa-party.cpp @@ -6,6 +6,7 @@ #include "GC/SemiSecret.h" #include "GC/SemiPrep.h" +#include "Protocols/Hemi.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/SemiPrep.hpp" #include "Protocols/SemiInput.hpp" diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp index df73da106..1caa1e02c 100644 --- a/ExternalIO/bankers-bonus-client.cpp +++ b/ExternalIO/bankers-bonus-client.cpp @@ -17,11 +17,10 @@ * - share of winning unique id * random value [w] * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] * - * To run with 2 parties / SPDZ engines: - * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * To run: * ./Scripts/setup-clients.sh to create SSL keys and certificates for clients - * ./compile.py bankers_bonus - * ./Scripts/run-online.sh bankers_bonus to run the engines. + * ./Scripts/compile-run.py bankers_bonus to compile and run the engines. + * (See https://github.com/data61/MP-SPDZ/?tab=readme-ov-file#protocols for options.) * * ./bankers-bonus-client.x 0 2 100 0 * ./bankers-bonus-client.x 1 2 200 0 diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index b0a287444..f28222fac 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -16,14 +16,14 @@ template class AddableVector: public vector { public: - AddableVector() {} - AddableVector(size_t n, const T& x = T()) : vector(n, x) {} + AddableVector() {} + AddableVector(size_t n, const T& x = T()) : vector(n, x) {} template - AddableVector(const Plaintext& other) : - AddableVector(other.get_poly()) {} + AddableVector(const Plaintext& other) : + AddableVector(other.get_poly()) {} template - AddableVector(const vector& other) + AddableVector(const vector& other) { this->assign(other.begin(), other.end()); } @@ -129,29 +129,41 @@ class AddableVector: public vector (*this)[i].pack(os); } - void unpack_size(octetStream& os, const T& init = T()) + size_t unpack_size(octetStream& os) { unsigned int size; os.get(size); - this->resize(size, init); + this->reserve(size); + return size; } void unpack(octetStream& os, const T& init = T()) { - unpack_size(os, init); - for (unsigned int i = 0; i < this->size(); i++) - (*this)[i].unpack(os); + size_t new_size = unpack_size(os); + this->clear(); + for (unsigned int i = 0; i < new_size; i++) + { + this->push_back(init); + this->back().unpack(os); + } } void add(octetStream& os, T& tmp) { - unpack_size(os, tmp); + size_t new_size = unpack_size(os); + T init = tmp; T& item = tmp; for (unsigned int i = 0; i < this->size(); i++) { item.unpack(os); (*this)[i] += item; } + for (size_t i = this->size(); i < new_size; i++) + { + item.unpack(os); + this->push_back(init); + this->back() += item; + } } T infinity_norm() const diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 4fe98e58b..aba201ca0 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -61,9 +61,7 @@ bigint FHE_Params::Q() const void FHE_Params::pack(octetStream& o) const { - o.store(FFTData.size()); - for(auto& fd: FFTData) - fd.pack(o); + o.store(FFTData); Chi.pack(o); Bval.pack(o); o.store(sec_p); @@ -73,11 +71,7 @@ void FHE_Params::pack(octetStream& o) const void FHE_Params::unpack(octetStream& o) { - size_t size; - o.get(size); - FFTData.resize(size); - for (auto& fd : FFTData) - fd.unpack(o); + o.get(FFTData); Chi.unpack(o); Bval.unpack(o); o.get(sec_p); diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index 8815fb356..066be5e19 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -311,22 +311,18 @@ void imatrix::hash(octetStream& o) const void imatrix::pack(octetStream& o) const { - o.store(size()); for (auto& x : *this) { assert(x.size() == size()); - x.pack(o); } + o.store(static_cast(*this)); } void imatrix::unpack(octetStream& o) { - size_t size; - o.get(size); - resize(size); + o.get(static_cast(*this)); for (auto& x : *this) { - x.resize(size); - x.unpack(o); + assert(x.size() == size()); } } diff --git a/FHE/Matrix.h b/FHE/Matrix.h index 34c915e09..ad35eab76 100644 --- a/FHE/Matrix.h +++ b/FHE/Matrix.h @@ -13,6 +13,8 @@ typedef vector< vector > matrix; class imatrix : public vector< BitVector > { + typedef vector super; + public: bool operator!=(const imatrix& other) const; diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 85b630ee6..33f3af343 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -13,6 +13,8 @@ #include "FHEOffline/Proof.h" +#include "Processor/OnlineOptions.h" + #include using namespace std; @@ -735,9 +737,9 @@ void load_or_generate(P2Data& P2D, const Ring& R) { P2D.load(R); } - catch (...) + catch (exception& e) { - cout << "Loading failed" << endl; + cerr << "Loading parameters failed, generating (" << e.what() << ")" << endl; init(P2D,R); P2D.store(R); } diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index ac4ae6f16..d081eebe6 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -2,6 +2,7 @@ #include "FHE/P2Data.h" #include "Math/Setup.h" #include "Math/fixint.h" +#include "Processor/OnlineOptions.h" #include @@ -74,7 +75,6 @@ bool P2Data::operator!=(const P2Data& other) const void P2Data::hash(octetStream& o) const { - check_dimensions(); o.store(gf2n_short::degree()); o.store(slots); A.hash(o); @@ -113,17 +113,18 @@ string get_filename(const Ring& Rg) void P2Data::load(const Ring& Rg) { string filename = get_filename(Rg); - cout << "Loading from " << filename << endl; - ifstream s(filename); + if (OnlineOptions::singleton.verbose) + cerr << "Loading from " << filename << endl; octetStream os; - os.input(s); + os.input(filename); unpack(os); } void P2Data::store(const Ring& Rg) const { string filename = get_filename(Rg); - cout << "Storing in " << filename << endl; + if (OnlineOptions::singleton.verbose) + cerr << "Storing in " << filename << endl; ofstream s(filename); octetStream os; pack(os); diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 8adb3d34a..b205cfaa9 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -562,22 +562,17 @@ template void Plaintext::pack(octetStream& o) const { to_poly(); - o.store((unsigned int)b.size()); - for (unsigned int i = 0; i < b.size(); i++) - o.store(b[i]); + o.store(b); } template void Plaintext::unpack(octetStream& o) { - type = Polynomial; - unsigned int size; - o.get(size); - allocate(); + o.get(b); + auto size = b.size(); + allocate(Polynomial); if (size != b.size() and size != 0) throw length_error("unexpected length received"); - for (unsigned int i = 0; i < size; i++) - b[i] = o.get(); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 4c53d28da..ea54fa6f7 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -512,9 +512,7 @@ modp Ring_Element::get_constant() const void store(octetStream& o,const vector& v,const Zp_Data& ZpD) { ZpD.pack(o); - o.store((int)v.size()); - for (unsigned int i=0; i& v,const Zp_Data& ZpD) throw runtime_error( "mismatch: " + to_string(check_Zpd.pr_bit_length) + "/" + to_string(ZpD.pr_bit_length)); - unsigned int length; - o.get(length); - v.clear(); - v.reserve(length); - modp tmp; - for (unsigned int i=0; i typedef MaliciousRepMC MC; typedef MC MAC_Check; + typedef HashMaliciousRepMC DefaultMC; typedef ReplicatedInput Input; typedef RepPrep LivePrep; diff --git a/GC/Program.hpp b/GC/Program.hpp index d493e4e1e..109942bb2 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -71,9 +71,9 @@ void Program::parse(istream& s) CALLGRIND_STOP_INSTRUMENTATION; while (!s.eof()) { + instr.parse(s, pos); if (s.bad() or s.fail()) throw runtime_error("error reading program"); - instr.parse(s, pos); p.push_back(instr); //cerr << "\t" << instr << endl; s.peek(); diff --git a/GC/Secret.h b/GC/Secret.h index 54addd235..e4877ed8e 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -65,6 +65,8 @@ class Secret typedef typename T::out_type out_type; + typedef void DefaultMC; + static string type_string() { return "evaluation secret"; } static string phase_name() { return T::name(); } @@ -179,6 +181,7 @@ class Secret void finalize_input(U& inputter, int from, int n_bits); int size() const { return registers.size(); } + size_t maximum_size() const { return registers.size(); } RegVector& get_regs() { return registers; } const RegVector& get_regs() const { return registers; } diff --git a/GC/Semi.cpp b/GC/Semi.cpp index 1efdaf3f1..d149028aa 100644 --- a/GC/Semi.cpp +++ b/GC/Semi.cpp @@ -17,7 +17,8 @@ namespace GC void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, bool repeat) { - if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1)) + if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1) + and P.num_players() == 2) { this->triples.push_back({{}}); auto& triple = this->triples.back(); diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index 30d7dfdf2..1ae7a0f82 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -51,6 +51,11 @@ class SemiSecretBase : public V, public ShareSecret static void trans(Processor& processor, int n_outputs, const vector& args); + static size_t maximum_size() + { + return default_length; + } + SemiSecretBase() { } diff --git a/GC/SemiSecret.hpp b/GC/SemiSecret.hpp index b869c87a1..5334ac168 100644 --- a/GC/SemiSecret.hpp +++ b/GC/SemiSecret.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_SEMISECRET_HPP_ +#define GC_SEMISECRET_HPP_ + #include "GC/ShareParty.h" #include "GC/ShareSecret.hpp" #include "Protocols/MAC_Check_Base.hpp" @@ -161,3 +164,5 @@ void SemiSecretBase::reveal(size_t n_bits, Clear& x) } } /* namespace GC */ + +#endif diff --git a/GC/ShareParty.h b/GC/ShareParty.h index ceda2f01f..8a125fff5 100644 --- a/GC/ShareParty.h +++ b/GC/ShareParty.h @@ -42,7 +42,7 @@ inline ShareParty& ShareParty::s() if (singleton) return *singleton; else - throw runtime_error("no singleton"); + throw runtime_error("no ShareParty singleton"); } } diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 6deea9c80..9721cd83d 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -125,6 +125,8 @@ class RepSecretBase : public FixedVec, public ShareSecret typedef NoShare bit_type; + typedef void DefaultMC; + static const int N_BITS = clear::N_BITS; static const bool dishonest_majority = false; @@ -166,6 +168,11 @@ class RepSecretBase : public FixedVec, public ShareSecret return T::fake_opts(); } + static size_t maximum_size() + { + return default_length; + } + RepSecretBase() { } @@ -203,7 +210,7 @@ class ReplicatedSecret : public RepSecretBase typedef ReplicatedBase Protocol; static ReplicatedSecret constant(const typename super::clear& value, - int my_num, typename super::mac_key_type, int = -1) + int my_num, typename super::mac_key_type = {}, int = -1) { ReplicatedSecret res; if (my_num < 2) diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 70aae69b3..c24f62ba1 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -67,7 +67,7 @@ inline ShareThread& ShareThread::s() if (singleton and T::is_real) return *singleton; else - throw runtime_error("no singleton"); + throw runtime_error("no ShareThread singleton"); } } /* namespace GC */ diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 88ca4fa40..0732e20ce 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -144,6 +144,9 @@ void ShareThread::and_(Processor& processor, res.mask(res, n); } } + + if (OnlineOptions::singleton.has_option("always_check")) + protocol->check(); } template @@ -195,6 +198,9 @@ void ShareThread::andrsvec(Processor& processor, const vector& args) } it += 2 * n_args + 1; } + + if (OnlineOptions::singleton.has_option("always_check")) + protocol->check(); } template diff --git a/GC/Thread.h b/GC/Thread.h index 6631ad723..b510120ed 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -48,6 +48,7 @@ class Thread Thread(int thread_num, ThreadMaster& master); virtual ~Thread(); + void start(); void run(); virtual void pre_run() {} virtual void run(Program& program); diff --git a/GC/Thread.hpp b/GC/Thread.hpp index d0b515cbf..8fd294ad8 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -31,7 +31,12 @@ template Thread::Thread(int thread_num, ThreadMaster& master) : master(master), machine(master.machine), processor(machine), N(master.N), P(0), - thread_num(thread_num) + thread_num(thread_num), thread(0) +{ +} + +template +void Thread::start() { pthread_create(&thread, 0, run_thread, this); } diff --git a/GC/ThreadMaster.h b/GC/ThreadMaster.h index e198c0f5d..53c473477 100644 --- a/GC/ThreadMaster.h +++ b/GC/ThreadMaster.h @@ -60,6 +60,7 @@ class ThreadMaster : public ThreadMasterBase virtual Thread* new_thread(int i); void run(); + void run_with_error(); virtual void post_run() {} }; diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index abcec91ec..d657dc906 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -59,6 +59,25 @@ Thread* ThreadMaster::new_thread(int i) template void ThreadMaster::run() +{ + if (opts.has_option("throw_exceptions")) + run_with_error(); + else + { + try + { + run_with_error(); + } + catch (exception& e) + { + cerr << "Fatal error: " << e.what() << endl; + exit(1); + } + } +} + +template +void ThreadMaster::run_with_error() { if (not opts.live_prep) { @@ -72,6 +91,9 @@ void ThreadMaster::run() for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); + // must start after constructor due to virtual functions + for (auto thread : threads) + thread->start(); for (auto thread : threads) thread->join_tape(); diff --git a/GC/TinyMC.h b/GC/TinyMC.h index 2a35f6e6b..98fb79589 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -50,7 +50,7 @@ class TinyMC : public MAC_Check_Base return *part_MC; } - void init_open(const Player& P, int n) + void init_open(const Player& P, int n = 0) { part_MC->init_open(P); sizes.clear(); diff --git a/Machines/Atlas.hpp b/Machines/Atlas.hpp index 045b69b9e..bb3e7a7ee 100644 --- a/Machines/Atlas.hpp +++ b/Machines/Atlas.hpp @@ -10,7 +10,8 @@ #include "Protocols/AtlasPrep.h" #include "GC/AtlasSecret.h" -#include "ShamirMachine.hpp" #include "Protocols/Atlas.hpp" +#include "Shamir.hpp" + #endif /* MACHINES_ATLAS_HPP_ */ diff --git a/Machines/BMR/mal-shamir-bmr-party.cpp b/Machines/BMR/mal-shamir-bmr-party.cpp index e6264dda5..5e84fc51a 100644 --- a/Machines/BMR/mal-shamir-bmr-party.cpp +++ b/Machines/BMR/mal-shamir-bmr-party.cpp @@ -4,7 +4,7 @@ */ #include "BMR/RealProgramParty.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include "Math/Z2k.hpp" #include "Machines/MalRep.hpp" diff --git a/Machines/BMR/shamir-bmr-party.cpp b/Machines/BMR/shamir-bmr-party.cpp index e6fe0ac8f..9633c1241 100644 --- a/Machines/BMR/shamir-bmr-party.cpp +++ b/Machines/BMR/shamir-bmr-party.cpp @@ -4,7 +4,7 @@ */ #include "BMR/RealProgramParty.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include "Math/Z2k.hpp" int main(int argc, const char** argv) diff --git a/Machines/SPDZ.cpp b/Machines/SPDZ.cpp index a93d1a366..6e16f34b5 100644 --- a/Machines/SPDZ.cpp +++ b/Machines/SPDZ.cpp @@ -5,3 +5,5 @@ #include "Math/gfp.hpp" template class FieldMachine; + +template class Machine>, Share>; diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index a221b087a..724d17c03 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -6,6 +6,9 @@ #ifndef MACHINES_SPDZ_HPP_ #define MACHINES_SPDZ_HPP_ +#include "Protocols/MAC_Check.h" +#include "Protocols/SPDZ.h" + #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" diff --git a/Machines/Shamir.hpp b/Machines/Shamir.hpp new file mode 100644 index 000000000..bf9f70ad1 --- /dev/null +++ b/Machines/Shamir.hpp @@ -0,0 +1,50 @@ +/* + * ShamirMachine.cpp + * + */ + +#ifndef MACHINE_SHAMIR_HPP_ +#define MACHINE_SHAMIR_HPP_ + +#include "Protocols/ShamirOptions.h" +#include "Protocols/ShamirShare.h" +#include "Protocols/MaliciousShamirShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "GC/VectorProtocol.h" +#include "GC/CcdPrep.h" +#include "GC/TinyMC.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/VectorInput.h" + +#include "Processor/FieldMachine.hpp" + +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/ShamirInput.hpp" +#include "Protocols/Shamir.hpp" +#include "Protocols/ShamirMC.hpp" +#include "Protocols/MaliciousShamirMC.hpp" +#include "Protocols/MaliciousShamirPO.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Beaver.hpp" +#include "Protocols/Spdz2kPrep.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/VectorProtocol.hpp" +#include "GC/Secret.hpp" +#include "GC/CcdPrep.hpp" +#include "Math/gfp.hpp" + +template class T> +ShamirMachineSpec::ShamirMachineSpec(int argc, const char** argv) +{ + auto& opts = ShamirOptions::singleton; + ez::ezOptionParser opt; + opts = {opt, argc, argv}; + HonestMajorityFieldMachine(argc, argv, opt, opts.nparties); +} + +#endif diff --git a/Machines/ccd-party.cpp b/Machines/ccd-party.cpp index 433aaf261..c5c42f48b 100644 --- a/Machines/ccd-party.cpp +++ b/Machines/ccd-party.cpp @@ -13,7 +13,7 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "GC/CcdPrep.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" int main(int argc, const char** argv) { diff --git a/Machines/export-atlas.cpp b/Machines/export-atlas.cpp new file mode 100644 index 000000000..a28d62db1 --- /dev/null +++ b/Machines/export-atlas.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>>; diff --git a/Machines/export-cowgear.cpp b/Machines/export-cowgear.cpp new file mode 100644 index 000000000..cd7e81e4f --- /dev/null +++ b/Machines/export-cowgear.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>>; diff --git a/Machines/export-dealer.cpp b/Machines/export-dealer.cpp new file mode 100644 index 000000000..5c9b0d4f6 --- /dev/null +++ b/Machines/export-dealer.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>>; diff --git a/Machines/export-hemi.cpp b/Machines/export-hemi.cpp new file mode 100644 index 000000000..6c861a8a6 --- /dev/null +++ b/Machines/export-hemi.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>>; diff --git a/Machines/export-rep4-ring.cpp b/Machines/export-rep4-ring.cpp new file mode 100644 index 000000000..3bc581f68 --- /dev/null +++ b/Machines/export-rep4-ring.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>; diff --git a/Machines/export-ring.cpp b/Machines/export-ring.cpp new file mode 100644 index 000000000..cbd540b8b --- /dev/null +++ b/Machines/export-ring.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>; diff --git a/Machines/export-semi2k.cpp b/Machines/export-semi2k.cpp new file mode 100644 index 000000000..ec5f14a91 --- /dev/null +++ b/Machines/export-semi2k.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>; diff --git a/Machines/export-sy-rep-ring.cpp b/Machines/export-sy-rep-ring.cpp new file mode 100644 index 000000000..75471b7c0 --- /dev/null +++ b/Machines/export-sy-rep-ring.cpp @@ -0,0 +1,8 @@ +/* + * export-vm.cpp + * + */ + +#include "maximal.hpp" + +template class Machine>; diff --git a/Machines/h-files.h b/Machines/h-files.h new file mode 100644 index 000000000..6d7a697ab --- /dev/null +++ b/Machines/h-files.h @@ -0,0 +1,377 @@ +#include "BMR/AndJob.h" +#include "BMR/BooleanCircuit.h" +#include "BMR/common.h" +#include "BMR/CommonParty.h" +#include "BMR/config.h" +#include "BMR/GarbledGate.h" +#include "BMR/Gate.h" +#include "BMR/Key.h" +#include "BMR/msg_types.h" +#include "BMR/Party.h" +#include "BMR/prf.h" +#include "BMR/proto_utils.h" +#include "BMR/RealGarbleWire.h" +#include "BMR/RealProgramParty.h" +#include "BMR/Register.h" +#include "BMR/Register_inline.h" +#include "BMR/SpdzWire.h" +#include "BMR/TrustedParty.h" +#include "BMR/Wire.h" +#include "ECDSA/CurveElement.h" +#include "ECDSA/EcdsaOptions.h" +#include "ECDSA/P256Element.h" +#include "ExternalIO/Client.h" +#include "FHE/AddableVector.h" +#include "FHE/Ciphertext.h" +#include "FHE/Diagonalizer.h" +#include "FHE/DiscreteGauss.h" +#include "FHE/FFT_Data.h" +#include "FHE/FFT.h" +#include "FHE/FHE_Keys.h" +#include "FHE/FHE_Params.h" +#include "FHE/Generator.h" +#include "FHE/Matrix.h" +#include "FHE/NoiseBounds.h" +#include "FHE/NTL-Subs.h" +#include "FHEOffline/config.h" +#include "FHEOffline/CutAndChooseMachine.h" +#include "FHEOffline/DataSetup.h" +#include "FHEOffline/DistDecrypt.h" +#include "FHEOffline/DistKeyGen.h" +#include "FHEOffline/EncCommit.h" +#include "FHEOffline/Multiplier.h" +#include "FHEOffline/PairwiseGenerator.h" +#include "FHEOffline/PairwiseMachine.h" +#include "FHEOffline/PairwiseSetup.h" +#include "FHEOffline/Producer.h" +#include "FHEOffline/Proof.h" +#include "FHEOffline/Prover.h" +#include "FHEOffline/Reshare.h" +#include "FHEOffline/Sacrificing.h" +#include "FHEOffline/SimpleDistDecrypt.h" +#include "FHEOffline/SimpleEncCommit.h" +#include "FHEOffline/SimpleGenerator.h" +#include "FHEOffline/SimpleMachine.h" +#include "FHEOffline/TemiSetup.h" +#include "FHEOffline/Verifier.h" +#include "FHE/P2Data.h" +#include "FHE/Plaintext.h" +#include "FHE/QGroup.h" +#include "FHE/Random_Coins.h" +#include "FHE/Ring_Element.h" +#include "FHE/Ring.h" +#include "FHE/Rq_Element.h" +#include "FHE/Subroutines.h" +#include "FHE/tools.h" +#include "GC/Access.h" +#include "GC/ArgTuples.h" +#include "GC/AtlasSecret.h" +#include "GC/AtlasShare.h" +#include "GC/BitAdder.h" +#include "GC/BitPrepFiles.h" +#include "GC/CcdPrep.h" +#include "GC/CcdSecret.h" +#include "GC/CcdShare.h" +#include "GC/Clear.h" +#include "GC/config.h" +#include "GC/DealerPrep.h" +#include "GC/FakeSecret.h" +#include "GC/Instruction.h" +#include "GC/Instruction_inline.h" +#include "GC/instructions.h" +#include "GC/Machine.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/MaliciousCcdShare.h" +#include "GC/MaliciousRepSecret.h" +#include "GC/Memory.h" +#include "GC/NoShare.h" +#include "GC/PersonalPrep.h" +#include "GC/PostSacriBin.h" +#include "GC/PostSacriSecret.h" +#include "GC/Processor.h" +#include "GC/Program.h" +#include "GC/Rep4Prep.h" +#include "GC/Rep4Secret.h" +#include "GC/RepPrep.h" +#include "GC/RuntimeBranching.h" +#include "GC/Secret.h" +#include "GC/Secret_inline.h" +#include "GC/Semi.h" +#include "GC/SemiHonestRepPrep.h" +#include "GC/SemiPrep.h" +#include "GC/SemiSecret.h" +#include "GC/ShareParty.h" +#include "GC/ShareSecret.h" +#include "GC/ShareThread.h" +#include "GC/ShiftableTripleBuffer.h" +#include "GC/square64.h" +#include "GC/Thread.h" +#include "GC/ThreadMaster.h" +#include "GC/TinierSecret.h" +#include "GC/TinierShare.h" +#include "GC/TinierSharePrep.h" +#include "GC/TinyMC.h" +#include "GC/TinySecret.h" +#include "GC/TinyShare.h" +#include "GC/VectorInput.h" +#include "GC/VectorProtocol.h" +#include "Machines/OTMachine.h" +#include "Machines/OutputCheck.h" +#include "Math/bigint.h" +#include "Math/Bit.h" +#include "Math/BitVec.h" +#include "Math/config.h" +#include "Math/field_types.h" +#include "Math/FixedVec.h" +#include "Math/fixint.h" +#include "Math/gf2n.h" +#include "Math/gf2nlong.h" +#include "Math/gfp.h" +#include "Math/gfpvar.h" +#include "Math/Integer.h" +#include "Math/modp.h" +#include "Math/mpn_fixed.h" +#include "Math/Setup.h" +#include "Math/Square.h" +#include "Math/ValueInterface.h" +#include "Math/Z2k.h" +#include "Math/Zp_Data.h" +#include "Networking/AllButLastPlayer.h" +#include "Networking/CryptoPlayer.h" +#include "Networking/data.h" +#include "Networking/Exchanger.h" +#include "Networking/PlayerBuffer.h" +#include "Networking/PlayerCtSocket.h" +#include "Networking/Player.h" +#include "Networking/Receiver.h" +#include "Networking/Sender.h" +#include "Networking/Server.h" +#include "Networking/ServerSocket.h" +#include "Networking/sockets.h" +#include "Networking/ssl_sockets.h" +#include "OT/BaseOT.h" +#include "OT/BitDiagonal.h" +#include "OT/BitMatrix.h" +#include "OT/config.h" +#include "OT/MamaRectangle.h" +#include "OT/MascotParams.h" +#include "OT/NPartyTripleGenerator.h" +#include "OT/OTExtension.h" +#include "OT/OTExtensionWithMatrix.h" +#include "OT/OTMultiplier.h" +#include "OT/OTTripleSetup.h" +#include "OT/OTVole.h" +#include "OT/Rectangle.h" +#include "OT/Row.h" +#include "OT/Tools.h" +#include "OT/TripleMachine.h" +#include "Processor/BaseMachine.h" +#include "Processor/Binary_File_IO.h" +#include "Processor/config.h" +#include "Processor/Conv2dTuple.h" +#include "Processor/Data_Files.h" +#include "Processor/DummyProtocol.h" +#include "Processor/EdabitBuffer.h" +#include "Processor/ExternalClients.h" +#include "Processor/FieldMachine.h" +#include "Processor/FixInput.h" +#include "Processor/FloatInput.h" +#include "Processor/FunctionArgument.h" +#include "Processor/HonestMajorityMachine.h" +#include "Processor/Input.h" +#include "Processor/InputTuple.h" +#include "Processor/Instruction.h" +#include "Processor/instructions.h" +#include "Processor/IntInput.h" +#include "Processor/Machine.h" +#include "Processor/Memory.h" +#include "Processor/NoFilePrep.h" +#include "Processor/OfflineMachine.h" +#include "Processor/OnlineMachine.h" +#include "Processor/OnlineOptions.h" +#include "Processor/Online-Thread.h" +#include "Processor/PrepBase.h" +#include "Processor/PrepBuffer.h" +#include "Processor/PrivateOutput.h" +#include "Processor/ProcessorBase.h" +#include "Processor/Processor.h" +#include "Processor/Program.h" +#include "Processor/RingMachine.h" +#include "Processor/RingOptions.h" +#include "Processor/SpecificPrivateOutput.h" +#include "Processor/ThreadJob.h" +#include "Processor/ThreadQueue.h" +#include "Processor/ThreadQueues.h" +#include "Processor/TruncPrTuple.h" +#include "Protocols/Atlas.h" +#include "Protocols/AtlasPrep.h" +#include "Protocols/AtlasShare.h" +#include "Protocols/Beaver.h" +#include "Protocols/BrainPrep.h" +#include "Protocols/BrainShare.h" +#include "Protocols/BufferScope.h" +#include "Protocols/ChaiGearPrep.h" +#include "Protocols/ChaiGearShare.h" +#include "Protocols/config.h" +#include "Protocols/CowGearOptions.h" +#include "Protocols/CowGearPrep.h" +#include "Protocols/CowGearShare.h" +#include "Protocols/dabit.h" +#include "Protocols/DabitSacrifice.h" +#include "Protocols/Dealer.h" +#include "Protocols/DealerInput.h" +#include "Protocols/DealerMatrixPrep.h" +#include "Protocols/DealerMC.h" +#include "Protocols/DealerPrep.h" +#include "Protocols/DealerShare.h" +#include "Protocols/DummyMatrixPrep.h" +#include "Protocols/edabit.h" +#include "Protocols/FakeInput.h" +#include "Protocols/FakeMC.h" +#include "Protocols/FakePrep.h" +#include "Protocols/FakeProtocol.h" +#include "Protocols/FakeShare.h" +#include "Protocols/fake-stuff.h" +#include "Protocols/Hemi.h" +#include "Protocols/HemiMatrixPrep.h" +#include "Protocols/HemiOptions.h" +#include "Protocols/HemiPrep.h" +#include "Protocols/HemiShare.h" +#include "Protocols/HighGearKeyGen.h" +#include "Protocols/HighGearShare.h" +#include "Protocols/LimitedPrep.h" +#include "Protocols/LowGearKeyGen.h" +#include "Protocols/LowGearShare.h" +#include "Protocols/MAC_Check_Base.h" +#include "Protocols/MAC_Check.h" +#include "Protocols/MaliciousRep3Share.h" +#include "Protocols/MaliciousRepMC.h" +#include "Protocols/MaliciousRepPO.h" +#include "Protocols/MaliciousRepPrep.h" +#include "Protocols/MaliciousShamirMC.h" +#include "Protocols/MaliciousShamirPO.h" +#include "Protocols/MaliciousShamirShare.h" +#include "Protocols/MalRepRingOptions.h" +#include "Protocols/MalRepRingPrep.h" +#include "Protocols/MalRepRingShare.h" +#include "Protocols/MamaPrep.h" +#include "Protocols/MamaShare.h" +#include "Protocols/MascotPrep.h" +#include "Protocols/MatrixFile.h" +#include "Protocols/NoLivePrep.h" +#include "Protocols/NoProtocol.h" +#include "Protocols/NoShare.h" +#include "Protocols/Opener.h" +#include "Protocols/PostSacrifice.h" +#include "Protocols/PostSacriRepFieldShare.h" +#include "Protocols/PostSacriRepRingShare.h" +#include "Protocols/ProtocolSet.h" +#include "Protocols/ProtocolSetup.h" +#include "Protocols/Rep3Share2k.h" +#include "Protocols/Rep3Share.h" +#include "Protocols/Rep3Shuffler.h" +#include "Protocols/Rep4.h" +#include "Protocols/Rep4Input.h" +#include "Protocols/Rep4MC.h" +#include "Protocols/Rep4Prep.h" +#include "Protocols/Rep4Share2k.h" +#include "Protocols/Rep4Share.h" +#include "Protocols/Replicated.h" +#include "Protocols/ReplicatedInput.h" +#include "Protocols/ReplicatedMC.h" +#include "Protocols/ReplicatedPO.h" +#include "Protocols/ReplicatedPrep.h" +#include "Protocols/RepRingOnlyEdabitPrep.h" +#include "Protocols/RingOnlyPrep.h" +#include "Protocols/SecureShuffle.h" +#include "Protocols/Semi2kShare.h" +#include "Protocols/Semi.h" +#include "Protocols/SemiInput.h" +#include "Protocols/SemiMC.h" +#include "Protocols/SemiPrep2k.h" +#include "Protocols/SemiPrep.h" +#include "Protocols/SemiRep3Prep.h" +#include "Protocols/SemiShare.h" +#include "Protocols/Shamir.h" +#include "Protocols/ShamirInput.h" +#include "Protocols/ShamirMC.h" +#include "Protocols/ShamirOptions.h" +#include "Protocols/ShamirShare.h" +#include "Protocols/Share.h" +#include "Protocols/ShareInterface.h" +#include "Protocols/ShareMatrix.h" +#include "Protocols/ShareVector.h" +#include "Protocols/ShuffleSacrifice.h" +#include "Protocols/SohoPrep.h" +#include "Protocols/SohoShare.h" +#include "Protocols/SPDZ2k.h" +#include "Protocols/Spdz2kPrep.h" +#include "Protocols/Spdz2kShare.h" +#include "Protocols/SPDZ.h" +#include "Protocols/SpdzWise.h" +#include "Protocols/SpdzWiseInput.h" +#include "Protocols/SpdzWiseMC.h" +#include "Protocols/SpdzWisePrep.h" +#include "Protocols/SpdzWiseRep3Shuffler.h" +#include "Protocols/SpdzWiseRing.h" +#include "Protocols/SpdzWiseRingPrep.h" +#include "Protocols/SpdzWiseRingShare.h" +#include "Protocols/SpdzWiseShare.h" +#include "Protocols/SquarePrep.h" +#include "Protocols/TemiPrep.h" +#include "Protocols/TemiShare.h" +#include "Tools/aes.h" +#include "Tools/avx_memcpy.h" +#include "Tools/benchmarking.h" +#include "Tools/BitVector.h" +#include "Tools/Buffer.h" +#include "Tools/Bundle.h" +#include "Tools/callgrind.h" +#include "Tools/CheckVector.h" +#include "Tools/Commit.h" +#include "Tools/Coordinator.h" +#include "Tools/cpu_support.h" +#include "Tools/DiskVector.h" +#include "Tools/Exceptions.h" +#include "Tools/ExecutionStats.h" +#include "Tools/FixedVector.h" +#include "Tools/FlexBuffer.h" +#include "Tools/Hash.h" +#include "Tools/int.h" +#include "Tools/intrinsics.h" +#include "Tools/Lock.h" +#include "Tools/MemoryUsage.h" +#include "Tools/mkpath.h" +#include "Tools/MMO.h" +#include "Tools/NamedStats.h" +#include "Tools/NetworkOptions.h" +#include "Tools/octetStream.h" +#include "Tools/oct.h" +#include "Tools/OfflineMachineBase.h" +#include "Tools/parse.h" +#include "Tools/PointerVector.h" +#include "Tools/pprint.h" +#include "Tools/random.h" +#include "Tools/Signal.h" +#include "Tools/Subroutines.h" +#include "Tools/SwitchableOutput.h" +#include "Tools/time-func.h" +#include "Tools/TimerWithComm.h" +#include "Tools/WaitQueue.h" +#include "Tools/Waksman.h" +#include "Tools/Worker.h" +#include "Yao/config.h" +#include "Yao/YaoAndJob.h" +#include "Yao/YaoCommon.h" +#include "Yao/YaoEvalInput.h" +#include "Yao/YaoEvalMaster.h" +#include "Yao/YaoEvaluator.h" +#include "Yao/YaoEvalWire.h" +#include "Yao/YaoGarbleInput.h" +#include "Yao/YaoGarbleMaster.h" +#include "Yao/YaoGarbler.h" +#include "Yao/YaoGarbleWire.h" +#include "Yao/YaoGate.h" +#include "Yao/YaoHalfGate.h" +#include "Yao/YaoPlayer.h" +#include "Yao/YaoWire.h" diff --git a/Machines/mal-shamir-offline.cpp b/Machines/mal-shamir-offline.cpp index 1f585f41a..2294cce05 100644 --- a/Machines/mal-shamir-offline.cpp +++ b/Machines/mal-shamir-offline.cpp @@ -3,7 +3,7 @@ * */ -#include "ShamirMachine.hpp" +#include #include "MalRep.hpp" #include "Processor/OfflineMachine.hpp" diff --git a/Machines/malicious-ccd-party.cpp b/Machines/malicious-ccd-party.cpp index 55ec8b99b..0adf68a32 100644 --- a/Machines/malicious-ccd-party.cpp +++ b/Machines/malicious-ccd-party.cpp @@ -13,7 +13,7 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "GC/CcdPrep.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include "Machines/MalRep.hpp" int main(int argc, const char** argv) diff --git a/Machines/malicious-shamir-party.cpp b/Machines/malicious-shamir-party.cpp index 06a01ee93..a7c99ce3f 100644 --- a/Machines/malicious-shamir-party.cpp +++ b/Machines/malicious-shamir-party.cpp @@ -3,11 +3,11 @@ * */ -#include "Machines/ShamirMachine.h" #include "Protocols/MaliciousShamirShare.h" +#include "Protocols/ShamirOptions.h" #include "Machines/MalRep.hpp" -#include "ShamirMachine.hpp" +#include "Shamir.hpp" int main(int argc, const char** argv) { diff --git a/Machines/maximal.hpp b/Machines/maximal.hpp new file mode 100644 index 000000000..55437e168 --- /dev/null +++ b/Machines/maximal.hpp @@ -0,0 +1,37 @@ +/* + * maximal.hpp + * + */ + +#ifndef MACHINES_MAXIMAL_HPP_ +#define MACHINES_MAXIMAL_HPP_ + +#include "minimal.hpp" + +#include "Atlas.hpp" +#include "MalRep.hpp" +#include "Rep4.hpp" +#include "Rep.hpp" +#include "RepRing.hpp" +#include "Semi2k.hpp" +#include "SPDZ2k.hpp" +#include "SPDZ.hpp" + +#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/TinyPrep.hpp" +#include "Protocols/ChaiGearPrep.hpp" +#include "Protocols/CowGearPrep.hpp" +#include "Protocols/DealerPrep.hpp" +#include "Protocols/DealerInput.hpp" +#include "Protocols/DealerMC.hpp" +#include "Protocols/DealerMatrixPrep.hpp" +#include "Protocols/SpdzWise.hpp" +#include "Protocols/SpdzWiseRing.hpp" +#include "Protocols/SpdzWiseInput.hpp" +#include "Protocols/SpdzWisePrep.hpp" +#include "Protocols/SpdzWiseShare.hpp" +#include "Protocols/SpdzWiseRep3Shuffler.hpp" +#include "Protocols/TemiPrep.hpp" + +#endif /* MACHINES_MAXIMAL_HPP_ */ diff --git a/Machines/minimal.hpp b/Machines/minimal.hpp new file mode 100644 index 000000000..d013edd69 --- /dev/null +++ b/Machines/minimal.hpp @@ -0,0 +1,26 @@ +/* + * minimal.hpp + * + */ + +// minimal header file to make all C++ code compile +// but not produce all templated code in binary +// use maximal.hpp for that + +// please report if otherwise + +#ifndef MACHINES_MINIMAL_HPP_ +#define MACHINES_MINIMAL_HPP_ + +#include "h-files.h" + +// some h files depend on hpp files + +#include "GC/Secret.hpp" +#include "GC/SemiSecret.hpp" +#include "Protocols/DealerMC.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/ShamirMC.hpp" + +#endif /* MACHINES_MINIMAL_HPP_ */ diff --git a/Machines/shamir-party.cpp b/Machines/shamir-party.cpp index c63c3d029..a78d59baa 100644 --- a/Machines/shamir-party.cpp +++ b/Machines/shamir-party.cpp @@ -3,10 +3,10 @@ * */ -#include "Machines/ShamirMachine.h" +#include "Protocols/ShamirOptions.h" #include "Protocols/ShamirShare.h" -#include "ShamirMachine.hpp" +#include "Shamir.hpp" int main(int argc, const char** argv) { diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index ced64919f..4f8fc5420 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -4,6 +4,7 @@ */ #include "Protocols/SohoShare.h" +#include "Protocols/SPDZ.h" #include "Math/gfp.h" #include "Math/gf2n.h" #include "FHE/P2Data.h" diff --git a/Machines/sy-shamir-party.cpp b/Machines/sy-shamir-party.cpp index d251e7cdc..34a0072ef 100644 --- a/Machines/sy-shamir-party.cpp +++ b/Machines/sy-shamir-party.cpp @@ -3,7 +3,6 @@ * */ -#include "ShamirMachine.h" #include "Protocols/SpdzWiseShare.h" #include "Protocols/MaliciousShamirShare.h" #include "Protocols/SpdzWiseMC.h" @@ -19,7 +18,7 @@ #include "Protocols/SpdzWisePrep.hpp" #include "Protocols/SpdzWiseInput.hpp" #include "Protocols/SpdzWiseShare.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include "Machines/MalRep.hpp" template diff --git a/Makefile b/Makefile index 9f5cbf3f8..ff04f1e1e 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp)) NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp)) -PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) +PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o @@ -59,7 +59,7 @@ DEPS := $(wildcard */*.d */*/*.d) .SECONDARY: $(OBJS) -all: arithmetic binary gen_input online offline externalIO bmr ecdsa +all: arithmetic binary gen_input online offline externalIO bmr ecdsa export vm: arithmetic binary .PHONY: doc @@ -124,7 +124,7 @@ tldr: setup mkdir Player-Data 2> /dev/null; true ifeq ($(ARM), 1) -$(patsubst %.cpp,%.o,$(wildcard */*.cpp)): deps/simde/simde +$(patsubst %.cpp,%.o,$(wildcard */*.cpp */*/*.cpp)): deps/simde/simde deps/sse2neon/sse2neon.h endif shamir: shamir-party.x malicious-shamir-party.x atlas-party.x galois-degree.x @@ -162,6 +162,16 @@ static-dir: static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) $(patsubst Machines/BMR/%.cpp, static/%.x, $(wildcard Machines/BMR/*-party.cpp)) static/emulate.x +EXPORT_VM = $(patsubst %.cpp, %.o, $(wildcard Machines/export-*.cpp)) +.SECONDARY: $(EXPORT_VM) + +export-trunc.x: Machines/export-ring.o +export-sort.x: Machines/export-ring.o +export-a2b.x: GC/AtlasSecret.o Machines/SPDZ.o Machines/SPDZ2^64+64.o $(GC_SEMI) $(TINIER) $(EXPORT_VM) GC/Rep4Secret.o GC/Rep4Prep.o $(FHEOFFLINE) +export-b2a.x: Machines/export-ring.o + +export: $(patsubst Utils/%.cpp, %.x, $(wildcard Utils/export*.cpp)) + Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o $(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS) @@ -367,8 +377,11 @@ mac-machine-setup: deps/simde/simde: git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde +deps/sse2neon/sse2neon.h: + git submodule update --init deps/sse2neon || git clone https://github.com/DLTcollab/sse2neon deps/sse2neon + clean-deps: - -rm -rf local/lib/liblibOTe.* deps/libOTe/out deps/SimplestOT_C + -rm -rf local/lib/liblibOTe.* deps/libOTe/out deps/SimplestOT_C deps/SimpleOT clean: clean-deps -rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so diff --git a/Math/FixedVec.h b/Math/FixedVec.h index d38dc0e4f..c066b6e53 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -63,25 +63,25 @@ class FixedVec return res; } - FixedVec(const T& other = {}) + FixedVec(const T& other = {}) { for (auto& x : v) x = other; } - FixedVec(long other) : - FixedVec(T(other)) + FixedVec(long other) : + FixedVec(T(other)) { } template - FixedVec(const FixedVec& other) + FixedVec(const FixedVec& other) { for (int i = 0; i < L; i++) v[i] = other[i]; } - FixedVec(const array& other) + FixedVec(const array& other) { v = other; } diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 2a54575c2..1a69156e3 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -71,14 +71,14 @@ int default_m(int& lgp, int& idx) return m; } -bigint generate_prime(int lgp, int m) +bigint generate_prime(int lgp, int m, bool force_degree) { bigint p; - generate_prime(p, lgp, m); + generate_prime(p, lgp, m, force_degree); return p; } -void generate_prime(bigint& p, int lgp, int m) +void generate_prime(bigint& p, int lgp, int m, bool force_degree) { if (OnlineOptions::singleton.prime > 0) { @@ -100,7 +100,8 @@ void generate_prime(bigint& p, int lgp, int m) } int idx; - m = max(m, default_m(lgp, idx)); + if (not force_degree) + m = max(m, default_m(lgp, idx)); bigint u; int ex; diff --git a/Math/Setup.h b/Math/Setup.h index 27724b58f..4004093e6 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -33,8 +33,8 @@ void check_setup(string dirname, bigint p); // Chooses a p of at least lgp bits bigint SPDZ_Data_Setup_Primes(int lgp); void SPDZ_Data_Setup_Primes(bigint& p,int lgp,int& idx,int& m); -void generate_prime(bigint& p, int lgp, int m); -bigint generate_prime(int lgp, int m); +void generate_prime(bigint& p, int lgp, int m, bool force_degree = false); +bigint generate_prime(int lgp, int m, bool force_degree = false); int default_m(int& lgp, int& idx); string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 64035090f..e8e85fa93 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -35,7 +35,7 @@ class int128 word get_upper() const { return _mm_cvtsi128_si64(_mm_unpackhi_epi64(a, a)); } word get_half(bool upper) const { return upper ? get_upper() : get_lower(); } -#ifdef __SSE41__ +#ifdef __SSE4_1__ bool operator==(const int128& other) const { return _mm_test_all_zeros(a ^ other.a, a ^ other.a); } #else bool operator==(const int128& other) const { return get_lower() == other.get_lower() and get_upper() == other.get_upper(); } @@ -152,6 +152,8 @@ class gf2n_long : public gf2n_ gf2n_long(const super& g) : super(g) {} gf2n_long(const int128& g) : super(g) {} gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {} + gf2n_long(long g) : gf2n_long(int128(g)) {} + gf2n_long(word g) : gf2n_long(int128(g)) {} template gf2n_long(IntBase g) : super(g.get()) {} template diff --git a/Math/gfp.h b/Math/gfp.h index 7cd7351bd..cc3c92f92 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -311,6 +311,11 @@ typedef gfp_<1, GFP_MOD_SZ> gfp1; template Zp_Data gfp_::ZpD; +template +gfp_ gfp_::two; + +template +const true_type gfp_::prime_field; template thread_local vector> gfp_::powers; diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 2a2e27785..c590e22f6 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -13,11 +13,7 @@ template const true_type gfp_::invertible; template -const true_type gfp_::prime_field; -template const int gfp_::MAX_N_BITS; -template -gfp_ gfp_::two; template inline void gfp_::read_or_generate_setup(string dir, diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 53bc35879..5dbdfcc11 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -95,8 +95,10 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) : continue; } - senders[i] = new Sender(i < my_num() ? sockets[i] : other_sockets[i]); - receivers[i] = new Receiver(i < my_num() ? other_sockets[i] : sockets[i]); + senders[i] = new Sender( + i < my_num() ? sockets[i] : other_sockets[i], i); + receivers[i] = new Receiver( + i < my_num() ? other_sockets[i] : sockets[i], i); } } diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 620ea0f3c..13767f408 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -80,11 +80,14 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante } if (nplayers_wanted > 0 and nplayers_wanted != nplayers) exit_error("not enough hosts in " + filename); -#ifdef DEBUG_NETWORKING - cerr << "Got list of " << nplayers << " players from file: " << endl; - for (unsigned int i = 0; i < names.size(); i++) - cerr << " " << names[i] << ":" << ports[i] << endl; -#endif + + if (OnlineOptions::singleton.has_option("debug_networking")) + { + cerr << "Got list of " << nplayers << " players from file: " << endl; + for (unsigned int i = 0; i < names.size(); i++) + cerr << " " << names[i] << ":" << ports[i] << endl; + } + setup_server(); } @@ -140,9 +143,9 @@ void Names::setup_names(const char *servername, int my_port) } octetStream("P" + to_string(player_no)).Send(socket_num); -#ifdef DEBUG_NETWORKING - cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; -#endif + + if (OnlineOptions::singleton.has_option("debug_networking")) + cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; // Send my name sockaddr_in address; @@ -151,10 +154,12 @@ void Names::setup_names(const char *servername, int my_port) char* my_name = inet_ntoa(address.sin_addr); octetStream(my_name).Send(socket_num); send(socket_num,(octet*)&my_port,4); -#ifdef DEBUG_NETWORKING - fprintf(stderr, "My Name = %s\n",my_name); - cerr << "My number = " << player_no << endl; -#endif + + if (OnlineOptions::singleton.has_option("debug_networking")) + { + fprintf(stderr, "My Name = %s\n",my_name); + cerr << "My number = " << player_no << endl; + } // Now get the set of names try @@ -172,10 +177,12 @@ void Names::setup_names(const char *servername, int my_port) if (names.size() != ports.size()) exit_error("invalid network setup"); nplayers = names.size(); -#ifdef VERBOSE - for (int i = 0; i < nplayers; i++) - cerr << "Player " << i << " is running on machine " << names[i] << endl; -#endif + + + if (OnlineOptions::singleton.has_option("debug_networking")) + for (int i = 0; i < nplayers; i++) + cerr << "Player " << i << " is running on machine " << names[i] << endl; + close_client_socket(socket_num); } @@ -640,8 +647,8 @@ ThreadPlayer::ThreadPlayer(const Names& Nms, const string& id_base) : { for (int i = 0; i < Nms.num_players(); i++) { - receivers.push_back(new Receiver(sockets[i])); - senders.push_back(new Sender(socket_to_send(i))); + receivers.push_back(new Receiver(sockets[i], i)); + senders.push_back(new Sender(socket_to_send(i), i)); } } @@ -845,10 +852,15 @@ Timer& CommStatsWithName::add_length_only(size_t length) } Timer& CommStatsWithName::add(const octetStream& os) +{ + return add(os.get_length()); +} + +Timer& CommStatsWithName::add(size_t length) { if (OnlineOptions::singleton.has_option("verbose_comm")) - fprintf(stderr, "%s %zu bytes\n", name.c_str(), os.get_length()); - return stats.add(os); + fprintf(stderr, "%s %zu bytes\n", name.c_str(), length); + return stats.add(length); } void Player::reset_stats() diff --git a/Networking/Player.h b/Networking/Player.h index 40e113bc1..7d7ea9d9b 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -160,6 +160,7 @@ class CommStatsWithName Timer& add_length_only(size_t length); Timer& add(const octetStream& os); + Timer& add(size_t length); void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } }; diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index 7e8c93fe9..efe58bccb 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -5,6 +5,7 @@ #include "Receiver.h" #include "ssl_sockets.h" +#include "Processor/OnlineOptions.h" #include using namespace std; @@ -19,8 +20,14 @@ void* Receiver::run_thread(void* receiver) return 0; } +CommunicationThread::CommunicationThread(int other) : + other(other) +{ +} + template -Receiver::Receiver(T socket) : socket(socket), thread(0) +Receiver::Receiver(T socket, int other) : + CommunicationThread(other), socket(socket), thread(0) { start(); } @@ -44,8 +51,28 @@ void Receiver::stop() pthread_join(thread, 0); } +void CommunicationThread::run() +{ + if (OnlineOptions::singleton.has_option("throw_exceptions")) + run_with_error(); + else + { + try + { + run_with_error(); + } + catch (exception& e) + { + cerr << "Fatal error in communication: " << e.what() << endl; + cerr << "This is probably because party " << other + << " encountered a problem." << endl; + exit(1); + } + } +} + template -void Receiver::run() +void Receiver::run_with_error() { octetStream* os = 0; while (in.pop(os)) diff --git a/Networking/Receiver.h b/Networking/Receiver.h index 98e2313b0..a3e2df5b1 100644 --- a/Networking/Receiver.h +++ b/Networking/Receiver.h @@ -12,8 +12,20 @@ #include "Tools/WaitQueue.h" #include "Tools/time-func.h" +class CommunicationThread +{ + int other; + +protected: + CommunicationThread(int other); + virtual ~CommunicationThread() {} + + void run(); + virtual void run_with_error() = 0; +}; + template -class Receiver +class Receiver : CommunicationThread { T socket; WaitQueue in; @@ -27,12 +39,12 @@ class Receiver void start(); void stop(); - void run(); + void run_with_error(); public: Timer timer; - Receiver(T socket); + Receiver(T socket, int other); ~Receiver(); T get_socket() diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 4e4b98810..f703c5c31 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -17,7 +17,8 @@ void* Sender::run_thread(void* sender) } template -Sender::Sender(T socket) : socket(socket), thread(0) +Sender::Sender(T socket, int other) : + CommunicationThread(other), socket(socket), thread(0) { start(); } @@ -42,7 +43,7 @@ void Sender::stop() } template -void Sender::run() +void Sender::run_with_error() { const octetStream* os = 0; while (in.pop(os)) diff --git a/Networking/Sender.h b/Networking/Sender.h index 699e1b920..dbd4cbc70 100644 --- a/Networking/Sender.h +++ b/Networking/Sender.h @@ -8,12 +8,14 @@ #include +#include "Receiver.h" + #include "Tools/octetStream.h" #include "Tools/WaitQueue.h" #include "Tools/time-func.h" template -class Sender +class Sender : CommunicationThread { T socket; WaitQueue in; @@ -27,12 +29,12 @@ class Sender void start(); void stop(); - void run(); + void run_with_error(); public: Timer timer; - Sender(T socket); + Sender(T socket, int other); ~Sender(); T get_socket() diff --git a/Networking/Server.cpp b/Networking/Server.cpp index 8f6d8d01a..33eacca38 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -2,6 +2,7 @@ #include "Networking/sockets.h" #include "Networking/ServerSocket.h" #include "Networking/Server.h" +#include "Processor/OnlineOptions.h" #include #include @@ -30,26 +31,23 @@ void Server::get_ip(int num) names[num] = ipstr; -#ifdef DEBUG_NETWORKING - cerr << "Client IP address: " << names[num] << endl; -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + cerr << "IP address of party " << num << ": " << names[num] << endl; } void Server::get_name(int num) { -#ifdef DEBUG_NETWORKING - cerr << "Player " << num << " started." << endl; -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + cerr << "Player " << num << " started." << endl; // Receive name sent by client (legacy) - not used here octetStream os; os.Receive(socket_num[num]); receive(socket_num[num],(octet*)&ports[num],4); -#ifdef DEBUG_NETWORKING - cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":" - << ports[num] << endl; -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + cerr << "Player " << num << " listening on " << os.str() << ":" + << ports[num] << endl; // Get client IP get_ip(num); @@ -121,13 +119,13 @@ void Server::start() // set up connections for (i=0; i= nplayers) { cerr << "Player number " << my_num << " outside range: 0-" diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index c57657a46..5beebf15a 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -8,6 +8,7 @@ #include "Tools/Exceptions.h" #include "Tools/time-func.h" #include "Tools/octetStream.h" +#include "Processor/OnlineOptions.h" #include #include @@ -60,10 +61,9 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum), thread(0) << "), trying again in a second ..." << endl; sleep(1); } -#ifdef DEBUG_NETWORKING else - { cerr << "ServerSocket is bound on port " << Portnum << endl; } -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + cerr << "ServerSocket is bound on port " << Portnum << endl; } if (fl<0) { error("set_up_socket:bind"); } @@ -121,11 +121,12 @@ void ServerSocket::wait_for_client_id(int socket, struct sockaddr dest) } catch (closed_connection&) { -#ifdef DEBUG_NETWORKING - auto& conn = *(sockaddr_in*) &dest; - fprintf(stderr, "client on %s:%d left without identification\n", - inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + { + auto& conn = *(sockaddr_in*) &dest; + fprintf(stderr, "client on %s:%d left without identification\n", + inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); + } } } diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index 6572b12a5..09488fbe6 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -2,19 +2,23 @@ #include "sockets.h" #include "Tools/Exceptions.h" #include "Tools/time-func.h" +#include "Processor/OnlineOptions.h" #include #include using namespace std; -void error(const char *str) +void error(const char *str, bool interrupted) { int old_errno = errno; char err[1000]; - gethostname(err,1000); - strcat(err," : "); - strcat(err,str); - exit_error(string() + err + " : " + strerror(old_errno)); + gethostname(err, 1000); + err[999] = 0; + string message = string() + "Fatal communication error on " + err + ": " + + str + " (" + strerror(old_errno) + ")"; + if (interrupted) + message += "\nThis is probably because another party encountered a problem."; + exit_error(message); } void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) @@ -91,12 +95,13 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) { close(mysocket); usleep(wait < 1000 ? wait *= 2 : wait); -#ifdef DEBUG_NETWORKING - string msg = "Connecting to " + string(hostname) + ":" + - to_string(Portnum) + " failed"; - errno = connect_errno; - perror(msg.c_str()); -#endif + if (OnlineOptions::singleton.has_option("debug_networking")) + { + string msg = "Connecting to " + string(hostname) + ":" + + to_string(Portnum) + " failed"; + errno = connect_errno; + perror(msg.c_str()); + } } errno = connect_errno; } diff --git a/Networking/sockets.h b/Networking/sockets.h index 4ea85174d..1bada90cb 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -28,7 +28,7 @@ using namespace std; #define CONNECTION_TIMEOUT 60 #endif -void error(const char *str); +void error(const char *str, bool interrupted = false); void set_up_client_socket(int& mysocket,const char* hostname,int Portnum); void close_client_socket(int socket); @@ -51,7 +51,7 @@ inline size_t send_non_blocking(int socket, octet* msg, size_t len) { if (errno != EINTR and errno != EAGAIN and errno != EWOULDBLOCK and errno != ENOBUFS) - { error("Send error - 1 "); } + { error("Sending error", true); } else return 0; } @@ -103,14 +103,14 @@ inline void receive(int socket,octet *msg,size_t len) if (errno == EAGAIN or errno == EINTR) { if (++fail > 25) - error("Unavailable too many times"); + error("Unavailable too many times", true); else { usleep(wait *= 2); } } else - { error("Receiving error - 1"); } + { error("Receiving error", true); } } else throw closed_connection(); @@ -130,7 +130,7 @@ inline ssize_t check_non_blocking_result(ssize_t res) if (res < 0) { if (errno != EWOULDBLOCK) - error("Non-blocking receiving error"); + error("Non-blocking receiving error", true); return 0; } return res; @@ -149,7 +149,7 @@ inline ssize_t receive_all_or_nothing(int socket, octet *msg, ssize_t len) if (res == len) { if (recv(socket, msg, len, 0) != len) - error("All or nothing receiving error"); + error("All or nothing receiving error", true); return len; } else diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index d432da465..e0ca9b907 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -111,6 +111,16 @@ void receiver_keygen(ref10_RECEIVER* r, unsigned char (*keys)[HASHBYTES]) ref10_receiver_keygen(r, keys); } +void BaseOT::allocate() +{ + for (int i = 0; i < nOT; i++) + { + sender_inputs[i][0] = BitVector(8 * AES_BLK_SIZE); + sender_inputs[i][1] = BitVector(8 * AES_BLK_SIZE); + receiver_outputs[i] = BitVector(8 * AES_BLK_SIZE); + } +} + int BaseOT::avx = -1; bool BaseOT::use_avx() @@ -186,6 +196,7 @@ void BaseOT::exec_base(bool new_receiver_inputs) } os[0].reset_write_head(); + allocate(); for (i = 0; i < nOT; i += 4) { @@ -409,6 +420,8 @@ void FakeOT::exec_base(bool new_receiver_inputs) vector os(2); vector bv(2, 128); + allocate(); + if ((ot_role & RECEIVER) && new_receiver_inputs) { for (int i = 0; i < nOT; i++) diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 14f71ed6e..987b85164 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -61,13 +61,6 @@ class BaseOT receiver_outputs.resize(nOT); G_sender.resize(nOT); G_receiver.resize(nOT); - - for (int i = 0; i < nOT; i++) - { - sender_inputs[i][0] = BitVector(8 * AES_BLK_SIZE); - sender_inputs[i][1] = BitVector(8 * AES_BLK_SIZE); - receiver_outputs[i] = BitVector(8 * AES_BLK_SIZE); - } } BaseOT(TwoPartyPlayer* player, OT_ROLE role) : @@ -118,6 +111,8 @@ class BaseOT bool is_sender() { return (bool) (ot_role & SENDER); } bool is_receiver() { return (bool) (ot_role & RECEIVER); } + void allocate(); + bool use_avx(); /// CPU-specific instantiation of Simplest OT using Curve25519 diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 4d723d1f7..5ab3f640e 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -9,6 +9,7 @@ #include "Protocols/MAC_Check.h" #include "GC/SemiSecret.h" #include "GC/SemiPrep.h" +#include "Processor/OnlineOptions.h" #include "OT/Triple.hpp" #include "OT/OTMultiplier.hpp" @@ -25,7 +26,21 @@ template void* run_ot_thread(void* ptr) { bigint::init_thread(); - ((OTMultiplierBase*)ptr)->multiply(); + auto multiplier = (OTMultiplierBase*) ptr; + if (OnlineOptions::singleton.has_option("throw_exceptions")) + multiplier->multiply(); + else + { + try + { + multiplier->multiply(); + } + catch (exception& e) + { + cerr << "Fatal error in OT thread: " << e.what() << endl; + exit(1); + } + } return NULL; } diff --git a/OT/Row.hpp b/OT/Row.hpp index b43923244..aad79911f 100644 --- a/OT/Row.hpp +++ b/OT/Row.hpp @@ -110,20 +110,13 @@ void DeferredPlus::pack(octetStream& o) const template void Row::pack(octetStream& o) const { - o.store(this->size()); - for (size_t i = 0; i < this->size(); i++) - rows[i].pack(o); + o.store(rows); } template void Row::unpack(octetStream& o) { - size_t size; - o.get(size); - rows.clear(); - rows.reserve(size); - for (size_t i = 0; i < size; i++) - rows.push_back(o.get()); + o.get(rows); } template diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index c2005afbc..32d3824bf 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -30,7 +30,7 @@ BaseMachine& BaseMachine::s() if (singleton) return *singleton; else - throw runtime_error("no singleton"); + throw runtime_error("no BaseMachine singleton"); } bool BaseMachine::has_program() @@ -72,7 +72,8 @@ int BaseMachine::bucket_size(size_t usage) int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols) { - unsigned res = min(100, OnlineOptions::singleton.batch_size); + int limit = max(1., 1e6 / (max(n_rows * n_inner, n_inner * n_cols))); + unsigned res = min(limit, OnlineOptions::singleton.batch_size); if (has_program()) res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols)); return res; @@ -93,7 +94,8 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols) return -1; } -BaseMachine::BaseMachine() : nthreads(0) +BaseMachine::BaseMachine() : + nthreads(0), multithread(false) { if (sodium_init() == -1) throw runtime_error("couldn't initialize libsodium"); @@ -147,7 +149,12 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) #endif long size = load_program(threadname, filename); if (expected >= 0 and expected != size) - throw runtime_error("broken bytecode file"); + { + stringstream os; + os << "broken bytecode file, found " << size + << " instructions, expected " << expected; + throw runtime_error(os.str()); + } } } diff --git a/Processor/Binary_File_IO.h b/Processor/Binary_File_IO.h index c19a129af..5ed62ff0f 100644 --- a/Processor/Binary_File_IO.h +++ b/Processor/Binary_File_IO.h @@ -37,7 +37,8 @@ class Binary_File_IO * Throws file_error. */ template - void read_from_file(const string filename, vector< T >& buffer, const int start_posn, int &end_posn); + void read_from_file(const string filename, vector& buffer, + const long start_posn, long& end_posn); }; #endif diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index 147ea42b0..8b74d9d5c 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -44,13 +44,22 @@ void Binary_File_IO::write_to_file(const string filename, } template -void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, const int start_posn, int &end_posn) +void Binary_File_IO::read_from_file(const string filename, vector& buffer, + const long start_posn, long& end_posn) { ifstream inf; inf.open(filename, ios::in | ios::binary); if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } - check_file_signature(inf, filename).get_length(); + try + { + check_file_signature(inf, filename).get_length(); + } + catch (exception& e) + { + throw persistence_error(e.what()); + } + auto data_start = inf.tellg(); int size_in_bytes = T::size() * buffer.size(); @@ -68,13 +77,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes from " << (long(data_start) + start_posn * T::size()) << ")."; - throw file_error(ss.str()); + throw persistence_error(ss.str()); } if (inf.fail()) { stringstream ss; ss << "IO problem when reading from disk"; - throw file_error(ss.str()); + throw persistence_error(ss.str()); } } while (n_read < size_in_bytes); diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index b93990619..e44cb0836 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -137,13 +137,16 @@ void DataPositions::print_cost() const cerr << " edaBits" << endl; for (auto it = edabits.begin(); it != edabits.end(); it++) { - if (print_verbose) - cerr << setw(13) << ""; - cerr << " " << setw(10) << it->second << " of length " - << it->first.second; - if (it->first.first) - cerr << " (strict)"; - cerr << endl; + if (it->second) + { + if (print_verbose) + cerr << setw(13) << ""; + cerr << " " << setw(10) << it->second << " of length " + << it->first.second; + if (it->first.first) + cerr << " (strict)"; + cerr << endl; + } } } @@ -235,3 +238,14 @@ long long DataPositions::total_edabits(int n_bits) const auto usage = edabits; return usage[{false, n_bits}] + usage[{true, n_bits}]; } + +long long DataPositions::triples_for_matmul() +{ + long long res = 0; + for (auto& x : matmuls) + { + auto dim = x.first; + res += x.second * dim[0] * dim[1] * dim[2]; + } + return res; +} diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 4c20e09ab..a0ad31200 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -89,6 +89,8 @@ class DataPositions bool any_more(const DataPositions& other) const; long long total_edabits(int n_bits) const; + + long long triples_for_matmul(); }; template class Processor; @@ -192,6 +194,8 @@ class Preprocessing : public PrepBase virtual void buffer_inverses() {} virtual Preprocessing& get_part() { throw runtime_error("no part"); } + + virtual int minimum_batch() { return 0; } }; template diff --git a/Processor/FunctionArgument.cpp b/Processor/FunctionArgument.cpp new file mode 100644 index 000000000..d1e1f5685 --- /dev/null +++ b/Processor/FunctionArgument.cpp @@ -0,0 +1,44 @@ +/* + * FunctionArgument.cpp + * + */ + +#include "FunctionArgument.h" + +#include + +void FunctionArgument::open(ifstream& file, const string& name, + vector& arguments) +{ + string signature; + for (auto& arg : arguments) + signature += "-" + arg.get_type_string(); + string filename = "Programs/Functions/" + name + signature; + + file.open(filename); + if (not file.good()) + { + string python_call = name + "("; + for (auto& arg : arguments) + { + python_call += arg.get_python_arg(); + if (&arg != &arguments[arguments.size() - 1]) + python_call += ", "; + } + python_call += ")"; + throw runtime_error( + "Cannot open " + filename + ", have you compiled '" + + python_call + + "' and added '@export' to the function '" + name + + "'?"); + } +} + +void FunctionArgument::check_type(const string& type_string) +{ + if (type_string != get_type_string() + and get_type_string() != "-") + throw runtime_error( + "return type mismatch: " + get_type_string() + "/" + + type_string); +} diff --git a/Processor/FunctionArgument.h b/Processor/FunctionArgument.h new file mode 100644 index 000000000..eaf35c17c --- /dev/null +++ b/Processor/FunctionArgument.h @@ -0,0 +1,128 @@ +/* + * FunctionArgument.h + * + */ + +#ifndef PROCESSOR_FUNCTIONARGUMENT_H_ +#define PROCESSOR_FUNCTIONARGUMENT_H_ + +#include "Protocols/ShareInterface.h" + +#include + +/** + * Inputs and outputs for functions exported in high-level code. + */ +class FunctionArgument +{ + void* data; + size_t size, n_bits; + string reg_type; + bool memory; + +public: + static void open(ifstream& file, const string& name, + vector& arguments); + + /** + * Argument with integer secret shares. + * + * @param values shares + * @param memory whether shares are in a (multi-)array (true) or vector (false) + */ + template + FunctionArgument(vector& values, bool memory = false) : + FunctionArgument(values.data(), values.size(), memory) + { + assert(not T::clear::characteristic_two); + } + + FunctionArgument(ShareInterface* data, size_t size, bool memory) : + data(data), size(size), n_bits(0), reg_type("s"), memory(memory) + { + } + + /** + * Void argument. + */ + FunctionArgument() : FunctionArgument(0, 0, false) + { + } + + /** + * Argument with binary secret shares (always in array). + * + * @param n_bits number of bits + * @param values shares (vector of vectors of bit_type) + */ + template + FunctionArgument(size_t n_bits, vector>& values) : + data(values.data()), size(values.size()), n_bits(n_bits), + reg_type("sbv"), memory(true) + { + assert(T::clear::binary); + assert(not values.empty()); + assert(n_bits > 0); + size_t n_limbs = DIV_CEIL(n_bits, T::default_length); + size_t dl = T::default_length; + for (auto& x : values) + { + assert(x.size() == n_limbs); + for (size_t i = 0; i < n_limbs; i++) + assert(size_t(x[i].maximum_size()) >= min(dl, n_bits - i * dl)); + } + } + + size_t get_size() + { + return size; + } + + size_t get_n_bits() + { + return n_bits; + } + + string get_type_string() + { + if (data == 0) + return "-"; + + if (memory) + if (reg_type == "sbv") + return reg_type + ":[" + to_string(get_size()) + "x" + + to_string(n_bits) + "]"; + else + return reg_type + ":[" + to_string(get_size()) + "]"; + else + return reg_type + ":" + to_string(get_size()); + } + + string get_python_arg() + { + assert(data); + if (memory) + if (reg_type == "sbv") + return "sbitvec.get_type(" + to_string(n_bits) + ").Array(" + + to_string(get_size()) + ")"; + else + return "sint.Array(" + to_string(get_size()) + ")"; + else + return "sint(0, size=" + to_string(get_size()) + ")"; + } + + bool get_memory() + { + return memory; + } + + template + T& get_value(size_t index) + { + return ((T*) data)[index]; + } + + void check_type(const string& type_string); +}; + +#endif /* PROCESSOR_FUNCTIONARGUMENT_H_ */ diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 512507d7d..483c183d3 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -93,19 +93,27 @@ void Instruction::bitdecint(ArithmeticProcessor& Proc) const } } -ostream& operator<<(ostream& s, const Instruction& instr) +string BaseInstruction::get_name() const { - switch (instr.get_opcode()) + switch (get_opcode()) { #define X(NAME, PRE, CODE) \ - case NAME: s << #NAME; break; + case NAME: return #NAME; ALL_INSTRUCTIONS #undef X #define X(NAME, CODE) \ - case NAME: s << #NAME; break; + case NAME: return #NAME; COMBI_INSTRUCTIONS + default: + stringstream ss; + ss << hex << get_opcode(); + return ss.str(); } +} +ostream& operator<<(ostream& s, const Instruction& instr) +{ + s << instr.get_name(); s << " size=" << instr.get_size(); s << " n=" << instr.get_n(); s << " r=("; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index c89880624..965789e69 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -378,6 +378,8 @@ class BaseInstruction // Returns the maximal register used unsigned get_max_reg(int reg_type) const; + + string get_name() const; }; class DataPositions; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index dc71ac5cd..77ea1a4c4 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -507,7 +507,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) default: ostringstream os; os << "Invalid instruction " << showbase << hex << opcode << " at " << dec - << pos << "/" << hex << file_pos << dec << endl; + << pos << "/" << hex << file_pos << dec; throw Invalid_Instruction(os.str()); } } @@ -944,20 +944,24 @@ inline void Instruction::execute(Processor& Proc) const switch (opcode) { case CONVMODP: - if (n == 0) - { - for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, - Proc.sync( - Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get())); - } - else if (n <= 64) - for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, - Proc.sync(Integer(Proc.read_Cp(r[1] + i), n).get())); - else - throw Processor_Error(to_string(n) + "-bit conversion impossible; " - "integer registers only have 64 bits"); + vector values; + values.reserve(size); + for (int i = 0; i < size; i++) + { + auto source = Proc.read_Cp(r[1] + i); + Integer tmp; + if (n == 0) + tmp = Integer::convert_unsigned(source); + else if (n <= 64) + tmp = Integer(source, n); + else + throw Processor_Error(to_string(n) + "-bit conversion impossible; " + "integer registers only have 64 bits"); + values.push_back(tmp); + } + sync(values, Proc.P); + for (int i = 0; i < size; i++) + Proc.write_Ci(r[0] + i, values[i].get()); return; } @@ -1371,12 +1375,12 @@ inline void Instruction::execute(Processor& Proc) const break; case WRITEFILESHARE: // Write shares to file system - Proc.write_shares_to_file(Proc.read_Ci(r[0]), start); - break; + Proc.write_shares_to_file(Proc.read_Ci(r[0]), start, size); + return; case READFILESHARE: // Read shares from file system - Proc.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start); - break; + Proc.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start, size); + return; case PUBINPUT: Proc.get_Cp_ref(r[0]) = Proc.template get_input>( @@ -1438,6 +1442,26 @@ inline void Instruction::execute(Processor& Proc) const template void Program::execute(Processor& Proc) const +{ + if (OnlineOptions::singleton.has_option("throw_exceptions")) + execute_with_errors(Proc); + else + { + try + { + execute_with_errors(Proc); + } + catch (exception& e) + { + cerr << "Fatal error at " << name << ":" << Proc.last_PC << " (" + << p[Proc.last_PC].get_name() << "): " << e.what() << endl; + exit(1); + } + } +} + +template +void Program::execute_with_errors(Processor& Proc) const { unsigned int size = p.size(); Proc.PC=0; @@ -1452,6 +1476,7 @@ void Program::execute(Processor& Proc) const while (Proc.PC #include #include using namespace std; -template +#include "OnlineOptions.hpp" + +template> class Machine : public BaseMachine { /* The mutex's lock the C-threads and then only release @@ -101,6 +106,9 @@ class Machine : public BaseMachine void run_step(const string& progname); pair stop_threads(); + void run_function(const string& name, FunctionArgument& result, + vector& arguments); + string memory_filename(); template @@ -109,8 +117,9 @@ class Machine : public BaseMachine void reqbl(int n); void active(int n); - typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } - typename sint::mac_key_type get_sint_mac_key() { return alphapi; } + typename sint::bit_type::mac_key_type get_bit_mac_key() const + { return alphabi; } + typename sint::mac_key_type get_sint_mac_key() const { return alphapi; } Player& get_player() { return *P; } diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 97151cc20..2b31b950f 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -213,6 +213,8 @@ void Machine::prepare(const string& progname_str) template Machine::~Machine() { + stop_threads(); + sint::LivePrep::teardown(); sgf2n::LivePrep::teardown(); @@ -221,8 +223,6 @@ Machine::~Machine() sgf2n::MAC_Check::teardown(); delete P; - for (auto& queue : queues) - delete queue; } template @@ -349,7 +349,8 @@ void Machine::fill_matmul(int thread_number, int tape_number, subdim, tinfo[thread_number].processor->Procp); if (not source_proto.use_plain_matmul(subdim, source_proc)) for (int i = 0; i < it->second; i++) - dest.push_triple(source.get_triple_no_count(-1)); + dynamic_cast>&>(dest).push_triple( + source.get_triple_no_count(-1)); } } } @@ -409,6 +410,72 @@ void Machine::run_step(const string& progname) join_tape(0); } +template +void Machine::run_function(const string& name, + FunctionArgument& result, vector& arguments) +{ + ifstream file; + FunctionArgument::open(file, name, arguments); + + string progname, return_type; + int tape_number, return_reg; + file >> progname >> tape_number >> return_type >> return_reg; + + result.check_type(return_type); + + vector arg_regs(arguments.size()); + for (auto& arg_reg : arg_regs) + file >> arg_reg; + + prepare(progname); + auto& processor = *tinfo.at(0).processor; + processor.reset(progs.at(tape_number), 0); + + for (size_t i = 0; i < arguments.size(); i++) + for (size_t j = 0; j < arguments[i].get_size(); j++) + { + if (arguments[i].get_n_bits()) + { + size_t n_limbs = DIV_CEIL(arguments[i].get_n_bits(), + sint::bit_type::default_length); + for (size_t k = 0; k < n_limbs; k++) + bit_memories.MS[arg_regs.at(i) + j * n_limbs + k] = + arguments[i].get_value>(j).at( + k); + } + else + { + auto& value = arguments[i].get_value(j); + if (arguments[i].get_memory()) + Mp.MS[arg_regs.at(i) + j] = value; + else + processor.Procp.get_S()[arg_regs.at(i) + j] = value; + } + } + + run_tape(0, tape_number, 0, N.num_players()); + join_tape(0); + + for (size_t j = 0; j < result.get_size(); j++) + result.get_value(j) = processor.Procp.get_S()[return_reg + j]; + + for (size_t i = 0; i < arguments.size(); i++) + if (arguments[i].get_memory()) + for (size_t j = 0; j < arguments[i].get_size(); j++) + { + if (arguments[i].get_n_bits()) + { + size_t n_limbs = DIV_CEIL(arguments[i].get_n_bits(), + sint::bit_type::default_length); + for (size_t k = 0; k < n_limbs; k++) + arguments[i].get_value>(j).at(k) = + bit_memories.MS[arg_regs.at(i) + j * n_limbs + k]; + } + else + arguments[i].get_value(j) = Mp.MS[arg_regs.at(i) + j]; + } +} + template pair Machine::stop_threads() { diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 5781061c3..9123f2ad1 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -14,8 +14,8 @@ void MemoryPart::indirect_read(const Instruction& inst, #ifndef NO_CHECK_SIZE assert(start + n <= indices.end()); assert(dest + n <= regs.end()); -#endif size_t size = this->size(); +#endif const T* data = this->data(); for (auto it = start; it < start + n; it++) { @@ -38,8 +38,8 @@ void MemoryPart::indirect_write(const Instruction& inst, #ifndef NO_CHECK_SIZE assert(start + n <= indices.end()); assert(source + n <= regs.end()); -#endif size_t size = this->size(); +#endif T* data = this->data(); for (auto it = start; it < start + n; it++) { diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 045e9f3e7..55b8b8522 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -357,6 +357,7 @@ void thread_info::Sub_Main_Func() NamedStats stats; stats["integer multiplications"] = Proc.Procp.protocol.counter; stats["integer multiplication rounds"] = Proc.Procp.protocol.rounds; + stats["integer dot products"] = Proc.Procp.protocol.dot_counter; stats["probabilistic truncations"] = Proc.Procp.protocol.trunc_pr_counter; stats["probabilistic truncation rounds"] = Proc.Procp.protocol.trunc_rounds; stats["ANDs"] = Proc.share_thread.protocol->bit_counter; diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index b337c3a7a..5ae16b8b9 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -174,7 +174,7 @@ Player* OnlineMachine::new_player(const string& id_base) template int OnlineMachine::run() { - if (online_opts.has_option("throw_exception")) + if (online_opts.has_option("throw_exceptions")) return run_with_error(); else { diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index b1677e89f..1a370fe8d 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -12,6 +12,8 @@ #include "Math/gfp.hpp" +#include + using namespace std; OnlineOptions OnlineOptions::singleton; @@ -164,6 +166,9 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, } opt.resetArgs(); + + if (argc > 0) + executable = boost::filesystem::path(argv[0]).filename().string(); } OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, @@ -354,20 +359,6 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, exit(1); } - if (opt.get("-lgp")) - { - bigint schedule_prime = BaseMachine::prime_from_schedule(progname); - if (prime != 0 and prime != schedule_prime and schedule_prime != 0) - { - cerr << "Different prime for compilation and computation." << endl; - cerr << "Run with '--prime " << schedule_prime - << "' or compile with '--prime " << prime << "'." << endl; - exit(1); - } - if (schedule_prime != 0) - prime = schedule_prime; - } - for (size_t i = name_index + 1; i < allArgs.size(); i++) { try @@ -384,6 +375,38 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, } } + if (has_option("throw_exceptions")) + finalize_with_error(opt); + else + { + try + { + finalize_with_error(opt); + } + catch (exception& e) + { + cerr << "Fatal error in option processing: " << e.what() << endl; + exit(1); + } + } +} + +void OnlineOptions::finalize_with_error(ez::ezOptionParser& opt) +{ + if (opt.get("-lgp")) + { + bigint schedule_prime = BaseMachine::prime_from_schedule(progname); + if (prime != 0 and prime != schedule_prime and schedule_prime != 0) + { + cerr << "Different prime for compilation and computation." << endl; + cerr << "Run with '--prime " << schedule_prime + << "' or compile with '--prime " << prime << "'." << endl; + exit(1); + } + if (schedule_prime != 0) + prime = schedule_prime; + } + // ignore program if length explicitly set from command line if (opt.get("-lgp") and not opt.isSet("-lgp")) { diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 00408c7af..39c58a541 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -12,6 +12,8 @@ class OnlineOptions { + void finalize_with_error(ez::ezOptionParser& opt); + public: static OnlineOptions singleton; @@ -38,6 +40,7 @@ class OnlineOptions std::string disk_memory; vector args; vector options; + string executable; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 0d467614c..a619832c5 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -40,6 +40,13 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, + to_string(my_num) + get_suffix(thread_num); } +string PrepBase::get_matrix_prefix(const string& prep_data_dir, + const array& dim) +{ + return prep_data_dir + "Matrix-" + to_string(dim[0]) + "x" + + to_string(dim[1]) + "x" + to_string(dim[2]); +} + PrepBase::PrepBase(DataPositions& usage) : usage(usage) { diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index 78aa7332c..447beda22 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -30,6 +30,9 @@ class PrepBase static string get_edabit_filename(const string& prep_data_dir, int n_bits, int my_num, int thread_num = 0); + static string get_matrix_prefix(const string& prep_data_dir, + const array& dim); + TimerWithComm prep_timer; PrepBase(DataPositions& usage); diff --git a/Processor/Processor.h b/Processor/Processor.h index 612ff00f1..f5ba8a00d 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -45,6 +45,8 @@ class SubProcessor void matmulsm_finalize(int i, int j, const vector& dim, typename vector::iterator C); + void maybe_check(); + template friend class Processor; template friend class SPDZ; template friend class ProtocolBase; @@ -223,7 +225,7 @@ class Processor : public ArithmeticProcessor SubProcessor Proc2; SubProcessor Procp; - unsigned int PC; + unsigned int PC, last_PC; TempVars temp; ExternalClients& external_clients; @@ -289,8 +291,10 @@ class Processor : public ArithmeticProcessor int size, bool send_macs); // Read and write secret numeric data to file (name hardcoded at present) - void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); - void write_shares_to_file(long start_pos, const vector& data_registers); + void read_shares_from_file(long start_file_pos, int end_file_pos_register, + const vector& data_registers, size_t vector_size); + void write_shares_to_file(long start_pos, const vector& data_registers, + size_t vector_size); cint get_inverse2(unsigned m); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d6cdd7567..d09e67114 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -389,7 +389,10 @@ void Processor::read_socket_private(int client_id, // file_pos_register is written with new file position (-1 is eof). // Tolerent to no file if no shares yet persisted. template -void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { +void Processor::read_shares_from_file(long start_file_posn, + int end_file_pos_register, const vector& data_registers, + size_t vector_size) +{ if (not sint::real_shares(P)) return; @@ -398,22 +401,24 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ unsigned int size = data_registers.size(); - vector< sint > outbuf(size); + PointerVector outbuf(size * vector_size); - int end_file_posn = start_file_posn; + auto end_file_posn = start_file_posn; try { binary_file_io.read_from_file(filename, outbuf, start_file_posn, end_file_posn); for (unsigned int i = 0; i < size; i++) { - get_Sp_ref(data_registers[i]) = outbuf[i]; + for (size_t j = 0; j < vector_size; j++) + get_Sp_ref(data_registers[i] + j) = outbuf.next(); } write_Ci(end_file_pos_register, (long)end_file_posn); } catch (file_missing& e) { - cerr << "Got file missing error, will return -2. " << e.what() << endl; + if (OnlineOptions::singleton.has_option("verbose_persistence")) + cerr << "Got file missing error, will return -2. " << e.what() << endl; write_Ci(end_file_pos_register, (long)-2); } } @@ -421,7 +426,7 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ // Append share data in data_registers to end of file. Expects Persistence directory to exist. template void Processor::write_shares_to_file(long start_pos, - const vector& data_registers) + const vector& data_registers, size_t vector_size) { if (not sint::real_shares(P)) return; @@ -430,16 +435,24 @@ void Processor::write_shares_to_file(long start_pos, unsigned int size = data_registers.size(); - vector< sint > inpbuf (size); + PointerVector inpbuf(size * vector_size); for (unsigned int i = 0; i < size; i++) { - inpbuf[i] = get_Sp_ref(data_registers[i]); + for (size_t j = 0; j < vector_size; j++) + inpbuf.next() = get_Sp_ref(data_registers[i] + j); } binary_file_io.write_to_file(filename, inpbuf, start_pos); } +template +void SubProcessor::maybe_check() +{ + if (OnlineOptions::singleton.has_option("always_check")) + check(); +} + template void SubProcessor::POpen(const Instruction& inst) { @@ -465,6 +478,8 @@ void SubProcessor::POpen(const Instruction& inst) Proc->sent += sz * size; Proc->rounds++; } + + maybe_check(); } template @@ -489,8 +504,10 @@ void SubProcessor::muls(const vector& reg) { proc.S[reg[4 * i + 1] + j] = protocol.finalize_mul(); } - protocol.counter += n * reg[4 * i]; + protocol.counter += reg[4 * i]; } + + maybe_check(); } template @@ -517,6 +534,8 @@ void SubProcessor::mulrs(const vector& reg) } protocol.counter += reg[4 * i]; } + + maybe_check(); } template @@ -550,6 +569,8 @@ void SubProcessor::dotprods(const vector& reg, int size) it = next; } } + + maybe_check(); } template @@ -590,6 +611,8 @@ void SubProcessor::matmuls(const StackedVector& source, for (int j = 0; j < dim[2]; j++) *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); } + + maybe_check(); } @@ -664,6 +687,8 @@ void SubProcessor::matmulsm(const MemoryPart& source, auto lastMatrixColumns = lastMatmulsArgs[5]; matmulsm_finalize_batch(batchStartMatrix, batchStartI, batchStartJ, lastMatmulsArgs, lastMatrixRows - 1, lastMatrixColumns - 1); + + maybe_check(); } template @@ -771,6 +796,8 @@ void SubProcessor::conv2ds(const Instruction& instruction) for (; done < i; done++) tuples[done].post(S, protocol); } + + maybe_check(); } inline @@ -869,6 +896,8 @@ void SubProcessor::secure_shuffle(const Instruction& instruction) typename T::Protocol::Shuffler(S, instruction.get_size(), instruction.get_n(), instruction.get_r(0), instruction.get_r(1), *this); + + maybe_check(); } template @@ -886,12 +915,14 @@ void SubProcessor::apply_shuffle(const Instruction& instruction, int handle, instruction.get_start()[0], instruction.get_start()[1], shuffle_store.get(handle), instruction.get_start()[4]); + maybe_check(); } template void SubProcessor::inverse_permutation(const Instruction& instruction) { shuffler.inverse_permutation(S, instruction.get_size(), instruction.get_start()[0], instruction.get_start()[1]); + maybe_check(); } template diff --git a/Processor/Program.cpp b/Processor/Program.cpp index 6774b8813..57c8f33e6 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -47,10 +47,26 @@ void Program::parse(string filename) void Program::parse_with_error(string filename) { + name = boost::filesystem::path(filename).stem().string(); ifstream pinp(filename); if (pinp.fail()) throw file_error(filename); - parse(pinp); + + try + { + parse(pinp); + } + catch (bytecode_error& e) + { + stringstream os; + os << "Cannot parse " << filename << " (" << e.what() << ")" << endl; + os << "Does the compiler version match the virtual machine? " + << "If in doubt, recompile the VM"; + if (not OnlineOptions::singleton.executable.empty()) + os << " using 'make " << OnlineOptions::singleton.executable << "'"; + os << "."; + throw bytecode_error(os.str()); + } // compute hash pinp.clear(); @@ -71,9 +87,26 @@ void Program::parse(istream& s) Instruction instr; s.peek(); while (!s.eof()) - { instr.parse(s, p.size()); - if (s.fail()) - throw runtime_error("error while parsing " + to_string(instr.opcode)); + { + bool fail = false; + try + { + instr.parse(s, p.size()); + } + catch (bad_alloc&) + { + fail = true; + } + fail |= s.fail(); + + if (fail) + { + stringstream os; + os << "error while parsing " << hex << showbase << instr.opcode + << " at " << dec << p.size(); + throw bytecode_error(os.str()); + } + p.push_back(instr); //cerr << "\t" << instr << endl; s.peek(); diff --git a/Processor/Program.h b/Processor/Program.h index 7783da687..45aecdb5b 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -28,6 +28,8 @@ class Program string hash; + string name; + void compute_constants(); public: @@ -66,6 +68,8 @@ class Program template void execute(Processor& Proc) const; + template + void execute_with_errors(Processor& Proc) const; }; #endif diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp index 2ac542e9a..9878bad4e 100644 --- a/Processor/RingOptions.cpp +++ b/Processor/RingOptions.cpp @@ -30,6 +30,15 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) int RingOptions::ring_size_from_opts_or_schedule(string progname) { + if (BaseMachine::prime_from_schedule(progname) + or BaseMachine::prime_length_from_schedule(progname)) + { + cerr << "Program was compiled for a prime field, " + << "not a ring modulo a power of two. " + << "Use './compile.py -R '." << endl; + exit(1); + } + int r = BaseMachine::ring_size_from_schedule(progname); if (R_is_set) { diff --git a/Processor/TruncPrTuple.cpp b/Processor/TruncPrTuple.cpp new file mode 100644 index 000000000..05a6aae20 --- /dev/null +++ b/Processor/TruncPrTuple.cpp @@ -0,0 +1,17 @@ +/* + * TruncPrTuple.cpp + * + */ + +#include "TruncPrTuple.h" + +void trunc_pr_check(int k, int m, int n_bits) +{ + if (not (m < k and 0 < m and k <= n_bits)) + { + stringstream ss; + ss << "invalid trunc_pr parameters, need 0 < m=" << m << " < k=" << k + << " <= n_bits=" << n_bits; + throw Processor_Error(ss.str()); + } +} diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h index 267acae48..d12a17cf2 100644 --- a/Processor/TruncPrTuple.h +++ b/Processor/TruncPrTuple.h @@ -12,6 +12,8 @@ using namespace std; #include "OnlineOptions.h" +void trunc_pr_check(int k, int m, int n_bits); + template class TruncPrTuple { @@ -36,6 +38,7 @@ class TruncPrTuple k = *it++; m = *it++; n_shift = T::N_BITS - 1 - k; + trunc_pr_check(k, m, T::n_bits()); assert(m < k); assert(0 < k); assert(m < T::n_bits()); diff --git a/Programs/Source/breast_logistic.mpc b/Programs/Source/breast_logistic.mpc index 28ee6be61..e5e60d11d 100644 --- a/Programs/Source/breast_logistic.mpc +++ b/Programs/Source/breast_logistic.mpc @@ -22,16 +22,17 @@ elif 'vertical' in program.args: b = sfix.input_tensor_via(1, X_train[:,X_train.shape[1] // 2:]) X_train = a.concat_columns(b) y_train = sint.input_tensor_via(0, y_train) -elif 'party0' in program.args: - a = sfix.input_tensor_via(0, X_train[:,:X_train.shape[1] // 2]) - b = sfix.input_tensor_via(1, shape=X_train[:,X_train.shape[1] // 2:].shape) - X_train = a.concat_columns(b) - y_train = sint.input_tensor_via(0, y_train) -elif 'party1' in program.args: - a = sfix.input_tensor_via(0, shape=X_train[:,:X_train.shape[1] // 2].shape) - b = sfix.input_tensor_via(1, X_train[:,X_train.shape[1] // 2:]) +elif 'party0' in program.args or 'party1' in program.args: + party = int('party1' in program.args) + a = sfix.input_tensor_via( + 0, X_train[:,:X_train.shape[1] // 2] if party == 0 else None, + shape=X_train[:,:X_train.shape[1] // 2].shape) + b = sfix.input_tensor_via( + 1, X_train[:,X_train.shape[1] // 2:] if party == 1 else None, + shape=X_train[:,X_train.shape[1] // 2:].shape) X_train = a.concat_columns(b) - y_train = sint.input_tensor_via(0, shape=y_train.shape) + y_train = sint.input_tensor_via(0, y_train if party == 0 else None, + shape=y_train.shape) else: X_train = sfix.input_tensor_via(0, X_train) y_train = sint.input_tensor_via(0, y_train) diff --git a/Programs/Source/export-a2b.py b/Programs/Source/export-a2b.py new file mode 100644 index 000000000..381cac773 --- /dev/null +++ b/Programs/Source/export-a2b.py @@ -0,0 +1,7 @@ +@export +def a2b(x, res): + print_ln('x=%s', x.reveal()) + res[:] = sbitvec(x, length=16) + print_ln('res=%s', x.reveal()) + +a2b(sint(size=10), sbitvec.get_type(16).Array(10)) diff --git a/Programs/Source/export-b2a.py b/Programs/Source/export-b2a.py new file mode 100644 index 000000000..d18ecb29f --- /dev/null +++ b/Programs/Source/export-b2a.py @@ -0,0 +1,7 @@ +@export +def b2a(res, x): + print_ln('x=%s', x.reveal()) + res[:] = sint(x[:]) + print_ln('res=%s', x.reveal()) + +b2a(sint.Array(size=10), sbitvec.get_type(16).Array(10)) diff --git a/Programs/Source/export-sort.py b/Programs/Source/export-sort.py new file mode 100644 index 000000000..a6ede5ac0 --- /dev/null +++ b/Programs/Source/export-sort.py @@ -0,0 +1,7 @@ +@export +def sort(x): + print_ln('x=%s', x.reveal()) + res = x.sort() + print_ln('res=%s', x.reveal()) + +sort(sint.Array(1000)) diff --git a/Programs/Source/export-trunc.py b/Programs/Source/export-trunc.py new file mode 100644 index 000000000..a8c123e17 --- /dev/null +++ b/Programs/Source/export-trunc.py @@ -0,0 +1,8 @@ +@export +def trunc_pr(x): + print_ln('x=%s', x.reveal()) + res = x.round(32, 2) + print_ln('res=%s', res.reveal()) + return res + +trunc_pr(sint(0, size=1000)) diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index caca22140..308465e3a 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -63,9 +63,6 @@ else: if 'nearest' in program.args: sfix.round_nearest = True -if program.options.ring: - assert sfix.f * 4 == int(program.options.ring) - debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2) if '1dense' in program.args: diff --git a/Programs/Source/test_flow_optimization.mpc b/Programs/Source/test_flow_optimization.mpc index ba7af6507..26e4a85f8 100644 --- a/Programs/Source/test_flow_optimization.mpc +++ b/Programs/Source/test_flow_optimization.mpc @@ -21,3 +21,13 @@ test(a, 10000, 10000) test(b, 10000, 20000) test(a, 1000000, 1000000) test(b, 1000000, 2000000) + +a = 1 +if True: + if True: + a = 2 + if True: + a = 3 +else: + a = 4 + crash() diff --git a/Programs/Source/torch_resnet.py b/Programs/Source/torch_resnet.py index 5a0e72e6f..08140c04e 100644 --- a/Programs/Source/torch_resnet.py +++ b/Programs/Source/torch_resnet.py @@ -41,7 +41,9 @@ layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0) -optimizer = ml.Optimizer(layers) +optimizer = ml.Optimizer(layers, time_layers='time_layers' in program.args) +start_timer(1) print_ln('Secure computation says %s', optimizer.eval(secret_input, top=True)[0].reveal()) +stop_timer(1) diff --git a/Programs/Source/torch_vgg.py b/Programs/Source/torch_vgg.py new file mode 100644 index 000000000..d5afe8c9d --- /dev/null +++ b/Programs/Source/torch_vgg.py @@ -0,0 +1,42 @@ +# this tests the pretrained VGG in secure computation + +program.options_from_args() + +from Compiler import ml + +try: + ml.set_n_threads(int(program.args[2])) +except: + pass + +import torchvision +import torch +import numpy +import requests +import io +import PIL + +from torchvision import transforms + +name = 'vgg' + program.args[1] +model = getattr(torchvision.models, name)(weights='DEFAULT') + +r = requests.get('https://github.com/pytorch/hub/raw/master/images/dog.jpg') +input_image = PIL.Image.open(io.BytesIO(r.content)) +input_tensor = transforms._presets.ImageClassification(crop_size=32)(input_image) +input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + +with torch.no_grad(): + output = int(model(input_batch).argmax()) + print('Model says %d' % output) + +secret_input = sfix.input_tensor_via( + 0, numpy.moveaxis(input_batch.numpy(), 1, -1)) + +layers = ml.layers_from_torch(model, secret_input.shape, 1, input_via=0) + +optimizer = ml.Optimizer(layers) +optimizer.time_layers = True + +print_ln('Secure computation says %s', + optimizer.eval(secret_input, top=True)[0].reveal()) diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 47f5c93d1..29f9a27cc 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -8,6 +8,8 @@ #include "Replicated.h" +#include "Tools/Bundle.h" + /** * ATLAS protocol (simple version). * Uses double sharings to reduce degree of Shamir secret sharing. diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index 29d76807d..5f6757cd9 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -7,6 +7,7 @@ #define PROTOCOLS_BRAINSHARE_H_ #include "Rep3Share.h" +#include "SemiShare.h" template class HashMaliciousRepMC; template class Beaver; @@ -20,18 +21,21 @@ class MaliciousRepSecret; template class BrainShare : public Rep3Share> { + typedef BrainShare This; typedef SignedZ2 T; typedef Rep3Share super; public: typedef T clear; - typedef Beaver Protocol; + typedef Beaver BasicProtocol; typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; typedef ::PrivateOutput PrivateOutput; typedef BrainPrep LivePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/ChaiGearShare.h b/Protocols/ChaiGearShare.h index e8d8b4491..21d6f9a05 100644 --- a/Protocols/ChaiGearShare.h +++ b/Protocols/ChaiGearShare.h @@ -21,6 +21,8 @@ class ChaiGearShare : public Share typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; + typedef DummyMatrixPrep MatrixPrep; typedef SPDZ Protocol; typedef ChaiGearPrep LivePrep; typedef Share prep_check_type; diff --git a/Protocols/CowGearShare.h b/Protocols/CowGearShare.h index 4007f73af..6ef4ab288 100644 --- a/Protocols/CowGearShare.h +++ b/Protocols/CowGearShare.h @@ -21,8 +21,10 @@ class CowGearShare : public Share typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; typedef SPDZ Protocol; typedef CowGearPrep LivePrep; + typedef DummyMatrixPrep MatrixPrep; typedef Share prep_check_type; const static bool needs_ot = false; diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index db1ed813b..474154212 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -25,7 +25,7 @@ class DealerMC : public MAC_Check_Base void prepare_open(const T& secret, int n_bits = -1); void exchange(const Player& P); typename T::open_type finalize_raw(); - array finalize_several(int n); + array finalize_several(size_t n); DealerMC& get_part_MC() { diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index 08b4b4587..b0467e851 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -74,7 +74,7 @@ typename T::open_type DealerMC::finalize_raw() } template -array DealerMC::finalize_several(int n) +array DealerMC::finalize_several(size_t n) { assert(sub_player); return internal.finalize_several(n); diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h index 417fdbac7..459d6dfd7 100644 --- a/Protocols/DealerPrep.h +++ b/Protocols/DealerPrep.h @@ -7,6 +7,7 @@ #define PROTOCOLS_DEALERPREP_H_ #include "ReplicatedPrep.h" +#include "DealerMatrixPrep.h" template class DealerPrep : virtual public BitPrep diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index e70edecb5..bdd83dca7 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -217,7 +217,7 @@ void DealerPrep::buffer_edabits(int length, false_type) { vector as; vector bs; - plain_edabits(as, bs, length, G); + plain_edabits(as, bs, length, G, edabitvec::MAX_SIZE); for (auto& a : as) { make_share(shares.data(), a, P.num_players() - 1, 0, G); diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h index e59e19494..4c66e14f6 100644 --- a/Protocols/DealerShare.h +++ b/Protocols/DealerShare.h @@ -8,6 +8,9 @@ #include "Math/Z2k.h" #include "SemiShare.h" +#include "DealerMC.h" +#include "Dealer.h" +#include "DealerInput.h" template class DealerPrep; template class DealerInput; diff --git a/Protocols/DummyMatrixPrep.h b/Protocols/DummyMatrixPrep.h new file mode 100644 index 000000000..e02447b59 --- /dev/null +++ b/Protocols/DummyMatrixPrep.h @@ -0,0 +1,27 @@ +/* + * DummyMatrixPrep.h + * + */ + +#ifndef PROTOCOLS_DUMMYMATRIXPREP_H_ +#define PROTOCOLS_DUMMYMATRIXPREP_H_ + +#include "Processor/Data_Files.h" +#include "ShareMatrix.h" + +class no_matrix_prep : public exception +{ +}; + +template +class DummyMatrixPrep : public Preprocessing> +{ +public: + DummyMatrixPrep(int, int, int, Preprocessing&, DataPositions& usage) : + Preprocessing>(usage) + { + throw no_matrix_prep(); + } +}; + +#endif /* PROTOCOLS_DUMMYMATRIXPREP_H_ */ diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 623663b85..47a45a027 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -15,19 +15,23 @@ template class Hemi : public T::BasicProtocol { - map, typename T::MatrixPrep*> matrix_preps; + typedef Preprocessing> matrix_prep; + + map, matrix_prep*> matrix_preps; DataPositions matrix_usage; - MatrixMC mc; + MatrixMC* mc; + + bool warned = false; public: Hemi(Player& P) : - T::BasicProtocol(P) + T::BasicProtocol(P), mc(0) { } ~Hemi(); - typename T::MatrixPrep& get_matrix_prep(const array& dimensions, + matrix_prep& get_matrix_prep(const array& dimensions, SubProcessor& processor); bool use_plain_matmul(const array dimensions, diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index bc4b4d3a5..db6a27ca6 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -9,6 +9,9 @@ #include "Hemi.h" #include "ShareMatrix.h" #include "HemiOptions.h" +#include "MatrixFile.h" +#include "DummyMatrixPrep.h" +#include "Processor/Conv2dTuple.h" #include "HemiMatrixPrep.hpp" #include "HemiPrep.hpp" @@ -18,17 +21,25 @@ Hemi::~Hemi() { for (auto& x : matrix_preps) delete x.second; + if (mc) + delete mc; } template -typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, +Preprocessing>& Hemi::get_matrix_prep(const array& dims, SubProcessor& processor) { if (matrix_preps.find(dims) == matrix_preps.end()) - matrix_preps.insert(pair, typename T::MatrixPrep*>(dims, - new typename T::MatrixPrep(dims[0], dims[1], dims[2], + { + Preprocessing>* prep; + if (OnlineOptions::singleton.live_prep) + prep = new typename T::MatrixPrep(dims[0], dims[1], dims[2], dynamic_cast(processor.DataF), - matrix_usage))); + matrix_usage); + else + prep = new MatrixFile(dims, matrix_usage, this->P); + matrix_preps.insert({dims, prep}); + } return *matrix_preps.at(dims); } @@ -38,6 +49,48 @@ bool Hemi::use_plain_matmul(const array dim, SubProcessor& process if (OnlineOptions::singleton.has_option("force_matrix_triples")) return false; + if (OnlineOptions::singleton.live_prep) + { + try + { + get_matrix_prep(dim, processor); + } + catch (no_matrix_prep&) + { + return true; + } + } + else + { + int found = false; + + try + { + get_matrix_prep(dim, processor); + found = true; + } + catch (signature_mismatch&) + { + if (not warned) + { + cerr << "Cannot find matrix triples on disk, " + << "reverting to plain triples" << endl; + cerr << "Use './Fake-Offline.x -p ...'" + << " to generate matrix triples" << endl; + warned = true; + } + } + + Bundle os(processor.P); + os.mine.store(found); + processor.P.Broadcast_Receive(os); + os.mine.reset_read_head(); + + for (auto& o : os) + if (not o.get_int(4)) + return true; + } + auto& prep = get_matrix_prep(dim, processor); int savings = (dim[0] * dim[2]) / (dim[0] + dim[2]) + 1; int requirement = BaseMachine::matrix_requirement(dim[0], dim[1], dim[2]); @@ -47,7 +100,6 @@ bool Hemi::use_plain_matmul(const array dim, SubProcessor& process prep.minimum_batch(), requirement); return HemiOptions::singleton.plain_matmul - or not OnlineOptions::singleton.live_prep or prep.minimum_batch() / savings > requirement; } @@ -135,12 +187,26 @@ template ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor) { + if (mc == 0) + { + mc = new MatrixMC(processor.MC); + } + Beaver> beaver(this->P); array dims = {{A.n_rows, A.n_cols, B.n_cols}}; ShareMatrix C(A.n_rows, B.n_cols); + bool verbose = OnlineOptions::singleton.has_option("verbose_matmul"); + int max_inner = OnlineOptions::singleton.batch_size; int max_cols = OnlineOptions::singleton.batch_size; + + if (not OnlineOptions::singleton.live_prep) + { + max_inner = A.n_cols; + max_cols = B.n_cols; + } + for (int i = 0; i < A.n_cols; i += max_inner) { for (int j = 0; j < B.n_cols; j += max_cols) @@ -149,19 +215,37 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[1] = min(max_inner, A.n_cols - i); subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); - beaver.init(prep, mc); + beaver.init(prep, *mc); beaver.init_mul(); bool for_real = T::real_shares(processor.P); - beaver.prepare_mul(A.from(0, i, subdim.data(), for_real), - B.from(i, j, subdim.data() + 1, for_real)); + if (verbose) + fprintf(stderr, "matmul prepare\n"); + auto AA = A.from(0, i, subdim.data(), for_real); + auto BB = B.from(i, j, subdim.data() + 1, for_real); + beaver.prepare_mul(AA, BB); if (for_real) { + if (verbose) + fprintf(stderr, "matmul exchange\n"); + for (size_t k = 0; k < AA.entries.size() + BB.entries.size(); + k++) + mc->inner.set_random_element({}); beaver.exchange(); C.add_from_col(j, beaver.finalize_mul()); } } } + if (OnlineOptions::singleton.has_option("debug_matmul")) + { + mc->inner.Check(processor.P); + auto opened = mc->open(C, processor.P); + for (auto& x: opened.entries) + cout << x << " "; + cout << endl; + mc->inner.Check(processor.P); + } + return C; } diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 781941eed..437a9b1da 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_HEMIMATRIXPREP_HPP_ +#define PROTOCOLS_HEMIMATRIXPREP_HPP_ + #include "HemiMatrixPrep.h" #include "MAC_Check.h" #include "FHE/Diagonalizer.h" @@ -226,3 +229,5 @@ void HemiMatrixPrep::buffer_triples() fflush(stderr); #endif } + +#endif diff --git a/Protocols/HighGearShare.h b/Protocols/HighGearShare.h index faa7fa267..a1449557a 100644 --- a/Protocols/HighGearShare.h +++ b/Protocols/HighGearShare.h @@ -19,6 +19,8 @@ class HighGearShare : public ChaiGearShare typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; + typedef DummyMatrixPrep MatrixPrep; typedef SPDZ Protocol; typedef ChaiGearPrep LivePrep; diff --git a/Protocols/LowGearShare.h b/Protocols/LowGearShare.h index afb2b7dac..88ddaa8ee 100644 --- a/Protocols/LowGearShare.h +++ b/Protocols/LowGearShare.h @@ -19,6 +19,8 @@ class LowGearShare : public CowGearShare typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; + typedef DummyMatrixPrep MatrixPrep; typedef SPDZ Protocol; typedef CowGearPrep LivePrep; diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 2fd32c85b..037fb42ed 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -107,9 +107,6 @@ class Tree_MAC_Check : public TreeSum, public MAC_Check_B virtual void AddToCheck(const U& share, const T& value, const Player& P); virtual void Check(const Player& P) = 0; - - // compatibility - void set_random_element(const U& random_element) { (void) random_element; } }; template diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index a38a1c921..ee6b03852 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -110,7 +110,11 @@ void Tree_MAC_Check::exchange(const Player& P) this->values_opened += this->values.size(); popen_cnt += this->values.size(); - CheckIfNeeded(P); + + if (OnlineOptions::singleton.has_option("always_check")) + Check(P); + else + CheckIfNeeded(P); } diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index 5b8553c7e..f2905c428 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -65,7 +65,7 @@ class MAC_Check_Base /// Get next opened value virtual typename T::clear finalize_open(); virtual typename T::open_type finalize_raw(); - array finalize_several(size_t n); + virtual array finalize_several(size_t n); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); @@ -73,6 +73,8 @@ class MAC_Check_Base virtual const Player& get_check_player(const Player& P) const { return P; } virtual void set_prep(Preprocessing&) {} + + void set_random_element(const T&) {} }; #endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */ diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index 4c263fd72..dc3e0cfd8 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -23,7 +23,7 @@ class MalRepRingShare : public MaliciousRep3Share> const static int BIT_LENGTH = K; const static int SECURITY = S; - typedef Beaver Protocol; + typedef Beaver BasicProtocol; typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; @@ -34,6 +34,8 @@ class MalRepRingShare : public MaliciousRep3Share> typedef Z2 random_type; typedef MalRepRingShare SquareToBitShare; typedef MalRepRingPrep SquarePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; MalRepRingShare() { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 2a153e5f2..718d8a69b 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -7,6 +7,7 @@ #define PROTOCOLS_MALICIOUSREP3SHARE_H_ #include "Rep3Share.h" +#include "SemiShare.h" template class HashMaliciousRepMC; template class Beaver; @@ -27,7 +28,7 @@ class MaliciousRep3Share : public Rep3Share typedef MaliciousRep3Share This; public: - typedef Beaver> Protocol; + typedef Beaver> BasicProtocol; typedef HashMaliciousRepMC> MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; @@ -39,6 +40,8 @@ class MaliciousRep3Share : public Rep3Share typedef MaliciousRep3Share prep_type; typedef T random_type; typedef This Scalar; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index e023945b1..53377605e 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -15,14 +15,29 @@ class MaliciousRepMC : public ReplicatedMC typedef ReplicatedMC super; public: - virtual void POpen(vector& values, - const vector& S, const Player& P); + MaliciousRepMC(typename T::mac_key_type = {}) + { + } + + virtual void POpen(vector&, + const vector&, const Player&) + { + throw runtime_error("use subclass"); + } + virtual void POpen_Begin(vector& values, const vector& S, const Player& P); - virtual void POpen_End(vector& values, - const vector& S, const Player& P); - virtual void Check(const Player& P); + virtual void POpen_End(vector&, + const vector&, const Player&) + { + throw runtime_error("use subclass"); + } + + virtual void Check(const Player&) + { + throw runtime_error("use subclass"); + } MaliciousRepMC& get_part_MC() { diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 631ef7667..e301ee9c0 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -22,28 +22,6 @@ void MaliciousRepMC::POpen_Begin(vector& values, super::POpen_Begin(values, S, P); } -template -void MaliciousRepMC::POpen_End(vector& values, - const vector& S, const Player& P) -{ - (void)values, (void)S, (void)P; - throw runtime_error("use subclass"); -} - -template -void MaliciousRepMC::POpen(vector&, - const vector&, const Player&) -{ - throw runtime_error("use subclass"); -} - -template -void MaliciousRepMC::Check(const Player& P) -{ - (void)P; - throw runtime_error("use subclass"); -} - template HashMaliciousRepMC::HashMaliciousRepMC() { diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 38e3100e2..2192ee83b 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -7,7 +7,8 @@ #define PROTOCOLS_MALICIOUS_SHAMIR_M_C_HPP_ #include "MaliciousShamirMC.h" -#include "Machines/ShamirMachine.h" + +#include "ShamirOptions.h" template MaliciousShamirMC::MaliciousShamirMC() diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 332996ddd..d72ffc303 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -27,7 +27,7 @@ class MaliciousShamirShare : public ShamirShare typedef MaliciousShamirShare This; public: - typedef Beaver> Protocol; + typedef Beaver> BasicProtocol; typedef MaliciousShamirMC MAC_Check; typedef MAC_Check Direct_MC; typedef ShamirInput Input; @@ -37,6 +37,8 @@ class MaliciousShamirShare : public ShamirShare typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; typedef T random_type; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; // indicate security relevance of field size typedef T mac_key_type; diff --git a/Protocols/MamaPrep.h b/Protocols/MamaPrep.h index b47bf0867..1c229dce6 100644 --- a/Protocols/MamaPrep.h +++ b/Protocols/MamaPrep.h @@ -18,7 +18,7 @@ class MamaPrep : public MascotInputPrep, public MaliciousRingPrep static void basic_setup(Player&) {}; static void teardown() {}; - MamaPrep(SubProcessor* proc, DataPositions& usage); + MamaPrep(SubProcessor* proc, DataPositions& usage); void buffer_triples(); }; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index 706c18139..280320531 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -53,11 +53,13 @@ class MamaShare : public Share_, MamaMac> typedef FixedVec, N> mac_key_type; typedef Share_, MamaMac> super; - typedef Beaver Protocol; + typedef Beaver BasicProtocol; typedef MAC_Check_ MAC_Check; typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef MamaPrep LivePrep; typedef MamaShare prep_type; diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 754efec0c..f199713c3 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -17,7 +17,7 @@ class OTPrep : public virtual BitPrep MascotParams params; - OTPrep(SubProcessor* proc, DataPositions& usage); + OTPrep(SubProcessor* proc, DataPositions& usage); ~OTPrep(); void set_protocol(typename T::Protocol& protocol); @@ -117,7 +117,7 @@ template class MascotFieldPrep : public virtual MascotPrep { public: - MascotFieldPrep(SubProcessor* proc, DataPositions& usage) : + MascotFieldPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), MaliciousDabitOnlyPrep(proc, usage), diff --git a/Protocols/MatrixFile.h b/Protocols/MatrixFile.h new file mode 100644 index 000000000..eb883526f --- /dev/null +++ b/Protocols/MatrixFile.h @@ -0,0 +1,45 @@ +/* + * MatrixFile.h + * + */ + +#ifndef PROTOCOLS_MATRIXFILE_H_ +#define PROTOCOLS_MATRIXFILE_H_ + +#include "Processor/Data_Files.h" +#include "ShareMatrix.h" + +template +class MatrixFile : public Preprocessing> +{ + typedef Preprocessing> super; + + array dims; + + ifstream file; + +public: + MatrixFile(array dims, DataPositions& usage, Player& P) : + super(usage), dims(dims) + { + string filename = PrepBase::get_matrix_prefix( + get_prep_sub_dir(P.num_players()), dims) + "-P" + + to_string(P.my_num()); + file.open(filename); + check_file_signature(file, filename); + } + + void get_three_no_count(Dtype type, ShareMatrix& A, ShareMatrix& B, + ShareMatrix& C) + { + assert(type == DATA_TRIPLE); + A = {dims[0], dims[1]}; + B = {dims[1], dims[2]}; + C = {dims[0], dims[2]}; + A.input(file); + B.input(file); + C.input(file); + } +}; + +#endif /* PROTOCOLS_MATRIXFILE_H_ */ diff --git a/Protocols/Opener.h b/Protocols/Opener.h new file mode 100644 index 000000000..4bc7d13cb --- /dev/null +++ b/Protocols/Opener.h @@ -0,0 +1,57 @@ +/* + * Opener.h + * + */ + +#ifndef PROTOCOLS_OPENER_H_ +#define PROTOCOLS_OPENER_H_ + +#include + +template +class Opener +{ + typedef typename conditional::value, + typename T::MAC_Check, typename T::DefaultMC>::type inner_type; + + inner_type inner; + Player& P; + +public: + Opener(Player& P, typename T::mac_key_type mac_key) : + inner(mac_key), P(P) + { + } + + ~Opener() + { + Check(); + } + + void Check() + { + inner.Check(P); + } + + void init_open(int n = 0) + { + inner.init_open(P, n); + } + + void prepare_open(const T& secret) + { + inner.prepare_open(secret); + } + + void exchange() + { + inner.exchange(P); + } + + typename T::clear finalize_open() + { + return inner.finalize_open(); + } +}; + +#endif /* PROTOCOLS_OPENER_H_ */ diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index 06196762b..cad03bdbd 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -20,13 +20,15 @@ class PostSacriRepFieldShare : public MaliciousRep3Share public: typedef typename super::clear clear; - typedef PostSacrifice Protocol; + typedef PostSacrifice BasicProtocol; typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; typedef MaliciousRepPrepWithBits LivePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; PostSacriRepFieldShare() { diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 30f0574b2..682eb7444 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -32,13 +32,15 @@ class PostSacriRepRingShare : public Rep3Share2 typedef Z2 random_type; typedef MalRepRingShare SquareToBitShare; - typedef PostSacrifice Protocol; + typedef PostSacrifice BasicProtocol; typedef HashMaliciousRepMC MAC_Check; typedef MAC_Check Direct_MC; typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; typedef MalRepRingPrepWithBits LivePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/PostSacrifice.hpp b/Protocols/PostSacrifice.hpp index 0f72f4e81..e24f9479e 100644 --- a/Protocols/PostSacrifice.hpp +++ b/Protocols/PostSacrifice.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_POSTSACRIFIE_HPP_ +#define PROTOCOLS_POSTSACRIFIE_HPP_ + #include "PostSacrifice.h" template @@ -81,3 +84,5 @@ void PostSacrifice::check() operands.erase(operands.begin(), operands.begin() + buffer_size); results.clear(); } + +#endif diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h index 09be88cb1..41e83714e 100644 --- a/Protocols/ProtocolSet.h +++ b/Protocols/ProtocolSet.h @@ -7,9 +7,13 @@ #define PROTOCOLS_PROTOCOLSET_H_ #include "Processor/Processor.h" +#include "Processor/Machine.h" #include "GC/ShareThread.h" #include "ProtocolSetup.h" +#include +#include + /** * Input, multiplication, and output protocol instance * for an arithmetic share type @@ -42,6 +46,16 @@ class ProtocolSet { } + /** + * @param P communication instance + * @param machine virtual machine instance + */ + template + ProtocolSet(Player& P, const Machine& machine) : + ProtocolSet(P, machine.get_sint_mac_key()) + { + } + /** * Run all protocol checks */ @@ -79,6 +93,19 @@ class BinaryProtocolSet { } + /** + * @param P communication instance + * @param machine virtual machine instance + */ + template + BinaryProtocolSet(Player& P, const Machine& machine) : + usage(P.num_players()), prep(usage), thread(prep, P, + machine.get_bit_mac_key()), output(*thread.MC), protocol( + *thread.protocol), input(output, prep, P) + { + assert((is_same())); + } + /** * Run all protocol checks */ @@ -117,6 +144,18 @@ class MixedProtocolSet { } + /** + * @param P communication instance + * @param machine virtual machine instance + */ + template + MixedProtocolSet(Player& P, const Machine& machine) : + arithmetic(P, machine), binary(P, machine), output( + arithmetic.output), preprocessing(arithmetic.preprocessing), protocol( + arithmetic.protocol), input(arithmetic.input) + { + } + /** * Run all protocol checks */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 1e2852a11..d902fbce8 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -124,7 +124,7 @@ class Rep3Share : public RepShare const static bool needs_ot = false; const static bool dishonest_majority = false; const static bool expensive = false; - const static bool variable_players = false; + static false_type variable_players; static const bool has_trunc_pr = true; static const bool malicious = false; @@ -175,4 +175,7 @@ class Rep3Share : public RepShare } }; +template +false_type Rep3Share::variable_players; + #endif /* PROTOCOLS_REP3SHARE_H_ */ diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index 10dc5e1f4..62955f607 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -266,7 +266,7 @@ void Rep4::exchange() template T Rep4::finalize_mul(int n_bits) { - this->counter++; + this->add_mul(n_bits); if (n_bits == -1) return results.next().res; else diff --git a/Protocols/Rep4MC.hpp b/Protocols/Rep4MC.hpp index 7b2fffda2..1035b8118 100644 --- a/Protocols/Rep4MC.hpp +++ b/Protocols/Rep4MC.hpp @@ -27,6 +27,9 @@ void Rep4MC::exchange(const Player& P) this->values[i] = this->secrets[i].sum() + a; } receive_hash.update(right); + + if (OnlineOptions::singleton.has_option("always_check")) + Check(P); } template diff --git a/Protocols/RepRingOnlyEdabitPrep.hpp b/Protocols/RepRingOnlyEdabitPrep.hpp index da2fab0b7..9dafbec9d 100644 --- a/Protocols/RepRingOnlyEdabitPrep.hpp +++ b/Protocols/RepRingOnlyEdabitPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_REPRINGONLYEDABITPREP_HPP_ +#define PROTOCOLS_REPRINGONLYEDABITPREP_HPP_ + #include "RepRingOnlyEdabitPrep.h" #include "GC/BitAdder.h" #include "Processor/Instruction.h" @@ -52,3 +55,5 @@ void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) this->push_edabits(this->edabits[{false, n_bits}], wholes, sums); } + +#endif diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index fd807364e..c494ceb34 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -216,7 +216,8 @@ void Replicated::stop_exchange() template void ProtocolBase::add_mul(int n) { - this->counter++; + // counted in SubProcessor + // this->counter++; this->bit_counter += n < 0 ? T::default_length : n; } diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index 17916a2e9..6648d7f4e 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -35,6 +35,7 @@ class ReplicatedMC : public MAC_Check_Base virtual void exchange(const Player& P); virtual typename T::open_type finalize_raw(); + virtual array finalize_several(size_t n); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 195b0dc83..8c47c8d5b 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -72,4 +72,12 @@ typename T::open_type ReplicatedMC::finalize_raw() return a + o.get(); } +template +array ReplicatedMC::finalize_several(size_t n) +{ + if (this->values.empty()) + finalize(this->values, this->secrets); + return MAC_Check_Base::finalize_several(n); +} + #endif diff --git a/Protocols/RingOnlyPrep.h b/Protocols/RingOnlyPrep.h index 2d2f1928f..2f088a8fa 100644 --- a/Protocols/RingOnlyPrep.h +++ b/Protocols/RingOnlyPrep.h @@ -15,7 +15,7 @@ template class RingOnlyPrep : public virtual RingPrep { protected: - RingOnlyPrep(SubProcessor* proc, DataPositions& usage) : + RingOnlyPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage) { diff --git a/Protocols/SPDZ.h b/Protocols/SPDZ.h index bd804ea0a..1d654bcd3 100644 --- a/Protocols/SPDZ.h +++ b/Protocols/SPDZ.h @@ -6,7 +6,8 @@ #ifndef PROTOCOLS_SPDZ_H_ #define PROTOCOLS_SPDZ_H_ -#include "Beaver.h" +#include "Hemi.h" +#include "SemiShare.h" #include using namespace std; @@ -19,22 +20,13 @@ class Player; * SPDZ protocol */ template -class SPDZ : public Beaver +class SPDZ : public MaybeHemi { public: - SPDZ(Player& P) : Beaver(P) + SPDZ(Player& P) : MaybeHemi(P) { } - static void assign(typename T::open_type& share, - const typename T::open_type& clear, int my_num) - { - if (my_num == 0) - share = clear; - else - share = 0; - } - int get_n_relevant_players() { return this->P.num_players(); diff --git a/Protocols/Semi.h b/Protocols/Semi.h index f73dfd9d5..4293aec8f 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -6,20 +6,20 @@ #ifndef PROTOCOLS_SEMI_H_ #define PROTOCOLS_SEMI_H_ -#include "SPDZ.h" +#include "Beaver.h" #include "Processor/TruncPrTuple.h" /** * Dishonest-majority protocol for computation modulo a power of two */ template -class Semi : public SPDZ +class Semi : public Beaver { SeededPRNG G; public: Semi(Player& P) : - SPDZ(P) + Beaver(P) { } diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index f88969d5b..d202c2a7f 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -19,14 +19,17 @@ template class Semi2kShare : public SemiShare> { typedef SignedZ2 T; + typedef Semi2kShare This; public: typedef SemiMC MAC_Check; typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef Semi Protocol; + typedef Semi BasicProtocol; typedef SemiPrep2k LivePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef Semi2kShare prep_type; typedef SemiMultiplier Multiplier; diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index 464d793a3..5fc2c4f05 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SEMIINPUT_H_ #include "ReplicatedInput.h" +#include "Hemi.h" template class SemiMC; diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 39d242708..eb5004d7f 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -19,11 +19,17 @@ template class Input; template class SemiMC; template class DirectSemiMC; template class Semi; +template class Hemi; template class SemiPrep; template class SemiInput; template class PrivateOutput; template class SemiMultiplier; template class OTTripleGenerator; +template class DummyMatrixPrep; + +template +using MaybeHemi = typename conditional>::type; namespace GC { @@ -50,6 +56,7 @@ template class SemiShare : public T, public ShareInterface { typedef T super; + typedef SemiShare This; public: typedef T open_type; @@ -60,9 +67,11 @@ class SemiShare : public T, public ShareInterface typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef Semi Protocol; + typedef Semi BasicProtocol; typedef SemiPrep LivePrep; typedef LivePrep TriplePrep; + typedef MaybeHemi Protocol; + typedef DummyMatrixPrep MatrixPrep; typedef SemiShare prep_type; typedef SemiMultiplier Multiplier; @@ -91,7 +100,10 @@ class SemiShare : public T, public ShareInterface static SemiShare constant(const open_type& other, int my_num, mac_key_type = {}, int = -1) { - return SemiShare(other, my_num); + if (my_num == 0) + return other; + else + return {}; } SemiShare() @@ -101,11 +113,6 @@ class SemiShare : public T, public ShareInterface SemiShare(const U& other) : T(other) { } - SemiShare(const open_type& other, int my_num, const T& alphai = {}) - { - (void) alphai; - Protocol::assign(*this, other, my_num); - } void assign(const char* buffer) { diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 815377a4a..20ed0bc8d 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -8,8 +8,8 @@ #include "Shamir.h" #include "ShamirInput.h" +#include "ShamirOptions.h" #include "ShamirShare.h" -#include "Machines/ShamirMachine.h" #include "Tools/benchmarking.h" template diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index d5e975b35..a889aae94 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -9,7 +9,7 @@ #include "Processor/Input.h" #include "Shamir.h" #include "SemiInput.h" -#include "Machines/ShamirMachine.h" +#include "ShamirOptions.h" /** * Base class for input protocols where the inputting player sends a share diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 1fcaf8476..aec6d1e80 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -7,8 +7,7 @@ #define PROTOCOLS_SHAMIRINPUT_HPP_ #include "ShamirInput.h" -#include "Machines/ShamirMachine.h" - +#include "ShamirOptions.h" #include "Protocols/ReplicatedInput.hpp" #include "Protocols/SemiInput.hpp" diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 8e2bafdbe..30db1de74 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -8,7 +8,7 @@ #include "MAC_Check_Base.h" #include "Protocols/ShamirShare.h" -#include "Machines/ShamirMachine.h" +#include "ShamirOptions.h" #include "Tools/Bundle.h" /** @@ -70,6 +70,7 @@ class ShamirMC : public IndirectShamirMC virtual void prepare_open(const T& secret, int = -1); virtual void exchange(const Player& P); virtual typename T::open_type finalize_raw(); + virtual array finalize_several(size_t n); void Check(const Player& P) { (void)P; } diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index db50c327a..59110110e 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -8,6 +8,9 @@ #include "ShamirMC.h" +#include "MAC_Check_Base.hpp" +#include "Shamir.hpp" + template ShamirMC::ShamirMC(int t) : os(0), player(0), threshold() @@ -116,6 +119,14 @@ void ShamirMC::finalize(vector& values, values.push_back(finalize_raw()); } +template +array ShamirMC::finalize_several(size_t n) +{ + this->values.clear(); + finalize(this->values, vector(n)); + return MAC_Check_Base::finalize_several(n); +} + template typename T::open_type ShamirMC::finalize_raw() { diff --git a/Machines/ShamirMachine.hpp b/Protocols/ShamirOptions.cpp similarity index 56% rename from Machines/ShamirMachine.hpp rename to Protocols/ShamirOptions.cpp index a4f22206e..d6b377482 100644 --- a/Machines/ShamirMachine.hpp +++ b/Protocols/ShamirOptions.cpp @@ -3,40 +3,10 @@ * */ -#ifndef MACHINE_SHAMIR_MACHINE_HPP_ -#define MACHINE_SHAMIR_MACHINE_HPP_ +#include "ShamirOptions.h" -#include -#include "Protocols/ShamirShare.h" -#include "Protocols/MaliciousShamirShare.h" -#include "Math/gfp.h" -#include "Math/gf2n.h" -#include "GC/VectorProtocol.h" -#include "GC/CcdPrep.h" -#include "GC/TinyMC.h" -#include "GC/MaliciousCcdSecret.h" -#include "GC/VectorInput.h" - -#include "Processor/FieldMachine.hpp" - -#include "Processor/Data_Files.hpp" -#include "Processor/Instruction.hpp" -#include "Processor/Machine.hpp" -#include "Protocols/ShamirInput.hpp" -#include "Protocols/Shamir.hpp" -#include "Protocols/ShamirMC.hpp" -#include "Protocols/MaliciousShamirMC.hpp" -#include "Protocols/MaliciousShamirPO.hpp" -#include "Protocols/MAC_Check_Base.hpp" -#include "Protocols/Beaver.hpp" -#include "Protocols/Spdz2kPrep.hpp" -#include "Protocols/ReplicatedPrep.hpp" -#include "Protocols/MalRepRingPrep.hpp" -#include "GC/ShareSecret.hpp" -#include "GC/VectorProtocol.hpp" -#include "GC/Secret.hpp" -#include "GC/CcdPrep.hpp" -#include "Math/gfp.hpp" +#include +using namespace std; ShamirOptions ShamirOptions::singleton; @@ -99,14 +69,3 @@ void ShamirOptions::set_threshold(ez::ezOptionParser& opt) exit(1); } } - -template class T> -ShamirMachineSpec::ShamirMachineSpec(int argc, const char** argv) -{ - auto& opts = ShamirOptions::singleton; - ez::ezOptionParser opt; - opts = {opt, argc, argv}; - HonestMajorityFieldMachine(argc, argv, opt, opts.nparties); -} - -#endif diff --git a/Machines/ShamirMachine.h b/Protocols/ShamirOptions.h similarity index 83% rename from Machines/ShamirMachine.h rename to Protocols/ShamirOptions.h index 6a4609279..1a80f7642 100644 --- a/Machines/ShamirMachine.h +++ b/Protocols/ShamirOptions.h @@ -3,8 +3,8 @@ * */ -#ifndef MACHINES_SHAMIRMACHINE_H_ -#define MACHINES_SHAMIRMACHINE_H_ +#ifndef PROTOCOLS_SHAMIROPTIONS_H_ +#define PROTOCOLS_SHAMIROPTIONS_H_ #include "Tools/ezOptionParser.h" @@ -34,4 +34,4 @@ class ShamirMachineSpec ShamirMachineSpec(int argc, const char** argv); }; -#endif /* MACHINES_SHAMIRMACHINE_H_ */ +#endif /* PROTOCOLS_SHAMIROPTIONS_H_ */ diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 9f6a0129f..2de9d0af6 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -8,8 +8,8 @@ #include "Protocols/Shamir.h" #include "Protocols/ShamirInput.h" -#include "Machines/ShamirMachine.h" #include "GC/NoShare.h" +#include "ShamirOptions.h" #include "ShareInterface.h" template class ReplicatedPrep; @@ -49,7 +49,7 @@ class ShamirShare : public T, public ShareInterface const static bool needs_ot = false; const static bool dishonest_majority = false; - const static bool variable_players = true; + static true_type variable_players; const static bool expensive = false; const static bool malicious = false; const static int bit_generation_threshold = 3; @@ -136,4 +136,7 @@ class ShamirShare : public T, public ShareInterface } }; +template +true_type ShamirShare::variable_players; + #endif /* PROTOCOLS_SHAMIRSHARE_H_ */ diff --git a/Protocols/Share.h b/Protocols/Share.h index 1bff3b487..69d6bf5d3 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -9,7 +9,6 @@ using namespace std; #include "Math/gf2n.h" -#include "Protocols/SPDZ.h" #include "Protocols/SemiShare.h" #include "ShareInterface.h" @@ -156,6 +155,8 @@ class Share_ : public ShareInterface template class Share : public Share_, SemiShare> { + typedef Share This; + public: typedef Share_, SemiShare> super; @@ -177,10 +178,12 @@ class Share : public Share_, SemiShare> typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; typedef SPDZ Protocol; typedef MascotFieldPrep LivePrep; typedef MascotPrep RandomPrep; typedef MascotTriplePrep TriplePrep; + typedef DummyMatrixPrep MatrixPrep; static const bool expensive = true; diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 79ec8c44b..7e94e216f 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -35,6 +35,8 @@ class ShareInterface typedef GC::NoShare mac_type; typedef GC::NoShare mac_share_type; + typedef void DefaultMC; + static const bool needs_ot = false; static const bool expensive = false; static const bool expensive_triples = false; diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index d67ce7123..1238419b8 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -200,17 +200,27 @@ class ValueMatrix : public ValueInterface This operator*(const This& other) const { - assert(n_cols == other.n_rows); - This res(n_rows, other.n_cols); - if (entries.v.empty() or other.entries.v.empty()) - return res; + This res; + res.mul(*this, other); + return res; + } + + template + void mul(const ValueMatrix& a, const ValueMatrix& b) + { + assert(a.n_cols == b.n_rows); + auto& res = *this; + res = {a.n_rows, b.n_cols}; + if (a.entries.v.empty() or b.entries.v.empty()) + return; res.entries.init(); - for (int i = 0; i < n_rows; i++) - for (int j = 0; j < other.n_cols; j++) - for (int k = 0; k < n_cols; k++) - res[{i, j}] += (*this)[{i, k}] * other[{k, j}]; + for (int i = 0; i < a.n_rows; i++) + { + for (int j = 0; j < b.n_cols; j++) + for (int k = 0; k < a.n_cols; k++) + res[{i, j}] += a[{i, k}] * b[{k, j}]; + } res.check(); - return res; } bool operator!=(const This& other) const @@ -234,6 +244,13 @@ class ValueMatrix : public ValueInterface return res; } + void input(istream& is) + { + entries.init(); + for (auto& x: entries) + x.input(is, false); + } + friend ostream& operator<<(ostream& o, const This&) { return o; @@ -253,7 +270,7 @@ class ShareMatrix : public ValueMatrix, public ShareInterface typedef DummyLivePrep LivePrep; typedef ValueMatrix clear; - typedef clear open_type; + typedef ValueMatrix open_type; typedef typename T::mac_key_type mac_key_type; static string type_string() @@ -323,39 +340,38 @@ class ShareMatrix : public ValueMatrix, public ShareInterface }; template -ShareMatrix operator*(const ValueMatrix& a, +ShareMatrix operator*(const typename ShareMatrix::open_type& a, const ShareMatrix& b) { - assert(a.n_cols == b.n_rows); - ShareMatrix res(a.n_rows, b.n_cols); - if (a.entries.v.empty() or b.entries.v.empty()) - return res; - res.entries.init(); - for (int i = 0; i < a.n_rows; i++) - for (int j = 0; j < b.n_cols; j++) - for (int k = 0; k < a.n_cols; k++) - res[{i, j}] += a[{i, k}] * b[{k, j}]; - res.check(); + ShareMatrix res; + res.mul(a, b); + return res; +} + +template +ShareMatrix operator*(const ShareMatrix& b, + const ValueMatrix& a) +{ + ShareMatrix res; + res.mul(b, a); return res; } template class MatrixMC : public MAC_Check_Base> { + friend class Hemi; + typename T::MAC_Check& inner; public: - MatrixMC() : - inner( - *(OnlineOptions::singleton.direct ? - new typename T::Direct_MC : - new typename T::MAC_Check)) + MatrixMC(typename T::MAC_Check& inner) : + MAC_Check_Base>(inner.get_alphai()), inner(inner) { } ~MatrixMC() { - delete &inner; } void exchange(const Player& P) diff --git a/Protocols/SohoShare.h b/Protocols/SohoShare.h index 7525c97c5..96c5cf8d7 100644 --- a/Protocols/SohoShare.h +++ b/Protocols/SohoShare.h @@ -21,8 +21,10 @@ class SohoShare : public SemiShare typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef Beaver BasicProtocol; + typedef MaybeHemi Protocol; typedef SohoPrep LivePrep; + typedef DummyMatrixPrep MatrixPrep; static const bool needs_ot = false; diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index 85f9ecb91..68b0dbce6 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -28,6 +28,8 @@ template class TinySecret; template class Spdz2kShare : public Share> { + typedef Spdz2kShare This; + public: typedef Z2 tmp_type; typedef Share super; @@ -51,6 +53,8 @@ class Spdz2kShare : public Share> typedef Direct_MAC_Check_Z2k Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; + typedef Beaver BasicProtocol; + typedef DummyMatrixPrep MatrixPrep; typedef SPDZ2k Protocol; typedef Spdz2kPrep LivePrep; diff --git a/Protocols/SpdzWisePrep.h b/Protocols/SpdzWisePrep.h index 6b4df251f..292029206 100644 --- a/Protocols/SpdzWisePrep.h +++ b/Protocols/SpdzWisePrep.h @@ -23,12 +23,11 @@ class SpdzWisePrep : public MaliciousRingPrep void buffer_inputs(int player); - template - void buffer_bits(MaliciousRep3Share>); - template - void buffer_bits(MaliciousShamirShare>); - template - void buffer_bits(U); + void buffer_bits(false_type, true_type, false_type); + void buffer_bits(true_type, true_type, false_type); + void buffer_bits(true_type, false_type, true_type); + void buffer_bits(false_type, false_type, false_type); + void buffer_bits(false_type, false_type, true_type); public: SpdzWisePrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index 522104a13..d1a945d77 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -29,14 +29,13 @@ void SpdzWisePrep::buffer_triples() } template -template -void SpdzWisePrep::buffer_bits(MaliciousRep3Share>) +void SpdzWisePrep::buffer_bits(false_type, true_type, false_type) { MaliciousRingPrep::buffer_bits(); } -template<> -void SpdzWisePrep>>::buffer_bits() +template +void SpdzWisePrep::buffer_bits(false_type, false_type, true_type) { typedef MaliciousRep3Share part_type; vector bits; @@ -84,19 +83,24 @@ void SpdzWiseRingPrep::buffer_bits() template void SpdzWisePrep::buffer_bits() { - buffer_bits(typename T::share_type()); + buffer_bits(T::share_type::variable_players, T::clear::prime_field, + T::clear::characteristic_two); } template -template -void SpdzWisePrep::buffer_bits(MaliciousShamirShare>) +void SpdzWisePrep::buffer_bits(true_type, true_type, false_type) { buffer_bits_from_squares(*this); } template -template -void SpdzWisePrep::buffer_bits(U) +void SpdzWisePrep::buffer_bits(false_type, false_type, false_type) +{ + super::buffer_bits(); +} + +template +void SpdzWisePrep::buffer_bits(true_type, false_type, true_type) { super::buffer_bits(); } diff --git a/Protocols/SpdzWiseRing.hpp b/Protocols/SpdzWiseRing.hpp index 36e638d14..b0d2afd90 100644 --- a/Protocols/SpdzWiseRing.hpp +++ b/Protocols/SpdzWiseRing.hpp @@ -5,6 +5,8 @@ #include "SpdzWiseRing.h" +#include "PostSacrifice.hpp" + template SpdzWiseRing::SpdzWiseRing(Player& P) : SpdzWise(P), zero_prep(0, zero_usage), zero_proc(zero_output, diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h index 2b6d8b87d..6a5cd9140 100644 --- a/Protocols/TemiPrep.h +++ b/Protocols/TemiPrep.h @@ -38,7 +38,7 @@ class TemiMultiplier * Semi-honest triple generation with semi-homomorphic encryption */ template -class TemiPrep : public SemiHonestRingPrep +class TemiPrep : public HemiPrep { friend class HemiMatrixPrep; @@ -62,7 +62,7 @@ class TemiPrep : public SemiHonestRingPrep TemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), - SemiHonestRingPrep(proc, usage) + HemiPrep(proc, usage) { } diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 0fc851137..730439dff 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -54,9 +54,28 @@ class KeySetup } }; +class FilesBase +{ +public: + virtual ~FilesBase() {} + virtual void output_shares(word a) = 0; + + void make_AES(int n, bool zero, PRNG& G); + void make_DES(int n, bool zero, PRNG& G); +}; + template -class Files +class Files : public FilesBase { + void open(int i, const string& filename) + { + cout << "Opening " << filename << endl; + outf[i].open(filename,ios::out | ios::binary); + file_signature(key.get(i)).output(outf[i]); + if (outf[i].fail()) + throw file_error(filename); + } + public: ofstream* outf; int N; @@ -81,17 +100,26 @@ class Files stringstream filename; filename << prefix << "-P" << i; filename << PrepBase::get_suffix(thread_num); - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); - file_signature(key.get(i)).output(outf[i]); - if (outf[i].fail()) - throw file_error(filename.str().c_str()); + open(i, filename.str()); } } + Files(const KeySetup& key, const vector& filenames, PRNG& G) : + N(filenames.size()), key(key), G(G) + { + insecure_fake(false); + outf = new ofstream[N]; + for (int i = 0; i < N; i++) + open(i, filenames[i]); + } ~Files() { delete[] outf; } + + void output_shares(word a) + { + output_shares(typename T::open_type(a)); + } template void output_shares(const typename U::open_type& a) { diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 22da13753..85a96fd75 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -33,50 +33,10 @@ template class MaliciousCcdSecret; template void make_share(Share_* Sa,const U& a,int N,const V& key,PRNG& G) { - T x; - W mac, y; - mac = a * key; - Share_ S; - S.set_share(a); - S.set_mac(mac); - - for (int i=0; i -void make_share(SpdzWiseShare>* Sa,const U& a,int N,const T& key,PRNG& G) -{ - auto mac = a * key; - FixedVec shares, macs; - shares.randomize_to_sum(a, G); - macs.randomize_to_sum(mac, G); - - for (int i = 0; i < N; i++) - { - MaliciousRep3Share share, mac; - share[0] = shares[i]; - share[1] = shares[positive_modulo(i - 1, 3)]; - mac[0] = macs[i]; - mac[1] = macs[positive_modulo(i - 1, 3)]; - Sa[i].set_share(share); - Sa[i].set_mac(mac); - } -} - -template -void make_share(SpdzWiseShare>* Sa, const U& a, int N, - const V& key, PRNG& G) -{ - vector> shares(N), macs(N); - make_share(shares.data(), a, N, {}, G); - make_share(macs.data(), a * key, N, {}, G); + vector shares(N); + vector macs(N); + make_share(shares.data(), a, N, GC::NoValue(), G); + make_share(macs.data(), a * key, N, GC::NoValue(), G); for (int i = 0; i < N; i++) { Sa[i].set_share(shares[i]); @@ -610,16 +570,15 @@ void make_inverse(const KeySetup& key, int N, int ntrip, bool zero, check_files(files.outf, N); } -template -void plain_edabits(vector& as, - vector& bs, int length, PRNG& G, +template +void plain_edabits(vector& as, + vector& bs, int length, PRNG& G, int max_size, bool zero = false) { - int max_size = edabitvec::MAX_SIZE; as.resize(max_size); bs.clear(); bs.resize(length); - Z2 value; + Z2 value; for (int j = 0; j < max_size; j++) { if (not zero) diff --git a/README.md b/README.md index 8bea89228..05ef8f0ab 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ parties and malicious security. On Linux, this requires a working toolchain and [all requirements](#requirements). On Ubuntu, the following might suffice: ``` -sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-iostreams-dev libboost-thread-dev libgmp-dev libntl-dev libsodium-dev libssl-dev libtool python3 +sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-filesystem-dev libboost-iostreams-dev libboost-thread-dev libgmp-dev libntl-dev libsodium-dev libssl-dev libtool python3 ``` On MacOS, this requires [brew](https://brew.sh) to be installed, which will be used for all dependencies. @@ -257,11 +257,9 @@ compute the preprocessing time for a particular computation. #### Requirements - - GCC 7 or later (tested with up to 11) or LLVM/clang 6 or later + - GCC 7 or later (tested with up to 14) or LLVM/clang 10 or later (tested with up to 19). The default is to use clang because it performs - better. clang 9 doesn't support libOTe, so you - need to deactivate its use for these compilers (see the next - section). + better. - For protocols using oblivious transfer, libOTe with [the necessary patches](https://github.com/mkskeller/softspoken-implementation) but without SimplestOT. The easiest way is to run `make libote`, @@ -566,6 +564,13 @@ This is particularly useful if want to add new command line arguments specifical Note that when using this approach, all objects provided in the high level interface (e.g. sint, print_ln) need to be imported, because the `.mpc` file is interpreted directly by Python (instead of being read by `compile.py`.) +Furthermore, this only covers compilation, so you will need to run execution separately, for example: +``` +Scripts/mascot.sh hello_world +``` + +Also note that programs in the above form are not compatible with `compile.py` and `compile-run.py`. + #### Compiling and running programs from external directories Programs can also be edited, compiled and run from any directory with @@ -1127,25 +1132,6 @@ same player number in the preprocessing and the online phase. ## Benchmarking offline phases -#### SPDZ-2 offline phase - -This implementation is suitable to generate the preprocessed data used in the online phase. -You need to compile with `USE_NTL = 1` in `CONFIG.mine` to run this. - -For quick run on one machine, you can call the following: - -`./spdz2-offline.x -p 0 & ./spdz2-offline.x -p 1` - -More generally, run the following on every machine: - -`./spdz2-offline.x -p -N -h -c ` - -The number of parties are counted from 0. As seen in the quick example, you can omit the total number of parties if it is 2 and the hostname if all parties run on the same machine. Invoke `./spdz2-offline.x` for more explanation on the options. - -`./spdz2-offline.x` provides covert security according to some parameter c (at least 2). A malicious adversary will get caught with probability 1-1/c. There is a linear correlation between c and the running time, that is, running with 2c takes twice as long as running with c. The default for c is 10. - -The program will generate every kind of randomness required by the online phase except input tuples until you stop it. You can shut it down gracefully pressing Ctrl-c (or sending the interrupt signal `SIGINT`), but only after an initial phase, the end of which is marked by the output `Starting to produce gf2n`. Note that the initial phase has been reported to take up to an hour. Furthermore, 3 GB of RAM are required per party. - #### Benchmarking the MASCOT or SPDZ2k offline phase These implementations are not suitable to generate the preprocessed diff --git a/Scripts/list-field-protocols.sh b/Scripts/list-field-protocols.sh index c303c9775..22f27b67a 100755 --- a/Scripts/list-field-protocols.sh +++ b/Scripts/list-field-protocols.sh @@ -1,4 +1,5 @@ #!/bin/bash +dir="$(dirname $0)" echo rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - atlas mal-shamir sy-shamir semi hemi temi mascot soho cowgear chaigear + atlas mal-shamir sy-shamir semi mascot `$dir/list-he-protocols.sh` diff --git a/Scripts/list-he-protocols.sh b/Scripts/list-he-protocols.sh new file mode 100755 index 000000000..43c9e0932 --- /dev/null +++ b/Scripts/list-he-protocols.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +echo hemi temi soho cowgear chaigear diff --git a/Scripts/prep-usage.py b/Scripts/prep-usage.py index cb8ca6198..b2697582f 100755 --- a/Scripts/prep-usage.py +++ b/Scripts/prep-usage.py @@ -14,7 +14,11 @@ res = collections.defaultdict(lambda: 0) m = 0 -tapename = next(Program.read_tapes(sys.argv[1])) +if os.path.isfile(sys.argv[1]): + tapename = re.sub(r'\.bc', '', os.path.basename(sys.argv[1])) +else: + tapename = next(Program.read_tapes(sys.argv[1])) + res = Tape.ReqNum() for inst in Tape.read_instructions(tapename): res.update(inst.get_usage()) diff --git a/Scripts/run-online.sh b/Scripts/run-online.sh index 75513e17e..ebee9f069 100755 --- a/Scripts/run-online.sh +++ b/Scripts/run-online.sh @@ -5,4 +5,7 @@ SPDZROOT=$HERE/.. . $HERE/run-common.sh +echo NOTE: This runs the SPDZ online phase, requiring a prior preprocessing generation with Fake-Offline.x +echo See https://github.com/data61/MP-SPDZ/?tab=readme-ov-file#protocols for all protocols. + run_player Player-Online.x $* || exit 1 diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 3acc2fd06..481b2a012 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -36,11 +36,10 @@ void BufferBase::seekg(int pos) { assert(not is_pipe()); -#ifdef DEBUG_BUFFER - if (pos != 0) + if (pos != 0 and OnlineOptions::singleton.has_option("verbose_buffer")) printf("seek %d %s thread %d\n", pos, filename.c_str(), BaseMachine::thread_num); -#endif + if (not file) { if (pos == 0) @@ -65,6 +64,7 @@ void BufferBase::seekg(int pos) void BufferBase::try_rewind() { + assert(not OnlineOptions::singleton.has_option("no_rewind")); assert(not is_pipe()); #ifndef INSECURE diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index 88a5d3636..7e6021e33 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -16,19 +16,19 @@ void exit_error(const string& message) exit(1); } -IO_Error::IO_Error(const string& m) +IO_Error::IO_Error(const string& m) : + ans(m) { - ans = "IO-Error : " + m; } -file_error::file_error(const string& m) +file_error::file_error(const string& m) : + ans(m) { - ans = "File Error : " + m; } -Processor_Error::Processor_Error(const string& m) +Processor_Error::Processor_Error(const string& m) : + msg(m) { - msg = "Processor-Error : " + m; } Processor_Error::Processor_Error(const char* m) : @@ -46,7 +46,11 @@ wrong_gfp_size::wrong_gfp_size(const char* name, const bigint& p, } overflow::overflow(const string& name, size_t i, size_t n) : - runtime_error(name + " overflow: " + to_string(i) + "/" + to_string(n)) + runtime_error( + name + " overflow: " + to_string(long(i)) + "/" + to_string(n) + + ((long(i) < 0) ? ". A negative value indicates that " + "the computation modulus might be too small" : + "")) { } @@ -123,3 +127,17 @@ insufficient_shares::insufficient_shares(int expected, int actual, exception& e) + to_string(actual) + " (" + e.what() + ")") { } + +persistence_error::persistence_error(const string& error) : + runtime_error( + "Error while reading from persistence file. " + "You need to write to it first. " + "See https://mp-spdz.readthedocs.io/en/latest/io.html#persistence. " + "Details: " + error) +{ +} + +bytecode_error::bytecode_error(const string& error) : + runtime_error(error) +{ +} diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index f5bf59001..9b282f912 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -158,11 +158,7 @@ class Processor_Error: public exception return msg.c_str(); } }; -class Invalid_Instruction : public Processor_Error - { - public: - Invalid_Instruction(string m) : Processor_Error(m) {} - }; + class max_mod_sz_too_small : public exception { string msg; @@ -193,9 +189,9 @@ class not_enough_to_buffer : public runtime_error }; class needs_cleaning : public exception {}; -class closed_connection +class closed_connection : public exception { - const char* what() const + const char* what() const throw() { return "connection closed down"; } @@ -210,9 +206,9 @@ class no_singleton : public runtime_error } }; -class ran_out +class ran_out : public exception { - const char* what() const + const char* what() const throw() { return "insufficient preprocessing"; } @@ -304,4 +300,18 @@ class insufficient_shares : public runtime_error insufficient_shares(int expected, int actual, exception& e); }; +class persistence_error : public runtime_error +{ +public: + persistence_error(const string& error); +}; + +class bytecode_error : public runtime_error +{ +public: + bytecode_error(const string& error); +}; + +typedef bytecode_error Invalid_Instruction; + #endif diff --git a/Tools/Hash.h b/Tools/Hash.h index 1fe899039..135efbbae 100644 --- a/Tools/Hash.h +++ b/Tools/Hash.h @@ -36,7 +36,8 @@ class Hash void update(const vector& v, const vector& bit_lengths) { assert(v.size() == bit_lengths.size()); - octetStream tmp(v.size() * sizeof(T)); + octetStream tmp; + tmp.reserve(v.size() * sizeof(T)); for (size_t i = 0; i < v.size(); i++) v[i].pack(tmp, bit_lengths[i]); tmp.append(0); diff --git a/Tools/Waksman.cpp b/Tools/Waksman.cpp index 46812730e..99173247e 100644 --- a/Tools/Waksman.cpp +++ b/Tools/Waksman.cpp @@ -8,6 +8,7 @@ #include #include #include +#include template void append(vector& x, const vector& y) @@ -80,6 +81,25 @@ vector > Waksman::configure(const vector& perm) append(res.back(), p1_config.at(i)); } +#ifdef DEBUG_WAKSMAN + for (auto& x : perm) + std::cout << x << " "; + std::cout << endl; + for (auto x : I) + cout << x << " "; + cout << endl; + for (auto& x : O) + cout << int(x) << " "; + cout << endl; + for (auto& x : res) + { + for (auto y : x) + std::cout << y << " "; + std::cout << endl; + } + cout << endl; +#endif + assert(res.size() == Waksman(perm.size()).n_rounds()); return res; } diff --git a/Tools/aes-arm.h b/Tools/aes-arm.h deleted file mode 100644 index 33f24e883..000000000 --- a/Tools/aes-arm.h +++ /dev/null @@ -1,328 +0,0 @@ -// This file is reduced to functionality necessary for AES in order to avoid -// conflicts with simde. - -/* - * sse2neon is freely redistributable under the MIT License. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#if defined(__GNUC__) || defined(__clang__) -#pragma push_macro("FORCE_INLINE") -#pragma push_macro("ALIGN_STRUCT") -#define FORCE_INLINE static inline __attribute__((always_inline)) -#define ALIGN_STRUCT(x) __attribute__((aligned(x))) -#else -#error "Macro name collisions may happen with unsupported compiler." -#ifdef FORCE_INLINE -#undef FORCE_INLINE -#endif -#define FORCE_INLINE static inline -#ifndef ALIGN_STRUCT -#define ALIGN_STRUCT(x) __declspec(align(x)) -#endif -#endif - -#define vreinterpretq_m128i_u8(x) vreinterpretq_s64_u8(x) -#define vreinterpretq_m128i_u32(x) vreinterpretq_s64_u32(x) - -#define vreinterpretq_u8_m128i(x) vreinterpretq_u8_s64(x) - -// A struct is defined in this header file called 'SIMDVec' which can be used -// by applications which attempt to access the contents of an _m128 struct -// directly. It is important to note that accessing the __m128 struct directly -// is bad coding practice by Microsoft: @see: -// https://msdn.microsoft.com/en-us/library/ayeb3ayc.aspx -// -// However, some legacy source code may try to access the contents of an __m128 -// struct directly so the developer can use the SIMDVec as an alias for it. Any -// casting must be done manually by the developer, as you cannot cast or -// otherwise alias the base NEON data type for intrinsic operations. -// -// union intended to allow direct access to an __m128 variable using the names -// that the MSVC compiler provides. This union should really only be used when -// trying to access the members of the vector as integer values. GCC/clang -// allow native access to the float members through a simple array access -// operator (in C since 4.6, in C++ since 4.8). -// -// Ideally direct accesses to SIMD vectors should not be used since it can cause -// a performance hit. If it really is needed however, the original __m128 -// variable can be aliased with a pointer to this union and used to access -// individual components. The use of this union should be hidden behind a macro -// that is used throughout the codebase to access the members instead of always -// declaring this type of variable. -typedef union ALIGN_STRUCT(16) SIMDVec { - float m128_f32[4]; // as floats - DON'T USE. Added for convenience. - int8_t m128_i8[16]; // as signed 8-bit integers. - int16_t m128_i16[8]; // as signed 16-bit integers. - int32_t m128_i32[4]; // as signed 32-bit integers. - int64_t m128_i64[2]; // as signed 64-bit integers. - uint8_t m128_u8[16]; // as unsigned 8-bit integers. - uint16_t m128_u16[8]; // as unsigned 16-bit integers. - uint32_t m128_u32[4]; // as unsigned 32-bit integers. - uint64_t m128_u64[2]; // as unsigned 64-bit integers. -} SIMDVec; - -// casting using SIMDVec -#define vreinterpretq_nth_u64_m128i(x, n) (((SIMDVec *) &x)->m128_u64[n]) -#define vreinterpretq_nth_u32_m128i(x, n) (((SIMDVec *) &x)->m128_u32[n]) -#define vreinterpretq_nth_u8_m128i(x, n) (((SIMDVec *) &x)->m128_u8[n]) - -/* Backwards compatibility for compilers with lack of specific type support */ - -// Older gcc does not define vld1q_u8_x4 type -#if defined(__GNUC__) && !defined(__clang__) && \ - ((__GNUC__ == 10 && (__GNUC_MINOR__ <= 1)) || \ - (__GNUC__ == 9 && (__GNUC_MINOR__ <= 3)) || \ - (__GNUC__ == 8 && (__GNUC_MINOR__ <= 4)) || __GNUC__ <= 7) -FORCE_INLINE uint8x16x4_t _sse2neon_vld1q_u8_x4(const uint8_t *p) -{ - uint8x16x4_t ret; - ret.val[0] = vld1q_u8(p + 0); - ret.val[1] = vld1q_u8(p + 16); - ret.val[2] = vld1q_u8(p + 32); - ret.val[3] = vld1q_u8(p + 48); - return ret; -} -#else -// Wraps vld1q_u8_x4 -FORCE_INLINE uint8x16x4_t _sse2neon_vld1q_u8_x4(const uint8_t *p) -{ - return vld1q_u8_x4(p); -} -#endif - -#if !defined(__ARM_FEATURE_CRYPTO) -/* clang-format off */ -#define SSE2NEON_AES_DATA(w) \ - { \ - w(0x63), w(0x7c), w(0x77), w(0x7b), w(0xf2), w(0x6b), w(0x6f), \ - w(0xc5), w(0x30), w(0x01), w(0x67), w(0x2b), w(0xfe), w(0xd7), \ - w(0xab), w(0x76), w(0xca), w(0x82), w(0xc9), w(0x7d), w(0xfa), \ - w(0x59), w(0x47), w(0xf0), w(0xad), w(0xd4), w(0xa2), w(0xaf), \ - w(0x9c), w(0xa4), w(0x72), w(0xc0), w(0xb7), w(0xfd), w(0x93), \ - w(0x26), w(0x36), w(0x3f), w(0xf7), w(0xcc), w(0x34), w(0xa5), \ - w(0xe5), w(0xf1), w(0x71), w(0xd8), w(0x31), w(0x15), w(0x04), \ - w(0xc7), w(0x23), w(0xc3), w(0x18), w(0x96), w(0x05), w(0x9a), \ - w(0x07), w(0x12), w(0x80), w(0xe2), w(0xeb), w(0x27), w(0xb2), \ - w(0x75), w(0x09), w(0x83), w(0x2c), w(0x1a), w(0x1b), w(0x6e), \ - w(0x5a), w(0xa0), w(0x52), w(0x3b), w(0xd6), w(0xb3), w(0x29), \ - w(0xe3), w(0x2f), w(0x84), w(0x53), w(0xd1), w(0x00), w(0xed), \ - w(0x20), w(0xfc), w(0xb1), w(0x5b), w(0x6a), w(0xcb), w(0xbe), \ - w(0x39), w(0x4a), w(0x4c), w(0x58), w(0xcf), w(0xd0), w(0xef), \ - w(0xaa), w(0xfb), w(0x43), w(0x4d), w(0x33), w(0x85), w(0x45), \ - w(0xf9), w(0x02), w(0x7f), w(0x50), w(0x3c), w(0x9f), w(0xa8), \ - w(0x51), w(0xa3), w(0x40), w(0x8f), w(0x92), w(0x9d), w(0x38), \ - w(0xf5), w(0xbc), w(0xb6), w(0xda), w(0x21), w(0x10), w(0xff), \ - w(0xf3), w(0xd2), w(0xcd), w(0x0c), w(0x13), w(0xec), w(0x5f), \ - w(0x97), w(0x44), w(0x17), w(0xc4), w(0xa7), w(0x7e), w(0x3d), \ - w(0x64), w(0x5d), w(0x19), w(0x73), w(0x60), w(0x81), w(0x4f), \ - w(0xdc), w(0x22), w(0x2a), w(0x90), w(0x88), w(0x46), w(0xee), \ - w(0xb8), w(0x14), w(0xde), w(0x5e), w(0x0b), w(0xdb), w(0xe0), \ - w(0x32), w(0x3a), w(0x0a), w(0x49), w(0x06), w(0x24), w(0x5c), \ - w(0xc2), w(0xd3), w(0xac), w(0x62), w(0x91), w(0x95), w(0xe4), \ - w(0x79), w(0xe7), w(0xc8), w(0x37), w(0x6d), w(0x8d), w(0xd5), \ - w(0x4e), w(0xa9), w(0x6c), w(0x56), w(0xf4), w(0xea), w(0x65), \ - w(0x7a), w(0xae), w(0x08), w(0xba), w(0x78), w(0x25), w(0x2e), \ - w(0x1c), w(0xa6), w(0xb4), w(0xc6), w(0xe8), w(0xdd), w(0x74), \ - w(0x1f), w(0x4b), w(0xbd), w(0x8b), w(0x8a), w(0x70), w(0x3e), \ - w(0xb5), w(0x66), w(0x48), w(0x03), w(0xf6), w(0x0e), w(0x61), \ - w(0x35), w(0x57), w(0xb9), w(0x86), w(0xc1), w(0x1d), w(0x9e), \ - w(0xe1), w(0xf8), w(0x98), w(0x11), w(0x69), w(0xd9), w(0x8e), \ - w(0x94), w(0x9b), w(0x1e), w(0x87), w(0xe9), w(0xce), w(0x55), \ - w(0x28), w(0xdf), w(0x8c), w(0xa1), w(0x89), w(0x0d), w(0xbf), \ - w(0xe6), w(0x42), w(0x68), w(0x41), w(0x99), w(0x2d), w(0x0f), \ - w(0xb0), w(0x54), w(0xbb), w(0x16) \ - } -/* clang-format on */ - -/* X Macro trick. See https://en.wikipedia.org/wiki/X_Macro */ -#define SSE2NEON_AES_H0(x) (x) -static const uint8_t SSE2NEON_sbox[256] = SSE2NEON_AES_DATA(SSE2NEON_AES_H0); -#undef SSE2NEON_AES_H0 - -// In the absence of crypto extensions, implement aesenc using regular neon -// intrinsics instead. See: -// https://www.workofard.com/2017/01/accelerated-aes-for-the-arm64-linux-kernel/ -// https://www.workofard.com/2017/07/ghash-for-low-end-cores/ and -// https://github.com/ColinIanKing/linux-next-mirror/blob/b5f466091e130caaf0735976648f72bd5e09aa84/crypto/aegis128-neon-inner.c#L52 -// for more information Reproduced with permission of the author. -FORCE_INLINE __m128i _mm_aesenc_si128(__m128i EncBlock, __m128i RoundKey) -{ -#if defined(__aarch64__) - static const uint8_t shift_rows[] = {0x0, 0x5, 0xa, 0xf, 0x4, 0x9, - 0xe, 0x3, 0x8, 0xd, 0x2, 0x7, - 0xc, 0x1, 0x6, 0xb}; - static const uint8_t ror32by8[] = {0x1, 0x2, 0x3, 0x0, 0x5, 0x6, 0x7, 0x4, - 0x9, 0xa, 0xb, 0x8, 0xd, 0xe, 0xf, 0xc}; - - uint8x16_t v; - uint8x16_t w = vreinterpretq_u8_m128i(EncBlock); - - // shift rows - w = vqtbl1q_u8(w, vld1q_u8(shift_rows)); - - // sub bytes - v = vqtbl4q_u8(_sse2neon_vld1q_u8_x4(SSE2NEON_sbox), w); - v = vqtbx4q_u8(v, _sse2neon_vld1q_u8_x4(SSE2NEON_sbox + 0x40), w - 0x40); - v = vqtbx4q_u8(v, _sse2neon_vld1q_u8_x4(SSE2NEON_sbox + 0x80), w - 0x80); - v = vqtbx4q_u8(v, _sse2neon_vld1q_u8_x4(SSE2NEON_sbox + 0xc0), w - 0xc0); - - // mix columns - w = (v << 1) ^ (uint8x16_t)(((int8x16_t) v >> 7) & 0x1b); - w ^= (uint8x16_t) vrev32q_u16((uint16x8_t) v); - w ^= vqtbl1q_u8(v ^ w, vld1q_u8(ror32by8)); - - // add round key - return vreinterpretq_m128i_u8(w) ^ RoundKey; - -#else /* ARMv7-A NEON implementation */ -#define SSE2NEON_AES_B2W(b0, b1, b2, b3) \ - (((uint32_t)(b3) << 24) | ((uint32_t)(b2) << 16) | ((uint32_t)(b1) << 8) | \ - (b0)) -#define SSE2NEON_AES_F2(x) ((x << 1) ^ (((x >> 7) & 1) * 0x011b /* WPOLY */)) -#define SSE2NEON_AES_F3(x) (SSE2NEON_AES_F2(x) ^ x) -#define SSE2NEON_AES_U0(p) \ - SSE2NEON_AES_B2W(SSE2NEON_AES_F2(p), p, p, SSE2NEON_AES_F3(p)) -#define SSE2NEON_AES_U1(p) \ - SSE2NEON_AES_B2W(SSE2NEON_AES_F3(p), SSE2NEON_AES_F2(p), p, p) -#define SSE2NEON_AES_U2(p) \ - SSE2NEON_AES_B2W(p, SSE2NEON_AES_F3(p), SSE2NEON_AES_F2(p), p) -#define SSE2NEON_AES_U3(p) \ - SSE2NEON_AES_B2W(p, p, SSE2NEON_AES_F3(p), SSE2NEON_AES_F2(p)) - static const uint32_t ALIGN_STRUCT(16) aes_table[4][256] = { - SSE2NEON_AES_DATA(SSE2NEON_AES_U0), - SSE2NEON_AES_DATA(SSE2NEON_AES_U1), - SSE2NEON_AES_DATA(SSE2NEON_AES_U2), - SSE2NEON_AES_DATA(SSE2NEON_AES_U3), - }; -#undef SSE2NEON_AES_B2W -#undef SSE2NEON_AES_F2 -#undef SSE2NEON_AES_F3 -#undef SSE2NEON_AES_U0 -#undef SSE2NEON_AES_U1 -#undef SSE2NEON_AES_U2 -#undef SSE2NEON_AES_U3 - - uint32_t x0 = _mm_cvtsi128_si32(EncBlock); - uint32_t x1 = _mm_cvtsi128_si32(_mm_shuffle_epi32(EncBlock, 0x55)); - uint32_t x2 = _mm_cvtsi128_si32(_mm_shuffle_epi32(EncBlock, 0xAA)); - uint32_t x3 = _mm_cvtsi128_si32(_mm_shuffle_epi32(EncBlock, 0xFF)); - - __m128i out = _mm_set_epi32( - (aes_table[0][x3 & 0xff] ^ aes_table[1][(x0 >> 8) & 0xff] ^ - aes_table[2][(x1 >> 16) & 0xff] ^ aes_table[3][x2 >> 24]), - (aes_table[0][x2 & 0xff] ^ aes_table[1][(x3 >> 8) & 0xff] ^ - aes_table[2][(x0 >> 16) & 0xff] ^ aes_table[3][x1 >> 24]), - (aes_table[0][x1 & 0xff] ^ aes_table[1][(x2 >> 8) & 0xff] ^ - aes_table[2][(x3 >> 16) & 0xff] ^ aes_table[3][x0 >> 24]), - (aes_table[0][x0 & 0xff] ^ aes_table[1][(x1 >> 8) & 0xff] ^ - aes_table[2][(x2 >> 16) & 0xff] ^ aes_table[3][x3 >> 24])); - - return _mm_xor_si128(out, RoundKey); -#endif -} - -// Perform the last round of an AES encryption flow on data (state) in a using -// the round key in RoundKey, and store the result in dst. -// https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_aesenclast_si128 -FORCE_INLINE __m128i _mm_aesenclast_si128(__m128i a, __m128i RoundKey) -{ - /* FIXME: optimized for NEON */ - uint8_t v[4][4] = { - {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 0)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 5)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 10)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 15)]}, - {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 4)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 9)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 14)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 3)]}, - {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 8)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 13)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 2)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 7)]}, - {SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 12)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 1)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 6)], - SSE2NEON_sbox[vreinterpretq_nth_u8_m128i(a, 11)]}, - }; - for (int i = 0; i < 16; i++) - vreinterpretq_nth_u8_m128i(a, i) = - v[i / 4][i % 4] ^ vreinterpretq_nth_u8_m128i(RoundKey, i); - return a; -} - -// Emits the Advanced Encryption Standard (AES) instruction aeskeygenassist. -// This instruction generates a round key for AES encryption. See -// https://kazakov.life/2017/11/01/cryptocurrency-mining-on-ios-devices/ -// for details. -// -// https://msdn.microsoft.com/en-us/library/cc714138(v=vs.120).aspx -FORCE_INLINE __m128i _mm_aeskeygenassist_si128(__m128i key, const int rcon) -{ - uint32_t X1 = _mm_cvtsi128_si32(_mm_shuffle_epi32(key, 0x55)); - uint32_t X3 = _mm_cvtsi128_si32(_mm_shuffle_epi32(key, 0xFF)); - for (int i = 0; i < 4; ++i) { - ((uint8_t *) &X1)[i] = SSE2NEON_sbox[((uint8_t *) &X1)[i]]; - ((uint8_t *) &X3)[i] = SSE2NEON_sbox[((uint8_t *) &X3)[i]]; - } - return _mm_set_epi32(((X3 >> 8) | (X3 << 24)) ^ rcon, X3, - ((X1 >> 8) | (X1 << 24)) ^ rcon, X1); -} -#undef SSE2NEON_AES_DATA - -#else /* __ARM_FEATURE_CRYPTO */ -// Implements equivalent of 'aesenc' by combining AESE (with an empty key) and -// AESMC and then manually applying the real key as an xor operation. This -// unfortunately means an additional xor op; the compiler should be able to -// optimize this away for repeated calls however. See -// https://blog.michaelbrase.com/2018/05/08/emulating-x86-aes-intrinsics-on-armv8-a -// for more details. -FORCE_INLINE __m128i _mm_aesenc_si128(__m128i a, __m128i b) -{ - return vreinterpretq_m128i_u8( - vaesmcq_u8(vaeseq_u8(vreinterpretq_u8_m128i(a), vdupq_n_u8(0))) ^ - vreinterpretq_u8_m128i(b)); -} - -// https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_aesenclast_si128 -FORCE_INLINE __m128i _mm_aesenclast_si128(__m128i a, __m128i RoundKey) -{ - return _mm_xor_si128(vreinterpretq_m128i_u8(vaeseq_u8( - vreinterpretq_u8_m128i(a), vdupq_n_u8(0))), - RoundKey); -} - -FORCE_INLINE __m128i _mm_aeskeygenassist_si128(__m128i a, const int rcon) -{ - // AESE does ShiftRows and SubBytes on A - uint8x16_t u8 = vaeseq_u8(vreinterpretq_u8_m128i(a), vdupq_n_u8(0)); - - uint8x16_t dest = { - // Undo ShiftRows step from AESE and extract X1 and X3 - u8[0x4], u8[0x1], u8[0xE], u8[0xB], // SubBytes(X1) - u8[0x1], u8[0xE], u8[0xB], u8[0x4], // ROT(SubBytes(X1)) - u8[0xC], u8[0x9], u8[0x6], u8[0x3], // SubBytes(X3) - u8[0x9], u8[0x6], u8[0x3], u8[0xC], // ROT(SubBytes(X3)) - }; - uint32x4_t r = {0, (unsigned) rcon, 0, (unsigned) rcon}; - return vreinterpretq_m128i_u8(dest) ^ vreinterpretq_m128i_u32(r); -} -#endif diff --git a/Tools/intrinsics.h b/Tools/intrinsics.h index 45f503c28..ecc29b725 100644 --- a/Tools/intrinsics.h +++ b/Tools/intrinsics.h @@ -13,12 +13,9 @@ #ifdef __aarch64__ #define SIMDE_X86_AVX_ENABLE_NATIVE_ALIASES #define SIMDE_X86_AVX2_ENABLE_NATIVE_ALIASES -#define SIMDE_X86_SSE2_ENABLE_NATIVE_ALIASES -#define SIMDE_X86_SSE4_1_ENABLE_NATIVE_ALIASES -#define SIMDE_X86_PCLMUL_ENABLE_NATIVE_ALIASES #include "simde/simde/x86/avx2.h" #include "simde/simde/x86/clmul.h" -#include "aes-arm.h" +#include "sse2neon/sse2neon.h" #endif #endif diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index 980c8c6aa..4a53537bb 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -90,7 +90,7 @@ string octetStream::str() const void octetStream::hash(octetStream& output) const { - assert(output.mxlen >= crypto_generichash_blake2b_BYTES_MIN); + output.resize(crypto_generichash_BYTES_MIN); crypto_generichash(output.data, crypto_generichash_BYTES_MIN, data, len, NULL, 0); output.len=crypto_generichash_BYTES_MIN; } @@ -221,6 +221,15 @@ void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive } + +void octetStream::input(const string& filename) +{ + ifstream s(filename); + if (not s.good()) + throw file_error("cannot read from " + filename); + input(s); +} + void octetStream::input(istream& s) { size_t size; diff --git a/Tools/octetStream.h b/Tools/octetStream.h index f36c3d864..485835be9 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -60,6 +60,9 @@ class octetStream // buffers for bit packing array bits; + // keep private to avoid confusing conversion from integers + octetStream(size_t maxlen); + void reset(); public: @@ -75,8 +78,6 @@ class octetStream void assign(const octetStream& os); octetStream() : len(0), mxlen(0), ptr(0), data(0) {} - /// Initial allocation - octetStream(size_t maxlen); /// Initial buffer octetStream(size_t len, const octet* source); /// Initial buffer @@ -114,7 +115,6 @@ class octetStream /// Hash content octetStream hash() const; - // output must have length at least HASH_SIZE void hash(octetStream& output) const; // The following produces a check sum for debugging purposes bigint check_sum(int req_bytes=crypto_hash_BYTES) const; @@ -250,6 +250,8 @@ class octetStream template void Receive(T socket_num); + /// Input from file, overwriting current content + void input(const string& filename); /// Input from stream, overwriting current content void input(istream& s); /// Output to stream diff --git a/Tools/parse.h b/Tools/parse.h index b4965909f..db9da854d 100644 --- a/Tools/parse.h +++ b/Tools/parse.h @@ -40,14 +40,19 @@ inline void get_ints(int* res, istream& s, int count) res[i] = be32toh(res[i]); } -inline void get_vector(int m, vector& start, istream& s) +inline void get_vector(unsigned m, vector& start, istream& s) { if (s.fail()) throw runtime_error("error when parsing vector"); - start.resize(m); - s.read((char*) start.data(), 4 * m); - for (int i = 0; i < m; i++) - start[i] = be32toh(start[i]); + int* buffer = new int[m]; + s.read((char*) buffer, 4 * m); + if (not s.fail()) + { + start.resize(m); + for (unsigned i = 0; i < m; i++) + start[i] = be32toh(buffer[i]); + } + delete[] buffer; } inline void get_string(string& res, istream& s) diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 12ceb7d01..37fee7a5f 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -36,7 +36,7 @@ #include "Math/Z2k.hpp" #include "Math/gfp.hpp" #include "GC/Secret.hpp" -#include "Machines/ShamirMachine.hpp" +#include "Machines/Shamir.hpp" #include #include @@ -52,6 +52,7 @@ class FakeParams public: ez::ezOptionParser opt; + DataPositions usage; template int generate(); @@ -72,7 +73,29 @@ class FakeParams template void make_basic(const KeySetup& key, int nplayers, int nitems, bool zero, PRNG& G, const KeySetup& bit_keys = {}); + template + void make_minimal(const KeySetup& key, int nplayers, int nitems, bool zero, PRNG& G); + template + void make_mult_triples(const KeySetup& key, int N, int ntrip, + bool zero, const string& prep_data_prefix, PRNG& G, int thread_num = -1); + template + void make_square_tuples(const KeySetup& key, int N, int ntrip, + const string& str, bool zero, PRNG& G); + template + void make_bits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, + int thread_num = -1); + template + void make_inverse(const KeySetup& key, int N, int ntrip, + bool zero, const string& prep_data_prefix, PRNG& G); + + template + void make_inputs(const KeySetup& key, int N, int ntrip, const string& str, + bool zero, PRNG& G); + + template + void make_dabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, + const KeySetup& bit_key = { }); template void make_edabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, false_type, const KeySetup& bit_key = {}); @@ -81,16 +104,59 @@ class FakeParams const KeySetup& = {}) { } + + template + void make_matrix_triples(const KeySetup& key, PRNG& G); + + template + int get_usage(Dtype type, int def) + { + auto field_type = T::clear::field_type(); + if (not usage.empty()) + { + auto res = usage.files[field_type][type] + BUFFER_SIZE; + if (type == DATA_TRIPLE and field_type == DATA_INT) + res -= usage.triples_for_matmul(); + return limit(res); + } + else + return def; + } + + long long limit(long long demand) + { + if (opt.isSet("--default") and default_num < demand) + return default_num; + else + return demand; + } }; +template +void FakeParams::make_mult_triples(const KeySetup& key, int N, + int ntrip, bool zero, const string& prep_data_prefix, PRNG& G, int thread_num) +{ + ::make_mult_triples(key, N, get_usage(DATA_TRIPLE, ntrip), zero, + prep_data_prefix, G, thread_num); +} + +template +void FakeParams::make_inverse(const KeySetup& key, int N, + int ntrip, bool zero, const string& prep_data_prefix, PRNG& G) +{ + ::make_inverse(key, N, get_usage(DATA_INVERSE, ntrip), zero, + prep_data_prefix, G); +} /* N = Number players * ntrip = Number tuples needed */ template -void make_square_tuples(const KeySetup& key,int N,int ntrip,const string& str,bool zero,PRNG& G) +void FakeParams::make_square_tuples(const KeySetup& key, int N, int ntrip, + const string& str, bool zero, PRNG& G) { (void) str; + ntrip = get_usage(DATA_SQUARE, ntrip); Files files(N, key, prep_data_prefix, DATA_SQUARE, G); typename T::clear a,c; /* Generate Squares */ @@ -109,9 +175,10 @@ void make_square_tuples(const KeySetup& key,int N,int ntrip,const string& str * ntrip = Number bits needed */ template -void make_bits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, - int thread_num = -1) +void FakeParams::make_bits(const KeySetup& key, int N, int ntrip, bool zero, + PRNG& G, int thread_num) { + ntrip = get_usage(DATA_BIT, ntrip); Files files(N, key, prep_data_prefix, DATA_BIT, G, thread_num); typename T::clear a; @@ -125,12 +192,11 @@ void make_bits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, } template -void make_dabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, - const KeySetup& bit_key = { }) +void FakeParams::make_dabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, + const KeySetup& bit_key) { - Files files(N, key, - get_prep_sub_dir(prep_data_prefix, N) - + DataPositions::dtype_names[DATA_DABIT] + "-" + T::type_short(), G); + ntrip = get_usage(DATA_DABIT, ntrip); + Files files(N, key, prep_data_prefix, DATA_DABIT, G); for (int i = 0; i < ntrip; i++) { bool bit = not zero && G.get_bit(); @@ -145,18 +211,37 @@ void FakeParams::make_edabits(const KeySetup& key, int N, int ntrip, bool zer { vector lengths; opt.get("-e")->getInts(lengths); + + if (not usage.empty()) + { + lengths.clear(); + for (auto& x : usage.edabits) + lengths.push_back(x.first.second); + } + for (auto length : lengths) { - Files files(N, key, - get_prep_sub_dir(prep_data_prefix, N) - + "edaBits-" + to_string(length), G); + vector filenames; + for (int i = 0; i < N; i++) + filenames.push_back( + PrepBase::get_edabit_filename( + get_prep_sub_dir(prep_data_prefix, N), length, i)); + Files files(key, filenames, G); bigint value; int max_size = edabitvec::MAX_SIZE; - for (int i = 0; i < ntrip / max_size; i++) + int n; + + if (usage.empty()) + n = ntrip / max_size; + else + n = limit(usage.edabits[{false, length}] + + usage.edabits[{true, length}]); + + for (int i = 0; i < n + 1; i++) { vector as; vector bs; - plain_edabits(as, bs, length, G, zero); + plain_edabits(as, bs, length, G, max_size, zero); for (auto& a : as) files.template output_shares(a); for (auto& b : bs) @@ -169,42 +254,41 @@ void FakeParams::make_edabits(const KeySetup& key, int N, int ntrip, bool zer * ntrip = Number inputs needed */ template -void make_inputs(const KeySetup& key,int N,int ntrip,const string& str,bool zero,PRNG& G) +void FakeParams::make_inputs(const KeySetup& key, int N, int ntrip, + const string& str, bool zero, PRNG& G) { (void) str; - ofstream* outf=new ofstream[N]; typename T::open_type a; vector Sa(N); /* Generate Inputs */ for (int player=0; player(DATA_OPEN, 0); + ntrip = limit(ntrip) + BUFFER_SIZE; + } + + vector filenames; + for (int i=0; i(prep_data_prefix, N), T::type_short(), player, i); - cout << "Opening " << filename << endl; - outf[i].open(filename, ios::out | ios::binary); - file_signature(key.get(i)).output(outf[i]); - if (outf[i].fail()) - throw file_error(filename); + filenames.push_back(filename); } + Files files(key, filenames, G); for (int i=0; i& key, int N, int ntrip, bool zero, PRNG& G) { stringstream ss; ss << get_prep_sub_dir(prep_data_prefix, N) << "Sbox-" << T::type_short(); Files files(N, key, ss.str(), G); + files.make_AES(ntrip, zero, G); +} + +void FilesBase::make_AES(int ntrip, bool zero, PRNG& G) +{ + auto& files = *this; gf2n_short x; for (int i = 0; i < ntrip; i++) @@ -266,12 +356,12 @@ void make_AES(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) { if (!zero) mask = G.get_uchar(); expand_byte(x, mask); - files.output_shares(x); + files.output_shares(x.get()); for (int j = 0; j < 256; j++) { expand_byte(x, sbox[mask ^ j]); - files.output_shares(x); + files.output_shares(x.get()); } } } @@ -295,6 +385,12 @@ void make_DES(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) stringstream ss; ss << get_prep_sub_dir(prep_data_prefix, N) << "SboxDes-" << T::type_short(); Files files(N, key, ss.str(), G); + files.make_DES(ntrip, zero, G); +} + +void FilesBase::make_DES(int ntrip, bool zero, PRNG& G) +{ + auto& files = *this; gf2n_short x; for (int i = 0; i < ntrip; i++) @@ -305,7 +401,7 @@ void make_DES(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) mask = G.get_uchar(); mask &= 63; //take only first 6 bits expand_byte(x, mask); - files.output_shares(x); + files.output_shares(x.get()); for (int j = 0; j < 64; j++) { files.output_shares(des_sbox[r][mask ^ j]); @@ -335,7 +431,50 @@ void make_Sbox(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G) } template -void make_minimal(const KeySetup& key, int nplayers, int nitems, bool zero, PRNG& G) +void FakeParams::make_matrix_triples(const KeySetup& key, PRNG& G) +{ + for (auto& x : usage.matmuls) + if (x.second > 0) + { + auto& dim = x.first; + Files files(nplayers, key, + PrepBase::get_matrix_prefix( + get_prep_sub_dir(prep_data_prefix, nplayers, + true), dim), G); + for (int i = 0; i < limit(x.second); i++) + { + ValueMatrix matrices[3] = {{dim[0], dim[1]}, {dim[1], + dim[2]}, {dim[0], dim[2]}}; + for (auto& matrix : matrices) + matrix.entries.init(); + if (zero) + { + for (int i = 0; i < 2; i++) + { + for (int j = 0; + j < min(matrices[i].n_rows, matrices[i].n_cols); + j++) + { + matrices[i][{j, j}] = 1; + } + } + } + else + { + matrices[0].randomize(G); + matrices[1].randomize(G); + } + matrices[2] = matrices[0] * matrices[1]; + for (auto& matrix : matrices) + for (auto& value : matrix.entries) + files.output_shares(value); + } + } +} + +template +void FakeParams::make_minimal(const KeySetup& key, int nplayers, int nitems, + bool zero, PRNG& G) { make_mult_triples(key, nplayers, nitems, zero, prep_data_prefix, G); make_bits(key, nplayers, nitems, zero, G); @@ -351,6 +490,8 @@ void FakeParams::make_basic(const KeySetup& key, int nplayers, make_dabits(key, nplayers, nitems, zero, G, bit_key); make_edabits(key, nplayers, nitems, zero, G, T::clear::characteristic_two, bit_key); + if (not T::clear::characteristic_two) + make_matrix_triples(key, G); if (T::clear::invertible) { make_inverse(key, nplayers, nitems, zero, prep_data_prefix, G); @@ -560,6 +701,15 @@ int main(int argc, const char** argv) "-seed", // Flag token. "--prngseed" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "", // Help description. + "-p", // Flag token. + "--program" // Flag token. + ); opt.parse(argc, argv); int lgp; @@ -635,6 +785,16 @@ int FakeParams::generate() ShamirOptions::singleton.set_threshold(opt); } + if (opt.isSet("--program")) + { + Program program(nplayers); + string name; + opt.get("--program") -> getString(name); + BaseMachine machine; + program.parse_with_error("Programs/Bytecode/" + name + "-0.bc"); + this->usage = program.get_offline_data_used(); + } + int ntrip2=0, ntripp=0, nbits2=0,nbitsp=0,nsqr2=0,nsqrp=0,ninp2=0,ninpp=0,ninv=0; vector list_options; int lg2, lgp; @@ -740,6 +900,7 @@ int FakeParams::generate() make_inverse(key2,nplayers,ninv,zero,prep_data_prefix,G); if (T::clear::invertible) make_inverse(keyp,nplayers,ninv,zero,prep_data_prefix,G); + make_matrix_triples(keyp, G); if (opt.isSet("-s")) { @@ -845,7 +1006,6 @@ void FakeParams::generate_field(true_type, PRNG& G) make_basic>({}, nplayers, default_num, zero, G); make_basic>({}, nplayers, default_num, zero, G); - make_basic>({}, nplayers, default_num, zero, G); if (nplayers > 2) { @@ -854,6 +1014,7 @@ void FakeParams::generate_field(true_type, PRNG& G) make_basic>({}, nplayers, default_num, zero, G); make_with_mac_key>>(nplayers, default_num, zero, G); + make_basic>({}, nplayers, default_num, zero, G); } } @@ -876,5 +1037,7 @@ inline void FakeParams::generate_ring(PRNG& G) make_basic>({}, nplayers, default_num, zero, G); make_basic>>({}, nplayers, default_num, zero, G); - make_basic>>({}, nplayers, default_num, zero, G); + + if (nplayers > 2) + make_basic>>({}, nplayers, default_num, zero, G); } diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index a07d74276..e852dd8ba 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -3,33 +3,7 @@ * */ -#include "GC/TinierSecret.h" -#include "GC/PostSacriSecret.h" -#include "GC/CcdSecret.h" -#include "GC/MaliciousCcdSecret.h" -#include "GC/AtlasSecret.h" -#include "GC/TinyMC.h" -#include "GC/VectorInput.h" -#include "GC/PostSacriBin.h" -#include "Protocols/ProtocolSet.h" - -#include "GC/ShareSecret.hpp" -#include "GC/CcdPrep.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/RepPrep.hpp" -#include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/ThreadMaster.hpp" -#include "GC/SemiSecret.hpp" -#include "Protocols/Atlas.hpp" -#include "Protocols/MaliciousRepPrep.hpp" -#include "Protocols/Share.hpp" -#include "Protocols/MaliciousRepMC.hpp" -#include "Protocols/Shamir.hpp" -#include "Protocols/fake-stuff.hpp" -#include "Machines/ShamirMachine.hpp" -#include "Machines/Rep4.hpp" -#include "Machines/Rep.hpp" +#include "Machines/maximal.hpp" template void run(int argc, char** argv); diff --git a/Utils/export-a2b.cpp b/Utils/export-a2b.cpp new file mode 100644 index 000000000..c6f92c6b4 --- /dev/null +++ b/Utils/export-a2b.cpp @@ -0,0 +1,96 @@ +/* + * export-a2b.cpp + * + */ + +// use maximal.hpp if changes cause compilation errors + +#include "Machines/minimal.hpp" + +template> +void run(int, const char**); + +int main(int argc, const char** argv) +{ + if (argc < 2 or argv[2] == string("ring")) + run>(argc, argv); + else if (argv[2] == string("atlas")) + run>>(argc, argv); + else if (argv[2] == string("mascot")) + run>, Share>(argc, argv); + else if (argv[2] == string("cowgear")) + run>>(argc, argv); + else if (argv[2] == string("dealer-ring")) + run>>(argc, argv); + else if (argv[2] == string("hemi")) + run>>(argc, argv); + else if (argv[2] == string("rep4-ring")) + run>(argc, argv); + else if (argv[2] == string("semi2k")) + run>(argc, argv); + else if (argv[2] == string("spdz2k")) + run, Share>(argc, argv); + else if (argv[2] == string("sy-rep-ring")) + run>(argc, argv); + else + { + cerr << "unsupported protocol: " << argv[2] << endl; + exit(1); + } +} + +template +void run(int argc, const char** argv) +{ + assert(argc > 3); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, atoi(argv[3]), "localhost", port_base); + + typedef typename share_type::bit_type bit_share_type; + + Machine machine(N); + Opener MC(machine.get_player(), machine.get_bit_mac_key()); + + int n = 10; + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back( + share_type::constant(i + 1, my_number, + machine.get_sint_mac_key())); + + vector> outputs(n, + vector(1, + bit_share_type::constant(0, my_number, + machine.get_bit_mac_key()))); + + vector args = {{inputs}, {16, outputs}}; + FunctionArgument res; + + machine.run_function("a2b", res, args); + + MC.init_open(); + for (auto& x : outputs) + MC.prepare_open(x.at(0)); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (x != i + 1 and share_type::real_shares(machine.get_player())) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } +} diff --git a/Utils/export-b2a.cpp b/Utils/export-b2a.cpp new file mode 100644 index 000000000..af5423ee8 --- /dev/null +++ b/Utils/export-b2a.cpp @@ -0,0 +1,64 @@ +/* + * export-b2a.cpp + * + */ + +#include "Machines/maximal.hpp" + +int main(int argc, const char** argv) +{ + assert(argc > 1); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + + typedef Rep3Share2<64> share_type; + Machine machine(N); + Opener MC(machine.get_player(), machine.get_sint_mac_key()); + MixedProtocolSet set(machine.get_player(), machine); + + int n = 10; + vector outputs(n); + vector> inputs(n); + + auto& inputter = set.binary.input; + inputter.reset(0); + for (int i = 0; i < n; i++) + if (my_number == 0) + inputter.add_mine(i + 1, 16); + else + inputter.add_other(0); + inputter.exchange(); + for (int i = 0; i < n; i++) + inputs.at(i).push_back(inputter.finalize(0, 16)); + + vector args = {{outputs, true}, {16, inputs}}; + FunctionArgument res; + + machine.run_function("b2a", res, args); + + MC.init_open(); + for (auto& x : outputs) + MC.prepare_open(x); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (x != i + 1) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } +} diff --git a/Utils/export-sort.cpp b/Utils/export-sort.cpp new file mode 100644 index 000000000..15ba8cf58 --- /dev/null +++ b/Utils/export-sort.cpp @@ -0,0 +1,65 @@ +/* + * export-sort.cpp + * + */ + +#include "Machines/maximal.hpp" + +int main(int argc, const char** argv) +{ + assert(argc > 1); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + + typedef Rep3Share2<64> share_type; + Machine machine(N); + + int n = 1000; + + ProtocolSet set(machine.get_player(), machine); + set.input.reset(0); + for (int i = 0; i < n; i++) + { + if (my_number == 0) + set.input.add_mine(n - i); + else + set.input.add_other(0); + } + set.input.exchange(); + + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back(set.input.finalize(0)); + + vector args = {{inputs, true}}; + FunctionArgument res; + + machine.run_function("sort", res, args); + + Opener MC(machine.get_player(), machine.get_sint_mac_key()); + MC.init_open(); + for (auto& x : inputs) + MC.prepare_open(x); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (x != i + 1) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } +} diff --git a/Utils/export-trunc.cpp b/Utils/export-trunc.cpp new file mode 100644 index 000000000..5bbe9a6b6 --- /dev/null +++ b/Utils/export-trunc.cpp @@ -0,0 +1,54 @@ +/* + * export-trunc.cpp + * + */ + +#include "Machines/minimal.hpp" + +int main(int argc, const char** argv) +{ + assert(argc > 1); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + + typedef Rep3Share2<64> share_type; + Machine machine(N); + + int n = 1000; + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back(share_type::constant(i, my_number)); + + vector args = {inputs}; + vector results(n); + FunctionArgument res(results); + + machine.run_function("trunc_pr", res, args); + + Opener MC(machine.get_player(), machine.get_sint_mac_key()); + MC.init_open(); + for (auto& x : results) + MC.prepare_open(x); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (not (x == (i / 4) or x == (i / 4 + 1))) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp index 6eda84e03..cdc84ceb7 100644 --- a/Utils/mixed-example.cpp +++ b/Utils/mixed-example.cpp @@ -3,14 +3,7 @@ * */ -#include "Protocols/ProtocolSet.h" - -#include "Machines/SPDZ.hpp" -#include "Machines/SPDZ2k.hpp" -#include "Machines/Semi2k.hpp" -#include "Machines/Rep.hpp" -#include "Machines/Rep4.hpp" -#include "Machines/Atlas.hpp" +#include "Machines/maximal.hpp" template void run(char** argv); diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index e5346ade6..e92423d83 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -7,15 +7,7 @@ #define NO_MIXED_CIRCUITS -#include "Math/gfp.hpp" -#include "Machines/SPDZ.hpp" -#include "Machines/SPDZ2k.hpp" -#include "Machines/MalRep.hpp" -#include "Machines/ShamirMachine.hpp" -#include "Machines/Semi2k.hpp" -#include "Protocols/CowGearShare.h" -#include "Protocols/CowGearPrep.hpp" -#include "Protocols/ProtocolSet.h" +#include "Machines/maximal.hpp" template void run(char** argv, int prime_length); diff --git a/Utils/prime.cpp b/Utils/prime.cpp index 3d9cd9c01..75f935f90 100644 --- a/Utils/prime.cpp +++ b/Utils/prime.cpp @@ -12,7 +12,7 @@ int main(int argc, char** argv) if (argc > 1) lgp = atoi(argv[1]); if (argc > 2) - cout << generate_prime(lgp, 1 << atoi(argv[2])) << endl; + cout << generate_prime(lgp, 1 << abs(atoi(argv[2])), atoi(argv[2]) <= 0) << endl; else cout << SPDZ_Data_Setup_Primes(lgp) << endl; } diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e4dc7a8e4..c5216cdfe 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -8,11 +8,11 @@ trigger: - to-publish pool: - vmImage: 'ubuntu-20.04' + vmImage: 'ubuntu-24.04' steps: - script: | - bash -c "sudo apt-get update && sudo apt-get install libsodium-dev libntl-dev python3-gmpy2 python3-networkx" + bash -c "sudo apt-get update && sudo apt-get install libboost-all-dev libsodium-dev libntl-dev python3-gmpy2 python3-networkx" - script: | make setup - script: diff --git a/bin/README.md b/bin/README.md index 2e6fdf208..b54cc0f94 100644 --- a/bin/README.md +++ b/bin/README.md @@ -1,5 +1,5 @@ In binary releases, this directory constains statically linked -binaries. They included code from the following projects, whose licenses +binaries. They include code from the following projects, whose licenses are thus provided in separate files: - Boost - glibc @@ -7,6 +7,12 @@ are thus provided in separate files: - GMP - OpenSSl - NTL +- SimpleOT +- SimplestOT_C +- libOTe +- cryptoTools +- simde +- sse2neon The binaries also include code from libstdc++ and libgcc. They have been produced using `Scripts/build.sh` and standard GCC from diff --git a/bin/SimpleOT-license.txt b/bin/SimpleOT-license.txt new file mode 100644 index 000000000..a84c39566 --- /dev/null +++ b/bin/SimpleOT-license.txt @@ -0,0 +1,25 @@ +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + diff --git a/bin/SimplestOT_C-license.txt b/bin/SimplestOT_C-license.txt new file mode 100644 index 000000000..a84c39566 --- /dev/null +++ b/bin/SimplestOT_C-license.txt @@ -0,0 +1,25 @@ +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + diff --git a/bin/cryptoTools-license.txt b/bin/cryptoTools-license.txt new file mode 100644 index 000000000..67e5dc2f2 --- /dev/null +++ b/bin/cryptoTools-license.txt @@ -0,0 +1,51 @@ +Dual-licensed under Unlicense or MIT. + + +----------------------- Unlicense --------------------------- + +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + + +----------------------- MIT --------------------------- +Copyright 2021 Peter Rindal + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, +merge, publish, distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be included in all copies +or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/bin/libOTe-license.txt b/bin/libOTe-license.txt new file mode 100644 index 000000000..eea383745 --- /dev/null +++ b/bin/libOTe-license.txt @@ -0,0 +1,51 @@ +Dual-licensed under Unlicense or MIT. + + +----------------------- Unlicense --------------------------- + +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + + +----------------------- MIT --------------------------- +Copyright 2021 Peter Rindal + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, +merge, publish, distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be included in all copies +or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/bin/simde-license.txt b/bin/simde-license.txt new file mode 100644 index 000000000..c7f6b6d75 --- /dev/null +++ b/bin/simde-license.txt @@ -0,0 +1,20 @@ +Copyright (c) 2017 Evan Nemerson + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/bin/sse2neon-license.txt b/bin/sse2neon-license.txt new file mode 100644 index 000000000..71488b164 --- /dev/null +++ b/bin/sse2neon-license.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2015-2024 SSE2NEON Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/deps/SimpleOT b/deps/SimpleOT index 84d735226..c8447aecd 160000 --- a/deps/SimpleOT +++ b/deps/SimpleOT @@ -1 +1 @@ -Subproject commit 84d73522619f90ba2aabce8d660baef1442aa26d +Subproject commit c8447aecd26d73c37d96442c6127cbffde689c98 diff --git a/deps/SimplestOT_C b/deps/SimplestOT_C index 033e24280..3afe913fb 160000 --- a/deps/SimplestOT_C +++ b/deps/SimplestOT_C @@ -1 +1 @@ -Subproject commit 033e24280ae32d599c5b037b44edd7c7e8228b35 +Subproject commit 3afe913fbe28f2f9fdebfa6497caf6da87f24d1a diff --git a/deps/libOTe b/deps/libOTe index f613f2216..5dab5b1bb 160000 --- a/deps/libOTe +++ b/deps/libOTe @@ -1 +1 @@ -Subproject commit f613f221650144367e0fddce5ca07fc2dda09e32 +Subproject commit 5dab5b1bb45b2442b6d074ddde9bd6f3de42d5ca diff --git a/deps/sse2neon b/deps/sse2neon new file mode 160000 index 000000000..29716df95 --- /dev/null +++ b/deps/sse2neon @@ -0,0 +1 @@ +Subproject commit 29716df957401e9c357348e9b73c95a38cdcd34f diff --git a/doc/Compiler.rst b/doc/Compiler.rst index e8eb77097..0a7b4f780 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -144,6 +144,15 @@ Compiler.path_oblivious_heap module dprint_ln, dprint_ln_if, dprint_str, indent, outdent, +Compiler.dijkstra module +------------------------ + +.. automodule:: Compiler.dijkstra + :members: + :no-undoc-members: + :exclude-members: IntVectorArray, Matrix, Vector, VectorArray, VectorList + + Compiler.sorting module ----------------------- .. automodule:: Compiler.sorting diff --git a/doc/Doxyfile b/doc/Doxyfile index d816a9727..4bc260902 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h ../FHE/Ciphertext.h ../FHE/FHE_Keys.h ../FHE/FHE_Params.h ../FHE/Plaintext.h ../Tools/random.h ../Math/bigint.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h ../FHE/Ciphertext.h ../FHE/FHE_Keys.h ../FHE/FHE_Params.h ../FHE/Plaintext.h ../Tools/random.h ../Math/bigint.h ../Processor/FunctionArgument.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/compilation.rst b/doc/compilation.rst index 417bc7922..46bd1efae 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -117,7 +117,7 @@ computation: Set the budget for loop unrolling with :py:func:`~Compiler.library.for_range_opt` and similar. This means that loops are unrolled up to *budget* instructions. Default is - 100,000 instructions. + 1000 instructions. .. cmdoption:: -C --CISC diff --git a/doc/conf.py b/doc/conf.py index 2024f08bf..32b476a25 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -25,7 +25,7 @@ # -- Project information ----------------------------------------------------- project = u'MP-SPDZ' -copyright = u'2022, CSIRO\'s Data61' +copyright = u'2024, CSIRO\'s Data61' author = u'Marcel Keller' # The short X.Y version diff --git a/doc/function-export.rst b/doc/function-export.rst new file mode 100644 index 000000000..95c5ab3f1 --- /dev/null +++ b/doc/function-export.rst @@ -0,0 +1,277 @@ +.. _function-export: + +Using High-Level Functionality in C++ +===================================== + +The fact that most functionality is implemented on the high level (in +the compiler) but the virtual machine running high-level code is +somewhat limited makes it desirable to call high-level functionality +from C++. MP-SPDZ supports defining functions on the high level and +calling them from C++. Functions can have integer secrets +(:py:class:`~Compiler.types.sint`) and (multi-)arrays thereof as +inputs and return values as well as (multi-)arrays of binary secrets +(types created using :py:class:`~Compiler.GC.types.sbitvec` and +:py:class:`~Compiler.GC.types.sbitintvec`). + +As a simple example, consider +:download:`../Programs/Source/export-sort.py` and +:download:`../Utils/export-sort.cpp`. The Python part looks as follows:: + + @export + def sort(x): + print_ln('x=%s', x.reveal()) + res = x.sort() + print_ln('res=%s', x.reveal()) + + sort(sint.Array(1000)) + +This makes the sorting of integer arrays of length 1000 accessible to +C++. The corresponding C++ code starts similarly to the :ref:`low-level +code example `: + +.. code-block:: cpp + + #include "Machines/maximal.hpp" + + int main(int argc, const char** argv) + { + assert(argc > 1); + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + +This includes all necessary headers and makes contact with the other +parties on the same machine. The next step is to set up an instance of +the virtual machine: + +.. code-block:: cpp + + typedef Rep3Share2<64> share_type; + Machine machine(N); + +In this example, we use replicated secret sharing modulo +:math:`2^{64}`. Next, we prepare the inputs: + +.. code-block:: cpp + + int n = 1000; + + ProtocolSet set(machine.get_player(), machine); + set.input.reset(0); + for (int i = 0; i < n; i++) + { + if (my_number == 0) + set.input.add_mine(n - i); + else + set.input.add_other(0); + } + set.input.exchange(); + + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back(set.input.finalize(0)); + +This initializes a :cpp:class:`ProtocolSet` using the virtual machine +instead of a :cpp:class:`ProtocolSetup`. This is necessary to avoid +differing MAC keys and other setup variables. Then, party 0 inputs the +numbers 1 to 1000 in reverse other, and the resulting secret shares +are stored in :cpp:var:`inputs`. Now we're ready to call the +function: + +.. code-block:: cpp + + vector args = {{inputs, true}}; + FunctionArgument res; + + machine.run_function("sort", res, args); + +This indicates that the function takes one argument, which is an array +(as opposed to a vector, see below) and that we don't expect a return +value. Lastly, we open and check the array: + +.. code-block:: cpp + + Opener MC(machine.get_player(), machine.get_sint_mac_key()); + MC.init_open(); + for (auto& x : inputs) + MC.prepare_open(x); + MC.exchange(); + + if (my_number == 0) + { + cout << "res: "; + for (int i = 0; i < 10; i++) + cout << MC.finalize_open() << " "; + cout << endl; + } + else + { + for (int i = 0; i < n; i++) + { + auto x = MC.finalize_open(); + if (x != i + 1) + { + cerr << "error at " << i << ": " << x << endl; + exit(1); + } + } + } + +The :cpp:class:`Opener` class is convenience that is bound to a +communication instance (unlike :cpp:class:`MAC_Check_Base` instances, +which require the communication instance in several function calls). + +You can run the example as follows: + +.. code-block:: console + + ./compile.py -E ring export-sort + make export-sort.x + for i in 0 1 2; do ./export-sort.x $i & true; done + +This makes sure that all the optimizations of the protocol are used. + + +Vector arguments and return values +---------------------------------- + +Instead of arrays, it is also possible to use +:py:class:`~Compiler.types.sint` vectors as demonstrated in +:download:`../Programs/Source/export-trunc.py`:: + + @export + def trunc_pr(x): + print_ln('x=%s', x.reveal()) + res = x.round(32, 2) + print_ln('res=%s', res.reveal()) + return res + + trunc_pr(sint(0, size=1000)) + +The calling C++ code in :download:`../Utils/export-trunc.cpp` looks as +follows: + +.. code-block:: cpp + + int n = 1000; + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back(share_type::constant(i, my_number)); + + vector args = {inputs}; + vector results(n); + FunctionArgument res(results); + + machine.run_function("trunc_pr", res, args); + +This creates integer shares using public constants instead of the +input protocol as above. The :cpp:class:`FunctionArgument` instance +for both input and output are created using the vector of secret +shares without the extra ``true`` argument. + + +Binary values +------------- + +It is possible to input and output binary secrets with an +array. Consider :download:`../Programs/Source/export-b2a.py`, which +converts arithmetic to binary shares:: + + @export + def b2a(res, x): + print_ln('x=%s', x.reveal()) + res[:] = sint(x[:]) + print_ln('res=%s', x.reveal()) + + b2a(sint.Array(size=10), sbitvec.get_type(16).Array(10)) + +This demonstrates the requirement of using an array of an +:py:class:`sbitvec` type with a defined number of bits (16 in this +case). :py:class:`sbitintvec` is a sub-class and also permissible. + +The C++ calling code looks as follows: + +.. code-block:: cpp + + int n = 10; + vector outputs(n); + vector> inputs(n); + + auto& inputter = set.binary.input; + inputter.reset(0); + for (int i = 0; i < n; i++) + if (my_number == 0) + inputter.add_mine(i + 1, 16); + else + inputter.add_other(0); + inputter.exchange(); + for (int i = 0; i < n; i++) + inputs.at(i).push_back(inputter.finalize(0, 16)); + + vector args = {{outputs, true}, {16, inputs}}; + FunctionArgument res; + + machine.run_function("b2a", res, args); + +This inputs the values 1 to 10 as 16-bit numbers. Note the nested +vectors for the inputs. This is due to the fact +``share_type::bit_type`` can only hold up to 64 bits, so for longer +bit lengths several entries have to be used. + +Lastly, :download:`../Programs/Source/export-a2b.py` covers the other +direction:: + + @export + def a2b(x, res): + print_ln('x=%s', x.reveal()) + res[:] = sbitvec(x, length=16) + print_ln('res=%s', x.reveal()) + + a2b(sint(size=10), sbitvec.get_type(16).Array(10)) + +The calling C++ code in :download:`../Utils/export-a2b.cpp` has to +initialize the binary shares even when they are only used for output: + +.. code-block:: + + int n = 10; + vector inputs; + for (int i = 0; i < n; i++) + inputs.push_back( + share_type::constant(i + 1, my_number, + machine.get_sint_mac_key())); + + vector> outputs(n, + vector(1, + bit_share_type::constant(0, my_number, + machine.get_bit_mac_key()))); + + vector args = {{inputs}, {16, outputs}}; + FunctionArgument res; + + machine.run_function("a2b", res, args); + + +C++ compilation +--------------- + +The easiest way is to include ``Machines/maximal.hpp`` as in the first +example and put the C++ in code ``Utils/.cpp`` and calling +``make .x`` in the main directory. If using oblivious transfer +or homomorphic encryption, add the following line to ``Makefile``:: + + .x: $(FHEOFFLINE) $(OT) + +Most of the examples work slightly differently, however, in order to +distribute the compilation load. Most notably, +:download:`../Utils/export-a2b.cpp`, which supports several protocols, +only includes ``Machines/minimal.hpp`` and "outsources" the virtual +machine for the various protocols to ``Machines/export-*.cpp``, which +are all compiled separately. + + +Reference +--------- + +.. doxygenclass:: FunctionArgument + :members: diff --git a/doc/index.rst b/doc/index.rst index 2cecbe0da..4e48385f8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -26,6 +26,8 @@ If you're new to MP-SPDZ, consider the following: optimization instructions low-level + function-export + navigating-c++ ml-quickstart machine-learning networking diff --git a/doc/journey.rst b/doc/journey.rst index b6b7521a1..c418d5d4f 100644 --- a/doc/journey.rst +++ b/doc/journey.rst @@ -126,10 +126,12 @@ looks as follows:: 000000c8 It consist of the instructions codes and the arguments in big-endian -order. For example, 0x2 is the code for :py:obj:`lsdi`, 0xa5 is the -code for :py:obj:`asm_open`, 0xb3 is the code for -:py:obj:`print_reg_plain`, etc. You can also spot repeated occurrences -of ``1f ff``, which is the hexadecimal representation of 8191. +order. For example, 0x2 is the code for +:py:class:`~Compiler.instructions.ldsi`, 0xa5 is the code for +:py:obj:`~Compiler.instructions.asm_open`, 0xb3 is the code for +:py:class:`~Compiler.instructions.print_reg_plain`, etc. You can also +spot repeated occurrences of ``1f ff``, which is the hexadecimal +representation of 8191. Finally, the compilation creates :file:`Programs/Schedules/journey.sch`, which is a text file:: @@ -166,7 +168,7 @@ with Rep3 modulo :math:`2^{64}`. The main function in <../Processor/Machine.hpp>` with :class:`sint` being ``Rep3Share2<64>``. Then, the following happens: -1. :file:`Programs/Schedules/journey.sch` is parsed :func:`load_schedule`. +1. :file:`Programs/Schedules/journey.sch` is parsed in :func:`load_schedule`. 2. :file:`Programs/Bytecode/journey-0.bc` is parsed in :func:`Machine::load_program` where :func:`Program::parse`. This creates an internal representation of the diff --git a/doc/navigating-c++.rst b/doc/navigating-c++.rst new file mode 100644 index 000000000..cd4b29e19 --- /dev/null +++ b/doc/navigating-c++.rst @@ -0,0 +1,259 @@ +.. default-domain:: cpp + + +Navigating the C++ Code +======================= + +In this section, we will explain how the most important aspects of the +C++ codebase fit together and explain them with brief examples. + +MP-SPDZ heavily relies on templates (also called generic +programming). This is to achieve high efficiency while retaining +modularity as many MPC building blocks work in a number of contexts. A +notable example of this is `Beaver multiplication +`_. The +same code is used in more than 20 contexts without loss of efficiency. +See `this introduction +`_ if +you're new to C++ templates. + +Due to the size of the codebase (more than 100,000 lines of code in +hundreds of files), we recommend using an integrated development +environment (IDE) to navigate it. `Eclipse +`_ has very useful features such as +jumping to the definition of function or variable using F3 or find all +references to a function or variable with Strg-Shift-g. The latter +doesn't work well with templates, however, so global text search is +necessary for a more comprehensive view. + + +Mathematical Domains +-------------------- + +All protocols in MP-SPDZ are based on finite domains for secret +sharing, most importantly integers modulo a number. While basic CPU +arithmetic and existing libraries provide all required functionality +(e.g., modern CPUs generally work with integers modulo :math:`2^{64}` +and the `GNU Multiple Precision Arithmetic Library +`_ has provisions for a plethora of integer +operations), we have found these to be inefficient in the context of +MPC. The main reason is that an MPC protocol uses the same modulus +throughout but the variable-length nature of GMP incurs a considerable +cost for every operation, which is not necessary. MP-SPDZ therefore +provides a tailored data type for every mathematical domain. These +data types use operator overloading for easy use. For example, + +.. code-block:: + + cout << (Z2<5>(20) + Z2<5>(30)) << endl; + +should output:: + + 18 + +because ``Z2<5>`` represents computation modulo :math:`2^5=32`. See +the reference of :cpp:class:`Z2` and :cpp:class:`SignedZ2` for further +details. + +For computation modulo a prime on the other hand, the data type fixes +only a range at compile time, so the exact modulus has be given before +usage:: + + gfp_<0; 1>::init_field(13); + cout << (gfp_<0, 1>(8) + gfp_<0, 1>(7)) << endl; + +should output:: + + 2 + +The first parameter to :cpp:class:`gfp_` is a counter that allows +several moduli to be used at once, for example:: + + gfp_<0, 1>::init_field(13); + gfp_<1, 1>::init_field(17); + +The second parameter denotes the number of 64-bit limbs, that is, it +should be 1 for primes in :math:`[0,2^{64}]`, 2 for prime in +:math:`[2^{64},2^{128}]` etc. + +In addition to the fixed domains, :cpp:class:`bigint` is a sub-class +of :cpp:class:`mpz_class` `type in GMP +`_. It used +for conversions amongst other things. + + +Communication +------------- + +MP-SPDZ provides a communication interface that is more involved than +sending bytes via a socket. There are two reasons for doing +this. First, the structure of MPC goes far beyond the query-reply +pattern often found in online communication. For example, two parties +might need to exchange a large quantity of information +simultaneously. Second, the atomic quantity communicated in MPC (i.e., +the numbers) are usually so small that it is preferential to send them +in batches. The following example demonstrates the exchange of a +vector of 64-bit numbers in the two-party setting:: + + Player& P = ...; + vector> numbers; + // populate vector + ... + octetStream os; + os.store(numbers); + P.pass_around(1, os); + os.get(numbers); + // numbers now contains the ones from the other side + +No matter how many numbers there are, the framework makes sure to send +and receive them at the same time. The number given to +:cpp:func:`Player::pass_around` denotes an offset, that is, the +numbers are sent to "next" party and received from the "previous" one +(regarding player number with wrap-around). See :ref:`this section +` for more details. + + +Randomness +---------- + +Randomness is a crucial component of MPC (as for cryptography in +general). Random number generation in MP-SPDZ centers on the +:cpp:class:`PRNG` class. It implements optimized random number +generation based on hardware AES if available. This allows for local +as well as coordinated randomness generation. An exampled for the +first is as follows:: + + SeededPRNG G; + auto res = G.get>(); + +This initializes the PRNG with secure randomness from libsodium and +then generates a random 64-bit element. + +On the other hand, the following initializes a global PRNG securely, +that is, with a seed that cannot be influenced by any party, before +generating a random element modulo a prime:: + + // initialize at some point + gfp_<0, 1>::init_field(prime); + Player& P = ...; + ... + GlobalPRNG G(P); + auto res = G.get>(); + + +Protocols +--------- + +The implementation of protocols is centered on the share types. They +not only hold all values necessary to represent a secret value for one +party, they also provide local operations, refer to other classes +implementing protocols, and contain variables and static functions to +describes protocols. + +As an example, consider :cpp:class:`Rep3Share\` in +:download:`../Protocols/Rep3Share.h`. It implements a share for +three-party replicated secret sharing. It takes one template parameter +for the mathematical domain because the secret sharing and the +multiplication protocol work for any finite domain. The following +typedef makes the cleartext domain generally accessible:: + + typedef T clear; + +Further typedefs are used to indicate which class to use for inputs, +multiplications, and outputs:: + + typedef ReplicatedInput Input; + typedef Replicated Protocol; + typedef ReplicatedMC MAC_Check; + +The latter usually contains the name MAC_Check or MC because MAC +checking is a core function of the output protocol in SPDZ. + +These typedefs follow the general pattern that the *share type* is a +template argument to the *protocol type*. This makes everything +contained defined by the share type accessible to the protocol +type. As an example of this, :cpp:class:`ReplicatedMC\` is +a sub-class of :cpp:class:`MAC_Check_Base\`, which +implements the general interface for opening shares. On the functions +there is defined as follows:: + + virtual typename T::clear finalize_open(); + +Another important typedef in :cpp:class:`Rep3Share` is the +preprocessing type:: + + typedef typename conditional, SemiRep3Prep>::type LivePrep; + +It is more complicated because it uses meta-programming to assign +different types depending on whether mathematical domain has +characteristic two (i.e., it's :math:`\mathrm{GF}(2^n)`). This is to +avoid compiling code for a specific daBit generation that doesn't make +sense in said domain. The preprocessing classes use polymorphism to +mix and match the possible protocols. For example, +:cpp:func:`BitPrep::buffer_squares` to implements a generic +protocol to generate square tuples from multiplication triples, but +this isn't the most efficient way with replicated secret sharing, +which is why :cpp:func:`ReplicatedRingPrep::buffer_squares` +overrides this with a more specific protocol in our example. + +The four protocol types above are contained in an instance +:cpp:class:`ProtocolSet\` as documented in :ref:`low-level` where +:py:class:`T` is a share type. + +:cpp:class:`Rep3Share\` is a sub-class of +:cpp:class:`FixedVec\`. The latter contains a pair of values in the +cleartext domain as one would expect with this kind of secret sharing, +and it implements element-wise addition, subtraction, and +multiplication via operator overloading, which makes it +straight-forward to run local operations with share types. + +Lastly, :cpp:class:`Rep3Share` defines a few variables that describe +the protocols, for example:: + + const static bool dishonest_majority = false; + const static bool variable_players = false; + +These indicate that replicated secret sharing requires and honest +majority and fixed number of players. First is used to the set default +number of parties and the second to decide whether to offer the +``--nparties`` command-line option. + + +Virtual Machines +---------------- + +The main function for the protocol-specific virtual machines is +defined in the file of the appropriate name in the ``Machines`` +directory. For example, the virtual machine for three-party replicated +secret sharing over prime fields is defined in +:download:`../Machines/replicated-field-party.cpp`, and the main function +looks as follows:: + + int main(int argc, const char** argv) + { + HonestMajorityFieldMachine(argc, argv); + } + +Indirectly, this calls an instance of :cpp:class:`Machine\` where :cpp:class:`sint` and :cpp:class:`sgf2n` denote the +complete share type for integer and :math:`\mathrm{GF}(2^n)`, +respectively. The defaults are :cpp:class:`Rep3Share\>` and +:cpp:class:`Rep3Share\` in the example. To choose that, +the constructor of :cpp:class:`FieldMachine` (in +:download:`../Processor/FieldMachine.hpp`) contains code to the length +for :cpp:class:`gfp_` (the second parameter, the first is always +0). For protocols modulo a power of two other than SPDZ2k, this +happens in the constructor of :cpp:class:`RingMachine` or +:cpp:class:`HonestMajorityRingMachineWithSecurity` in +:download:`../Processor/RingMachine.hpp`. The purpose of all this is +to fix the mathematical domains throughout for maximum performance. + +The includes are structured in a way that all relevant templated code +is included in these files, so compiling it makes sure that the object +file contains most protocol-specific code. The main exceptions from +this are code related to homomorphic encryption (in ``libFHE.so``), +oblivious transfer (included via object files), and Tinier (in +``Machines/Tinier.o``). Furthermore, all general code is put in +``libSPDZ.so``. All this is to reduce the compilation time and/or the +binary size. diff --git a/doc/non-linear.rst b/doc/non-linear.rst index ec2a53c23..2f4762a93 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -26,7 +26,18 @@ Unknown prime modulus of the cleartext range. If you want to use this approach with a given prime, do *not* - specify the prime during compilation but during execution. + specify the prime during compilation but during execution, that + is:: + + -party.x -P ... + + or:: + + Scripts/.sh -P ... + + If using ``Scripts/compile-run.py``, put it after a double dash:: + + Scripts/compile-run.py -- -P Known prime modulus `Damgård et al. `_ have diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index b9f918615..8ff630db3 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -250,6 +250,19 @@ to `Catrina and de Hoogh See also the paragraph on unknown prime moduli in :ref:`nonlinear`. +Prime number not compatible with encryption scheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MP-SPDZ only supports homomorphic encryption based on the +number-theoretic transform, without it operations would expected to be +considerably. The requirement is that the prime number equals one +modulo a certain power of two. The exact power of two varies due to a +number of parameters, but for the standard choice it's usually +:math:`2^{14}` or :math:`2^{15}`. See `Gentry et +al. `_ for more details on the +underlying mathematics. + + Windows/VirtualBox performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -282,3 +295,59 @@ This is a catch-all failure in protocols with malicious protocols that can be caused by something being wrong at any level. Please file a bug report with the specifics of your case. + +Debugging errors in a virtual machine +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Unlike Python or Java, C++ gives limited information when something +goes wrong. On Linux, the `GNU Debugger (GDB) +`_ aims to mitigate this +by providing more introspection into where exactly something went +wrong. MP-SPDZ comes with a few scripts that facilitate its +use. First, you need to make sure gdb and `screen +`_ are installed. On Ubuntu, +you can run the following:: + + sudo apt-get install gdb screen + +You can then run the following script call:: + + prefix=gdb_screen Scripts/.sh ... -o throw_exceptions + +This runs every party in the background using the screen utility. You +can get a party to the foreground using:: + + screen -r : + +This will show the relevant running inside GDB. You can use the +sequence "Ctrl-a d" to return to your usual terminal. + +If running the different parties separately, you can also use:: + + . Scripts/run-common.sh + gdb_front ./-party.x ... -o throw_exceptions + +If the virtual machine aborts due to an error, GDB will indicate where +in the code this happened. For example, deactivating all range checks +on memory accesses and then running an illegal memory access triggers +a segfault and the following output:: + + Thread 13 "shamir-party.x" received signal SIGSEGV, Segmentation fault. + [Switching to Thread 0x7fffdffff640 (LWP 246396)] + 0x0000000000434c57 in MemoryPart > >::indirect_read > (this=, inst=..., regs=..., indices=...) at ./Processor/Memory.hpp:26 + 26 *dest++ = data[it->get()]; + +Entering ``bt`` (for backtrace) gives even more information as to +where the error happened:: + + (gdb) bt + #0 0x0000000000434c57 in MemoryPart > >::indirect_read > (this=, inst=..., regs=..., indices=...) at ./Processor/Memory.hpp:26 + #1 Program::execute >, ShamirShare > (this=0x620cc0, Proc=...) at ./Processor/Instruction.hpp:1486 + #2 0x0000000000428fd1 in thread_info >, ShamirShare >::Sub_Main_Func (this=, this@entry=0x656900) at ./Processor/Online-Thread.hpp:280 + #3 0x0000000000426e45 in thread_info >, ShamirShare >::Main_Func_With_Purge (this=0x656900) at ./Processor/Online-Thread.hpp:431 + #4 thread_info >, ShamirShare >::Main_Func (ptr=0x656900) at ./Processor/Online-Thread.hpp:410 + #5 0x00007ffff6bbaac3 in start_thread (arg=) at ./nptl/pthread_create.c:442 + #6 0x00007ffff6c4c850 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81 + +This information can be very useful to find the error and fix bugs, so +make sure to include it in GitHub issues etc. diff --git a/doc/utils.rst b/doc/utils.rst index 019bf2889..08d4bdcdc 100644 --- a/doc/utils.rst +++ b/doc/utils.rst @@ -10,6 +10,17 @@ of the minimum RAM usage per party. The range is relatively large due to fact the bytecode is independent of the secret sharing. +Preprocessing usage +------------------- + +``Scripts/prep-usage.py `` gives you an upper limit +for the usage of preprocessing data such as triples. Note that the +exact number depends on the protocol in various ways, and that the +usage is sometimes unpredictable resulting in ``inf`` given. For an +exact number, you have to run the virtual machine or script using +the ``--verbose`` argument. + + Human-readable bytecode/circuit representation ----------------------------------------------