From 5ec4ce2aa463f010114e5287f02c0261700de1dc Mon Sep 17 00:00:00 2001 From: Wind <573966@qq.com> Date: Thu, 14 Dec 2023 13:37:59 +0800 Subject: [PATCH] Fix failed test cases --- cluster/agent.go | 2 +- cluster/raft/base.go | 18 +++++++---- cluster/raft/kv_test.go | 61 +++++++++++++++++++++++++++++++++++++ cluster/topics/trie_test.go | 18 +++++++++++ mqtt/server_test.go | 26 ++++++++-------- 5 files changed, 105 insertions(+), 20 deletions(-) create mode 100644 cluster/raft/kv_test.go diff --git a/cluster/agent.go b/cluster/agent.go index f482fb9..9247c1e 100644 --- a/cluster/agent.go +++ b/cluster/agent.go @@ -483,7 +483,7 @@ func (a *Agent) processOutboundConnect(pk *packets.Packet) { // pickNodes pick nodes, if the filter is shared, select a node at random func (a *Agent) pickNodes(filter string, sharedFilters map[string]bool) (ns []string) { tmpNs := a.raftPeer.Lookup(filter) - if len(tmpNs) == 0 { + if tmpNs == nil || len(tmpNs) == 0 { return ns } diff --git a/cluster/raft/base.go b/cluster/raft/base.go index ca28ec1..16f9f98 100644 --- a/cluster/raft/base.go +++ b/cluster/raft/base.go @@ -62,6 +62,7 @@ func (k *KV) Add(key, value string) (new bool) { } // Del return true if the array corresponding to key is deleted +// If the value is "", the key-values pair is deleted func (k *KV) Del(key, value string) (empty bool) { k.Lock() defer k.Unlock() @@ -74,7 +75,7 @@ func (k *KV) Del(key, value string) (empty bool) { } } - if value == "" || len(vs) == 0 { + if value == "" || len(k.data[key]) == 0 { delete(k.data, key) empty = true } @@ -82,21 +83,26 @@ func (k *KV) Del(key, value string) (empty bool) { return } +// DelByValue delete the specified value from the key-values array +// and delete the key-value pair if the key-values array is empty func (k *KV) DelByValue(value string) int { k.Lock() defer k.Unlock() c := 0 + if value == "" { + return c + } + for f, vs := range k.data { for i, v := range vs { if v == value { - if len(vs) == 1 { - delete(k.data, f) - } else { - k.data[f] = append(vs[:i], vs[i+1:]...) - } + k.data[f] = append(vs[:i], vs[i+1:]...) c++ } } + if len(k.data[f]) == 0 { + delete(k.data, f) + } } return c } diff --git a/cluster/raft/kv_test.go b/cluster/raft/kv_test.go new file mode 100644 index 0000000..f670180 --- /dev/null +++ b/cluster/raft/kv_test.go @@ -0,0 +1,61 @@ +package raft + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestKV_Add(t *testing.T) { + kv := NewKV() + + new := kv.Add("key1", "value1") + require.Equal(t, true, new) + new = kv.Add("key1", "value2") + require.Equal(t, false, new) + ln := len(kv.Get("key1")) + require.Equal(t, 2, ln) + + new = kv.Add("key2", "value1") + ln = len(kv.Get("key1")) + require.Equal(t, true, new) + require.Equal(t, 2, ln) +} + +func TestKV_Del(t *testing.T) { + kv := NewKV() + + kv.Add("key1", "value1") + kv.Add("key1", "value2") + ln := len(kv.Get("key1")) + require.Equal(t, 2, ln) + + empty := kv.Del("key1", "value3") + require.Equal(t, false, empty) + empty = kv.Del("key1", "value2") + require.Equal(t, false, empty) + empty = kv.Del("key1", "value1") + require.Equal(t, true, empty) +} + +func TestKV_DelByValue(t *testing.T) { + kv := NewKV() + + // Test for value exists in multiple keys + kv.Add("key1", "value1") + kv.Add("key2", "value1") + kv.Add("key3", "value1") + c := kv.DelByValue("value1") + require.Equal(t, 3, c) + vs := kv.Get("key1") + require.Empty(t, nil, vs) + ln := len(kv.data) + require.Equal(t, 0, ln) + + kv.Add("key4", "value4") + kv.Add("key5", "value5") + kv.DelByValue("value4") + ln = len(kv.data) + require.Equal(t, 1, ln) + vs = kv.Get("key5") + require.EqualValues(t, []string{"value5"}, vs) +} diff --git a/cluster/topics/trie_test.go b/cluster/topics/trie_test.go index 2b160c5..a471754 100644 --- a/cluster/topics/trie_test.go +++ b/cluster/topics/trie_test.go @@ -313,3 +313,21 @@ func BenchmarkIsolateParticle(b *testing.B) { isolateParticle("path/to/my/mqtt", 3) } } + +func TestConvertSharedFilter(t *testing.T) { + srcFilter := "$share/group1/filter/1" + + group, destFilter := convertSharedFilter(srcFilter) + require.Equal(t, "group1/", group) + require.Equal(t, "filter/1", destFilter) +} + +func TestRestoreShareFilter(t *testing.T) { + filter := "filter" + sharedGroups := []string{"Group1/", "Group2/", "Group3/"} + + result := restoreShareFilter(filter, sharedGroups) + expectedResult := []string{"$share/Group1/filter", "$share/Group2/filter", "$share/Group3/filter"} + + require.EqualValues(t, expectedResult, result) +} diff --git a/mqtt/server_test.go b/mqtt/server_test.go index 3171bb0..a67f87e 100644 --- a/mqtt/server_test.go +++ b/mqtt/server_test.go @@ -1689,7 +1689,7 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet pkx.Origin = cl.ID - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) _ = w.Close() }() @@ -1747,7 +1747,7 @@ func TestPublishToSubscribers(t *testing.T) { }() go func() { - s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) _ = w1.Close() _ = w2.Close() @@ -1791,7 +1791,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet pkx.Created = time.Now().Unix() - 30 - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) _ = w1.Close() }() @@ -1815,7 +1815,7 @@ func TestPublishToSubscribersIdentifiers(t *testing.T) { require.True(t, subbed) go func() { - s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) _ = w.Close() }() @@ -1840,7 +1840,7 @@ func TestPublishToSubscribersPkIgnore(t *testing.T) { go func() { pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet pk.Ignore = true - s.PublishToSubscribers(pk) + s.publishToSubscribers(pk) time.Sleep(time.Millisecond) _ = w.Close() }() @@ -2071,7 +2071,7 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) { _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) _ = w.Close() } @@ -2093,7 +2093,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { _ = r.Close() pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet pkx.PacketID = 0 - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) time.Sleep(time.Millisecond) _ = w.Close() } @@ -2109,7 +2109,7 @@ func TestPublishToSubscribersNoConnection(t *testing.T) { // coverage: subscriber publish errors are non-returnable // can we hook into zerolog ? _ = r.Close() - s.PublishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) + s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet) time.Sleep(time.Millisecond) _ = w.Close() } @@ -2566,7 +2566,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { }() if i == 0 { - s.PublishToSubscribers(*tx.in.Packet) + s.publishToSubscribers(*tx.in.Packet) } else { err := s.processPacket(cl, *tx.in.Packet) require.NoError(t, err) @@ -3495,7 +3495,7 @@ func TestPublishToInlineSubscriber(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) }() require.Equal(t, true, <-finishCh) @@ -3528,10 +3528,10 @@ func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) }() for i := 0; i < subNumber; i++ { @@ -3566,7 +3566,7 @@ func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) { go func() { pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet - s.PublishToSubscribers(pkx) + s.publishToSubscribers(pkx) }() for i := 0; i < subNumber; i++ {