Skip to content

Commit

Permalink
Optimized matrix multiplication in Hemi.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Sep 17, 2021
1 parent 5c6f101 commit 799929b
Show file tree
Hide file tree
Showing 151 changed files with 5,261 additions and 747 deletions.
5 changes: 2 additions & 3 deletions BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
auto& MC = this->MC;

this->_id = online_opts.playerno + 1;
Server* server = Server::start_networking(N, online_opts.playerno, nparties,
Server::start_networking(N, online_opts.playerno, nparties,
network_opts.hostname, network_opts.portnum_base);
if (T::dishonest_majority)
P = new PlainPlayer(N, 0);
Expand Down Expand Up @@ -159,8 +159,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();

if (server)
delete server;
this->machine.write_memory(this->N.my_num());
}

template<class T>
Expand Down
5 changes: 5 additions & 0 deletions BMR/network/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ static void throw_bad_ip(const char* ip) {
throw std::invalid_argument( "bad ip" );
}

namespace BIU
{

Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable, unsigned int max_message_size)
:_max_msg_sz(max_message_size),
_numservers(numservers),
Expand Down Expand Up @@ -205,3 +208,5 @@ void Client::_send_blocking(SendBuffer& msg, int id) {
fflush(0);
#endif
}

}
4 changes: 4 additions & 0 deletions BMR/network/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class ClientUpdatable {



namespace BIU
{

class Client {
public:
Expand Down Expand Up @@ -61,4 +63,6 @@ class Client {
boost::thread_group threads;
};

}

#endif /* NETWORK_INC_CLIENT_H_ */
2 changes: 1 addition & 1 deletion BMR/network/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num
_ready_nodes = new bool[_numparties](); //initialized to false
_clients_connected = new bool[_numparties]();
_server = new BIU::Server(_port, _numparties-1, this, max_message_size);
_client = new Client(_endpoints, _numparties-1, this, max_message_size);
_client = new BIU::Client(_endpoints, _numparties-1, this, max_message_size);
}

Node::~Node() {
Expand Down
2 changes: 1 addition & 1 deletion BMR/network/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Node : public ServerUpdatable, public ClientUpdatable {
int _numparties;

endpoint_t* _endpoints;
Client* _client;
BIU::Client* _client;
BIU::Server* _server;
bool* _ready_nodes;
volatile bool _connected_to_servers;
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.7 (Sep 17, 2021)

- Optimized matrix multiplication in Hemi
- Improved client communication
- Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`
- Compiler option to translate some Python control flow instructions
to run-time instructions
- Functionality to break out of run-time loops
- Run-time range check of data structure accesses
- Improved documentation of network infrastructure

## 0.2.6 (Aug 6, 2021)

- [ATLAS](https://eprint.iacr.org/2021/833)
Expand Down
13 changes: 12 additions & 1 deletion Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ClearBitsAF(base.RegisterArgFormat):
CONVCBIT2S = 0x249,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
CONVCINT = 0x213,
REVEAL = 0x214,
STMSDCI = 0x215,
Expand Down Expand Up @@ -190,6 +191,16 @@ class nots(BinaryVectorInstruction):
code = opcodes['NOTS']
arg_format = ['int','sbw','sb']

class notcb(BinaryVectorInstruction):
""" Bitwise NOT of secret register vector.
:param: number of bits
:param: result (cbit)
:param: operand (cbit)
"""
code = opcodes['NOTCB']
arg_format = ['int','cbw','cb']

class addcb(NonVectorInstruction):
""" Integer addition two single clear bit registers.
Expand Down Expand Up @@ -617,4 +628,4 @@ class cond_print_strb(base.IOInstruction):
arg_format = ['cb', 'int']

def __init__(self, cond, val):
super(cond_print_str, self).__init__(cond, self.str_to_int(val))
super(cond_print_strb, self).__init__(cond, self.str_to_int(val))
58 changes: 38 additions & 20 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,33 @@ def load_other(self, other):
(str(other), repr(other), type(other), type(self)))
def long_one(self):
return 2**self.n - 1 if self.n != None else None
def is_long_one(self, other):
return util.is_all_ones(other, self.n) or \
(other is None and self.n == None)
def res_type(self, other):
if self.n == None and other.n == None:
n = None
else:
n = max(self.n, other.n)
return self.get_type(n)
@read_mem_value
def __and__(self, other):
if util.is_zero(other):
return 0
elif self.is_long_one(other):
return self
else:
return self._and(other)
@read_mem_value
def __xor__(self, other):
if util.is_zero(other):
return self
elif self.is_long_one(other):
return ~self
else:
return self._xor(other)
__rand__ = __and__
__rxor__ = __xor__
def __repr__(self):
if self.n != None:
suffix = '%d' % self.n
Expand Down Expand Up @@ -245,19 +272,20 @@ def clear_op(self, other, c_inst, ci_inst, op):
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
__sub__ = lambda self, other: \
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
def __xor__(self, other):
def _xor(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
elif isinstance(other, cbits):
res = cbits.get_type(max(self.n, other.n))()
res = self.res_type(other)()
assert res.size == self.size
assert res.size == other.size
inst.xorcb(res.n, res, self, other)
return res
else:
return self.clear_op(other, None, inst.xorcbi, operator.xor)
def _and(self, other):
return NotImplemented
__radd__ = __add__
__rxor__ = __xor__
def __mul__(self, other):
if isinstance(other, cbits):
return NotImplemented
Expand All @@ -278,7 +306,9 @@ def __lshift__(self, other):
inst.shlcbi(res, self, other)
return res
def __invert__(self):
return self ^ self.long_one()
res = type(self)()
inst.notcb(self.n, res, self)
return res
def print_reg(self, desc=''):
inst.print_regb(self, desc)
def print_reg_plain(self):
Expand All @@ -287,6 +317,8 @@ def print_reg_plain(self):
def print_if(self, string):
inst.cond_print_strb(self, string)
def output_if(self, cond):
if Program.prog.options.binary:
raise CompilerError('conditional output not supported')
cint(self).output_if(cond)
def reveal(self):
return self
Expand Down Expand Up @@ -423,8 +455,7 @@ def __add__(self, other):
__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__xor__ = __add__
__rxor__ = __add__
_xor = __add__
@read_mem_value
def __mul__(self, other):
if isinstance(other, int):
Expand All @@ -440,13 +471,7 @@ def __mul__(self, other):
except AttributeError:
return NotImplemented
__rmul__ = __mul__
@read_mem_value
def __and__(self, other):
if util.is_zero(other):
return 0
elif util.is_all_ones(other, self.n) or \
(other is None and self.n == None):
return self
def _and(self, other):
res = self.new(n=self.n)
if not isinstance(other, sbits):
other = cbits.get_type(self.n).conv(other)
Expand All @@ -456,7 +481,6 @@ def __and__(self, other):
assert(self.n == other.n)
inst.ands(self.n, res, self, other)
return res
__rand__ = __and__
def xor_int(self, other):
if other == 0:
return self
Expand Down Expand Up @@ -551,12 +575,6 @@ def bit_adder(*args, **kwargs):
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
def to_sint(self, n_bits):
""" Convert the :py:obj:`n_bits` least significant bits to
:py:obj:`~Compiler.types.sint`. """
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
bits = sint(bits, size=n_bits)
return sint.bit_compose(bits)

class sbitvec(_vec):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
Expand Down
2 changes: 1 addition & 1 deletion Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def run(self, instructions):
elif isinstance(inst, IndirectMemoryInstruction):
if inst.args[1] in self.cache:
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
elif isinstance(inst, convint_class):
elif type(inst) == convint_class:
if inst.args[1] in self.cache:
res = self.cache[inst.args[1]]
self.cache[inst.args[0]] = res
Expand Down
56 changes: 54 additions & 2 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .GC import types as GC_types

import sys
import re, tempfile, os


def run(args, options):
Expand All @@ -21,11 +22,62 @@ def run(args, options):
del VARS[i]

print('Compiling file', prog.infile)

f = open(prog.infile, 'rb')

changed = False
if options.flow_optimization:
output = []
if_stack = []
for line in open(prog.infile):
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
line)
if m:
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
m.group(3)))
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
changed = True
continue
m = re.match('(\s*)if(\W.*):', line)
if m:
if_stack.append((m.group(1), len(output)))
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
output.append('%sdef _():\n' % (m.group(1)))
changed = True
continue
m = re.match('(\s*)elif\s+', line)
if m:
raise CompilerError('elif not supported')
if if_stack:
m = re.match('%selse:' % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
output[start])
output.append('%s@else_\n' % ws)
output.append('%sdef _():\n' % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile('w+', delete=False)
for line in output:
infile.write(line)
infile.seek(0)
else:
infile = open(prog.infile)
else:
infile = open(prog.infile)

# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
exec(compile(infile.read(), infile.name, 'exec'), VARS)

if changed and not options.debug:
os.unlink(infile.name)

prog.finalize()

Expand Down
4 changes: 2 additions & 2 deletions Compiler/floatingpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
Expand Down Expand Up @@ -642,7 +642,7 @@ def BitDecFull(a, maybe_mixed=False):
b, bbits = sint.get_edabit(logp, True, size=a.size)
if logp != bit_length:
from .GC.types import sbits
bbits += [sbits.get_type(a.size)(0)]
bbits += [0]
else:
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
b = sint.bit_compose(bbits)
Expand Down
Loading

0 comments on commit 799929b

Please sign in to comment.