From 7bbbe51a05b7c5baa25e8ef740e5c6dc22b62d47 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 9 Sep 2024 11:03:16 -0700 Subject: [PATCH] Move Mask into BDN and remove the interface --- sign/bdn/bdn.go | 27 ++-- sign/bdn/bdn_test.go | 17 +-- sign/bdn/mask.go | 246 +++++++++++++++++++++++++----------- sign/{ => bdn}/mask_test.go | 4 +- sign/mask.go | 183 --------------------------- 5 files changed, 188 insertions(+), 289 deletions(-) rename sign/{ => bdn}/mask_test.go (97%) delete mode 100644 sign/mask.go diff --git a/sign/bdn/bdn.go b/sign/bdn/bdn.go index cd93ef73..71dc6326 100644 --- a/sign/bdn/bdn.go +++ b/sign/bdn/bdn.go @@ -122,13 +122,9 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128} -func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) { - bdnMask, err := newCachedMask(mask, false) - if err != nil { - return nil, err - } +func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *Mask) (kyber.Point, error) { agg := scheme.sigGroup.Point() - for i := range bdnMask.publics { + for i := range mask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -145,12 +141,12 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point sigs = sigs[1:] sig := scheme.sigGroup.Point() - err = sig.UnmarshalBinary(buf) + err := sig.UnmarshalBinary(buf) if err != nil { return nil, err } - sigC := sig.Clone().Mul(bdnMask.coefs[i], sig) + sigC := sig.Clone().Mul(mask.publicCoefs[i], sig) // c+1 because R is in the range [1, 2^128] and not [0, 2^128-1] sigC = sigC.Add(sigC, sig) agg = agg.Add(agg, sigC) @@ -166,14 +162,9 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point // AggregatePublicKeys aggregates a set of public keys (similarly to // AggregateSignatures for signatures) using the hash function // H: keyGroup -> R with R = {1, ..., 2^128}. -func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) { - bdnMask, err := newCachedMask(mask, false) - if err != nil { - return nil, err - } - +func (scheme *Scheme) AggregatePublicKeys(mask *Mask) (kyber.Point, error) { agg := scheme.keyGroup.Point() - for i := range bdnMask.publics { + for i := range mask.publics { if enabled, err := mask.GetBit(i); err != nil { // this should never happen because of the loop boundary // an error here is probably a bug in the mask implementation @@ -182,7 +173,7 @@ func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) { continue } - agg = agg.Add(agg, bdnMask.getOrComputePubC(i)) + agg = agg.Add(agg, mask.publicTerms[i]) } return agg, nil @@ -216,7 +207,7 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error { // AggregateSignatures aggregates the signatures using a coefficient for each // one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128} // Deprecated: use the new scheme methods instead. -func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) { +func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask) } @@ -224,6 +215,6 @@ func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.P // AggregateSignatures for signatures) using the hash function // H: G2 -> R with R = {1, ..., 2^128}. // Deprecated: use the new scheme methods instead. -func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) { +func AggregatePublicKeys(suite pairing.Suite, mask *Mask) (kyber.Point, error) { return NewSchemeOnG1(suite).AggregatePublicKeys(mask) } diff --git a/sign/bdn/bdn_test.go b/sign/bdn/bdn_test.go index 55f929aa..b686e4a6 100644 --- a/sign/bdn/bdn_test.go +++ b/sign/bdn/bdn_test.go @@ -10,7 +10,6 @@ import ( "go.dedis.ch/kyber/v4" "go.dedis.ch/kyber/v4/pairing/bls12381/kilic" "go.dedis.ch/kyber/v4/pairing/bn256" - "go.dedis.ch/kyber/v4/sign" "go.dedis.ch/kyber/v4/sign/bls" "go.dedis.ch/kyber/v4/util/random" ) @@ -33,7 +32,7 @@ func TestBDN_HashPointToR_BN256(t *testing.T) { require.Equal(t, "933f6013eb3f654f9489d6d45ad04eaf", coefs[2].String()) require.Equal(t, 16, coefs[0].MarshalSize()) - mask, _ := sign.NewMask([]kyber.Point{p1, p2, p3}, nil) + mask, _ := NewMask([]kyber.Point{p1, p2, p3}, nil) mask.SetBit(0, true) mask.SetBit(1, true) mask.SetBit(2, true) @@ -57,7 +56,7 @@ func TestBDN_AggregateSignatures(t *testing.T) { sig2, err := Sign(suite, private2, msg) require.NoError(t, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public2}, nil) mask.SetBit(0, true) mask.SetBit(1, true) @@ -95,7 +94,7 @@ func TestBDN_SubsetSignature(t *testing.T) { sig2, err := Sign(suite, private2, msg) require.NoError(t, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public3, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public3, public2}, nil) mask.SetBit(0, true) mask.SetBit(2, true) @@ -134,7 +133,7 @@ func TestBDN_RogueAttack(t *testing.T) { require.NoError(t, scheme.Verify(agg, msg, sig)) // New scheme that should detect - mask, _ := sign.NewMask(pubs, nil) + mask, _ := NewMask(pubs, nil) mask.SetBit(0, true) mask.SetBit(1, true) agg, err = AggregatePublicKeys(suite, mask) @@ -152,7 +151,7 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) { sig2, err := Sign(suite, private2, msg) require.Nil(b, err) - mask, _ := sign.NewMask([]kyber.Point{public1, public2}, nil) + mask, _ := NewMask([]kyber.Point{public1, public2}, nil) mask.SetBit(0, true) mask.SetBit(1, false) @@ -173,9 +172,7 @@ func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) { privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng) } - baseMask, err := sign.NewMask(pubKeys, nil) - require.NoError(b, err) - mask, err := NewCachedMask(baseMask) + mask, err := NewMask(pubKeys, nil) require.NoError(b, err) for i := range pubKeys { require.NoError(b, mask.SetBit(i, true)) @@ -246,7 +243,7 @@ func TestBDNFixtures(t *testing.T) { require.Nil(t, err) require.Equal(t, sig3Exp, sig3) - mask, _ := sign.NewMask([]kyber.Point{public1, public2, public3}, nil) + mask, _ := NewMask([]kyber.Point{public1, public2, public3}, nil) mask.SetBit(0, true) mask.SetBit(1, false) mask.SetBit(2, true) diff --git a/sign/bdn/mask.go b/sign/bdn/mask.go index 3cb90bc6..c83c9457 100644 --- a/sign/bdn/mask.go +++ b/sign/bdn/mask.go @@ -1,113 +1,209 @@ package bdn import ( + "errors" "fmt" + "slices" "go.dedis.ch/kyber/v4" - "go.dedis.ch/kyber/v4/sign" ) -//nolint:interfacebloat -type Mask interface { - GetBit(i int) (bool, error) - SetBit(i int, enable bool) error +// Mask is a bitmask of the participation to a collective signature. +type Mask struct { + mask []byte + publics []kyber.Point + + // Coefficients used when aggregating signatures. + publicCoefs []kyber.Scalar + // Terms used to aggregate public keys + publicTerms []kyber.Point +} - IndexOfNthEnabled(nth int) int - NthEnabledAtIndex(idx int) int +// NewMask creates a new mask from a list of public keys. If a key is provided, it +// will set the bit of the key to 1 or return an error if it is not found. +func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) { + m := &Mask{ + publics: publics, + } + m.mask = make([]byte, m.Len()) + + if myKey != nil { + for i, key := range publics { + if key.Equal(myKey) { + err := m.SetBit(i, true) + return m, err + } + } + + return nil, errors.New("key not found") + } + + var err error + m.publicCoefs, err = hashPointToR(publics) + if err != nil { + return nil, fmt.Errorf("failed to hash public keys: %w", err) + } - Publics() []kyber.Point - Participants() []kyber.Point + m.publicTerms = make([]kyber.Point, len(publics)) + for i, pub := range publics { + pubC := pub.Clone().Mul(m.publicCoefs[i], pub) + m.publicTerms[i] = pubC.Add(pubC, pub) + } - CountEnabled() int - CountTotal() int + return m, nil +} - Len() int - Mask() []byte - SetMask(mask []byte) error - Merge(mask []byte) error +// Mask returns the bitmask as a byte array. +func (m *Mask) Mask() []byte { + clone := make([]byte, len(m.mask)) + copy(clone, m.mask) + return clone } -var _ Mask = (*sign.Mask)(nil) +// Len returns the length of the byte array necessary to store the bitmask. +func (m *Mask) Len() int { + return (len(m.publics) + 7) / 8 +} -// We need to rename this, otherwise we have a public field named Mask (when we embed it) which -// conflicts with the function named Mask. It also makes it private, which is nice. -type maskI = Mask +// SetMask replaces the current mask by the new one if the length matches. +func (m *Mask) SetMask(mask []byte) error { + if m.Len() != len(mask) { + return fmt.Errorf("mismatching mask lengths") + } -type CachedMask struct { - maskI - coefs []kyber.Scalar - pubKeyC []kyber.Point - // We could call Mask.Publics() instead of keeping these here, but that function copies the - // slice and this field lets us avoid that copy. - publics []kyber.Point + m.mask = mask + return nil } -// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms. -// -// This cached mask will: -// -// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated, -// distinct sets of signatures can be aggregated without any BLAKE2S hashing. -// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated, -// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders -// of magnitude faster than aggregating from scratch. -func NewCachedMask(mask Mask) (*CachedMask, error) { - return newCachedMask(mask, true) +// GetBit returns true if the given bit is set. +func (m *Mask) GetBit(i int) (bool, error) { + if i >= len(m.publics) || i < 0 { + return false, errors.New("index out of range") + } + + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + return m.mask[byteIndex]&mask != 0, nil } -func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) { - if m, ok := mask.(*CachedMask); ok { - return m, nil +// SetBit turns on or off the bit at the given index. +func (m *Mask) SetBit(i int, enable bool) error { + if i >= len(m.publics) || i < 0 { + return errors.New("index out of range") } - publics := mask.Publics() - coefs, err := hashPointToR(publics) - if err != nil { - return nil, fmt.Errorf("failed to hash public keys: %w", err) + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + if enable { + m.mask[byteIndex] |= mask + } else { + m.mask[byteIndex] &^= mask } + return nil +} - cm := &CachedMask{ - maskI: mask, - coefs: coefs, - publics: publics, +// forEachBitEnabled is a helper to iterate over the bits set to 1 in the mask +// and to return the result of the callback only if it is positive. +func (m *Mask) forEachBitEnabled(f func(i, j, n int) int) int { + n := 0 + for i, b := range m.mask { + for j := uint(0); j < 8; j++ { + mm := byte(1) << (j & 7) + + if b&mm != 0 { + if res := f(i, int(j), n); res >= 0 { + return res + } + + n++ + } + } } - if precomputePubC { - pubKeyC := make([]kyber.Point, len(publics)) - for i := range publics { - pubKeyC[i] = cm.getOrComputePubC(i) + return -1 +} + +// IndexOfNthEnabled returns the index of the nth enabled bit or -1 if out of bounds. +func (m *Mask) IndexOfNthEnabled(nth int) int { + return m.forEachBitEnabled(func(i, j, n int) int { + if n == nth { + return i*8 + int(j) + } + + return -1 + }) +} + +// NthEnabledAtIndex returns the sum of bits set to 1 until the given index. In other +// words, it returns how many bits are enabled before the given index. +func (m *Mask) NthEnabledAtIndex(idx int) int { + return m.forEachBitEnabled(func(i, j, n int) int { + if i*8+int(j) == idx { + return n + } + + return -1 + }) +} + +// Publics returns a copy of the list of public keys. +func (m *Mask) Publics() []kyber.Point { + pubs := make([]kyber.Point, len(m.publics)) + copy(pubs, m.publics) + return pubs +} + +// Participants returns the list of public keys participating. +func (m *Mask) Participants() []kyber.Point { + pp := []kyber.Point{} + for i, p := range m.publics { + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + if (m.mask[byteIndex] & mask) != 0 { + pp = append(pp, p) } - cm.pubKeyC = pubKeyC } - return cm, err + return pp } -// Clone copies the BDN mask while keeping the precomputed coefficients, etc. -func (cm *CachedMask) Clone() *CachedMask { - newMask, err := sign.NewMask(cm.publics, nil) - if err != nil { - // Not possible given that we didn't pass our own key. - panic(fmt.Sprintf("failed to create mask: %s", err)) +// CountEnabled returns the number of bit set to 1 +func (m *Mask) CountEnabled() int { + count := 0 + for i := range m.publics { + byteIndex := i / 8 + mask := byte(1) << uint(i&7) + if (m.mask[byteIndex] & mask) != 0 { + count++ + } } - if err := newMask.SetMask(cm.Mask()); err != nil { - // Not possible given that we're using the same sized mask. - panic(fmt.Sprintf("failed to create mask: %s", err)) + return count +} + +// CountTotal returns the number of potential participants +func (m *Mask) CountTotal() int { + return len(m.publics) +} + +// Merge merges the given mask to the current one only if the length matches. +func (m *Mask) Merge(mask []byte) error { + if len(m.mask) != len(mask) { + return errors.New("mismatching mask length") } - return &CachedMask{ - maskI: newMask, - coefs: cm.coefs, - pubKeyC: cm.pubKeyC, - publics: cm.publics, + + for i := range m.mask { + m.mask[i] |= mask[i] } + + return nil } -func (cm *CachedMask) getOrComputePubC(i int) kyber.Point { - if cm.pubKeyC == nil { - // NOTE: don't cache here as we may be sharing this mask between threads. - pub := cm.publics[i] - pubC := pub.Clone().Mul(cm.coefs[i], pub) - return pubC.Add(pubC, pub) +// Clone copies the mask while keeping the precomputed coefficients, etc. +func (m *Mask) Clone() *Mask { + return &Mask{ + mask: slices.Clone(m.mask), + publics: m.publics, + publicCoefs: m.publicCoefs, + publicTerms: m.publicTerms, } - return cm.pubKeyC[i] } diff --git a/sign/mask_test.go b/sign/bdn/mask_test.go similarity index 97% rename from sign/mask_test.go rename to sign/bdn/mask_test.go index 5eb91175..f87cf270 100644 --- a/sign/mask_test.go +++ b/sign/bdn/mask_test.go @@ -1,4 +1,4 @@ -package sign +package bdn import ( "crypto/rand" @@ -6,13 +6,11 @@ import ( "github.com/stretchr/testify/require" "go.dedis.ch/kyber/v4" - "go.dedis.ch/kyber/v4/pairing/bn256" "go.dedis.ch/kyber/v4/util/key" ) const n = 17 -var suite = bn256.NewSuiteBn256() var publics []kyber.Point func init() { diff --git a/sign/mask.go b/sign/mask.go deleted file mode 100644 index d02692c3..00000000 --- a/sign/mask.go +++ /dev/null @@ -1,183 +0,0 @@ -// Package sign contains useful tools for the different signing algorithms. -package sign - -import ( - "errors" - "fmt" - - "go.dedis.ch/kyber/v4" -) - -// Mask is a bitmask of the participation to a collective signature. -type Mask struct { - mask []byte - publics []kyber.Point -} - -// NewMask creates a new mask from a list of public keys. If a key is provided, it -// will set the bit of the key to 1 or return an error if it is not found. -func NewMask(publics []kyber.Point, myKey kyber.Point) (*Mask, error) { - m := &Mask{ - publics: publics, - } - m.mask = make([]byte, m.Len()) - - if myKey != nil { - for i, key := range publics { - if key.Equal(myKey) { - err := m.SetBit(i, true) - return m, err - } - } - - return nil, errors.New("key not found") - } - - return m, nil -} - -// Mask returns the bitmask as a byte array. -func (m *Mask) Mask() []byte { - clone := make([]byte, len(m.mask)) - copy(clone, m.mask) - return clone -} - -// Len returns the length of the byte array necessary to store the bitmask. -func (m *Mask) Len() int { - return (len(m.publics) + 7) / 8 -} - -// SetMask replaces the current mask by the new one if the length matches. -func (m *Mask) SetMask(mask []byte) error { - if m.Len() != len(mask) { - return fmt.Errorf("mismatching mask lengths") - } - - m.mask = mask - return nil -} - -// GetBit returns true if the given bit is set. -func (m *Mask) GetBit(i int) (bool, error) { - if i >= len(m.publics) || i < 0 { - return false, errors.New("index out of range") - } - - byteIndex := i / 8 - mask := byte(1) << uint(i&7) - return m.mask[byteIndex]&mask != 0, nil -} - -// SetBit turns on or off the bit at the given index. -func (m *Mask) SetBit(i int, enable bool) error { - if i >= len(m.publics) || i < 0 { - return errors.New("index out of range") - } - - byteIndex := i / 8 - mask := byte(1) << uint(i&7) - if enable { - m.mask[byteIndex] |= mask - } else { - m.mask[byteIndex] &^= mask - } - return nil -} - -// forEachBitEnabled is a helper to iterate over the bits set to 1 in the mask -// and to return the result of the callback only if it is positive. -func (m *Mask) forEachBitEnabled(f func(i, j, n int) int) int { - n := 0 - for i, b := range m.mask { - for j := uint(0); j < 8; j++ { - mm := byte(1) << (j & 7) - - if b&mm != 0 { - if res := f(i, int(j), n); res >= 0 { - return res - } - - n++ - } - } - } - - return -1 -} - -// IndexOfNthEnabled returns the index of the nth enabled bit or -1 if out of bounds. -func (m *Mask) IndexOfNthEnabled(nth int) int { - return m.forEachBitEnabled(func(i, j, n int) int { - if n == nth { - return i*8 + int(j) - } - - return -1 - }) -} - -// NthEnabledAtIndex returns the sum of bits set to 1 until the given index. In other -// words, it returns how many bits are enabled before the given index. -func (m *Mask) NthEnabledAtIndex(idx int) int { - return m.forEachBitEnabled(func(i, j, n int) int { - if i*8+int(j) == idx { - return n - } - - return -1 - }) -} - -// Publics returns a copy of the list of public keys. -func (m *Mask) Publics() []kyber.Point { - pubs := make([]kyber.Point, len(m.publics)) - copy(pubs, m.publics) - return pubs -} - -// Participants returns the list of public keys participating. -func (m *Mask) Participants() []kyber.Point { - pp := []kyber.Point{} - for i, p := range m.publics { - byteIndex := i / 8 - mask := byte(1) << uint(i&7) - if (m.mask[byteIndex] & mask) != 0 { - pp = append(pp, p) - } - } - - return pp -} - -// CountEnabled returns the number of bit set to 1 -func (m *Mask) CountEnabled() int { - count := 0 - for i := range m.publics { - byteIndex := i / 8 - mask := byte(1) << uint(i&7) - if (m.mask[byteIndex] & mask) != 0 { - count++ - } - } - return count -} - -// CountTotal returns the number of potential participants -func (m *Mask) CountTotal() int { - return len(m.publics) -} - -// Merge merges the given mask to the current one only if -// the length matches -func (m *Mask) Merge(mask []byte) error { - if len(m.mask) != len(mask) { - return errors.New("mismatching mask length") - } - - for i := range m.mask { - m.mask[i] |= mask[i] - } - - return nil -}