Skip to content

Commit

Permalink
Remove n^2 algorithm from signature/key aggregation
Browse files Browse the repository at this point in the history
CountEnabled and IndexOfNthEnabled are both O(n) in the size of the
mask, making this loop n^2. The BLS operations still tend to be the slow
part, but the n^2 factor will start to show up with thousands of keys.
  • Loading branch information
Stebalien committed Sep 5, 2024
1 parent 7738a1f commit 6f78267
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 20 deletions.
45 changes: 26 additions & 19 deletions sign/bdn/bdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package bdn
import (
"crypto/cipher"
"errors"
"fmt"
"math/big"

"go.dedis.ch/kyber/v4"
Expand Down Expand Up @@ -129,31 +130,36 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {
// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
if len(sigs) != mask.CountEnabled() {
return nil, errors.New("length of signatures and public keys must match")
}

coefs, err := hashPointToR(mask.Publics())
publics := mask.Publics()
coefs, err := hashPointToR(publics)
if err != nil {
return nil, err
}

agg := scheme.sigGroup.Point()
for i, buf := range sigs {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
// this should never happen as we check the lenths at the beginning
// an error here is probably a bug in the mask
return nil, errors.New("couldn't find the index")
for i := range publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

if len(sigs) == 0 {
return nil, errors.New("length of signatures and public keys must match")
}

buf := sigs[0]
sigs = sigs[1:]

sig := scheme.sigGroup.Point()
err = sig.UnmarshalBinary(buf)
if err != nil {
return nil, err
}

sigC := sig.Clone().Mul(coefs[peerIndex], sig)
sigC := sig.Clone().Mul(coefs[i], sig)
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
sigC = sigC.Add(sigC, sig)
agg = agg.Add(agg, sigC)
Expand All @@ -166,22 +172,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber
// AggregateSignatures for signatures) using the hash function
// H: keyGroup -> R with R = {1, ..., 2^128}.
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
coefs, err := hashPointToR(mask.Publics())
publics := mask.Publics()
coefs, err := hashPointToR(publics)
if err != nil {
return nil, err
}

agg := scheme.keyGroup.Point()
for i := 0; i < mask.CountEnabled(); i++ {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
for i, pub := range publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, errors.New("couldn't find the index")
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

pub := mask.Publics()[peerIndex]
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
pubC := pub.Clone().Mul(coefs[i], pub)
pubC = pubC.Add(pubC, pub)
agg = agg.Add(agg, pubC)
}
Expand Down
11 changes: 11 additions & 0 deletions sign/mask.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ func (m *Mask) SetMask(mask []byte) error {
return nil
}

// GetBit returns true if the given bit is set.
func (m *Mask) GetBit(i int) (bool, error) {
if i >= len(m.publics) || i < 0 {
return false, errors.New("index out of range")
}

byteIndex := i / 8
mask := byte(1) << uint(i&7)
return m.mask[byteIndex]&mask != 0, nil
}

// SetBit turns on or off the bit at the given index.
func (m *Mask) SetBit(i int, enable bool) error {
if i >= len(m.publics) || i < 0 {
Expand Down
43 changes: 42 additions & 1 deletion sign/mask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,63 @@ func TestMask_SetBit(t *testing.T) {
mask, err := NewMask(publics, publics[2])
require.NoError(t, err)

// Make sure the mask is initially as we'd expect.

bit, err := mask.GetBit(1)
require.NoError(t, err)
require.False(t, bit)

bit, err = mask.GetBit(2)
require.NoError(t, err)
require.True(t, bit)

// Set bit 1

err = mask.SetBit(1, true)
require.NoError(t, err)
require.Equal(t, uint8(0x6), mask.Mask()[0])
require.Equal(t, 2, len(mask.Participants()))

// Set it again, nothing should change.
bit, err = mask.GetBit(1)
require.NoError(t, err)
require.True(t, bit)

// Set bit 1 again, nothing should change

err = mask.SetBit(1, true)
require.NoError(t, err)
require.Equal(t, uint8(0x6), mask.Mask()[0])
require.Equal(t, 2, len(mask.Participants()))

bit, err = mask.GetBit(1)
require.NoError(t, err)
require.True(t, bit)

// Unset bit 2

err = mask.SetBit(2, false)
require.NoError(t, err)
require.Equal(t, uint8(0x2), mask.Mask()[0])
require.Equal(t, 1, len(mask.Participants()))

bit, err = mask.GetBit(2)
require.NoError(t, err)
require.False(t, bit)

// Set bit 10 (using byte 2 now)

err = mask.SetBit(10, true)
require.NoError(t, err)
require.Equal(t, uint8(0x2), mask.Mask()[0])
require.Equal(t, uint8(0x4), mask.Mask()[1])
require.Equal(t, 2, len(mask.Participants()))

bit, err = mask.GetBit(10)
require.NoError(t, err)
require.True(t, bit)

// And make sure the range limit works.

err = mask.SetBit(-1, true)
require.Error(t, err)
err = mask.SetBit(len(publics), true)
Expand Down

0 comments on commit 6f78267

Please sign in to comment.