Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updater.StopBot should close polling context #211

Merged
merged 4 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions ext/botmapping.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ext

import (
"context"
"encoding/json"
"errors"
"io"
Expand All @@ -19,6 +20,8 @@ type botData struct {

// updateChan represents the incoming updates channel.
updateChan chan json.RawMessage
// pollingContextCloser allows one to stop polling instantly without waiting for the return
pollingContextCloser context.CancelFunc
// updateWriterControl is used to count the number of current writers on the update channel.
// This is required to ensure that we can safely close the channel, and thus stop processing incoming updates.
// While this remains non-zero, it is unsafe to close the update channel.
Expand Down Expand Up @@ -47,12 +50,22 @@ type botMapping struct {
errorLog *log.Logger
}

var ErrBotAlreadyExists = errors.New("bot already exists in bot mapping")
var ErrBotUrlPathAlreadyExists = errors.New("url path already exists in bot mapping")
var (
ErrBotAlreadyExists = errors.New("bot already exists in bot mapping")
ErrBotUrlPathAlreadyExists = errors.New("url path already exists in bot mapping")
)

func (m *botMapping) addWebhookBot(b *gotgbot.Bot, urlPath string, webhookSecret string) (*botData, error) {
return m.addBot(b, urlPath, webhookSecret, nil)
}

func (m *botMapping) addPollingBot(b *gotgbot.Bot, ctxClose context.CancelFunc) (*botData, error) {
return m.addBot(b, "", "", ctxClose)
}

// addBot Adds a new bot to the botMapping structure.
// Pass an empty urlPath/webhookSecret if using polling instead of webhooks.
func (m *botMapping) addBot(b *gotgbot.Bot, urlPath string, webhookSecret string) (*botData, error) {
func (m *botMapping) addBot(b *gotgbot.Bot, urlPath string, webhookSecret string, ctxClose context.CancelFunc) (*botData, error) {
// Clean up the URLPath such that it remains consistent.
urlPath = strings.TrimPrefix(urlPath, "/")

Expand All @@ -75,12 +88,13 @@ func (m *botMapping) addBot(b *gotgbot.Bot, urlPath string, webhookSecret string
}

bData := botData{
bot: b,
updateChan: make(chan json.RawMessage),
stopUpdates: make(chan struct{}),
updateWriterControl: &sync.WaitGroup{},
urlPath: urlPath,
webhookSecret: webhookSecret,
bot: b,
updateChan: make(chan json.RawMessage),
pollingContextCloser: ctxClose,
stopUpdates: make(chan struct{}),
updateWriterControl: &sync.WaitGroup{},
urlPath: urlPath,
webhookSecret: webhookSecret,
}

m.mapping[bData.bot.Token] = bData
Expand Down Expand Up @@ -208,6 +222,11 @@ func (b *botData) stop() {
close(b.stopUpdates)
}

// If we have a context to close, close it. This will stop polling immediately.
if b.pollingContextCloser != nil {
b.pollingContextCloser()
}

// Wait for all writers to finish writing to the updateChannel
b.updateWriterControl.Wait()

Expand Down
12 changes: 9 additions & 3 deletions ext/botmapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func Test_botMapping(t *testing.T) {
t.Run("addBot", func(t *testing.T) {
// check that bots can be added fine
var err error
origBdata, err = bm.addBot(b, "", "")
origBdata, err = bm.addPollingBot(b, nil)
if err != nil {
t.Errorf("expected to be able to add a new bot fine: %s", err.Error())
t.FailNow()
Expand All @@ -31,7 +31,7 @@ func Test_botMapping(t *testing.T) {

t.Run("doubleAdd", func(t *testing.T) {
// Adding the same bot twice should fail
_, err := bm.addBot(b, "", "")
_, err := bm.addPollingBot(b, nil)
if err == nil {
t.Errorf("adding the same bot twice should throw an error")
t.FailNow()
Expand Down Expand Up @@ -84,7 +84,10 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) {
BotClient: &gotgbot.BaseBotClient{},
}

bData, err := bm.addBot(b, "", "")
ctxCancelled := false
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was suspecting this test might fail if executed with -race, but the different one failed instead.

Note: this possibly does not impact the bot's behaviour in real environment, just a quirk of how the test is accessing the data.

[9:15:47] 0:_thirdparty/gotgbot/ext (paul/close-polling-ctx)> go test -race ./...
==================
WARNING: DATA RACE
Write at 0x00c000214098 by goroutine 219:
  runtime.racewrite()
      <autogenerated>:1 +0x1e
  github.com/PaulSonOfLars/gotgbot/v2/ext.(*botData).stop()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/botmapping.go:231 +0xbd
  github.com/PaulSonOfLars/gotgbot/v2/ext.(*Updater).StopAllBots()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/updater.go:288 +0xa4
  github.com/PaulSonOfLars/gotgbot/v2/ext.(*Updater).Stop()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/updater.go:264 +0x10d
  github.com/PaulSonOfLars/gotgbot/v2/ext_test.TestUpdaterSupportsTwoPollingBots()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/updater_test.go:434 +0xa52
  testing.tRunner()
      /usr/local/Cellar/go/1.23.4/libexec/src/testing/testing.go:1690 +0x226
  testing.(*T).Run.gowrap1()
      /usr/local/Cellar/go/1.23.4/libexec/src/testing/testing.go:1743 +0x44

Previous read at 0x00c000214098 by goroutine 222:
  runtime.raceread()
      <autogenerated>:1 +0x1e
  github.com/PaulSonOfLars/gotgbot/v2/ext.(*Updater).pollingLoop()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/updater.go:174 +0x8d
  github.com/PaulSonOfLars/gotgbot/v2/ext.(*Updater).StartPolling.gowrap2()
      /Users/rustam/wp/_thirdparty/gotgbot/ext/updater.go:168 +0x79
<... elided for brevity ...>
--- FAIL: TestUpdaterSupportsTwoPollingBots (0.00s)
    testing.go:1399: race detected during execution of test

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're unable to get it to fail, so running like this helps:

go test ./... -race -count=10

takes a long time tho to run.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - theres definitely something funky going on there, not sure what exactly. Will take a look. Might be a different PR, depending on how long it takes! Thanks for the review and the report <3

bData, err := bm.addPollingBot(b, func() {
ctxCancelled = true
})
if err != nil {
t.Errorf("bot with token %s should not have failed to be added", b.Token)
return
Expand All @@ -99,4 +102,7 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) {
t.Errorf("bot with token %s should be stopped", b.Token)
return
}
if !ctxCancelled {
t.Errorf("bot with token %s should have a cancelled context ", b.Token)
}
}
29 changes: 20 additions & 9 deletions ext/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,27 @@ func (u *Updater) StartPolling(b *gotgbot.Bot, opts *PollingOpts) error {
}
}

bData, err := u.botMapping.addBot(b, "", "")
ctx, closeFn := context.WithCancel(context.Background())
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could reuse the opts logic to pass in an updater from #210 here - might be a good way to handle signals?

Though the signals approach would only close the updater, not the dispatcher; so might be best to handle closure explitly?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach is much better than in #210 - the library should not be concerned about any of the OS signals — it just needs to know one thing: "something wants the updater to stop". And technically it doesn't need the Context, as it's not going to use it for anything else but only to know when it's cancelled.

Handing specific OS signals and termination is a caller's concern, for example, that's how the echoBot could handle the os.Interrupt signal and tell the bot to terminate — works great on this branch:

diff --git a/samples/echoBot/main.go b/samples/echoBot/main.go
index bdffc88..2f390b1 100644
--- a/samples/echoBot/main.go
+++ b/samples/echoBot/main.go
@@ -1,10 +1,12 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"log"
 	"net/http"
 	"os"
+	"os/signal"
 	"time"
 
 	"github.com/PaulSonOfLars/gotgbot/v2"
@@ -34,6 +36,8 @@ func main() {
 	if err != nil {
 		panic("failed to create new bot: " + err.Error())
 	}
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
+	defer cancel()
 
 	// Create updater and dispatcher.
 	dispatcher := ext.NewDispatcher(&ext.DispatcherOpts{
@@ -45,6 +49,11 @@ func main() {
 		MaxRoutines: ext.DefaultMaxRoutines,
 	})
 	updater := ext.NewUpdater(dispatcher, nil)
+	go func() {
+		<-ctx.Done()
+		log.Println("shutting down...")
+		updater.Stop()
+	}()
 
 	// Add echo handler to reply to all text messages.
 	dispatcher.AddHandler(handlers.NewMessage(message.Text, echo))


bData, err := u.botMapping.addPollingBot(b, closeFn)
if err != nil {
return fmt.Errorf("failed to add bot with long polling: %w", err)
}

go u.Dispatcher.Start(b, bData.updateChan)
go u.pollingLoop(bData, reqOpts, v)

bData.updateWriterControl.Add(1)
go func() {
// defer, so it gets called even in case of panics.
defer bData.updateWriterControl.Done()

u.pollingLoop(ctx, bData, reqOpts, v)
}()

return nil
}

func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[string]string) {
bData.updateWriterControl.Add(1)
defer bData.updateWriterControl.Done()

func (u *Updater) pollingLoop(ctx context.Context, bData *botData, opts *gotgbot.RequestOpts, v map[string]string) {
Comment on lines -166 to +180
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rusq This should fix it! An edge case for sure, but I can see this being triggered in live systems which are dynamically adding/removing bots.

TL;DR: Race condition caused by the waitgroup being started in a goroutine, rather than before it.

This meant that when the stars (mis)aligned, calling updater.StartPolling(...) followed by updater.Stop(), could result in .Wait() (from the .Stop()) before calling .Add(1) (from the goroutine in .StartPolling); which then causes the race.

The fix makes sure that the .Add is done sequentially, thus will always run before the .Wait.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concurrency bugs are always the hardest to debug, great break down!

for {
// Check if updater loop has been terminated.
if bData.shouldStopUpdates() {
Expand All @@ -180,8 +186,13 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s

// Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and
// unnecessary reallocation of url.Values in the polling loop.
r, err := bData.bot.Request("getUpdates", v, nil, opts)
r, err := bData.bot.RequestWithContext(ctx, "getUpdates", v, nil, opts)
if err != nil {
if errors.Is(err, context.Canceled) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful

// context cancelled; means the bot was stopped gracefully through Updater.StopBot
return
}

if u.UnhandledErrFunc != nil {
u.UnhandledErrFunc(err)
} else {
Expand Down Expand Up @@ -259,7 +270,7 @@ func (u *Updater) Stop() error {
// Stop the dispatcher from processing any further updates.
u.Dispatcher.Stop()

// Finally, atop idling.
// Finally, stop idling.
if u.stopIdling != nil {
close(u.stopIdling)
}
Expand Down Expand Up @@ -317,7 +328,7 @@ func (u *Updater) AddWebhook(b *gotgbot.Bot, urlPath string, opts *AddWebhookOpt
secretToken = opts.SecretToken
}

bData, err := u.botMapping.addBot(b, urlPath, secretToken)
bData, err := u.botMapping.addWebhookBot(b, urlPath, secretToken)
if err != nil {
return fmt.Errorf("failed to add webhook for bot: %w", err)
}
Expand Down
49 changes: 48 additions & 1 deletion ext/updater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,54 @@ func concurrentTest(t *testing.T) {
time.Sleep(delay * 2)
}

func TestUpdater_StopBot(t *testing.T) {
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {
delay: time.Second * 3, // server close will take 3s
reply: `{"ok": true, "result": []}`,
},
})
defer server.Close()

reqOpts := &gotgbot.RequestOpts{
APIURL: server.URL,
Timeout: time.Second * 4,
}

b := &gotgbot.Bot{
User: gotgbot.User{},
Token: "SOME_TOKEN",
BotClient: &gotgbot.BaseBotClient{},
}

d := ext.NewDispatcher(&ext.DispatcherOpts{MaxRoutines: 1})
u := ext.NewUpdater(d, nil)

err := u.StartPolling(b, &ext.PollingOpts{
GetUpdatesOpts: &gotgbot.GetUpdatesOpts{
RequestOpts: reqOpts,
},
})
if err != nil {
t.Errorf("failed to start polling: %v", err)
return
}

// sleep to ensure polling is fully underway
time.Sleep(time.Millisecond * 500)

// Should still have lots of time until timeout, so... kill the bots.
start := time.Now()
u.StopAllBots()

// If it took longer than 1ms then we know something went wrong; ctx should've stopped immediately.
since := time.Since(start)
t.Logf("stopping took %dms", since.Milliseconds())
if since > 5*time.Millisecond {
t.Errorf("stopping all bots took %dms; shouldve taken less than 1ms", since.Milliseconds())
}
}

func TestUpdaterDisallowsEmptyWebhooks(t *testing.T) {
b := &gotgbot.Bot{
Token: "SOME_TOKEN",
Expand Down Expand Up @@ -348,7 +396,6 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
b1 := &gotgbot.Bot{
Token: "SOME_TOKEN",
BotClient: &gotgbot.BaseBotClient{

DefaultRequestOpts: reqOpts,
},
}
Expand Down
Loading