Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it safe to call rungroup actor Interrupt multiple times #1342

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions ee/desktop/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ type DesktopUsersProcessesRunner struct {
menuRefreshInterval time.Duration
interrupt chan struct{}
// uidProcs is a map of uid to desktop process
uidProcs map[string]processRecord
uidProcs map[string]processRecord
uidProcsMutex sync.Mutex
// procsWg is a WaitGroup to wait for all desktop processes to finish during an interrupt
procsWg *sync.WaitGroup
// interruptTimeout how long to wait for desktop proccesses to finish on interrupt
Expand Down Expand Up @@ -150,8 +151,9 @@ type processRecord struct {
func New(k types.Knapsack, opts ...desktopUsersProcessesRunnerOption) (*DesktopUsersProcessesRunner, error) {
runner := &DesktopUsersProcessesRunner{
logger: log.NewNopLogger(),
interrupt: make(chan struct{}),
interrupt: make(chan struct{}, 1),
uidProcs: make(map[string]processRecord),
uidProcsMutex: sync.Mutex{},
updateInterval: k.DesktopUpdateInterval(),
menuRefreshInterval: k.DesktopMenuRefreshInterval(),
procsWg: &sync.WaitGroup{},
Expand Down Expand Up @@ -222,8 +224,15 @@ func (r *DesktopUsersProcessesRunner) Execute() error {
// Interrupt stops creating launcher desktop processes and kills any existing ones.
// It also signals the execute loop to exit, so new desktop processes cease to spawn.
func (r *DesktopUsersProcessesRunner) Interrupt(_ error) {
// Tell the execute loop to stop checking, and exit
r.interrupt <- struct{}{}
// Non-blocking send to interrupt channel
select {
case r.interrupt <- struct{}{}:
// First time we've received an interrupt, so we've notified r.Execute.
default:
// Execute loop is no longer running, so there's nothing to interrupt
directionless marked this conversation as resolved.
Show resolved Hide resolved
}

time.Sleep(3 * time.Second)

// Kill any desktop processes that may exist
r.killDesktopProcesses()
Expand All @@ -241,6 +250,9 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) {

// killDesktopProcesses kills any existing desktop processes
func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
r.uidProcsMutex.Lock()
defer r.uidProcsMutex.Unlock()

wgDone := make(chan struct{})
go func() {
defer close(wgDone)
Expand Down Expand Up @@ -297,6 +309,9 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() {
}

func (r *DesktopUsersProcessesRunner) SendNotification(n notify.Notification) error {
r.uidProcsMutex.Lock()
defer r.uidProcsMutex.Unlock()

if len(r.uidProcs) == 0 {
return errors.New("cannot send notification, no child desktop processes")
}
Expand Down Expand Up @@ -380,6 +395,8 @@ func (r *DesktopUsersProcessesRunner) refreshMenu() {
}

// Tell any running desktop user processes that they should refresh the latest menu data
r.uidProcsMutex.Lock()
defer r.uidProcsMutex.Unlock()
for uid, proc := range r.uidProcs {
client := client.New(r.userServerAuthToken, proc.socketPath)
if err := client.Refresh(); err != nil {
Expand Down Expand Up @@ -531,6 +548,9 @@ func (r *DesktopUsersProcessesRunner) runConsoleUserDesktop() error {

// addProcessTrackingRecordForUser adds process information to the internal tracking state
func (r *DesktopUsersProcessesRunner) addProcessTrackingRecordForUser(uid string, socketPath string, osProcess *os.Process) error {
r.uidProcsMutex.Lock()
defer r.uidProcsMutex.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()

Expand Down Expand Up @@ -592,6 +612,9 @@ func (r *DesktopUsersProcessesRunner) determineExecutablePath() (string, error)
}

func (r *DesktopUsersProcessesRunner) userHasDesktopProcess(uid string) bool {
r.uidProcsMutex.Lock()
defer r.uidProcsMutex.Unlock()

// have no record of process
proc, ok := r.uidProcs[uid]
if !ok {
Expand Down
30 changes: 29 additions & 1 deletion ee/desktop/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) {
assert.NoError(t, r.Execute())
}()

// let is run a few interval
// let it run a few intervals
time.Sleep(r.updateInterval * 3)
r.Interrupt(nil)

Expand Down Expand Up @@ -185,6 +185,34 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) {
p.Process.Wait()
}
})

// Confirm we can call Interrupt multiple times without blocking
interruptComplete := make(chan struct{})
expectedInterrupts := 3
for i := 0; i < expectedInterrupts; i += 1 {
go func() {
r.Interrupt(nil)
interruptComplete <- struct{}{}
}()
}

receivedInterrupts := 0
for {
if receivedInterrupts >= expectedInterrupts {
break
}

select {
case <-interruptComplete:
receivedInterrupts += 1
continue
case <-time.After(5 * time.Second):
t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts)
t.FailNow()
}
}

require.Equal(t, expectedInterrupts, receivedInterrupts)
})
}
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/log/checkpoint/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ func (c *checkPointer) Run() error {
}

func (c *checkPointer) Interrupt(_ error) {
c.interrupt <- struct{}{}
// Non-blocking channel send
select {
case c.interrupt <- struct{}{}:
default:
level.Debug(c.logger).Log("msg", "received additional call to interrupt")
}
}

func (c *checkPointer) Once() {
Expand Down
90 changes: 90 additions & 0 deletions pkg/log/checkpoint/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,102 @@ package checkpoint

import (
"net/url"
"os"
"reflect"
"testing"
"time"

"github.com/go-kit/kit/log"
"github.com/kolide/launcher/pkg/agent/types/mocks"
"github.com/stretchr/testify/require"
"go.etcd.io/bbolt"
)

func TestInterrupt(t *testing.T) {
t.Parallel()

// Set up temp db
file, err := os.CreateTemp("", "kolide_launcher_test")
if err != nil {
t.Fatalf("creating temp file: %s", err.Error())
}

db, err := bbolt.Open(file.Name(), 0600, nil)
if err != nil {
t.Fatalf("opening bolt DB: %s", err.Error())
}

defer func() {
db.Close()
os.Remove(file.Name())
}()

// Set up knapsack
k := mocks.NewKnapsack(t)
k.On("BboltDB").Return(db)
k.On("KolideHosted").Return(true)
k.On("InModernStandby").Return(false).Maybe()
k.On("KolideServerURL").Return("")
k.On("InsecureTransportTLS").Return(false)
k.On("Autoupdate").Return(true)
k.On("MirrorServerURL").Return("")
k.On("NotaryServerURL").Return("")
k.On("TufServerURL").Return("")
k.On("ControlServerURL").Return("")

// Start the checkpointer, let it run, interrupt it, and confirm it can return from the interrupt
testCheckpointer := New(log.NewNopLogger(), k)

runInterruptReceived := make(chan struct{}, 1)

go func() {
require.Nil(t, testCheckpointer.Run())
runInterruptReceived <- struct{}{}
}()

// Give it a couple seconds to run before calling interrupt
time.Sleep(3 * time.Second)

testCheckpointer.Interrupt(nil)

select {
case <-runInterruptReceived:
break
case <-time.After(5 * time.Second):
t.Error("could not interrupt checkpointer within 5 seconds")
t.FailNow()
}

// Now call interrupt a couple more times
expectedAdditionalInterrupts := 3
additionalInterruptsReceived := make(chan struct{}, expectedAdditionalInterrupts)

for i := 0; i < expectedAdditionalInterrupts; i += 1 {
go func() {
testCheckpointer.Interrupt(nil)
additionalInterruptsReceived <- struct{}{}
}()
}

receivedInterrupts := 0
for {
if receivedInterrupts >= expectedAdditionalInterrupts {
break
}

select {
case <-additionalInterruptsReceived:
receivedInterrupts += 1
continue
case <-time.After(5 * time.Second):
t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts)
t.FailNow()
}
}

require.Equal(t, expectedAdditionalInterrupts, receivedInterrupts)
}

func Test_urlsToTest(t *testing.T) {
t.Parallel()

Expand Down