From 0ba27504389c08cefe1c050fee86cd3b540822b3 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 24 Sep 2024 10:30:33 +0000 Subject: [PATCH] Optimize BDN Signature/Key Aggregation (#546) * Add BDN test fixtures * Remove n^2 algorithm from signature/key aggregation CountEnabled and IndexOfNthEnabled are both O(n) in the size of the mask, making this loop n^2. The BLS operations still tend to be the slow part, but the n^2 factor will start to show up with thousands of keys. * Remove an unnecessary loop from hashPointToR * Introduce a new CachedMask for BDN This new mask will pre-compute reusable values, speeding up repeated verification and aggregation of aggregate signatures (mostly the former). * Ignore golangci lint * Move Mask into BDN and remove the interface * fix docs Co-authored-by: AnomalRoil * Document mutability of Mask fields --------- Co-authored-by: AnomalRoil --- sign/bdn/bdn.go | 84 ++++++++++++--------------- sign/bdn/bdn_test.go | 110 ++++++++++++++++++++++++++++++++++-- sign/{ => bdn}/mask.go | 55 +++++++++++++++++- sign/{ => bdn}/mask_test.go | 47 +++++++++++++-- 4 files changed, 236 insertions(+), 60 deletions(-) rename sign/{ => bdn}/mask.go (67%) rename sign/{ => bdn}/mask_test.go (74%) diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index b7716f6ce..71dc63267 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -12,6 +12,7 @@ package bdn import ( "crypto/cipher" "errors" + "fmt" "math/big" "go.dedis.ch/kyber/v4" @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI // We also use the entire roster so that the coefficient will vary for the same // public key used in different roster func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) { - peers := make([][]byte, len(pubs)) - for i, pub := range pubs { - peer, err := pub.MarshalBinary() - if err != nil { - return nil, err - } - - peers[i] = peer - } - h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil) if err != nil { return nil, err } - - for _, peer := range peers { - _, err := h.Write(peer) + for _, pub := range pubs { + peer, err := pub.MarshalBinary() + if err != nil { + return nil, err + } + _, err = h.Write(peer) if err != nil { return nil, err } @@ -128,62 +122,58 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} -func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { - if len(sigs) != mask.CountEnabled() { - return nil, errors.New("length of signatures and public keys must match") - } - - coefs, err := hashPointToR(mask.Publics()) - if err != nil { - return nil, err - } - +func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *Mask) (kyber.Point, error) { agg := scheme.sigGroup.Point() - for i, buf := range sigs { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { - // this should never happen as we check the lenths at the beginning - // an error here is probably a bug in the mask - return nil, errors.New("couldn't find the index") + for i := range mask.publics { + if enabled, err := mask.GetBit(i); err != nil { + // this should never happen because of the loop boundary + // an error here is probably a bug in the mask implementation + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue + } + + if len(sigs) == 0 { + return nil, errors.New("length of signatures and public keys must match") } + buf := sigs[0] + sigs = sigs[1:] + sig := scheme.sigGroup.Point() - err = sig.UnmarshalBinary(buf) + err := sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(coefs[peerIndex], sig) + sigC := sig.Clone().Mul(mask.publicCoefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) } + if len(sigs) > 0 { + return nil, errors.New("length of signatures and public keys must match") + } + return agg, nil } // AggregatePublicKeys aggregates a set of public keys (similarly to // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. -func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) { - coefs, err := hashPointToR(mask.Publics()) - if err != nil { - return nil, err - } - +func (scheme *Scheme) AggregatePublicKeys(mask *Mask) (kyber.Point, error) { agg := scheme.keyGroup.Point() - for i := 0; i < mask.CountEnabled(); i++ { - peerIndex := mask.IndexOfNthEnabled(i) - if peerIndex < 0 { + for i := range mask.publics { + if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation - return nil, errors.New("couldn't find the index") + return nil, fmt.Errorf("couldn't find the index %d: %w", i, err) + } else if !enabled { + continue } - pub := mask.Publics()[peerIndex] - pubC := pub.Clone().Mul(coefs[peerIndex], pub) - pubC = pubC.Add(pubC, pub) - agg = agg.Add(agg, pubC) + agg = agg.Add(agg, mask.publicTerms[i]) } return agg, nil @@ -217,7 +207,7 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128} // Deprecated: use the new scheme methods instead. -func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) { +func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask) } @@ -225,6 +215,6 @@ func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (k // AggregateSignatures for signatures) using the hash function // H: G2 -> R with R = {1, ..., 2^128}. // Deprecated: use the new scheme methods instead. -func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) { +func AggregatePublicKeys(suite pairing.Suite, mask *Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregatePublicKeys(mask) } diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index 46aa659d8..b686e4a66 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -1,13 +1,15 @@ package bdn import ( + "encoding" + "encoding/hex" "fmt" "testing" "github.com/stretchr/testify/require" "go.dedis.ch/kyber/v4" + "go.dedis.ch/kyber/v4/pairing/bls12381/kilic" "go.dedis.ch/kyber/v4/pairing/bn256" - "go.dedis.ch/kyber/v4/sign" "go.dedis.ch/kyber/v4/sign/bls" "go.dedis.ch/kyber/v4/util/random" ) @@ -30,7 +32,7 @@ func TestBDN_HashPointToR_BN256(t *testing.T) { require.Equal(t, "933f6013eb3f654f9489d6d45ad04eaf", coefs[2].String()) require.Equal(t, 16, coefs[0].MarshalSize()) - mask, _ := sign.NewMask([]kyber.Point{p1, p2, p3}, nil) + mask, _ := NewMask([]kyber.Point{p1, p2, p3}, nil) mask.SetBit(0, true) mask.SetBit(1, true) mask.SetBit(2, true) @@ -54,7 +56,7 @@ func TestBDN_AggregateSignatures(t *testing.T) { sig2, err := Sign(suite, private2, msg) require.NoError(t, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public2}, nil) mask.SetBit(0, true) mask.SetBit(1, true) @@ -92,7 +94,7 @@ func TestBDN_SubsetSignature(t *testing.T) { sig2, err := Sign(suite, private2, msg) require.NoError(t, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public3, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public3, public2}, nil) mask.SetBit(0, true) mask.SetBit(2, true) @@ -131,7 +133,7 @@ func TestBDN_RogueAttack(t *testing.T) { require.NoError(t, scheme.Verify(agg, msg, sig)) // New scheme that should detect - mask, _ := sign.NewMask(pubs, nil) + mask, _ := NewMask(pubs, nil) mask.SetBit(0, true) mask.SetBit(1, true) agg, err = AggregatePublicKeys(suite, mask) @@ -149,7 +151,7 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { sig2, err := Sign(suite, private2, msg) require.Nil(b, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public2}, nil) mask.SetBit(0, true) mask.SetBit(1, false) @@ -158,3 +160,99 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { AggregateSignatures(suite, [][]byte{sig1, sig2}, mask) } } + +func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) { + suite := kilic.NewBLS12381Suite() + schemeOnG2 := NewSchemeOnG2(suite) + + rng := random.New() + pubKeys := make([]kyber.Point, 3000) + privKeys := make([]kyber.Scalar, 3000) + for i := range pubKeys { + privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng) + } + + mask, err := NewMask(pubKeys, nil) + require.NoError(b, err) + for i := range pubKeys { + require.NoError(b, mask.SetBit(i, true)) + } + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sigs := make([][]byte, len(privKeys)) + for i, k := range privKeys { + s, err := schemeOnG2.Sign(k, msg) + require.NoError(b, err) + sigs[i] = s + } + + sig, err := schemeOnG2.AggregateSignatures(sigs, mask) + require.NoError(b, err) + sigb, err := sig.MarshalBinary() + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pk, err := schemeOnG2.AggregatePublicKeys(mask) + require.NoError(b, err) + require.NoError(b, schemeOnG2.Verify(pk, msg, sigb)) + } +} + +func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T { + t.Helper() + b, err := hex.DecodeString(s) + require.NoError(t, err) + require.NoError(t, into.UnmarshalBinary(b)) + return into +} + +// This tests exists to make sure we don't accidentally make breaking changes to signature +// aggregation by using checking against known aggregated signatures and keys. +func TestBDNFixtures(t *testing.T) { + suite := bn256.NewSuite() + schemeOnG1 := NewSchemeOnG1(suite) + + public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493") + private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38") + public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f") + private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581") + public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec") + private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa") + + sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8") + require.NoError(t, err) + sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484") + require.NoError(t, err) + sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195") + require.NoError(t, err) + + aggSigExp := unmarshalHex(t, suite.G1().Point(), "43c1d2ad5a7d71a08f3cd7495db6b3c81a4547af1b76438b2f215e85ec178fea048f93f6ffed65a69ea757b47761e7178103bb347fd79689652e55b6e0054af2") + aggKeyExp := unmarshalHex(t, suite.G2().Point(), "43b5161ede207b9a69fc93114b0c5022b76cc22e813ba739c7e622d826b132333cd637505399963b94e393ec7f5d4875f82391620b34be1fde1f232204fa4f723935d4dbfb725f059456bcf2557f846c03190969f7b800e904d25b0b5bcbdd421c9877d443f0313c3425dfc1e7e646b665d27b9e649faadef1129f95670d70e1") + + msg := []byte("Hello many times Boneh-Lynn-Shacham") + sig1, err := schemeOnG1.Sign(private1, msg) + require.Nil(t, err) + require.Equal(t, sig1Exp, sig1) + + sig2, err := schemeOnG1.Sign(private2, msg) + require.Nil(t, err) + require.Equal(t, sig2Exp, sig2) + + sig3, err := schemeOnG1.Sign(private3, msg) + require.Nil(t, err) + require.Equal(t, sig3Exp, sig3) + + mask, _ := NewMask([]kyber.Point{public1, public2, public3}, nil) + mask.SetBit(0, true) + mask.SetBit(1, false) + mask.SetBit(2, true) + + aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig3}, mask) + require.NoError(t, err) + require.True(t, aggSigExp.Equal(aggSig)) + + aggKey, err := schemeOnG1.AggregatePublicKeys(mask) + require.NoError(t, err) + require.True(t, aggKeyExp.Equal(aggKey)) +} diff --git a/sign/mask.go b/sign/bdn/mask.go similarity index 67% rename from sign/mask.go rename to sign/bdn/mask.go index 7c68c9132..a6a6fd78c 100644 --- a/sign/mask.go +++ b/sign/bdn/mask.go @@ -1,21 +1,36 @@ -// Package sign contains useful tools for the different signing algorithms. -package sign +package bdn import ( "errors" "fmt" + "slices" "go.dedis.ch/kyber/v4" ) // Mask is a bitmask of the participation to a collective signature. type Mask struct { - mask []byte + // The bitmask indicating which public keys are enabled/disabled for aggregation. This is + // the only mutable field. + mask []byte + + // The following fields are immutable and should not be changed after the mask is created. + // They may be shared between multiple masks. + + // Public keys for aggregation & signature verification. publics []kyber.Point + // Coefficients used when aggregating signatures. + publicCoefs []kyber.Scalar + // Terms used to aggregate public keys + publicTerms []kyber.Point } // NewMask creates a new mask from a list of public keys. If a key is provided, it // will set the bit of the key to 1 or return an error if it is not found. +// +// The returned Mask will contain pre-computed terms and coefficients for all provided public +// keys, so it should be re-used for optimal performance (e.g., by creating a "base" mask and +// cloning it whenever aggregating signatures and/or public keys). func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) { m := &Mask{ publics: publics, @@ -33,6 +48,18 @@ func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) { return nil, errors.New("key not found") } + var err error + m.publicCoefs, err = hashPointToR(publics) + if err != nil { + return nil, fmt.Errorf("failed to hash public keys: %w", err) + } + + m.publicTerms = make([]kyber.Point, len(publics)) + for i, pub := range publics { + pubC := pub.Clone().Mul(m.publicCoefs[i], pub) + m.publicTerms[i] = pubC.Add(pubC, pub) + } + return m, nil } @@ -58,6 +85,17 @@ func (m *Mask) SetMask(mask []byte) error { return nil } +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + return m.mask[byteIndex]&mask != 0, nil +} + // SetBit turns on or off the bit at the given index. func (m *Mask) SetBit(i int, enable bool) error { if i >= len(m.publics) || i < 0 { @@ -170,3 +208,14 @@ func (m *Mask) Merge(mask []byte) error { return nil } + +// Clone copies the mask while keeping the precomputed coefficients, etc. This method is thread safe +// and does not modify the original mask. Modifications to the new Mask will not affect the original. +func (m *Mask) Clone() *Mask { + return &Mask{ + mask: slices.Clone(m.mask), + publics: m.publics, + publicCoefs: m.publicCoefs, + publicTerms: m.publicTerms, + } +} diff --git a/sign/mask_test.go b/sign/bdn/mask_test.go similarity index 74% rename from sign/mask_test.go rename to sign/bdn/mask_test.go index 7a3eeb118..f87cf2706 100644 --- a/sign/mask_test.go +++ b/sign/bdn/mask_test.go @@ -1,4 +1,4 @@ -package sign +package bdn import ( "crypto/rand" @@ -6,13 +6,11 @@ import ( "github.com/stretchr/testify/require" "go.dedis.ch/kyber/v4" - "go.dedis.ch/kyber/v4/pairing/bn256" "go.dedis.ch/kyber/v4/util/key" ) const n = 17 -var suite = bn256.NewSuiteBn256() var publics []kyber.Point func init() { @@ -49,22 +47,63 @@ func TestMask_SetBit(t *testing.T) { mask, err := NewMask(publics, publics[2]) require.NoError(t, err) + // Make sure the mask is initially as we'd expect. + + bit, err := mask.GetBit(1) + require.NoError(t, err) + require.False(t, bit) + + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) - // Set it again, nothing should change. + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Set bit 1 again, nothing should change + err = mask.SetBit(1, true) require.NoError(t, err) require.Equal(t, uint8(0x6), mask.Mask()[0]) require.Equal(t, 2, len(mask.Participants())) + bit, err = mask.GetBit(1) + require.NoError(t, err) + require.True(t, bit) + + // Unset bit 2 + err = mask.SetBit(2, false) require.NoError(t, err) require.Equal(t, uint8(0x2), mask.Mask()[0]) require.Equal(t, 1, len(mask.Participants())) + bit, err = mask.GetBit(2) + require.NoError(t, err) + require.False(t, bit) + + // Set bit 10 (using byte 2 now) + + err = mask.SetBit(10, true) + require.NoError(t, err) + require.Equal(t, uint8(0x2), mask.Mask()[0]) + require.Equal(t, uint8(0x4), mask.Mask()[1]) + require.Equal(t, 2, len(mask.Participants())) + + bit, err = mask.GetBit(10) + require.NoError(t, err) + require.True(t, bit) + + // And make sure the range limit works. + err = mask.SetBit(-1, true) require.Error(t, err) err = mask.SetBit(len(publics), true)