diff --git a/script/constructor.go b/script/constructor.go index 5d1ae9d..2ad62ba 100644 --- a/script/constructor.go +++ b/script/constructor.go @@ -1,16 +1,20 @@ package script -// TODO: generate constructrs from bitfield/script's source code +// TODO: generate constructors from bitfield/script's source code import ( "context" "net/http" + "sync" "github.com/bitfield/script" ) func newPipeFrom(pipe *script.Pipe) *Pipe { - return &Pipe{Pipe: pipe} + return &Pipe{ + Pipe: pipe, + mu: new(sync.Mutex), + } } func NewPipe() *Pipe { diff --git a/script/contextual.go b/script/contextual.go index bf7b566..dff6265 100644 --- a/script/contextual.go +++ b/script/contextual.go @@ -14,6 +14,7 @@ import ( "os/exec" "path/filepath" "strings" + "sync" "text/template" "github.com/bitfield/script" @@ -25,15 +26,30 @@ var NewReadAutoCloser = script.NewReadAutoCloser type Pipe struct { *script.Pipe - stderr io.Writer // captured from WithStderr - // wd is the working directory for current pipe. wd string + + mu *sync.Mutex // protects the following fields + + stderr io.Writer // captured from WithStderr + // env is the environment variables for current pipe. // Non-empty value will be set to the exec.Command instance. env []string } +func (p *Pipe) environments() []string { + p.mu.Lock() + defer p.mu.Unlock() + return p.env +} + +func (p *Pipe) stdErr() io.Writer { + p.mu.Lock() + defer p.mu.Unlock() + return p.stderr +} + func (p *Pipe) At(dir string) *Pipe { p.wd = dir return p @@ -46,23 +62,35 @@ func (p *Pipe) WithCurrentEnv() *Pipe { // WithEnv sets the environment variables for the current pipe. func (p *Pipe) WithEnv(env []string) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() + p.env = env return p } // AppendEnv appends the environment variables for the current pipe. func (p *Pipe) AppendEnv(env ...string) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() + p.env = append(p.env, env...) return p } // WithEnvKV sets the environment variable key-value pair for the current pipe. func (p *Pipe) WithEnvKV(key, value string) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() + p.env = append(p.env, key+"="+value) return p } func (p *Pipe) WithStderr(w io.Writer) *Pipe { + p.mu.Lock() + defer p.mu.Unlock() + p.stderr = w p.Pipe = p.Pipe.WithStderr(w) return p @@ -132,8 +160,8 @@ func (p *Pipe) execContext( if p.wd != "" { cmd.Dir = p.wd } - if len(p.env) > 0 { - cmd.Env = p.env + if envs := p.environments(); len(envs) > 0 { + cmd.Env = envs } return cmd @@ -150,8 +178,8 @@ func (p *Pipe) ExecContext(ctx context.Context, cmdLine string) *Pipe { cmd.Stdin = r cmd.Stdout = w cmd.Stderr = w - if p.stderr != nil { - cmd.Stderr = p.stderr + if stderr := p.stdErr(); stderr != nil { + cmd.Stderr = stderr } if err := cmd.Start(); err != nil { @@ -189,8 +217,8 @@ func (p *Pipe) ExecForEachContext(ctx context.Context, cmdLine string) *Pipe { cmd := p.execContext(ctx, args[0], args[1:]) cmd.Stdout = w cmd.Stderr = w - if p.stderr != nil { - cmd.Stderr = p.stderr + if stderr := p.stdErr(); stderr != nil { + cmd.Stderr = stderr } err = cmd.Start() if err != nil { diff --git a/script/contextual_test.go b/script/contextual_test.go new file mode 100644 index 0000000..8a9f710 --- /dev/null +++ b/script/contextual_test.go @@ -0,0 +1,15 @@ +package script + +import ( + "context" + "testing" +) + +func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) { + t.Parallel() + ctx := context.Background() + err := ExecContext(ctx, "echo").WithStderr(nil).Wait() + if err != nil { + t.Fatal(err) + } +} diff --git a/script/go.mod b/script/go.mod index d2fece6..d40343f 100644 --- a/script/go.mod +++ b/script/go.mod @@ -1,9 +1,9 @@ module github.com/b4fun/script-contextual/script -go 1.22.3 +go 1.22.7 require ( - github.com/bitfield/script v0.22.1 + github.com/bitfield/script v0.23.0 mvdan.cc/sh/v3 v3.7.0 ) diff --git a/script/go.sum b/script/go.sum index dd45551..c9d2af2 100644 --- a/script/go.sum +++ b/script/go.sum @@ -1,5 +1,5 @@ -github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4= -github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI= +github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28= +github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI= github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA= github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= diff --git a/tests/go.mod b/tests/go.mod index 6e4f13c..e166641 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module tests -go 1.22.3 +go 1.22.7 replace github.com/b4fun/script-contextual/script => ../script @@ -11,7 +11,7 @@ require ( ) require ( - github.com/bitfield/script v0.22.1 // indirect + github.com/bitfield/script v0.23.0 // indirect github.com/itchyny/gojq v0.12.13 // indirect github.com/itchyny/timefmt-go v0.1.5 // indirect golang.org/x/sys v0.10.0 // indirect diff --git a/tests/go.sum b/tests/go.sum index dd45551..c9d2af2 100644 --- a/tests/go.sum +++ b/tests/go.sum @@ -1,5 +1,5 @@ -github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4= -github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI= +github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28= +github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI= github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA= github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= diff --git a/tests/script_test.go b/tests/script_test.go index 8063710..7c8634c 100644 --- a/tests/script_test.go +++ b/tests/script_test.go @@ -1768,6 +1768,40 @@ func TestWithStdout_SetsSpecifiedWriterAsStdout(t *testing.T) { } } +func TestWithEnv_UnsetsAllEnvVarsGivenEmptySlice(t *testing.T) { + t.Parallel() + p := script.NewPipe().WithEnv([]string{"ENV1=test1"}).Exec("sh -c 'echo ENV1=$ENV1'") + want := "ENV1=test1\n" + got, err := p.String() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("want %q, got %q", want, got) + } + got, err = p.Echo("").WithEnv([]string{}).Exec("sh -c 'echo ENV1=$ENV1'").String() + if err != nil { + t.Fatal(err) + } + want = "ENV1=\n" + if got != want { + t.Errorf("want %q, got %q", want, got) + } +} + +func TestWithEnv_SetsGivenVariablesForSubsequentExec(t *testing.T) { + t.Parallel() + env := []string{"ENV1=test1", "ENV2=test2"} + got, err := script.NewPipe().WithEnv(env).Exec("sh -c 'echo ENV1=$ENV1 ENV2=$ENV2'").String() + if err != nil { + t.Fatal(err) + } + want := "ENV1=test1 ENV2=test2\n" + if got != want { + t.Errorf("want %q, got %q", want, got) + } +} + func TestErrorReturnsErrorSetByPreviousPipeStage(t *testing.T) { t.Parallel() p := script.File("testdata/nonexistent.txt") @@ -1850,6 +1884,135 @@ func TestReadReturnsErrorGivenReadErrorOnPipe(t *testing.T) { } } +func TestWait_ReturnsErrorPresentOnPipe(t *testing.T) { + t.Parallel() + p := script.Echo("a\nb\nc\n").ExecForEach("{{invalid template syntax}}") + if p.Wait() == nil { + t.Error("want error, got nil") + } +} + +func TestWait_DoesNotReturnErrorForValidExecution(t *testing.T) { + t.Parallel() + p := script.Echo("a\nb\nc\n").ExecForEach("echo \"{{.}}\"") + if err := p.Wait(); err != nil { + t.Fatal(err) + } +} + +var base64Cases = []struct { + name string + decoded string + encoded string +}{ + { + name: "empty string", + decoded: "", + encoded: "", + }, + { + name: "single line string", + decoded: "hello world", + encoded: "aGVsbG8gd29ybGQ=", + }, + { + name: "multi line string", + decoded: "hello\nthere\nworld\n", + encoded: "aGVsbG8KdGhlcmUKd29ybGQK", + }, +} + +func TestEncodeBase64_CorrectlyEncodes(t *testing.T) { + t.Parallel() + for _, tc := range base64Cases { + t.Run(tc.name, func(t *testing.T) { + got, err := script.Echo(tc.decoded).EncodeBase64().String() + if err != nil { + t.Fatal(err) + } + if got != tc.encoded { + t.Logf("input %q incorrectly encoded:", tc.decoded) + t.Error(cmp.Diff(tc.encoded, got)) + } + }) + } +} + +func TestDecodeBase64_CorrectlyDecodes(t *testing.T) { + t.Parallel() + for _, tc := range base64Cases { + t.Run(tc.name, func(t *testing.T) { + got, err := script.Echo(tc.encoded).DecodeBase64().String() + if err != nil { + t.Fatal(err) + } + if got != tc.decoded { + t.Logf("input %q incorrectly decoded:", tc.encoded) + t.Error(cmp.Diff(tc.decoded, got)) + } + }) + } +} + +func TestEncodeBase64_FollowedByDecodeRecoversOriginal(t *testing.T) { + t.Parallel() + for _, tc := range base64Cases { + t.Run(tc.name, func(t *testing.T) { + decoded, err := script.Echo(tc.decoded).EncodeBase64().DecodeBase64().String() + if err != nil { + t.Fatal(err) + } + if decoded != tc.decoded { + t.Error("encode-decode round trip failed:", cmp.Diff(tc.decoded, decoded)) + } + encoded, err := script.Echo(tc.encoded).DecodeBase64().EncodeBase64().String() + if err != nil { + t.Fatal(err) + } + if encoded != tc.encoded { + t.Error("decode-encode round trip failed:", cmp.Diff(tc.encoded, encoded)) + } + }) + } +} + +func TestDecodeBase64_CorrectlyDecodesInputToBytes(t *testing.T) { + t.Parallel() + input := "CAAAEA==" + got, err := script.Echo(input).DecodeBase64().Bytes() + if err != nil { + t.Fatal(err) + } + want := []byte{8, 0, 0, 16} + if !bytes.Equal(want, got) { + t.Logf("input %#v incorrectly decoded:", input) + t.Error(cmp.Diff(want, got)) + } +} + +func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) { + t.Parallel() + input := []byte{8, 0, 0, 16} + reader := bytes.NewReader(input) + want := "CAAAEA==" + got, err := script.NewPipe().WithReader(reader).EncodeBase64().String() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Logf("input %#v incorrectly encoded:", input) + t.Error(cmp.Diff(want, got)) + } +} + +func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) { + t.Parallel() + err := script.Exec("echo").WithStderr(nil).Wait() + if err != nil { + t.Fatal(err) + } +} + func ExampleArgs() { script.Args().Stdout() // prints command-line arguments @@ -1969,6 +2132,12 @@ func ExamplePipe_CountLines() { // 3 } +func ExamplePipe_DecodeBase64() { + script.Echo("SGVsbG8sIHdvcmxkIQ==").DecodeBase64().Stdout() + // Output: + // Hello, world! +} + func ExamplePipe_Do() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) @@ -2004,6 +2173,12 @@ func ExamplePipe_Echo() { // Hello, world! } +func ExamplePipe_EncodeBase64() { + script.Echo("Hello, world!").EncodeBase64().Stdout() + // Output: + // SGVsbG8sIHdvcmxkIQ== +} + func ExamplePipe_ExitStatus() { p := script.Exec("echo") fmt.Println(p.ExitStatus())