Skip to content

Commit

Permalink
Merge pull request #8 from jxsl13/feat/player-count-channel-notification
Browse files Browse the repository at this point in the history
add sql for fetching
  • Loading branch information
jxsl13 authored Jan 22, 2024
2 parents 602d655 + 6f06e56 commit 93556d3
Show file tree
Hide file tree
Showing 30 changed files with 909 additions and 744 deletions.
116 changes: 114 additions & 2 deletions bot/async.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package bot

import (
"fmt"
"log"
"time"

"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/jxsl13/twstatus-bot/model"
)

Expand Down Expand Up @@ -64,7 +67,116 @@ loop:
log.Printf("goroutine %d: closed async goroutine for message updates", id)
}

func (b *Bot) cacheCleanup() {
func (b *Bot) notificationUpdater(id int) {
log.Printf("goroutine %d starting async goroutine for channel notifications", id)

loop:
for {
select {
case <-b.ctx.Done():
break loop
case notification, ok := <-b.n:
if !ok {
break loop
}
err := b.updateChannelNotification(notification)
if err != nil {
b.l.Errorf("goroutine %0d: failed to update channel notification %v: %v", id, notification, err)
}

}
}

log.Printf("goroutine %d: closed async goroutine for channel notifications", id)
}

func (b *Bot) updateChannelNotification(n model.PlayerCountNotificationMessage) (err error) {
dao, closer, err := b.TxDAO(b.ctx)
if err != nil {
return fmt.Errorf("failed to get transaction queries for channel notification: %w", err)
}
defer func() {
err = closer(err)
}()

// remove previous notification message if exists
if n.PrevMessageID != 0 {
// check if message still exists
err := b.state.DeleteMessage(
n.ChannelTarget.ChannelID,
n.PrevMessageID,
api.AuditLogReason("removing previous channel notification message"),
)
if err != nil && !ErrIsNotFound(err) {
b.l.Errorf("failed to delete previous notification message %s: %v", n.MessageTarget(n.PrevMessageID), err)
err = nil
}

// cleanup database if notification was deleted by some user/admin
err = dao.RemovePlayerCountNotificationMessage(b.ctx, n.ChannelTarget.ChannelID, n.PrevMessageID)
if err != nil {
return fmt.Errorf("failed to remove previous channel notification message from database: %w", err)
}
}

// remove all requests from database
for _, umt := range n.RemoveUserMessageReactions {
err = dao.RemovePlayerCountNotificationRequest(b.ctx, model.PlayerCountNotificationRequest{
MessageUserTarget: model.MessageUserTarget{
UserID: umt.UserID,
MessageTarget: model.MessageTarget{
ChannelTarget: model.ChannelTarget{
GuildID: n.ChannelTarget.GuildID,
ChannelID: n.ChannelTarget.ChannelID,
},
MessageID: umt.MessageID,
},
},
Threshold: umt.Threshold,
})
if err != nil {
return fmt.Errorf("failed to remove player count notification request from database: %w", err)
}
}

// delete all reactions from specified messages
for _, mt := range n.RemoveMessageReactions {
err = b.state.DeleteReactions(n.ChannelID, mt.MessageID, mt.Reaction())
if err != nil && !ErrIsNotFound(err) {
b.l.Errorf("failed to delete reaction %s from message %s: %v", mt.Reaction(), n.MessageTarget(mt.MessageID), err)
err = nil
}
}

mentionUsers := n.UserIDs
if len(n.UserIDs) > 100 {
// we do not expect more than 100 users to be mentioned anyway
mentionUsers = n.UserIDs[:100]
}

// send new message
msg, err := b.state.SendMessageComplex(n.ChannelTarget.ChannelID, api.SendMessageData{
Content: n.Format(),
Flags: discord.SuppressEmbeds,
AllowedMentions: &api.AllowedMentions{
Users: mentionUsers,
},
})
if err != nil {
return err
}

// update database to contain latest notification message for the current channel
err = dao.AddPlayerCountNotificationMessage(b.ctx, msg.ChannelID, msg.ID)
if err != nil {
return err
}

return nil
}

func (b *Bot) cacheCleanup(id int) {
log.Printf("goroutine %d starting async goroutine for cache cleanup", id)
var (
cleanupInterval = 20 * b.pollingInterval
timer = time.NewTimer(cleanupInterval)
Expand Down Expand Up @@ -96,7 +208,7 @@ func (b *Bot) cacheCleanup() {
log.Printf("cache contains %d entries after cleanup at %s", b.conflictMap.Size(), now)

case <-b.ctx.Done():
log.Println("closed async goroutine for cache cleanup")
log.Printf("goroutine %d: closed async goroutine for cache cleanup", id)
return
}
}
Expand Down
64 changes: 23 additions & 41 deletions bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,6 @@ const (
channelOptionName = "channel"
)

var (
reactionPlayerCountNotificationMap = map[discord.APIEmoji]int{
discord.NewAPIEmoji(0, "1️⃣"): 1,
discord.NewAPIEmoji(0, "2️⃣"): 2,
discord.NewAPIEmoji(0, "3️⃣"): 3,
discord.NewAPIEmoji(0, "4️⃣"): 4,
discord.NewAPIEmoji(0, "5️⃣"): 5,
discord.NewAPIEmoji(0, "6️⃣"): 6,
discord.NewAPIEmoji(0, "7️⃣"): 7,
discord.NewAPIEmoji(0, "8️⃣"): 8,
discord.NewAPIEmoji(0, "9️⃣"): 9,
discord.NewAPIEmoji(0, "🔟"): 10,
}
reactionPlayerCountNotificationReverseMap = map[int]discord.APIEmoji{
1: discord.NewAPIEmoji(0, "1️⃣"),
2: discord.NewAPIEmoji(0, "2️⃣"),
3: discord.NewAPIEmoji(0, "3️⃣"),
4: discord.NewAPIEmoji(0, "4️⃣"),
5: discord.NewAPIEmoji(0, "5️⃣"),
6: discord.NewAPIEmoji(0, "6️⃣"),
7: discord.NewAPIEmoji(0, "7️⃣"),
8: discord.NewAPIEmoji(0, "8️⃣"),
9: discord.NewAPIEmoji(0, "9️⃣"),
10: discord.NewAPIEmoji(0, "🔟"),
}
)

var ownerCommandList = []api.CreateCommandData{
{
Name: "list-guilds",
Expand Down Expand Up @@ -287,6 +260,7 @@ type Bot struct {
channelID discord.ChannelID
userID discord.UserID
c chan model.ChangedServerStatus
n chan model.PlayerCountNotificationMessage
pollingInterval time.Duration
conflictMap *xsync.MapOf[model.MessageTarget, Backoff]
l *logging.Logger
Expand Down Expand Up @@ -318,6 +292,7 @@ func New(
superAdmins: superAdmins,
useEmbeds: !legacyMessageFormat,
c: make(chan model.ChangedServerStatus, 1024),
n: make(chan model.PlayerCountNotificationMessage, 1024),
conflictMap: xsync.NewMapOf[model.MessageTarget, Backoff](),
pollingInterval: pollingInterval,
guildID: guildID,
Expand Down Expand Up @@ -357,11 +332,19 @@ func New(
log.Fatalf("failed to synchronize database with discord state: %v", err)
}

routines := 1

// start polling
go bot.cacheCleanup()
go bot.cacheCleanup(routines)
go bot.serverUpdater(pollingInterval)
for i := 0; i < max(2*runtime.NumCPU(), 5); i++ {
go bot.messageUpdater(i + 1)
routines++
go bot.messageUpdater(routines)
}

for i := 0; i < max(runtime.NumCPU(), 3); i++ {
routines++
go bot.notificationUpdater(routines)
}
})
})
Expand Down Expand Up @@ -452,15 +435,15 @@ func (b *Bot) TxDAO(ctx context.Context) (d *dao.DAO, closer func(error) error,
if err != nil {
return nil, nil, err
}
return dao.NewDAO(sqlc.New(tx)), closer, nil
return dao.NewDAO(sqlc.New(tx), b.l), closer, nil
}

func (b *Bot) ConnDAO(ctx context.Context) (d *dao.DAO, closer func(), err error) {
c, f, err := b.db.Conn(ctx)
if err != nil {
return nil, nil, err
}
return dao.NewDAO(sqlc.New(c)), f, nil
return dao.NewDAO(sqlc.New(c), b.l), f, nil
}

func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
Expand All @@ -469,7 +452,7 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
err = closer(err)
}()

err = dao.RemovePlayerCountNotifications(ctx)
err = dao.RemovePlayerCountNotificationRequests(ctx)
if err != nil {
return err
}
Expand All @@ -479,8 +462,7 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
return err
}

