diff --git a/resource/systemd/unit/dbus_linux.go b/resource/systemd/unit/dbus_linux.go index 897e864e0..f6a5c053a 100644 --- a/resource/systemd/unit/dbus_linux.go +++ b/resource/systemd/unit/dbus_linux.go @@ -49,4 +49,8 @@ type SystemdConnection interface { // KillUnit sends a unix signal to the process KillUnit(name string, signal int32) + + EnableUnitFiles(files []string, runtime bool, force bool) (bool, []dbus.EnableUnitFileChange, error) + + DisableUnitFiles(files []string, runtime bool) ([]dbus.DisableUnitFileChange, error) } diff --git a/resource/systemd/unit/executor.go b/resource/systemd/unit/executor.go index 959049246..f84c92160 100644 --- a/resource/systemd/unit/executor.go +++ b/resource/systemd/unit/executor.go @@ -42,6 +42,12 @@ type SystemdExecutor interface { // will only work on systemd-aware processes. ReloadUnit(*Unit) error + // EnableUnit will enable the unit file and return a list of any changes + EnableUnit(whichUnit *Unit, runtime, force bool) (bool, []*unitFileChange, error) + + // DisableUnit will disable the unit file and return a list of any changes + DisableUnit(whichUnit *Unit, runtime bool) ([]*unitFileChange, error) + // Send a unix signal to a process. SendSignal(u *Unit, signal Signal) } diff --git a/resource/systemd/unit/executor_mock_test.go b/resource/systemd/unit/executor_mock_test.go index 7789e02db..eda67ce61 100644 --- a/resource/systemd/unit/executor_mock_test.go +++ b/resource/systemd/unit/executor_mock_test.go @@ -73,3 +73,15 @@ func (m *ExecutorMock) SendSignal(u *Unit, signal Signal) { m.Called(u, signal) return } + +func (m *ExecutorMock) EnableUnit(u *Unit, runtime, force bool) (bool, []*unitFileChange, error) { + m.maybeSleep() + args := m.Called(u, runtime, force) + return args.Bool(0), args.Get(1).([]*unitFileChange), args.Error(2) +} + +func (m *ExecutorMock) DisableUnit(u *Unit, runtime bool) ([]*unitFileChange, error) { + m.maybeSleep() + args := m.Called(u, runtime) + return args.Get(0).([]*unitFileChange), args.Error(1) +} diff --git a/resource/systemd/unit/fsexecutor.go b/resource/systemd/unit/fsexecutor.go new file mode 100644 index 000000000..9e044f000 --- /dev/null +++ b/resource/systemd/unit/fsexecutor.go @@ -0,0 +1,37 @@ +// Copyright © 2017 Asteris, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unit + +import ( + "os" + "path/filepath" +) + +type fsexecutor interface { + EvalSymlinks(path string) (string, error) + Walk(root string, f func(string, os.FileInfo, error) error) error +} + +type realFsExecutor struct{} + +func (r realFsExecutor) EvalSymlinks(path string) (string, error) { + return filepath.EvalSymlinks(path) +} + +func (r realFsExecutor) Walk(root string, f func(string, os.FileInfo, error) error) error { + return filepath.Walk(root, f) +} + +func filesystemExecutor() fsexecutor { return realFsExecutor{} } diff --git a/resource/systemd/unit/fsexecutor_mock_test.go b/resource/systemd/unit/fsexecutor_mock_test.go new file mode 100644 index 000000000..c9ea647b4 --- /dev/null +++ b/resource/systemd/unit/fsexecutor_mock_test.go @@ -0,0 +1,74 @@ +// Copyright © 2017 Asteris, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unit + +import ( + "os" + "strings" + "time" + + "github.com/stretchr/testify/mock" +) + +type walkFuncArgs struct { + path string + info os.FileInfo + err error +} + +type mockFsExecutor struct { + mock.Mock + walkWith []walkFuncArgs +} + +func (m *mockFsExecutor) EvalSymlinks(path string) (string, error) { + args := m.Called(path) + return args.String(0), args.Error(1) +} + +func (m *mockFsExecutor) Walk(root string, f func(string, os.FileInfo, error) error) error { + args := m.Called(root, f) + for _, node := range m.walkWith { + if !strings.HasPrefix(node.path, root) { + continue + } + f(node.path, node.info, node.err) + } + return args.Error(0) +} + +func newMockWithPaths(path ...string) *mockFsExecutor { + var args []walkFuncArgs + for _, p := range path { + a := walkFuncArgs{ + path: p, + err: nil, + info: mockFileInfo{p}, + } + args = append(args, a) + } + return &mockFsExecutor{walkWith: args} +} + +type mockFileInfo struct { + path string +} + +func (n mockFileInfo) Name() string { s := strings.Split(n.path, "/"); return s[len(s)-1] } +func (n mockFileInfo) Size() int64 { return 0 } +func (n mockFileInfo) Mode() os.FileMode { return 0777 } +func (n mockFileInfo) ModTime() time.Time { return time.Now() } +func (n mockFileInfo) IsDir() bool { return false } +func (n mockFileInfo) Sys() interface{} { return nil } diff --git a/resource/systemd/unit/linux_mocks_test.go b/resource/systemd/unit/linux_mocks_test.go index 849964386..85da7ae40 100644 --- a/resource/systemd/unit/linux_mocks_test.go +++ b/resource/systemd/unit/linux_mocks_test.go @@ -39,6 +39,16 @@ func (m *DbusMock) ListUnits() ([]dbus.UnitStatus, error) { return args.Get(0).([]dbus.UnitStatus), args.Error(1) } +func (m *DbusMock) EnableUnitFiles(files []string, runtime, force bool) (bool, []dbus.EnableUnitFileChange, error) { + args := m.Called(files, runtime, force) + return args.Bool(0), args.Get(1).([]dbus.EnableUnitFileChange), args.Error(2) +} + +func (m *DbusMock) DisableUnitFiles(files []string, runtime bool) ([]dbus.DisableUnitFileChange, error) { + args := m.Called(files, runtime) + return args.Get(1).([]dbus.DisableUnitFileChange), args.Error(2) +} + // ListUnits mocks ListUnitsByNames func (m *DbusMock) ListUnitsByNames(names []string) ([]dbus.UnitStatus, error) { args := m.Called(names) diff --git a/resource/systemd/unit/preparer.go b/resource/systemd/unit/preparer.go index a1fc3d137..45c3bba42 100644 --- a/resource/systemd/unit/preparer.go +++ b/resource/systemd/unit/preparer.go @@ -25,10 +25,17 @@ import ( // UnitState configures loaded systemd units, allowing you to start, stop, or // restart them, reload configuration files, and send unix signals. type Preparer struct { - // The name of the unit. This may optionally include the unit type, - // e.g. "foo.service" and "foo" are both valid. + // The name of the unit. This may optionally omit the unit type if there is + // only a single unit type of the given name. e.g. "foo.service" and "foo" + // are both valid if, and only if, no other unit type named "foo" exists. Name string `hcl:"unit" required:"true"` + // The full path to the unit. If path is specified then it will be used when + // determining if the unit has been enabled or disabled. Note that this path + // must exist within one of the normal systemd search directories + // (e.g. `/lib/systemd/system`) + Path string `hcl:"path"` + // The desired state of the unit. This will affect the current unit job. Use // `systemd.unit.file` to enable and disable jobs, or `systemd.unit.config` to // set options. @@ -53,6 +60,16 @@ type Preparer struct { // an unsigned integer value between 1 and 31 inclusive. SignalNumber uint `hcl:"signal_number" mutually_exclusive:"signal_name,signal_num"` + // Specifies that a unit file should be persistently enabled or disabled. If + // true, enable the unit, if false, disable it, otherwise leave the current + // settings unmodified. + Enable *bool `hcl:"enabled"` + + // Specifies that a unit file should be temporarily enabled or disabled. If + // true, enable the unit, if false, disable it, otherwise leave the current + // settings unmodified. + EnableRuntime *bool `hcl:"enabled_runtime"` + executor SystemdExecutor } @@ -82,10 +99,13 @@ func (p *Preparer) Prepare(ctx context.Context, render resource.Renderer) (resou } r := &Resource{ - Reload: p.Reload, - Name: p.Name, - State: p.State, - systemdExecutor: p.executor, + Reload: p.Reload, + Name: p.Name, + State: p.State, + systemdExecutor: p.executor, + enableChange: p.Enable, + enableRuntimeChange: p.EnableRuntime, + fs: realFsExecutor{}, } if signal != nil { diff --git a/resource/systemd/unit/preparer_test.go b/resource/systemd/unit/preparer_test.go index 49d7513c3..7c37d9d4b 100644 --- a/resource/systemd/unit/preparer_test.go +++ b/resource/systemd/unit/preparer_test.go @@ -119,11 +119,14 @@ func TestPreparer(t *testing.T) { }) t.Run("sets-fields", func(t *testing.T) { t.Parallel() + untrue := false res, err := (&Preparer{ - Name: "test1", - State: "state1", - Reload: true, - executor: &ExecutorMock{}, + Name: "test1", + State: "state1", + Reload: true, + Enable: &untrue, + EnableRuntime: &untrue, + executor: &ExecutorMock{}, }).Prepare(context.Background(), fakerenderer.New()) require.NoError(t, err) assert.Equal(t, "test1", res.(*Resource).Name) @@ -132,6 +135,81 @@ func TestPreparer(t *testing.T) { assert.False(t, res.(*Resource).sendSignal) assert.Equal(t, "", res.(*Resource).SignalName) assert.Equal(t, uint(0), res.(*Resource).SignalNumber) + assert.False(t, *res.(*Resource).enableChange) + assert.False(t, *res.(*Resource).enableRuntimeChange) + }) + t.Run("handles-enable-disable", func(t *testing.T) { + t.Parallel() + valTrue := true + valFalse := false + t.Run("when-true-true", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valTrue, + EnableRuntime: &valTrue, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-false-true", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valFalse, + EnableRuntime: &valTrue, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-true-false", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valTrue, + EnableRuntime: &valFalse, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-false-false", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valFalse, + EnableRuntime: &valFalse, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-true-nil", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valTrue, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-false-nil", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + Enable: &valFalse, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-nil-true", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + EnableRuntime: &valTrue, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) + t.Run("when-nil-alse", func(t *testing.T) { + t.Parallel() + _, err := (&Preparer{ + EnableRuntime: &valFalse, + executor: &ExecutorMock{}, + }).Prepare(context.Background(), fakerenderer.New()) + assert.NoError(t, err) + }) }) } diff --git a/resource/systemd/unit/resource.go b/resource/systemd/unit/resource.go index 139b0d3ce..241f89d00 100644 --- a/resource/systemd/unit/resource.go +++ b/resource/systemd/unit/resource.go @@ -1,4 +1,4 @@ -// Copyright © 2016 Asteris, LLC +// Copyright © 2017 Asteris, LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package unit import ( "fmt" + "os" "github.com/pkg/errors" @@ -53,6 +54,12 @@ type Resource struct { // explanation of these signals. SignalNumber uint `export:"signal_number"` + // Set to true if the unit is enabled (symlinks exist in `/etc`) + Enabled bool `export:"enabled"` + + // Set to tru if the unit is enabled for runtime (symlinks exist in `/run`) + EnabledRuntime bool `export:"enabled_runtime"` + // The full path to the unit file on disk. This field will be empty if the // unit was not started from a systemd unit file on disk. Path string `export:"path"` @@ -146,9 +153,12 @@ type Resource struct { // for for more information. ScopeProperties *ScopeTypeProperties `re-export-as:"scope_properties"` - sendSignal bool - systemdExecutor SystemdExecutor - hasRun bool + enableChange *bool // enabled if true, disabled if false, unmodified if nil + enableRuntimeChange *bool // enabled if true, disabled if false, unmodified if nil + sendSignal bool + systemdExecutor SystemdExecutor + hasRun bool + fs fsexecutor } type response struct { @@ -214,12 +224,52 @@ func (r *Resource) runCheck() (resource.TaskStatus, error) { case "stopped": r.shouldStop(u, status) } + enabledRuntime, enabledPersistent, err := r.isEnabled(u) + if err != nil { + return nil, errors.Wrap(err, "unable to determine if the unit is enabled") + } + r.Enabled = enabledPersistent + r.EnabledRuntime = enabledRuntime + + status = r.updateEnableStatus("persistent", + status, + r.Enabled, + r.enableChange) + status = r.updateEnableStatus("runtime", + status, + r.EnabledRuntime, + r.enableRuntimeChange) r.hasRun = true return status, nil } +func (r *Resource) updateEnableStatus( + msg string, + status *resource.Status, + current bool, + want *bool) *resource.Status { + var tgt bool + if want == nil { + return status + } + tgt = *want + if tgt == current { + status.AddMessage(fmt.Sprintf("%s unit is already %s", msg, showEnabled(tgt))) + return status + } + status.RaiseLevel(resource.StatusWillChange) + status.AddDifference(msg, showEnabled(current), showEnabled(tgt), "") + return status +} + +func showEnabled(b bool) string { + if b { + return "enabled" + } + return "disabled" +} + func (r *Resource) runApply() (resource.TaskStatus, error) { - log.WithField("Unit Name: ", r.Name).Infof("calling runApply()....") status := resource.NewStatus() tempStatus := resource.NewStatus() u, err := r.systemdExecutor.QueryUnit(r.Name, false) @@ -258,9 +308,55 @@ func (r *Resource) runApply() (resource.TaskStatus, error) { case "restarted": runstateErr = r.systemdExecutor.RestartUnit(u) } + + enabledRuntime, enabledPersistent, err := r.isEnabled(u) + if err != nil { + return nil, errors.Wrap(err, "unable to determine if the unit is enabled") + } + r.Enabled = enabledPersistent + r.EnabledRuntime = enabledRuntime + + var symlinkChanges []*unitFileChange + + if r.enableChange != nil { + changes, err := r.toggleUnitEnabled(u, status, false, *r.enableChange, enabledPersistent) + if err != nil { + return status, err + } + symlinkChanges = append(symlinkChanges, changes...) + } + + if r.enableRuntimeChange != nil { + changes, err := r.toggleUnitEnabled(u, status, true, *r.enableRuntimeChange, enabledRuntime) + if err != nil { + return status, err + } + symlinkChanges = append(symlinkChanges, changes...) + } + + for _, ch := range symlinkChanges { + status.AddDifference(ch.Type, "", fmt.Sprintf("%s -> %s", ch.Filename, ch.Destination), "") + } + return status, runstateErr } +func (r *Resource) toggleUnitEnabled(u *Unit, status *resource.Status, runtime, shouldBeEnabled, isEnabled bool) ([]*unitFileChange, error) { + if shouldBeEnabled == isEnabled { + if isEnabled { + status.AddMessage("unit is already enabled") + } else { + status.AddMessage("unit is already disabled") + } + return []*unitFileChange{}, nil + } + if shouldBeEnabled { + _, c, e := r.systemdExecutor.EnableUnit(u, runtime, true) + return c, e + } + return r.systemdExecutor.DisableUnit(u, runtime) +} + // We copy data from the unit into the resource to make the UX nicer for users // who want to access systemd information. func (r *Resource) populateFromUnit(u *Unit) { @@ -362,6 +458,62 @@ func (r *Resource) shouldStop(u *Unit, st *resource.Status) bool { return true } +// isEnabled checks a unit file to see if it's enabled at runtime and/or +// persistently. It returns a thruple of the runtime enablement, system +// enablement, and an error +func (r *Resource) isEnabled(unit *Unit) (runtime bool, persistent bool, err error) { + runtime, err = r.existsInTree("/run/systemd", unit) + if err != nil { + return false, false, err + } + persistent, err = r.existsInTree("/etc/systemd", unit) + if err != nil { + return false, false, err + } + return runtime, persistent, nil +} + +func (r *Resource) existsInTree(root string, unit *Unit) (bool, error) { + var found bool + toFind := unit.Name + var checkSymlink bool + + fmt.Fprintf(os.Stderr, "existsInTree (fprintf to stderr)\n") + + if unit == nil { + fmt.Fprintf(os.Stderr, "unit is nil in existsInTree!") + return false, errors.New("unit is nil") + } + + if unit.Path != "" { + toFind = unit.Path + checkSymlink = true + } + + err := r.fs.Walk(root, func(path string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + if checkSymlink { + matches, matchErr := r.symlinkTargetMatches(path, toFind) + found = found || matches + err = matchErr + } else if info.Name() == toFind { + found = true + } + return nil + }) + return found, err +} + +func (r *Resource) symlinkTargetMatches(symlinkPath, expectedPath string) (bool, error) { + canonical, err := r.fs.EvalSymlinks(symlinkPath) + if err != nil { + return false, err + } + return (canonical == expectedPath), nil +} + func getFailedReason(u *Unit) (string, error) { err := errors.New("unable to determine cause of failure: no properties available") var reason string diff --git a/resource/systemd/unit/resource_test.go b/resource/systemd/unit/resource_test.go index bbd0f5561..c59679e0f 100644 --- a/resource/systemd/unit/resource_test.go +++ b/resource/systemd/unit/resource_test.go @@ -27,6 +27,9 @@ import ( ) func TestCheck(t *testing.T) { + fs := &mockFsExecutor{} + fs.On("Walk", any, any).Return(nil) + t.Parallel() t.Run("send-signal", func(t *testing.T) { r := &Resource{ @@ -34,6 +37,7 @@ func TestCheck(t *testing.T) { SignalName: "SIGKILL", SignalNumber: 9, sendSignal: true, + fs: fs, } e := &ExecutorMock{} r.systemdExecutor = e @@ -47,6 +51,7 @@ func TestCheck(t *testing.T) { r := &Resource{ State: "running", Reload: true, + fs: fs, } e := &ExecutorMock{} r.systemdExecutor = e @@ -62,6 +67,7 @@ func TestCheck(t *testing.T) { r := &Resource{ Name: "resource1", State: "running", + fs: fs, } t.Run("query-unit-returns-error", func(t *testing.T) { expected := errors.New("error1") @@ -146,6 +152,7 @@ func TestCheck(t *testing.T) { r := &Resource{ Name: "resource1", State: "stopped", + fs: fs, } t.Run("when-status-active", func(t *testing.T) { unit := &Unit{ActiveState: "active"} @@ -210,6 +217,7 @@ func TestCheck(t *testing.T) { r := &Resource{ Name: "resource1", State: "restarted", + fs: fs, } e := &ExecutorMock{} r.systemdExecutor = e @@ -335,11 +343,14 @@ func TestGetFailedReason(t *testing.T) { // TestApply runs a test func TestApply(t *testing.T) { + fs := &mockFsExecutor{} + fs.On("Walk", any, any).Return(nil) + t.Parallel() t.Run("query-unit-returns-error", func(t *testing.T) { t.Parallel() expected := errors.New("error1") - r := &Resource{} + r := &Resource{fs: fs} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return((*Unit)(nil), expected) r.systemdExecutor = e @@ -350,7 +361,13 @@ func TestApply(t *testing.T) { t.Run("when-send-signal", func(t *testing.T) { t.Parallel() u := &Unit{ActiveState: "active"} - r := &Resource{ActiveState: "running", SignalName: "SIGKILL", SignalNumber: 9, sendSignal: true} + r := &Resource{ + ActiveState: "running", + SignalName: "SIGKILL", + SignalNumber: 9, + sendSignal: true, + fs: fs, + } e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) e.On("SendSignal", any, any).Return() @@ -369,6 +386,7 @@ func TestApply(t *testing.T) { r := &Resource{ State: "running", Reload: true, + fs: fs, } e := &ExecutorMock{} u := &Unit{ActiveState: "active"} @@ -385,6 +403,7 @@ func TestApply(t *testing.T) { r := &Resource{ State: "running", Reload: true, + fs: fs, } e := &ExecutorMock{} u := &Unit{ActiveState: "active"} @@ -400,7 +419,7 @@ func TestApply(t *testing.T) { t.Run("when-want-running", func(t *testing.T) { t.Parallel() t.Run("start-returns-error", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "inactive"} e := &ExecutorMock{} expected := errors.New("error1") @@ -413,7 +432,7 @@ func TestApply(t *testing.T) { assert.Equal(t, expected, err) }) t.Run("status-is-active", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "active"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -426,7 +445,7 @@ func TestApply(t *testing.T) { e.AssertNotCalled(t, "StartUnit", u) }) t.Run("status-is-reloading", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "reloading"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -439,7 +458,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StartUnit", u) }) t.Run("status-is-inactive", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "inactive"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -452,7 +471,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StartUnit", u) }) t.Run("status-is-failed", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "failed"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -465,7 +484,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StartUnit", u) }) t.Run("status-is-activating", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "activating"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -478,7 +497,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StartUnit", u) }) t.Run("status-is-deactivating", func(t *testing.T) { - r := &Resource{State: "running"} + r := &Resource{State: "running", fs: fs} u := &Unit{ActiveState: "deactivating"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -495,7 +514,7 @@ func TestApply(t *testing.T) { t.Run("when-want-stopped", func(t *testing.T) { t.Parallel() t.Run("stop-returns-error", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "active"} e := &ExecutorMock{} expected := errors.New("error1") @@ -508,7 +527,7 @@ func TestApply(t *testing.T) { assert.Equal(t, expected, err) }) t.Run("status-is-active", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "active"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -521,7 +540,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StopUnit", u) }) t.Run("status-is-reloading", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "reloading"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -534,7 +553,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StopUnit", u) }) t.Run("status-is-inactive", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "inactive"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -547,7 +566,7 @@ func TestApply(t *testing.T) { e.AssertNotCalled(t, "StopUnit", u) }) t.Run("status-is-failed", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "failed"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -560,7 +579,7 @@ func TestApply(t *testing.T) { e.AssertNotCalled(t, "StopUnit", u) }) t.Run("status-is-activating", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "activating"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -573,7 +592,7 @@ func TestApply(t *testing.T) { e.AssertCalled(t, "StopUnit", u) }) t.Run("status-is-deactivating", func(t *testing.T) { - r := &Resource{State: "stopped"} + r := &Resource{State: "stopped", fs: fs} u := &Unit{ActiveState: "deactivating"} e := &ExecutorMock{} e.On("QueryUnit", any, any).Return(u, nil) @@ -590,7 +609,7 @@ func TestApply(t *testing.T) { t.Parallel() t.Run("when-restart-returns-error", func(t *testing.T) { t.Parallel() - r := &Resource{State: "restarted"} + r := &Resource{State: "restarted", fs: fs} u := &Unit{ActiveState: "active"} e := &ExecutorMock{} r.systemdExecutor = e @@ -606,7 +625,7 @@ func TestApply(t *testing.T) { for _, st := range states { t.Run(st, func(t *testing.T) { u := &Unit{ActiveState: st} - r := &Resource{State: "restarted"} + r := &Resource{State: "restarted", fs: fs} e := &ExecutorMock{} e.On("RestartUnit", any).Return(nil) e.On("QueryUnit", any, any).Return(u, nil) @@ -618,11 +637,78 @@ func TestApply(t *testing.T) { } }) }) + t.Run("tries-to-enable-unit", func(t *testing.T) { + True := true + fs := newMockWithPaths() + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("", nil) + u := &Unit{ + Name: "name1.service", + Path: "/lib/systemd/system/name1.service", + } + e := &ExecutorMock{} + e.On("QueryUnit", any, any).Return(u, nil) + e.On("EnableUnit", any, any, any).Return(false, []*unitFileChange{}, nil) + r := &Resource{ + Name: "name1.service", + systemdExecutor: e, + enableChange: &True, + fs: fs, + } + _, err := r.Apply(context.Background()) + e.AssertCalled(t, "EnableUnit", any, any, any) + require.NoError(t, err) + }) + t.Run("tries-to-disable-unit", func(t *testing.T) { + fs := newMockWithPaths("/etc/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/etc/systemd/system/name1.service", nil) + u := &Unit{ + Name: "name1.service", + } + e := &ExecutorMock{} + e.On("QueryUnit", any, any).Return(u, nil) + e.On("EnableUnit", any, any, any).Return(false, []*unitFileChange{}, nil) + e.On("DisableUnit", any, any).Return([]*unitFileChange{}, nil) + r := &Resource{ + Name: "name1.service", + systemdExecutor: e, + enableChange: new(bool), + fs: fs, + } + _, err := r.Apply(context.Background()) + e.AssertCalled(t, "DisableUnit", any, any) + require.NoError(t, err) + }) + t.Run("updates-enablement", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths( + "/run/systemd/system/name1.service", + ) + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/lib/systemd/system/name1.service", nil) + u := &Unit{ + Name: "name1.service", + Path: "/lib/systemd/system/name1.service", + } + e := &ExecutorMock{} + e.On("QueryUnit", any, any).Return(u, nil) + r := &Resource{ + Name: "name1.service", + systemdExecutor: e, + fs: fs, + } + _, err := r.Apply(context.Background()) + require.NoError(t, err) + assert.True(t, r.EnabledRuntime) + }) } // TestCheckAfterApply runs a test func TestCheckAfterApply(t *testing.T) { t.Parallel() + fs := &mockFsExecutor{} + fs.On("Walk", any, any).Return(nil) t.Run("when-send-signal", func(t *testing.T) { t.Parallel() @@ -631,6 +717,7 @@ func TestCheckAfterApply(t *testing.T) { SignalName: "SIGKILL", SignalNumber: 9, sendSignal: true, + fs: fs, } u := &Unit{ActiveState: "active"} e := &ExecutorMock{} @@ -649,6 +736,7 @@ func TestCheckAfterApply(t *testing.T) { r := &Resource{ State: "running", Reload: true, + fs: fs, } u := &Unit{ActiveState: "active"} e := &ExecutorMock{} @@ -665,6 +753,9 @@ func TestCheckAfterApply(t *testing.T) { // TestHandlesContext runs a test func TestHandlesContext(t *testing.T) { + fs := &mockFsExecutor{} + fs.On("Walk", any, any).Return(nil) + t.Parallel() t.Run("Check", func(t *testing.T) { @@ -674,7 +765,7 @@ func TestHandlesContext(t *testing.T) { expected := "context was cancelled" ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) time.Sleep(2 * time.Millisecond) - r := &Resource{} + r := &Resource{fs: fs} e := &ExecutorMock{DoSleep: true, SleepFor: (10 * time.Millisecond)} e.On("QueryUnit", any, any).Return(&Unit{}, nil) r.systemdExecutor = e @@ -686,7 +777,7 @@ func TestHandlesContext(t *testing.T) { t.Run("when-canceled", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) - r := &Resource{} + r := &Resource{fs: fs} e := &ExecutorMock{DoSleep: true, SleepFor: (10 * time.Millisecond)} e.On("QueryUnit", any, any).Return(&Unit{}, nil) r.systemdExecutor = e @@ -702,7 +793,7 @@ func TestHandlesContext(t *testing.T) { expected := "context was cancelled" ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) time.Sleep(2 * time.Millisecond) - r := &Resource{} + r := &Resource{fs: fs} e := &ExecutorMock{DoSleep: true, SleepFor: (10 * time.Millisecond)} e.On("QueryUnit", any, any).Return(&Unit{}, nil) r.systemdExecutor = e @@ -713,7 +804,7 @@ func TestHandlesContext(t *testing.T) { t.Run("when-canceled", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) - r := &Resource{} + r := &Resource{fs: fs} e := &ExecutorMock{DoSleep: true, SleepFor: (10 * time.Millisecond)} e.On("QueryUnit", any, any).Return(&Unit{}, nil) r.systemdExecutor = e @@ -721,5 +812,136 @@ func TestHandlesContext(t *testing.T) { _, err := r.Apply(ctx) assert.Error(t, err) }) + + }) +} + +// TestIsEnabled tests validation of whether or not a unit is enabled +func TestIsEnabled(t *testing.T) { + t.Parallel() + t.Run("check-sets-fields", func(t *testing.T) { + t.Parallel() + t.Run("when-no-path-set", func(t *testing.T) { + t.Run("when-not-enabled", func(t *testing.T) { + t.Parallel() + fs := &mockFsExecutor{} + fs.On("Walk", any, any).Return(nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: ""} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.False(t, runtime) + assert.False(t, persistent) + }) + t.Run("when-enabled", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths("/etc/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: ""} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.False(t, runtime) + assert.True(t, persistent) + }) + t.Run("when-runtime-enabled", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: ""} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.True(t, runtime) + assert.False(t, persistent) + }) + t.Run("when-enabled-and-runtime-enabled", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths( + "/run/systemd/system/name1.service", + "/etc/systemd/system/name1.service", + ) + fs.On("Walk", any, any).Return(nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: ""} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.True(t, runtime) + assert.True(t, persistent) + }) + }) + t.Run("when-path-set", func(t *testing.T) { + t.Parallel() + t.Run("when-disabled", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths() + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/lib/systemd/system/name1.service"} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.False(t, runtime) + assert.False(t, persistent) + }) + t.Run("when-enabled", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/path1", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/path1"} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.True(t, runtime) + assert.False(t, persistent) + }) + t.Run("when-enabled-with-different-symlink", func(t *testing.T) { + t.Parallel() + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/path2", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/path1"} + runtime, persistent, err := r.isEnabled(u) + assert.NoError(t, err) + assert.False(t, runtime) + assert.False(t, persistent) + }) + }) + }) +} + +func TestExistsInTree(t *testing.T) { + t.Parallel() + t.Run("when-path-and-symlink-target-match", func(t *testing.T) { + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/lib/systemd/system/name-full.service", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/lib/systemd/system/name-full.service"} + inTree, err := r.existsInTree("/run", u) + require.NoError(t, err) + assert.True(t, inTree) + }) + t.Run("when-path-and-symlink-target-mismatch", func(t *testing.T) { + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(nil) + fs.On("EvalSymlinks", any).Return("/lib/systemd/system/name-full.service", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/etc/init.d/name-full"} + inTree, err := r.existsInTree("/run", u) + require.NoError(t, err) + assert.False(t, inTree) + }) + t.Run("when-symlink-error", func(t *testing.T) { + expected := errors.New("error1") + fs := newMockWithPaths("/run/systemd/system/name1.service") + fs.On("Walk", any, any).Return(expected) + fs.On("EvalSymlinks", any).Return("", nil) + r := &Resource{fs: fs} + u := &Unit{Name: "name1.service", Path: "/etc/init.d/name-full"} + _, err := r.existsInTree("/run", u) + assert.Equal(t, expected, err) }) } diff --git a/resource/systemd/unit/systemd_connection_mock_test.go b/resource/systemd/unit/systemd_connection_mock_test.go index bbd4fbcaf..22aa89599 100644 --- a/resource/systemd/unit/systemd_connection_mock_test.go +++ b/resource/systemd/unit/systemd_connection_mock_test.go @@ -40,6 +40,16 @@ func (m *SystemdMock) ListUnitsByNames(units []string) ([]dbus.UnitStatus, error return args.Get(0).([]dbus.UnitStatus), args.Error(1) } +func (m *SystemdMock) EnableUnitFiles(files []string, runtime, force bool) (bool, []dbus.EnableUnitFileChange, error) { + args := m.Called(files, runtime, force) + return args.Bool(0), args.Get(1).([]dbus.EnableUnitFileChange), args.Error(2) +} + +func (m *SystemdMock) DisableUnitFiles(files []string, runtime bool) ([]dbus.DisableUnitFileChange, error) { + args := m.Called(files, runtime) + return args.Get(1).([]dbus.DisableUnitFileChange), args.Error(2) +} + func (m *SystemdMock) GetUnitProperties(unit string) (map[string]interface{}, error) { args := m.Called(unit) return args.Get(0).(map[string]interface{}), args.Error(1) diff --git a/resource/systemd/unit/systemd_linux.go b/resource/systemd/unit/systemd_linux.go index d80cf4d9c..3fa10ddc4 100644 --- a/resource/systemd/unit/systemd_linux.go +++ b/resource/systemd/unit/systemd_linux.go @@ -97,6 +97,57 @@ func (l LinuxExecutor) SendSignal(u *Unit, signal Signal) { l.dbusConn.KillUnit(u.Name, int32(signal)) } +// EnableUnit will enable a unit file. u specifies the unit file to enable, +// runtime specifies whether the unit should be enabled at runtime (true) or +// persistently (false), and force specifies whether any existing symlinks +// should be overwritten. It returns a thruple of a bool, which specifies +// whether any enablement hooks (e.g. from an [Install] section) were run, a +// list of changes that were made on the filesystem, and an error. +func (l LinuxExecutor) EnableUnit(u *Unit, runtime, force bool) (bool, []*unitFileChange, error) { + var whatChanged []*unitFileChange + ranHooks, changes, err := l.dbusConn.EnableUnitFiles([]string{u.Name}, runtime, force) + for _, change := range changes { + convChanges, convErr := newUnitChange(&change) + if convErr != nil { + return false, []*unitFileChange{}, convErr + } + whatChanged = append(whatChanged, convChanges) + } + return ranHooks, whatChanged, err +} + +// DisableUnit will disable a unit file. u specidies the unit file to disable, +// and runtime determines whether the unit file should be disabled for the +// current run (true) or persistently (false). It returns a list of changes and +// an error. +func (l LinuxExecutor) DisableUnit(u *Unit, runtime bool) ([]*unitFileChange, error) { + var whatChanged []*unitFileChange + changes, err := l.dbusConn.DisableUnitFiles([]string{u.Name}, runtime) + for _, change := range changes { + convChanges, convErr := newUnitChange(&change) + if convErr != nil { + return []*unitFileChange{}, convErr + } + whatChanged = append(whatChanged, convChanges) + } + return whatChanged, err +} + +func newUnitChange(dbusChange interface{}) (*unitFileChange, error) { + switch in := dbusChange.(type) { + case dbus.EnableUnitFileChange: + return &unitFileChange{Type: in.Type, Filename: in.Filename, Destination: in.Destination}, nil + case dbus.DisableUnitFileChange: + return &unitFileChange{Type: in.Type, Filename: in.Filename, Destination: in.Destination}, nil + case *dbus.EnableUnitFileChange: + return newUnitChange(*in) + case *dbus.DisableUnitFileChange: + return newUnitChange(*in) + default: + return nil, fmt.Errorf("unsupported type: %T", dbusChange) + } +} + func runDbusCommand(f func(string, string, chan<- string) (int, error), name, mode, operation string) error { ch := make(chan string) defer close(ch) diff --git a/resource/systemd/unit/systemd_stub.go b/resource/systemd/unit/systemd_stub.go index bdfe52abb..9b2b3759f 100644 --- a/resource/systemd/unit/systemd_stub.go +++ b/resource/systemd/unit/systemd_stub.go @@ -61,6 +61,16 @@ func (s StubExecutor) SendSignal(*Unit, Signal) { return } +// EnableUnit is a stub +func (s StubExecutor) EnableUnit(*Unit, bool, bool) (bool, []*unitFileChange, error) { + return false, []*unitFileChange{}, ErrUnsupportedOS +} + +// DisableUnit is a stub +func (s StubExecutor) DisableUnit(*Unit, bool) ([]*unitFileChange, error) { + return []*unitFileChange{}, ErrUnsupportedOS +} + func realExecutor() (SystemdExecutor, error) { return StubExecutor{}, ErrUnsupportedOS } diff --git a/resource/systemd/unit/unitChange.go b/resource/systemd/unit/unitChange.go new file mode 100644 index 000000000..6cab51979 --- /dev/null +++ b/resource/systemd/unit/unitChange.go @@ -0,0 +1,24 @@ +// Copyright © 2016 Asteris, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unit + +// unitFileChange mirrors github.com/coreos/go-systemd/dbus.EnableUnitFileChange +// and github.com/coreos/go-systemd/dbus.DisableUnitFileChange, and is recreated +// here to avoid including a dependency on dbus for non-linux systems. +type unitFileChange struct { + Type string // one of 'link' or 'unlink' + Filename string // filename of the symlink + Destination string // destination of the symlink +}