diff --git a/ee/desktop/runner/runner.go b/ee/desktop/runner/runner.go index dc961cc29..dfcd8d7c5 100644 --- a/ee/desktop/runner/runner.go +++ b/ee/desktop/runner/runner.go @@ -108,7 +108,8 @@ type DesktopUsersProcessesRunner struct { menuRefreshInterval time.Duration interrupt chan struct{} // uidProcs is a map of uid to desktop process - uidProcs map[string]processRecord + uidProcs map[string]processRecord + uidProcsMutex sync.Mutex // procsWg is a WaitGroup to wait for all desktop processes to finish during an interrupt procsWg *sync.WaitGroup // interruptTimeout how long to wait for desktop proccesses to finish on interrupt @@ -150,8 +151,9 @@ type processRecord struct { func New(k types.Knapsack, opts ...desktopUsersProcessesRunnerOption) (*DesktopUsersProcessesRunner, error) { runner := &DesktopUsersProcessesRunner{ logger: log.NewNopLogger(), - interrupt: make(chan struct{}), + interrupt: make(chan struct{}, 1), uidProcs: make(map[string]processRecord), + uidProcsMutex: sync.Mutex{}, updateInterval: k.DesktopUpdateInterval(), menuRefreshInterval: k.DesktopMenuRefreshInterval(), procsWg: &sync.WaitGroup{}, @@ -222,8 +224,15 @@ func (r *DesktopUsersProcessesRunner) Execute() error { // Interrupt stops creating launcher desktop processes and kills any existing ones. // It also signals the execute loop to exit, so new desktop processes cease to spawn. func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { - // Tell the execute loop to stop checking, and exit - r.interrupt <- struct{}{} + // Non-blocking send to interrupt channel + select { + case r.interrupt <- struct{}{}: + // First time we've received an interrupt, so we've notified r.Execute. + default: + // Execute loop is no longer running, so there's nothing to interrupt + } + + time.Sleep(3 * time.Second) // Kill any desktop processes that may exist r.killDesktopProcesses() @@ -241,6 +250,9 @@ func (r *DesktopUsersProcessesRunner) Interrupt(_ error) { // killDesktopProcesses kills any existing desktop processes func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { + r.uidProcsMutex.Lock() + defer r.uidProcsMutex.Unlock() + wgDone := make(chan struct{}) go func() { defer close(wgDone) @@ -297,6 +309,9 @@ func (r *DesktopUsersProcessesRunner) killDesktopProcesses() { } func (r *DesktopUsersProcessesRunner) SendNotification(n notify.Notification) error { + r.uidProcsMutex.Lock() + defer r.uidProcsMutex.Unlock() + if len(r.uidProcs) == 0 { return errors.New("cannot send notification, no child desktop processes") } @@ -380,6 +395,8 @@ func (r *DesktopUsersProcessesRunner) refreshMenu() { } // Tell any running desktop user processes that they should refresh the latest menu data + r.uidProcsMutex.Lock() + defer r.uidProcsMutex.Unlock() for uid, proc := range r.uidProcs { client := client.New(r.userServerAuthToken, proc.socketPath) if err := client.Refresh(); err != nil { @@ -531,6 +548,9 @@ func (r *DesktopUsersProcessesRunner) runConsoleUserDesktop() error { // addProcessTrackingRecordForUser adds process information to the internal tracking state func (r *DesktopUsersProcessesRunner) addProcessTrackingRecordForUser(uid string, socketPath string, osProcess *os.Process) error { + r.uidProcsMutex.Lock() + defer r.uidProcsMutex.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() @@ -592,6 +612,9 @@ func (r *DesktopUsersProcessesRunner) determineExecutablePath() (string, error) } func (r *DesktopUsersProcessesRunner) userHasDesktopProcess(uid string) bool { + r.uidProcsMutex.Lock() + defer r.uidProcsMutex.Unlock() + // have no record of process proc, ok := r.uidProcs[uid] if !ok { diff --git a/ee/desktop/runner/runner_test.go b/ee/desktop/runner/runner_test.go index 36b1f54e6..2b934cb04 100644 --- a/ee/desktop/runner/runner_test.go +++ b/ee/desktop/runner/runner_test.go @@ -145,7 +145,7 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) { assert.NoError(t, r.Execute()) }() - // let is run a few interval + // let it run a few intervals time.Sleep(r.updateInterval * 3) r.Interrupt(nil) @@ -185,6 +185,34 @@ func TestDesktopUserProcessRunner_Execute(t *testing.T) { p.Process.Wait() } }) + + // Confirm we can call Interrupt multiple times without blocking + interruptComplete := make(chan struct{}) + expectedInterrupts := 3 + for i := 0; i < expectedInterrupts; i += 1 { + go func() { + r.Interrupt(nil) + interruptComplete <- struct{}{} + }() + } + + receivedInterrupts := 0 + for { + if receivedInterrupts >= expectedInterrupts { + break + } + + select { + case <-interruptComplete: + receivedInterrupts += 1 + continue + case <-time.After(5 * time.Second): + t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts) + t.FailNow() + } + } + + require.Equal(t, expectedInterrupts, receivedInterrupts) }) } } diff --git a/pkg/log/checkpoint/checkpoint.go b/pkg/log/checkpoint/checkpoint.go index c3765ab83..2fdc4caf1 100644 --- a/pkg/log/checkpoint/checkpoint.go +++ b/pkg/log/checkpoint/checkpoint.go @@ -97,7 +97,12 @@ func (c *checkPointer) Run() error { } func (c *checkPointer) Interrupt(_ error) { - c.interrupt <- struct{}{} + // Non-blocking channel send + select { + case c.interrupt <- struct{}{}: + default: + level.Debug(c.logger).Log("msg", "received additional call to interrupt") + } } func (c *checkPointer) Once() { diff --git a/pkg/log/checkpoint/checkpoint_test.go b/pkg/log/checkpoint/checkpoint_test.go index 6d0c06fb0..273391701 100644 --- a/pkg/log/checkpoint/checkpoint_test.go +++ b/pkg/log/checkpoint/checkpoint_test.go @@ -2,12 +2,102 @@ package checkpoint import ( "net/url" + "os" "reflect" "testing" + "time" + "github.com/go-kit/kit/log" "github.com/kolide/launcher/pkg/agent/types/mocks" + "github.com/stretchr/testify/require" + "go.etcd.io/bbolt" ) +func TestInterrupt(t *testing.T) { + t.Parallel() + + // Set up temp db + file, err := os.CreateTemp("", "kolide_launcher_test") + if err != nil { + t.Fatalf("creating temp file: %s", err.Error()) + } + + db, err := bbolt.Open(file.Name(), 0600, nil) + if err != nil { + t.Fatalf("opening bolt DB: %s", err.Error()) + } + + defer func() { + db.Close() + os.Remove(file.Name()) + }() + + // Set up knapsack + k := mocks.NewKnapsack(t) + k.On("BboltDB").Return(db) + k.On("KolideHosted").Return(true) + k.On("InModernStandby").Return(false).Maybe() + k.On("KolideServerURL").Return("") + k.On("InsecureTransportTLS").Return(false) + k.On("Autoupdate").Return(true) + k.On("MirrorServerURL").Return("") + k.On("NotaryServerURL").Return("") + k.On("TufServerURL").Return("") + k.On("ControlServerURL").Return("") + + // Start the checkpointer, let it run, interrupt it, and confirm it can return from the interrupt + testCheckpointer := New(log.NewNopLogger(), k) + + runInterruptReceived := make(chan struct{}, 1) + + go func() { + require.Nil(t, testCheckpointer.Run()) + runInterruptReceived <- struct{}{} + }() + + // Give it a couple seconds to run before calling interrupt + time.Sleep(3 * time.Second) + + testCheckpointer.Interrupt(nil) + + select { + case <-runInterruptReceived: + break + case <-time.After(5 * time.Second): + t.Error("could not interrupt checkpointer within 5 seconds") + t.FailNow() + } + + // Now call interrupt a couple more times + expectedAdditionalInterrupts := 3 + additionalInterruptsReceived := make(chan struct{}, expectedAdditionalInterrupts) + + for i := 0; i < expectedAdditionalInterrupts; i += 1 { + go func() { + testCheckpointer.Interrupt(nil) + additionalInterruptsReceived <- struct{}{} + }() + } + + receivedInterrupts := 0 + for { + if receivedInterrupts >= expectedAdditionalInterrupts { + break + } + + select { + case <-additionalInterruptsReceived: + receivedInterrupts += 1 + continue + case <-time.After(5 * time.Second): + t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts) + t.FailNow() + } + } + + require.Equal(t, expectedAdditionalInterrupts, receivedInterrupts) +} + func Test_urlsToTest(t *testing.T) { t.Parallel()