//msgs := make([]*discord.Message, 0, len(trackings))
notifications := make(map[model.MessageUserTarget]model.PlayerCountNotification)
notifications := make(map[model.MessageUserTarget]model.PlayerCountNotificationRequest)

for _, t := range trackings {
log.Printf("fetching message %s for notification tracking", t.MessageTarget)
Expand All @@ -500,7 +482,7 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
// iterate over all message reactions
for _, reaction := range m.Reactions {
emoji := reaction.Emoji.APIString()
if _, ok := reactionPlayerCountNotificationMap[emoji]; !ok {
if _, ok := model.ReactionPlayerCountNotificationMap[emoji]; !ok {
// none of the ones that we want to look at
continue
}
Expand All @@ -512,7 +494,7 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
}
return err
}
val := reactionPlayerCountNotificationMap[emoji]
val := model.ReactionPlayerCountNotificationMap[emoji]

log.Printf("found %d users for emoji %s of message %s", len(users), emoji, t.MessageTarget)
for _, user := range users {
Expand All @@ -528,7 +510,7 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
n.ChannelID,
n.MessageID,
n.UserID,
reactionPlayerCountNotificationReverseMap[n.Threshold],
model.ReactionPlayerCountNotificationReverseMap[n.Threshold],
)
if err != nil {
return err
Expand All @@ -543,14 +525,14 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
n.ChannelID,
n.MessageID,
n.UserID,
reactionPlayerCountNotificationReverseMap[val],
model.ReactionPlayerCountNotificationReverseMap[val],
)
if err != nil {
return err
}
}
} else {
notifications[userTarget] = model.PlayerCountNotification{
notifications[userTarget] = model.PlayerCountNotificationRequest{
MessageUserTarget: userTarget,
Threshold: val,
}
Expand All @@ -560,9 +542,9 @@ func (b *Bot) syncDatabaseState(ctx context.Context) (err error) {
}

values := utils.Values(notifications)
sort.Sort(model.ByPlayerCountNotificationIDs(values))
sort.Sort(model.ByPlayerCountNotificationRequestIDs(values))

err = dao.SetPlayerCountNotificationList(ctx, values)
err = dao.SetPlayerCountNotificationRequestList(ctx, values)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion bot/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (b *Bot) logWriter() {
Flags: discord.SuppressEmbeds,
})
if err != nil {
b.l.Errorf("failed to send log message: %v", err)
b.l.Errorf("failed to send log message to %s: %v", b.channelID, err)
continue
}
case <-b.ctx.Done():
Expand Down
12 changes: 9 additions & 3 deletions bot/message_deleter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ import (
)

func (b *Bot) handleMessageDeletion(e *gateway.MessageDeleteEvent) {
dao, closer, err := b.ConnDAO(b.ctx)
dao, closer, err := b.TxDAO(b.ctx)
if err != nil {
b.l.Errorf("failed to get connection queries for message deletion: %v", err)
b.l.Errorf("failed to get transaction dao for message deletion: %v", err)
return
}
defer closer()
defer func() {
err = closer(err)
if err != nil {
b.l.Errorf("failed to close transaction dao for message deletion: %v", err)
}
}()

// delete tracking messages from db in case someone deletes any message
err = dao.RemoveTrackingByMessageID(b.ctx, e.GuildID, e.ID)
if err != nil {
b.l.Errorf("failed to remove tracking of guild %s and message id: %s: %v", e.GuildID, e.ID, err)
}

}
Loading

0 comments on commit 93556d3

Please sign in to comment.