From 6f78267ef68607fa5a2f8f26a1789b0c9dc3f2c6 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Fri, 9 Aug 2024 14:51:09 -0700 Subject: [PATCH] Remove n^2 algorithm from signature/key aggregation CountEnabled and IndexOfNthEnabled are both O(n) in the size of the mask, making this loop n^2. The BLS operations still tend to be the slow part, but the n^2 factor will start to show up with thousands of keys. --- sign/bdn/bdn.go | 45 ++++++++++++++++++++++++++------------------- sign/mask.go | 11 +++++++++++ sign/mask_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 20 deletions(-) diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index b7716f6ce..d722bcd13 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -12,6 +12,7 @@ package bdn import ( "crypto/cipher" "errors" + "fmt" "math/big" "go.dedis.ch/kyber/v4" @@ -129,31 +130,36 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - if len(sigs) != mask.CountEnabled() { - return nil, errors.New("length of signatures and public keys must match") - } - - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.sigGroup.Point() - for i, buf := range sigs { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { - // this should never happen as we check the lenths at the beginning - // an error here is probably a bug in the mask - return nil, errors.New("couldn't find the index") + for i := range publics { + if enabled, err := mask.GetBit(i); err != nil { + // this should never happen because of the loop boundary + // an error here is probably a bug in the mask implementation + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } + if len(sigs) == 0 { + return nil, errors.New("length of signatures and public keys must match") + } + + buf := sigs[0] + sigs = sigs[1:] + sig := scheme.sigGroup.Point() err = sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(coefs[peerIndex], sig) + sigC := sig.Clone().Mul(coefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -166,22 +172,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - coefs, err := hashPointToR(mask.Publics()) + publics := mask.Publics() + coefs, err := hashPointToR(publics) if err != nil { return nil, err } agg := scheme.keyGroup.Point() - for i := 0; i < mask.CountEnabled(); i++ { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { + for i, pub := range publics { + if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation - return nil, errors.New("couldn't find the index") + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } - pub := mask.Publics()[peerIndex] - pubC := pub.Clone().Mul(coefs[peerIndex], pub) + pubC := pub.Clone().Mul(coefs[i], pub) pubC = pubC.Add(pubC, pub) agg = agg.Add(agg, pubC) } diff --git a/sign/mask.go b/sign/mask.go index 7c68c9132..d02692c32 100644 --- a/sign/mask.go +++ b/sign/mask.go @@ -58,6 +58,17 @@ func (m *Mask) SetMask(mask []byte) error { return nil } +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + return m.mask[byteIndex]&mask != 0, nil +} + // SetBit turns on or off the bit at the given index. func (m *Mask) SetBit(i int, enable bool) error { if i >= len(m.publics) || i < 0 { diff --git a/sign/mask_test.go b/sign/mask_test.go index 7a3eeb118..5eb911755 100644 --- a/sign/mask_test.go +++ b/sign/mask_test.go @@ -49,22 +49,63 @@ func TestMask_SetBit(t *testing.T) { mask, err := NewMask(publics, publics[2]) require.NoError(t, err) + // Make sure the mask is initially as we'd expect. + + bit, err := mask.GetBit(1) + require.NoError(t, err) + require.False(t, bit) + + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) - // Set it again, nothing should change. + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 again, nothing should change + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Unset bit 2 + err = mask.SetBit(2, false) require.NoError(t, err) require.Equal(t, uint8(0x2), mask.Mask()[0]) require.Equal(t, 1, len(mask.Participants())) + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.False(t, bit) + + // Set bit 10 (using byte 2 now) + + err = mask.SetBit(10, true) + require.NoError(t, err) + require.Equal(t, uint8(0x2), mask.Mask()[0]) + require.Equal(t, uint8(0x4), mask.Mask()[1]) + require.Equal(t, 2, len(mask.Participants())) + + bit, err = mask.GetBit(10) + require.NoError(t, err) + require.True(t, bit) + + // And make sure the range limit works. + err = mask.SetBit(-1, true) require.Error(t, err) err = mask.SetBit(len(publics), true)