From 1a9fa375e5fbc198a9f4475f2e3c88bfa25f5574 Mon Sep 17 00:00:00 2001 From: Wind <573966@qq.com> Date: Wed, 13 Dec 2023 20:45:12 +0800 Subject: [PATCH] The cluster supports shared subscriptions 1.The cluster supports shared subscriptions 2.Log display and other logical optimization --- cluster/agent.go | 149 ++++++++++++++++++++++----------- cluster/events.go | 23 +++-- cluster/raft/base.go | 86 ++++++++++++++++++- cluster/raft/etcd/kvstore.go | 102 ++++++---------------- cluster/raft/etcd/peer.go | 6 +- cluster/raft/hashicorp/fsm.go | 124 ++++++--------------------- cluster/raft/hashicorp/peer.go | 2 +- cluster/topics/trie.go | 129 +++++++++++++++++++--------- mqtt/hooks.go | 14 ++++ mqtt/server.go | 47 +++++++++-- mqtt/topics.go | 13 ++- plugin/bridge/kafka/kafka.go | 2 + 12 files changed, 415 insertions(+), 282 deletions(-) diff --git a/cluster/agent.go b/cluster/agent.go index 31c62f1..f482fb9 100644 --- a/cluster/agent.go +++ b/cluster/agent.go @@ -7,10 +7,12 @@ package cluster import ( "bytes" "context" + "math/rand" "net" "path" "path/filepath" "strconv" + "strings" "github.com/panjf2000/ants/v2" "github.com/wind-c/comqtt/v2/cluster/discovery" @@ -156,18 +158,6 @@ func (a *Agent) initPool() error { return nil } -func (a *Agent) SubmitOutTask(pk *packets.Packet) { - a.OutPool.Submit(func() { - a.processOutboundPacket(pk) - }) -} - -func (a *Agent) SubmitRaftTask(msg *message.Message) { - a.raftPool.Submit(func() { - a.raftPropose(msg) - }) -} - func (a *Agent) Stop() { a.cancel() a.OutPool.Release() @@ -264,7 +254,6 @@ func (a *Agent) raftApplyListener() { } else { continue } - log.Info("apply listening", "from", msg.NodeID, "filter", filter, "type", msg.Type) case <-a.ctx.Done(): return } @@ -275,7 +264,7 @@ func (a *Agent) raftApplyListener() { func (a *Agent) raftPropose(msg *message.Message) { if a.raftPeer.IsApplyRight() { err := a.raftPeer.Propose(msg) - OnApplyLog(a.GetLocalName(), msg.NodeID, msg.Type, msg.Payload, "apply raft log", err) + OnApplyLog(a.GetLocalName(), msg.NodeID, msg.Type, msg.Payload, "raft apply log", err) } else { //send to leader apply _, leaderId := a.raftPeer.GetLeader() if leaderId == "" { @@ -284,14 +273,14 @@ func (a *Agent) raftPropose(msg *message.Message) { } else { a.membership.SendToOthers(msg.MsgpackBytes()) } - OnApplyLog("unknown", msg.NodeID, msg.Type, msg.Payload, "broadcast raft log", nil) + OnApplyLog("unknown", msg.NodeID, msg.Type, msg.Payload, "raft broadcast log", nil) } else { if a.Config.GrpcEnable { a.grpcClientManager.RelayRaftApply(leaderId, msg) } else { a.membership.SendToNode(leaderId, msg.MsgpackBytes()) } - OnApplyLog(leaderId, msg.NodeID, msg.Type, msg.Payload, "forward raft log", nil) + OnApplyLog(leaderId, msg.NodeID, msg.Type, msg.Payload, "raft forward log", nil) } } } @@ -362,7 +351,7 @@ func (a *Agent) processRelayMsg(msg *message.Message) { } offset := len(msg.Payload) - pk.FixedHeader.Remaining // Unpack fixedheader. if err := pk.PublishDecode(msg.Payload[offset:]); err == nil { // Unpack skips fixedheader - a.mqttServer.PublishToSubscribers(pk) + a.mqttServer.PublishToSubscribers(pk, false) OnPublishPacketLog(DirectionInbound, msg.NodeID, msg.ClientID, pk.TopicName, pk.PacketID) } case packets.Connect: @@ -413,51 +402,111 @@ func (a *Agent) processInboundMsg() { } } -// processOutboundPacket process outbound msg -func (a *Agent) processOutboundPacket(pk *packets.Packet) { +func (a *Agent) SubmitOutPublishTask(pk *packets.Packet, sharedFilters map[string]bool) { + a.OutPool.Submit(func() { + a.processOutboundPublish(pk, sharedFilters) + }) +} + +func (a *Agent) SubmitOutConnectTask(pk *packets.Packet) { + a.OutPool.Submit(func() { + a.processOutboundConnect(pk) + }) +} + +func (a *Agent) SubmitRaftTask(msg *message.Message) { + a.raftPool.Submit(func() { + a.raftPropose(msg) + }) +} + +// processOutboundPublish process outbound publish msg +func (a *Agent) processOutboundPublish(pk *packets.Packet, sharedFilters map[string]bool) { msg := message.Message{ NodeID: a.Config.NodeName, ClientID: pk.Origin, ProtocolVersion: pk.ProtocolVersion, } - switch pk.FixedHeader.Type { - case packets.Publish: - var buf bytes.Buffer - pk.Mods.AllowResponseInfo = true - if err := pk.PublishEncode(&buf); err != nil { - return + + var buf bytes.Buffer + pk.Mods.AllowResponseInfo = true + if err := pk.PublishEncode(&buf); err != nil { + return + } + msg.Type = packets.Publish + msg.Payload = buf.Bytes() + tmpFilters := a.subTree.Scan(pk.TopicName, make([]string, 0)) + oldNodes := make([]string, 0) + filters := make([]string, 0) + for _, filter := range tmpFilters { + if !utils.Contains(filters, filter) { + filters = append(filters, filter) } - msg.Type = packets.Publish - msg.Payload = buf.Bytes() - filters := a.subTree.Scan(pk.TopicName, make([]string, 0)) - oldNodes := make([]string, 0) - for _, filter := range filters { - ns := a.raftPeer.Lookup(filter) - for _, node := range ns { - if node != a.GetLocalName() && !utils.Contains(oldNodes, node) { - if a.Config.GrpcEnable { - a.grpcClientManager.RelayPublishPacket(node, &msg) - } else { - bs := msg.MsgpackBytes() - a.membership.SendToNode(node, bs) - } - oldNodes = append(oldNodes, node) - OnPublishPacketLog(DirectionOutbound, node, pk.Origin, pk.TopicName, pk.PacketID) + } + for _, filter := range filters { + ns := a.pickNodes(filter, sharedFilters) + for _, node := range ns { + if node != a.GetLocalName() && !utils.Contains(oldNodes, node) { + if a.Config.GrpcEnable { + a.grpcClientManager.RelayPublishPacket(node, &msg) + } else { + bs := msg.MsgpackBytes() + a.membership.SendToNode(node, bs) } + oldNodes = append(oldNodes, node) + OnPublishPacketLog(DirectionOutbound, node, pk.Origin, pk.TopicName, pk.PacketID) } } - case packets.Connect: - msg.Type = packets.Connect - if msg.ClientID == "" { - msg.ClientID = pk.Connect.ClientIdentifier + } +} + +// processOutboundConnect process outbound connect msg +func (a *Agent) processOutboundConnect(pk *packets.Packet) { + msg := message.Message{ + NodeID: a.Config.NodeName, + ClientID: pk.Origin, + ProtocolVersion: pk.ProtocolVersion, + } + + msg.Type = packets.Connect + if msg.ClientID == "" { + msg.ClientID = pk.Connect.ClientIdentifier + } + if a.Config.GrpcEnable { + a.grpcClientManager.ConnectNotifyToOthers(&msg) + } else { + a.membership.SendToOthers(msg.MsgpackBytes()) + } + OnConnectPacketLog(DirectionOutbound, a.GetLocalName(), msg.ClientID) +} + +// pickNodes pick nodes, if the filter is shared, select a node at random +func (a *Agent) pickNodes(filter string, sharedFilters map[string]bool) (ns []string) { + tmpNs := a.raftPeer.Lookup(filter) + if len(tmpNs) == 0 { + return ns + } + + if strings.HasPrefix(filter, topics.SharePrefix) { + if b, ok := sharedFilters[filter]; ok && b { + return ns } - if a.Config.GrpcEnable { - a.grpcClientManager.ConnectNotifyToOthers(&msg) - } else { - a.membership.SendToOthers(msg.MsgpackBytes()) + + for _, n := range tmpNs { + // The shared subscription is local priority, indicating that it has been sent + if n == a.GetLocalName() { + return ns + } } - OnConnectPacketLog(DirectionOutbound, a.GetLocalName(), msg.ClientID) + // Share subscription Select a node at random + n := tmpNs[rand.Intn(len(tmpNs))] + ns = []string{n} + return ns } + + // Not shared subscriptions are returned as is + ns = tmpNs + return } func OnJoinLog(nodeId, addr, prompt string, err error) { diff --git a/cluster/events.go b/cluster/events.go index b1fc3c4..c82c7f7 100644 --- a/cluster/events.go +++ b/cluster/events.go @@ -30,7 +30,7 @@ func (h *MqttEventHook) Provides(b byte) bool { mqtt.OnSessionEstablished, mqtt.OnSubscribed, mqtt.OnUnsubscribed, - mqtt.OnPublished, + mqtt.OnPublishedWithSharedFilters, mqtt.OnWillSent, }, []byte{b}) } @@ -54,15 +54,26 @@ func (h *MqttEventHook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) if cl.InheritWay != mqtt.InheritWayRemote { return } - h.agent.SubmitOutTask(&pk) + if pk.Connect.ClientIdentifier == "" && cl != nil { + pk.Connect.ClientIdentifier = cl.ID + } + h.agent.SubmitOutConnectTask(&pk) } // OnPublished is called when a client has published a message to subscribers. -func (h *MqttEventHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { +//func (h *MqttEventHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { +// if pk.Connect.ClientIdentifier == "" && cl != nil { +// pk.Connect.ClientIdentifier = cl.ID +// } +// h.agent.SubmitOutTask(&pk) +//} + +// OnPublishedWithSharedFilters is called when a client has published a message to cluster. +func (h *MqttEventHook) OnPublishedWithSharedFilters(pk packets.Packet, sharedFilters map[string]bool) { if pk.Connect.ClientIdentifier == "" { - pk.Connect.ClientIdentifier = cl.ID + pk.Connect.ClientIdentifier = pk.Origin } - h.agent.SubmitOutTask(&pk) + h.agent.SubmitOutPublishTask(&pk, sharedFilters) } // OnWillSent is called when an LWT message has been issued from a disconnecting client. @@ -70,7 +81,7 @@ func (h *MqttEventHook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { if pk.Connect.ClientIdentifier == "" { pk.Connect.ClientIdentifier = cl.ID } - h.agent.SubmitOutTask(&pk) + h.agent.SubmitOutPublishTask(&pk, nil) } // OnSubscribed is called when a client subscribes to one or more filters. diff --git a/cluster/raft/base.go b/cluster/raft/base.go index fc59f60..ca28ec1 100644 --- a/cluster/raft/base.go +++ b/cluster/raft/base.go @@ -4,7 +4,11 @@ package raft -import "github.com/wind-c/comqtt/v2/cluster/message" +import ( + "github.com/wind-c/comqtt/v2/cluster/message" + "github.com/wind-c/comqtt/v2/cluster/utils" + "sync" +) type IPeer interface { Join(nodeID, addr string) error @@ -16,3 +20,83 @@ type IPeer interface { GenPeersFile(file string) error Stop() } + +type data map[string][]string + +type KV struct { + data data + sync.RWMutex +} + +func NewKV() *KV { + return &KV{ + data: make(map[string][]string), + } +} + +func (k *KV) GetAll() *data { + return &k.data +} + +func (k *KV) Get(key string) []string { + k.RLock() + defer k.RUnlock() + vs := k.data[key] + return vs +} + +// Add return true if key is set for the first time +func (k *KV) Add(key, value string) (new bool) { + k.Lock() + defer k.Unlock() + if vs, ok := k.data[key]; ok { + if utils.Contains(vs, value) { + return + } + k.data[key] = append(vs, value) + } else { + k.data[key] = []string{value} + new = true + } + return +} + +// Del return true if the array corresponding to key is deleted +func (k *KV) Del(key, value string) (empty bool) { + k.Lock() + defer k.Unlock() + if vs, ok := k.data[key]; ok { + if utils.Contains(vs, value) { + for i, item := range vs { + if item == value { + k.data[key] = append(vs[:i], vs[i+1:]...) + } + } + } + + if value == "" || len(vs) == 0 { + delete(k.data, key) + empty = true + } + } + return +} + +func (k *KV) DelByValue(value string) int { + k.Lock() + defer k.Unlock() + c := 0 + for f, vs := range k.data { + for i, v := range vs { + if v == value { + if len(vs) == 1 { + delete(k.data, f) + } else { + k.data[f] = append(vs[:i], vs[i+1:]...) + } + c++ + } + } + } + return c +} diff --git a/cluster/raft/etcd/kvstore.go b/cluster/raft/etcd/kvstore.go index e2864dc..09b42b5 100644 --- a/cluster/raft/etcd/kvstore.go +++ b/cluster/raft/etcd/kvstore.go @@ -7,19 +7,18 @@ package etcd import ( "bytes" "encoding/gob" - "sync" - "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/cluster/message" + base "github.com/wind-c/comqtt/v2/cluster/raft" "github.com/wind-c/comqtt/v2/mqtt/packets" "go.etcd.io/etcd/raft/v3/raftpb" "go.etcd.io/etcd/server/v3/etcdserver/api/snap" + "strings" ) // KVStore is a key-value store backed by raft type KVStore struct { - mu sync.RWMutex - data map[string][]string // current committed key-value pairs + *base.KV snapshotter *snap.Snapshotter commitC <-chan *commit errorC <-chan error @@ -28,7 +27,7 @@ type KVStore struct { func newKVStore(snapshotter *snap.Snapshotter, commitC <-chan *commit, errorC <-chan error, notifyCh chan<- *message.Message) *KVStore { s := &KVStore{ - data: make(map[string][]string), + KV: base.NewKV(), snapshotter: snapshotter, commitC: commitC, errorC: errorC, @@ -49,62 +48,12 @@ func newKVStore(snapshotter *snap.Snapshotter, commitC <-chan *commit, errorC <- return s } -func (s *KVStore) Lookup(key string) ([]string, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - v, ok := s.data[key] - return v, ok +func (s *KVStore) Lookup(key string) []string { + return s.Get(key) } -func (s *KVStore) Del(key, value string) { - s.mu.Lock() - defer s.mu.Unlock() - if vs, ok := s.data[key]; ok { - if value == "" || len(vs) == 1 { - delete(s.data, key) - return - } - - for i, item := range vs { - if item == value { - s.data[key] = append(vs[:i], vs[i+1:]...) - } - } - } -} - -func (d *KVStore) DelByValue(value string) int { - d.mu.Lock() - defer d.mu.Unlock() - c := 0 - for k, vs := range d.data { - for i, v := range vs { - if v == value { - if len(vs) == 1 { - delete(d.data, k) - } else { - d.data[k] = append(vs[:i], vs[i+1:]...) - } - c++ - } - } - } - return c -} - -func (s *KVStore) Add(key, value string) { - s.mu.Lock() - defer s.mu.Unlock() - if vs, ok := s.data[key]; ok { - for _, item := range vs { - if item == value { - return - } - } - s.data[key] = append(vs, value) - } else { - s.data[key] = []string{value} - } +func (s *KVStore) DelByNode(node string) int { + return s.DelByValue(node) } func (s *KVStore) GetErrorC(key, value string) <-chan error { @@ -133,15 +82,19 @@ func (s *KVStore) readCommits() { if err := msg.MsgpackLoad(data); err != nil { continue } + filter := string(msg.Payload) + deliverable := false if msg.Type == packets.Subscribe { - s.Add(string(msg.Payload), msg.NodeID) + deliverable = s.Add(filter, msg.NodeID) } else if msg.Type == packets.Unsubscribe { - s.Del(string(msg.Payload), msg.NodeID) + deliverable = s.Del(filter, msg.NodeID) } else { continue } - - s.notifyCh <- &msg + log.Info("raft apply", "from", msg.NodeID, "filter", filter, "type", msg.Type) + if s.notifyCh != nil && deliverable { + s.notifyCh <- &msg + } } close(commit.applyDoneC) } @@ -151,10 +104,8 @@ func (s *KVStore) readCommits() { } func (s *KVStore) getSnapshot() ([]byte, error) { - s.mu.RLock() - defer s.mu.RUnlock() var buffer bytes.Buffer - if err := gob.NewEncoder(&buffer).Encode(s.data); err != nil { + if err := gob.NewEncoder(&buffer).Encode(s.GetAll()); err != nil { return nil, err } return buffer.Bytes(), nil @@ -173,10 +124,8 @@ func (s *KVStore) loadSnapshot() (*raftpb.Snapshot, error) { } func (s *KVStore) recoverFromSnapshot(snapshot []byte) error { - s.mu.Lock() - defer s.mu.Unlock() buffer := bytes.NewBuffer(snapshot) - if err := gob.NewDecoder(buffer).Decode(&s.data); err != nil { + if err := gob.NewDecoder(buffer).Decode(s.GetAll()); err != nil { return err } s.notifyReplay() @@ -184,14 +133,13 @@ func (s *KVStore) recoverFromSnapshot(snapshot []byte) error { } func (s *KVStore) notifyReplay() { - for filter, ns := range s.data { - for _, nodeId := range ns { - msg := message.Message{ - Type: packets.Subscribe, - NodeID: nodeId, - Payload: []byte(filter), - } - s.notifyCh <- &msg + for filter, ns := range *s.GetAll() { + msg := message.Message{ + Type: packets.Subscribe, + NodeID: strings.Join(ns, ","), + Payload: []byte(filter), } + s.notifyCh <- &msg + log.Info("raft replay", "from", msg.NodeID, "filter", filter, "type", msg.Type) } } diff --git a/cluster/raft/etcd/peer.go b/cluster/raft/etcd/peer.go index eefc37e..57d16ca 100644 --- a/cluster/raft/etcd/peer.go +++ b/cluster/raft/etcd/peer.go @@ -173,10 +173,14 @@ func (p *Peer) Propose(msg *message.Message) error { } func (p *Peer) Lookup(key string) []string { - rs, _ := p.kvStore.Lookup(key) + rs := p.kvStore.Lookup(key) return rs } +func (p *Peer) DelByNode(node string) int { + return p.kvStore.DelByNode(node) +} + func getZapLogger() *zap.Logger { encoderCfg := zapcore.EncoderConfig{ MessageKey: "msg", diff --git a/cluster/raft/hashicorp/fsm.go b/cluster/raft/hashicorp/fsm.go index 065c07c..472fdb3 100644 --- a/cluster/raft/hashicorp/fsm.go +++ b/cluster/raft/hashicorp/fsm.go @@ -7,23 +7,23 @@ package hashicorp import ( "bytes" "encoding/gob" - "io" - "sync" - + "github.com/hashicorp/raft" + "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/cluster/message" + base "github.com/wind-c/comqtt/v2/cluster/raft" "github.com/wind-c/comqtt/v2/mqtt/packets" - - "github.com/hashicorp/raft" + "io" + "strings" ) type Fsm struct { - kv KVStore + *base.KV notifyCh chan<- *message.Message } func NewFsm(notifyCh chan<- *message.Message) *Fsm { fsm := &Fsm{ - kv: NewKVStore(), + KV: base.NewKV(), notifyCh: notifyCh, } return fsm @@ -35,34 +35,36 @@ func (f *Fsm) Apply(l *raft.Log) interface{} { return nil } filter := string(msg.Payload) + deliverable := false if msg.Type == packets.Subscribe { - f.kv.Set(filter, msg.NodeID) + deliverable = f.Add(filter, msg.NodeID) } else if msg.Type == packets.Unsubscribe { - f.kv.Del(filter, msg.NodeID) + deliverable = f.Del(filter, msg.NodeID) } else { return nil } - if f.notifyCh != nil { + log.Info("raft apply", "from", msg.NodeID, "filter", filter, "type", msg.Type) + if f.notifyCh != nil && deliverable { f.notifyCh <- &msg } return nil } -func (f *Fsm) Search(key string) []string { - return f.kv.Get(key) +func (f *Fsm) Lookup(key string) []string { + return f.Get(key) } func (f *Fsm) DelByNode(node string) int { - return f.kv.DelByValue(node) + return f.DelByValue(node) } func (f *Fsm) Snapshot() (raft.FSMSnapshot, error) { - return &f.kv, nil + return f, nil } func (f *Fsm) Restore(ir io.ReadCloser) error { - if err := gob.NewDecoder(ir).Decode(&f.kv.Data); err != nil { + if err := gob.NewDecoder(ir).Decode(f.GetAll()); err != nil { return err } f.notifyReplay() @@ -70,92 +72,20 @@ func (f *Fsm) Restore(ir io.ReadCloser) error { } func (f *Fsm) notifyReplay() { - for filter, ns := range f.kv.Data { - for _, nodeId := range ns { - msg := message.Message{ - Type: packets.Subscribe, - NodeID: nodeId, - Payload: []byte(filter), - } - f.notifyCh <- &msg - } - } -} - -type KVStore struct { - Data map[string][]string - mu sync.RWMutex -} - -func NewKVStore() KVStore { - return KVStore{ - Data: make(map[string][]string), - } -} - -func (d *KVStore) Get(key string) []string { - d.mu.RLock() - defer d.mu.RUnlock() - vs := d.Data[key] - return vs -} - -func (d *KVStore) Set(key, value string) { - d.mu.Lock() - defer d.mu.Unlock() - if vs, ok := d.Data[key]; ok { - for _, item := range vs { - if item == value { - return - } - } - d.Data[key] = append(vs, value) - } else { - d.Data[key] = []string{value} - } -} - -func (d *KVStore) Del(key, value string) { - d.mu.Lock() - defer d.mu.Unlock() - if vs, ok := d.Data[key]; ok { - if value == "" || len(vs) == 1 { - delete(d.Data, key) - return - } - - for i, item := range vs { - if item == value { - d.Data[key] = append(vs[:i], vs[i+1:]...) - } - } - } -} - -func (d *KVStore) DelByValue(value string) int { - d.mu.Lock() - defer d.mu.Unlock() - c := 0 - for k, vs := range d.Data { - for i, v := range vs { - if v == value { - if len(vs) == 1 { - delete(d.Data, k) - } else { - d.Data[k] = append(vs[:i], vs[i+1:]...) - } - c++ - } + for filter, ns := range *f.GetAll() { + msg := message.Message{ + Type: packets.Subscribe, + NodeID: strings.Join(ns, ","), + Payload: []byte(filter), } + f.notifyCh <- &msg + log.Info("raft replay", "from", msg.NodeID, "filter", filter, "type", msg.Type) } - return c } -func (d *KVStore) Persist(sink raft.SnapshotSink) error { - d.mu.Lock() - defer d.mu.Unlock() +func (f *Fsm) Persist(sink raft.SnapshotSink) error { var buffer bytes.Buffer - err := gob.NewEncoder(&buffer).Encode(d.Data) + err := gob.NewEncoder(&buffer).Encode(f.GetAll()) if err != nil { return err } @@ -164,4 +94,4 @@ func (d *KVStore) Persist(sink raft.SnapshotSink) error { return nil } -func (d *KVStore) Release() {} +func (f *Fsm) Release() {} diff --git a/cluster/raft/hashicorp/peer.go b/cluster/raft/hashicorp/peer.go index 54ac4ca..de73189 100644 --- a/cluster/raft/hashicorp/peer.go +++ b/cluster/raft/hashicorp/peer.go @@ -216,7 +216,7 @@ func (p *Peer) Propose(msg *message.Message) error { } func (p *Peer) Lookup(key string) []string { - return p.fsm.Search(key) + return p.fsm.Lookup(key) } func (p *Peer) DelByNode(node string) int { diff --git a/cluster/topics/trie.go b/cluster/topics/trie.go index 76db927..9402761 100644 --- a/cluster/topics/trie.go +++ b/cluster/topics/trie.go @@ -9,6 +9,10 @@ import ( "sync" ) +const ( + SharePrefix = "$share/" // The lower prefix of the shared topic +) + // Subscriptions is a map of subscriptions keyed on client. type Subscriptions map[string]byte @@ -33,10 +37,7 @@ func New() *Index { func (x *Index) Subscribe(filter string) bool { x.mu.Lock() defer x.mu.Unlock() - n := x.poperate(filter) - n.Filter = filter - return n.Count > 0 } @@ -48,6 +49,41 @@ func (x *Index) Unsubscribe(filter string) bool { return x.unpoperate(filter) } +// poperate iterates and populates through a filter path, instantiating +// leaves as it goes and returning the final leaf in the branch. +// poperate is a more enjoyable word than iterpop. +func (x *Index) poperate(filter string) *Leaf { + var d int + var particle string + var hasNext = true + n := x.Root + group, filter := convertSharedFilter(filter) + for hasNext { + particle, hasNext = isolateParticle(filter, d) + d++ + + child, _ := n.Leaves[particle] + if child == nil { + child = &Leaf{ + Key: particle, + Parent: n, + Leaves: make(map[string]*Leaf), + Count: 0, + SharedGroups: make([]string, 0), + } + n.Leaves[particle] = child + } + n = child + } + n.Count++ + n.Filter = filter + if group != "" { + n.SharedGroups = append(n.SharedGroups, group) + } + + return n +} + // unpoperate steps backward through a trie sequence and removes any orphaned // nodes. If a client id is specified, it will unsubscribe a client. If message // is true, it will delete a retained message. @@ -56,6 +92,7 @@ func (x *Index) unpoperate(filter string) bool { var particle string var hasNext = true e := x.Root + group, filter := convertSharedFilter(filter) for hasNext { particle, hasNext = isolateParticle(filter, d) d++ @@ -80,6 +117,14 @@ func (x *Index) unpoperate(filter string) bool { if e.Count > 0 { e.Count-- } + if group != "" { + for i, v := range e.SharedGroups { + if v == group { + e.SharedGroups = append(e.SharedGroups[:i], e.SharedGroups[i+1:]...) + break + } + } + } end = false } @@ -98,35 +143,6 @@ func (x *Index) unpoperate(filter string) bool { return true } -// poperate iterates and populates through a topic/filter path, instantiating -// leaves as it goes and returning the final leaf in the branch. -// poperate is a more enjoyable word than iterpop. -func (x *Index) poperate(topic string) *Leaf { - var d int - var particle string - var hasNext = true - n := x.Root - for hasNext { - particle, hasNext = isolateParticle(topic, d) - d++ - - child, _ := n.Leaves[particle] - if child == nil { - child = &Leaf{ - Key: particle, - Parent: n, - Leaves: make(map[string]*Leaf), - Count: 0, - } - n.Leaves[particle] = child - } - n = child - } - n.Count++ - - return n -} - // Scan returns true if a matching filter exists func (x *Index) Scan(topic string, filters []string) []string { x.mu.RLock() @@ -136,11 +152,12 @@ func (x *Index) Scan(topic string, filters []string) []string { // Leaf is a child node on the tree. type Leaf struct { - Key string // the key that was used to create the leaf. - Parent *Leaf // a pointer to the parent node for the leaf. - Leaves map[string]*Leaf // a map of child nodes, keyed on particle id. - Filter string // the path of the topic filter being matched. - Count int + Key string // the key that was used to create the leaf. + Parent *Leaf // a pointer to the parent node for the leaf. + Leaves map[string]*Leaf // a map of child nodes, keyed on particle id. + Filter string // the path of the topic filter being matched. + Count int // the number of nodes subscribed to the topic. + SharedGroups []string // the shared topics of this leaf. } // scanSubscribers recursively steps through a branch of leaves finding clients who @@ -167,14 +184,26 @@ func (l *Leaf) scan(topic string, d int, filters []string) []string { if !hasNext || particle == "#" { // matching the topic. if child.Filter != "" { - filters = append(filters, child.Filter) + if child.Count > len(child.SharedGroups) { + filters = append(filters, child.Filter) + } + if len(child.SharedGroups) > 0 { + groups := restoreShareFilter(child.Filter, child.SharedGroups) + filters = append(filters, groups...) + } } // Make sure we also capture any client who are listening // to this topic via path/# if !hasNext { if extra, ok := child.Leaves["#"]; ok { - filters = append(filters, extra.Filter) + if extra.Count > len(extra.SharedGroups) { + filters = append(filters, extra.Filter) + } + if len(extra.SharedGroups) > 0 { + groups := restoreShareFilter(extra.Filter, extra.SharedGroups) + filters = append(filters, groups...) + } } } } @@ -211,6 +240,28 @@ func isolateParticle(filter string, d int) (particle string, hasNext bool) { return } +// convertSharedFilter converts a shared filter to a regular filter and a group. +func convertSharedFilter(srcFilter string) (group, destFilter string) { + if strings.HasPrefix(srcFilter, SharePrefix) { + prefixLen := len(SharePrefix) + end := strings.IndexRune(srcFilter[prefixLen:], '/') + group = srcFilter[prefixLen : end+prefixLen+1] + destFilter = srcFilter[end+prefixLen+1:] + } else { + destFilter = srcFilter + } + return +} + +// restoreShareFilter restores a filter to a shared filters. +func restoreShareFilter(filter string, sharedGroups []string) []string { + fs := make([]string, len(sharedGroups)) + for i, v := range sharedGroups { + fs[i] = SharePrefix + v + filter + } + return fs +} + // ReLeaf is a dev function for showing the trie leafs. /* func ReLeaf(m string, leaf *Leaf, d int) { diff --git a/mqtt/hooks.go b/mqtt/hooks.go index 9746675..a607ada 100644 --- a/mqtt/hooks.go +++ b/mqtt/hooks.go @@ -50,6 +50,7 @@ const ( OnWillSent OnClientExpired OnRetainedExpired + OnPublishedWithSharedFilters StoredClients StoredSubscriptions StoredInflightMessages @@ -106,6 +107,7 @@ type Hook interface { OnWillSent(cl *Client, pk packets.Packet) OnClientExpired(cl *Client) OnRetainedExpired(filter string) + OnPublishedWithSharedFilters(pk packets.Packet, sharedFilters map[string]bool) StoredClients() ([]storage.Client, error) StoredSubscriptions() ([]storage.Subscription, error) StoredInflightMessages() ([]storage.Message, error) @@ -549,6 +551,15 @@ func (h *Hooks) OnRetainedExpired(filter string) { } } +// OnPublishedWithSharedFilters is called when a client has published a message to cluster. +func (h *Hooks) OnPublishedWithSharedFilters(pk packets.Packet, sharedFilters map[string]bool) { + for _, hook := range h.GetAll() { + if hook.Provides(OnPublishedWithSharedFilters) { + hook.OnPublishedWithSharedFilters(pk, sharedFilters) + } + } +} + // StoredClients returns all clients, e.g. from a persistent store, is used to // populate the server clients list before start. func (h *Hooks) StoredClients() (v []storage.Client, err error) { @@ -912,6 +923,9 @@ func (h *HookBase) OnClientExpired(cl *Client) {} // OnRetainedExpired is called when a retained message for a topic has expired. func (h *HookBase) OnRetainedExpired(topic string) {} +// OnClusterPublish is called when a client has published a message to cluster. +func (h *HookBase) OnClusterPublish(pk packets.Packet, sharedFilters map[string]bool) {} + // StoredClients returns all clients from a store. func (h *HookBase) StoredClients() (v []storage.Client, err error) { return diff --git a/mqtt/server.go b/mqtt/server.go index bf2fdf6..98c209d 100644 --- a/mqtt/server.go +++ b/mqtt/server.go @@ -14,6 +14,7 @@ import ( "runtime" "sort" "strconv" + "strings" "sync/atomic" "time" @@ -265,7 +266,6 @@ func (s *Server) AddListener(l listeners.Listener) error { } s.Listeners.Add(l) - s.Log.Info("attached listener", "id", l.ID(), "protocol", l.Protocol(), "address", l.Address()) return nil } @@ -872,7 +872,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { // When it publishes a package with a qos > 0, the server treats // the package as qos=0, and the client receives it as qos=1 or 2. if pk.FixedHeader.Qos == 0 || cl.Net.Inline { - s.PublishToSubscribers(pk) + s.publishToSubscribers(pk) s.hooks.OnPublished(cl, pk) return nil } @@ -901,7 +901,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { s.hooks.OnQosComplete(cl, ack) } - s.PublishToSubscribers(pk) + s.publishToSubscribers(pk) s.hooks.OnPublished(cl, pk) return nil @@ -921,7 +921,13 @@ func (s *Server) retainMessage(cl *Client, pk packets.Packet) { } // PublishToSubscribers publishes a publish packet to all subscribers with matching topic filters. -func (s *Server) PublishToSubscribers(pk packets.Packet) { +func (s *Server) publishToSubscribers(pk packets.Packet) { + s.PublishToSubscribers(pk, true) +} + +// PublishToSubscribers publishes a publish packet to all subscribers with matching topic filters. +// local: true indicates the current process call,false indicates external forwarding +func (s *Server) PublishToSubscribers(pk packets.Packet, local bool) { if pk.Ignore { return } @@ -935,13 +941,25 @@ func (s *Server) PublishToSubscribers(pk packets.Packet) { pk.Expiry = pk.Created + int64(pk.Properties.MessageExpiryInterval) } + sharedFilters := make(map[string]bool) subscribers := s.Topics.Subscribers(pk.TopicName) if len(subscribers.Shared) > 0 { subscribers = s.hooks.OnSelectSubscribers(subscribers, pk) if len(subscribers.SharedSelected) == 0 { subscribers.SelectShared() } + + // records shared subscriptions for different groups + for _, sub := range subscribers.SharedSelected { + sharedFilters[sub.Filter] = false + } + subscribers.MergeSharedSelected() + } else { + // no shared subscription, publish directly to the cluster + if !strings.HasPrefix(pk.TopicName, SysPrefix) && local { + s.hooks.OnPublishedWithSharedFilters(pk, sharedFilters) + } } for _, inlineSubscription := range subscribers.InlineSubscriptions { @@ -950,12 +968,23 @@ func (s *Server) PublishToSubscribers(pk packets.Packet) { for id, subs := range subscribers.Subscriptions { if cl, ok := s.Clients.Get(id); ok { - _, err := s.publishToClient(cl, subs, pk) - if err != nil { + if _, err := s.publishToClient(cl, subs, pk); err != nil { + if strings.HasPrefix(subs.Filter, "$share") { + sharedFilters[subs.Filter] = false + } s.Log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk) + } else { + if strings.HasPrefix(subs.Filter, "$share") { + sharedFilters[subs.Filter] = true + } } } } + + // publish results with local shared subscriptions + if len(sharedFilters) > 0 && local { + s.hooks.OnPublishedWithSharedFilters(pk, sharedFilters) + } } func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { @@ -1423,7 +1452,7 @@ func (s *Server) publishSysTopics() { pk.TopicName = topic pk.Payload = []byte(payload) s.Topics.RetainMessage(pk.Copy(false)) - s.PublishToSubscribers(pk) + s.publishToSubscribers(pk) } s.hooks.OnSysInfoTick(s.Info) @@ -1482,7 +1511,7 @@ func (s *Server) sendLWT(cl *Client) { s.retainMessage(cl, pk) } - s.PublishToSubscribers(pk) // [MQTT-3.1.2-8] + s.publishToSubscribers(pk) // [MQTT-3.1.2-8] atomic.StoreUint32(&cl.Properties.Will.Flag, 0) // [MQTT-3.1.2-10] s.hooks.OnWillSent(cl, pk) } @@ -1666,7 +1695,7 @@ func (s *Server) clearExpiredInflights(now int64) { func (s *Server) sendDelayedLWT(dt int64) { for id, pk := range s.loop.willDelayed.GetAll() { if dt > pk.Expiry { - s.PublishToSubscribers(pk) // [MQTT-3.1.2-8] + s.publishToSubscribers(pk) // [MQTT-3.1.2-8] if cl, ok := s.Clients.Get(id); ok { if pk.FixedHeader.Retain { s.retainMessage(cl, pk) diff --git a/mqtt/topics.go b/mqtt/topics.go index f9d122e..3a46f75 100644 --- a/mqtt/topics.go +++ b/mqtt/topics.go @@ -150,6 +150,17 @@ func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, o return val, ok } +// SubsInGroupLen returns the number of subscriptions in a shared subscription group. +func (s *SharedSubscriptions) SubsInGroupLen(group string) int { + s.RLock() + defer s.RUnlock() + n := 0 + if m, ok := s.internal[group]; ok { + n = len(m) + } + return n +} + // GroupLen returns the number of groups subscribed to the filter. func (s *SharedSubscriptions) GroupLen() int { s.RLock() @@ -410,7 +421,7 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription n := x.set(subscription.Filter, 2) _, existed = n.shared.Get(group, client) n.shared.Add(group, client, subscription) - count = n.shared.Len() + count = n.shared.SubsInGroupLen(group) } else { n := x.set(subscription.Filter, 0) _, existed = n.subscriptions.Get(client) diff --git a/plugin/bridge/kafka/kafka.go b/plugin/bridge/kafka/kafka.go index fdaf052..be4b627 100644 --- a/plugin/bridge/kafka/kafka.go +++ b/plugin/bridge/kafka/kafka.go @@ -54,6 +54,7 @@ type Message struct { ProtocolVersion byte `json:"protocolVersion,omitempty"` // mqtt protocol version of the client Clean bool `json:"clean,omitempty"` // if the client requested a clean start/session Timestamp int64 `json:"ts"` // event time + PacketID uint16 `json:"packetid,omitempty"` // the packet id } // MarshalBinary encodes the values into a json string. @@ -311,6 +312,7 @@ func (b *Bridge) OnPublished(cl *mqtt.Client, pk packets.Packet) { Topics: []string{pk.TopicName}, Payload: pk.Payload, Timestamp: timestamp, + PacketID: pk.PacketID, } data, err := msg.MarshalBinary() if err != nil {