Skip to content

Commit

Permalink
The cluster supports shared subscriptions
Browse files Browse the repository at this point in the history
1.The cluster supports shared subscriptions
2.Log display and other logical optimization
  • Loading branch information
wind-c committed Dec 13, 2023
1 parent 0bbf3cc commit 1a9fa37
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 282 deletions.
149 changes: 99 additions & 50 deletions cluster/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ package cluster
import (
"bytes"
"context"
"math/rand"
"net"
"path"
"path/filepath"
"strconv"
"strings"

"github.com/panjf2000/ants/v2"
"github.com/wind-c/comqtt/v2/cluster/discovery"
Expand Down Expand Up @@ -156,18 +158,6 @@ func (a *Agent) initPool() error {
return nil
}

func (a *Agent) SubmitOutTask(pk *packets.Packet) {
a.OutPool.Submit(func() {
a.processOutboundPacket(pk)
})
}

func (a *Agent) SubmitRaftTask(msg *message.Message) {
a.raftPool.Submit(func() {
a.raftPropose(msg)
})
}

func (a *Agent) Stop() {
a.cancel()
a.OutPool.Release()
Expand Down Expand Up @@ -264,7 +254,6 @@ func (a *Agent) raftApplyListener() {
} else {
continue
}
log.Info("apply listening", "from", msg.NodeID, "filter", filter, "type", msg.Type)
case <-a.ctx.Done():
return
}
Expand All @@ -275,7 +264,7 @@ func (a *Agent) raftApplyListener() {
func (a *Agent) raftPropose(msg *message.Message) {
if a.raftPeer.IsApplyRight() {
err := a.raftPeer.Propose(msg)
OnApplyLog(a.GetLocalName(), msg.NodeID, msg.Type, msg.Payload, "apply raft log", err)
OnApplyLog(a.GetLocalName(), msg.NodeID, msg.Type, msg.Payload, "raft apply log", err)
} else { //send to leader apply
_, leaderId := a.raftPeer.GetLeader()
if leaderId == "" {
Expand All @@ -284,14 +273,14 @@ func (a *Agent) raftPropose(msg *message.Message) {
} else {
a.membership.SendToOthers(msg.MsgpackBytes())
}
OnApplyLog("unknown", msg.NodeID, msg.Type, msg.Payload, "broadcast raft log", nil)
OnApplyLog("unknown", msg.NodeID, msg.Type, msg.Payload, "raft broadcast log", nil)
} else {
if a.Config.GrpcEnable {
a.grpcClientManager.RelayRaftApply(leaderId, msg)
} else {
a.membership.SendToNode(leaderId, msg.MsgpackBytes())
}
OnApplyLog(leaderId, msg.NodeID, msg.Type, msg.Payload, "forward raft log", nil)
OnApplyLog(leaderId, msg.NodeID, msg.Type, msg.Payload, "raft forward log", nil)
}
}
}
Expand Down Expand Up @@ -362,7 +351,7 @@ func (a *Agent) processRelayMsg(msg *message.Message) {
}
offset := len(msg.Payload) - pk.FixedHeader.Remaining // Unpack fixedheader.
if err := pk.PublishDecode(msg.Payload[offset:]); err == nil { // Unpack skips fixedheader
a.mqttServer.PublishToSubscribers(pk)
a.mqttServer.PublishToSubscribers(pk, false)
OnPublishPacketLog(DirectionInbound, msg.NodeID, msg.ClientID, pk.TopicName, pk.PacketID)
}
case packets.Connect:
Expand Down Expand Up @@ -413,51 +402,111 @@ func (a *Agent) processInboundMsg() {
}
}

// processOutboundPacket process outbound msg
func (a *Agent) processOutboundPacket(pk *packets.Packet) {
func (a *Agent) SubmitOutPublishTask(pk *packets.Packet, sharedFilters map[string]bool) {
a.OutPool.Submit(func() {
a.processOutboundPublish(pk, sharedFilters)
})
}

func (a *Agent) SubmitOutConnectTask(pk *packets.Packet) {
a.OutPool.Submit(func() {
a.processOutboundConnect(pk)
})
}

func (a *Agent) SubmitRaftTask(msg *message.Message) {
a.raftPool.Submit(func() {
a.raftPropose(msg)
})
}

// processOutboundPublish process outbound publish msg
func (a *Agent) processOutboundPublish(pk *packets.Packet, sharedFilters map[string]bool) {
msg := message.Message{
NodeID: a.Config.NodeName,
ClientID: pk.Origin,
ProtocolVersion: pk.ProtocolVersion,
}
switch pk.FixedHeader.Type {
case packets.Publish:
var buf bytes.Buffer
pk.Mods.AllowResponseInfo = true
if err := pk.PublishEncode(&buf); err != nil {
return

var buf bytes.Buffer
pk.Mods.AllowResponseInfo = true
if err := pk.PublishEncode(&buf); err != nil {
return
}
msg.Type = packets.Publish
msg.Payload = buf.Bytes()
tmpFilters := a.subTree.Scan(pk.TopicName, make([]string, 0))
oldNodes := make([]string, 0)
filters := make([]string, 0)
for _, filter := range tmpFilters {
if !utils.Contains(filters, filter) {
filters = append(filters, filter)
}
msg.Type = packets.Publish
msg.Payload = buf.Bytes()
filters := a.subTree.Scan(pk.TopicName, make([]string, 0))
oldNodes := make([]string, 0)
for _, filter := range filters {
ns := a.raftPeer.Lookup(filter)
for _, node := range ns {
if node != a.GetLocalName() && !utils.Contains(oldNodes, node) {
if a.Config.GrpcEnable {
a.grpcClientManager.RelayPublishPacket(node, &msg)
} else {
bs := msg.MsgpackBytes()
a.membership.SendToNode(node, bs)
}
oldNodes = append(oldNodes, node)
OnPublishPacketLog(DirectionOutbound, node, pk.Origin, pk.TopicName, pk.PacketID)
}
for _, filter := range filters {
ns := a.pickNodes(filter, sharedFilters)
for _, node := range ns {
if node != a.GetLocalName() && !utils.Contains(oldNodes, node) {
if a.Config.GrpcEnable {
a.grpcClientManager.RelayPublishPacket(node, &msg)
} else {
bs := msg.MsgpackBytes()
a.membership.SendToNode(node, bs)
}
oldNodes = append(oldNodes, node)
OnPublishPacketLog(DirectionOutbound, node, pk.Origin, pk.TopicName, pk.PacketID)
}
}
case packets.Connect:
msg.Type = packets.Connect
if msg.ClientID == "" {
msg.ClientID = pk.Connect.ClientIdentifier
}
}

// processOutboundConnect process outbound connect msg
func (a *Agent) processOutboundConnect(pk *packets.Packet) {
msg := message.Message{
NodeID: a.Config.NodeName,
ClientID: pk.Origin,
ProtocolVersion: pk.ProtocolVersion,
}

msg.Type = packets.Connect
if msg.ClientID == "" {
msg.ClientID = pk.Connect.ClientIdentifier
}
if a.Config.GrpcEnable {
a.grpcClientManager.ConnectNotifyToOthers(&msg)
} else {
a.membership.SendToOthers(msg.MsgpackBytes())
}
OnConnectPacketLog(DirectionOutbound, a.GetLocalName(), msg.ClientID)
}

// pickNodes pick nodes, if the filter is shared, select a node at random
func (a *Agent) pickNodes(filter string, sharedFilters map[string]bool) (ns []string) {
tmpNs := a.raftPeer.Lookup(filter)
if len(tmpNs) == 0 {
return ns
}

if strings.HasPrefix(filter, topics.SharePrefix) {
if b, ok := sharedFilters[filter]; ok && b {
return ns
}
if a.Config.GrpcEnable {
a.grpcClientManager.ConnectNotifyToOthers(&msg)
} else {
a.membership.SendToOthers(msg.MsgpackBytes())

for _, n := range tmpNs {
// The shared subscription is local priority, indicating that it has been sent
if n == a.GetLocalName() {
return ns
}
}
OnConnectPacketLog(DirectionOutbound, a.GetLocalName(), msg.ClientID)
// Share subscription Select a node at random
n := tmpNs[rand.Intn(len(tmpNs))]
ns = []string{n}
return ns
}

// Not shared subscriptions are returned as is
ns = tmpNs
return
}

func OnJoinLog(nodeId, addr, prompt string, err error) {
Expand Down
23 changes: 17 additions & 6 deletions cluster/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (h *MqttEventHook) Provides(b byte) bool {
mqtt.OnSessionEstablished,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnPublished,
mqtt.OnPublishedWithSharedFilters,
mqtt.OnWillSent,
}, []byte{b})
}
Expand All @@ -54,23 +54,34 @@ func (h *MqttEventHook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet)
if cl.InheritWay != mqtt.InheritWayRemote {
return
}
h.agent.SubmitOutTask(&pk)
if pk.Connect.ClientIdentifier == "" && cl != nil {
pk.Connect.ClientIdentifier = cl.ID
}
h.agent.SubmitOutConnectTask(&pk)
}

// OnPublished is called when a client has published a message to subscribers.
func (h *MqttEventHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
//func (h *MqttEventHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
// if pk.Connect.ClientIdentifier == "" && cl != nil {
// pk.Connect.ClientIdentifier = cl.ID
// }
// h.agent.SubmitOutTask(&pk)
//}

// OnPublishedWithSharedFilters is called when a client has published a message to cluster.
func (h *MqttEventHook) OnPublishedWithSharedFilters(pk packets.Packet, sharedFilters map[string]bool) {
if pk.Connect.ClientIdentifier == "" {
pk.Connect.ClientIdentifier = cl.ID
pk.Connect.ClientIdentifier = pk.Origin
}
h.agent.SubmitOutTask(&pk)
h.agent.SubmitOutPublishTask(&pk, sharedFilters)
}

// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *MqttEventHook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
if pk.Connect.ClientIdentifier == "" {
pk.Connect.ClientIdentifier = cl.ID
}
h.agent.SubmitOutTask(&pk)
h.agent.SubmitOutPublishTask(&pk, nil)
}

// OnSubscribed is called when a client subscribes to one or more filters.
Expand Down
86 changes: 85 additions & 1 deletion cluster/raft/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

package raft

import "github.com/wind-c/comqtt/v2/cluster/message"
import (
"github.com/wind-c/comqtt/v2/cluster/message"
"github.com/wind-c/comqtt/v2/cluster/utils"
"sync"
)

type IPeer interface {
Join(nodeID, addr string) error
Expand All @@ -16,3 +20,83 @@ type IPeer interface {
GenPeersFile(file string) error
Stop()
}

type data map[string][]string

type KV struct {
data data
sync.RWMutex
}

func NewKV() *KV {
return &KV{
data: make(map[string][]string),
}
}

func (k *KV) GetAll() *data {
return &k.data
}

func (k *KV) Get(key string) []string {
k.RLock()
defer k.RUnlock()
vs := k.data[key]
return vs
}

// Add return true if key is set for the first time
func (k *KV) Add(key, value string) (new bool) {
k.Lock()
defer k.Unlock()
if vs, ok := k.data[key]; ok {
if utils.Contains(vs, value) {
return
}
k.data[key] = append(vs, value)
} else {
k.data[key] = []string{value}
new = true
}
return
}

// Del return true if the array corresponding to key is deleted
func (k *KV) Del(key, value string) (empty bool) {
k.Lock()
defer k.Unlock()
if vs, ok := k.data[key]; ok {
if utils.Contains(vs, value) {
for i, item := range vs {
if item == value {
k.data[key] = append(vs[:i], vs[i+1:]...)
}
}
}

if value == "" || len(vs) == 0 {
delete(k.data, key)
empty = true
}
}
return
}

func (k *KV) DelByValue(value string) int {
k.Lock()
defer k.Unlock()
c := 0
for f, vs := range k.data {
for i, v := range vs {
if v == value {
if len(vs) == 1 {
delete(k.data, f)
} else {
k.data[f] = append(vs[:i], vs[i+1:]...)
}
c++
}
}
}
return c
}
Loading

0 comments on commit 1a9fa37

Please sign in to comment.