Skip to content

Commit

Permalink
Merge pull request #7 from b4fun/stderr-race
Browse files Browse the repository at this point in the history
fix: fix with stderr data race
  • Loading branch information
bcho authored Oct 3, 2024
2 parents f7543cf + 14ed58e commit 721eff6
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 18 deletions.
8 changes: 6 additions & 2 deletions script/constructor.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package script

// TODO: generate constructrs from bitfield/script's source code
// TODO: generate constructors from bitfield/script's source code

import (
"context"
"net/http"
"sync"

"github.com/bitfield/script"
)

func newPipeFrom(pipe *script.Pipe) *Pipe {
return &Pipe{Pipe: pipe}
return &Pipe{
Pipe: pipe,
mu: new(sync.Mutex),
}
}

func NewPipe() *Pipe {
Expand Down
44 changes: 36 additions & 8 deletions script/contextual.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os/exec"
"path/filepath"
"strings"
"sync"
"text/template"

"github.com/bitfield/script"
Expand All @@ -25,15 +26,30 @@ var NewReadAutoCloser = script.NewReadAutoCloser
type Pipe struct {
*script.Pipe

stderr io.Writer // captured from WithStderr

// wd is the working directory for current pipe.
wd string

mu *sync.Mutex // protects the following fields

stderr io.Writer // captured from WithStderr

// env is the environment variables for current pipe.
// Non-empty value will be set to the exec.Command instance.
env []string
}

func (p *Pipe) environments() []string {
p.mu.Lock()
defer p.mu.Unlock()
return p.env
}

func (p *Pipe) stdErr() io.Writer {
p.mu.Lock()
defer p.mu.Unlock()
return p.stderr
}

func (p *Pipe) At(dir string) *Pipe {
p.wd = dir
return p
Expand All @@ -46,23 +62,35 @@ func (p *Pipe) WithCurrentEnv() *Pipe {

// WithEnv sets the environment variables for the current pipe.
func (p *Pipe) WithEnv(env []string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = env
return p
}

// AppendEnv appends the environment variables for the current pipe.
func (p *Pipe) AppendEnv(env ...string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = append(p.env, env...)
return p
}

// WithEnvKV sets the environment variable key-value pair for the current pipe.
func (p *Pipe) WithEnvKV(key, value string) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.env = append(p.env, key+"="+value)
return p
}

func (p *Pipe) WithStderr(w io.Writer) *Pipe {
p.mu.Lock()
defer p.mu.Unlock()

p.stderr = w
p.Pipe = p.Pipe.WithStderr(w)
return p
Expand Down Expand Up @@ -132,8 +160,8 @@ func (p *Pipe) execContext(
if p.wd != "" {
cmd.Dir = p.wd
}
if len(p.env) > 0 {
cmd.Env = p.env
if envs := p.environments(); len(envs) > 0 {
cmd.Env = envs
}

return cmd
Expand All @@ -150,8 +178,8 @@ func (p *Pipe) ExecContext(ctx context.Context, cmdLine string) *Pipe {
cmd.Stdin = r
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
if stderr := p.stdErr(); stderr != nil {
cmd.Stderr = stderr
}

if err := cmd.Start(); err != nil {
Expand Down Expand Up @@ -189,8 +217,8 @@ func (p *Pipe) ExecForEachContext(ctx context.Context, cmdLine string) *Pipe {
cmd := p.execContext(ctx, args[0], args[1:])
cmd.Stdout = w
cmd.Stderr = w
if p.stderr != nil {
cmd.Stderr = p.stderr
if stderr := p.stdErr(); stderr != nil {
cmd.Stderr = stderr
}
err = cmd.Start()
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions script/contextual_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package script

import (
"context"
"testing"
)

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
ctx := context.Background()
err := ExecContext(ctx, "echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions script/go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module github.com/b4fun/script-contextual/script

go 1.22.3
go 1.22.7

require (
github.com/bitfield/script v0.22.1
github.com/bitfield/script v0.23.0
mvdan.cc/sh/v3 v3.7.0
)

Expand Down
4 changes: 2 additions & 2 deletions script/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4=
github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28=
github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand Down
4 changes: 2 additions & 2 deletions tests/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module tests

go 1.22.3
go 1.22.7

replace github.com/b4fun/script-contextual/script => ../script

Expand All @@ -11,7 +11,7 @@ require (
)

require (
github.com/bitfield/script v0.22.1 // indirect
github.com/bitfield/script v0.23.0 // indirect
github.com/itchyny/gojq v0.12.13 // indirect
github.com/itchyny/timefmt-go v0.1.5 // indirect
golang.org/x/sys v0.10.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions tests/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/bitfield/script v0.22.1 h1:DphxoC5ssYciwd0ZS+N0Xae46geAD/0mVWh6a2NUxM4=
github.com/bitfield/script v0.22.1/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/bitfield/script v0.23.0 h1:N0R5yLEl6wJIS9PR/A6xXwjMsplMubyxdi05N5l0X28=
github.com/bitfield/script v0.23.0/go.mod h1:fv+6x4OzVsRs6qAlc7wiGq8fq1b5orhtQdtW0dwjUHI=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
Expand Down
175 changes: 175 additions & 0 deletions tests/script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,40 @@ func TestWithStdout_SetsSpecifiedWriterAsStdout(t *testing.T) {
}
}

func TestWithEnv_UnsetsAllEnvVarsGivenEmptySlice(t *testing.T) {
t.Parallel()
p := script.NewPipe().WithEnv([]string{"ENV1=test1"}).Exec("sh -c 'echo ENV1=$ENV1'")
want := "ENV1=test1\n"
got, err := p.String()
if err != nil {
t.Fatal(err)
}
if got != want {
t.Fatalf("want %q, got %q", want, got)
}
got, err = p.Echo("").WithEnv([]string{}).Exec("sh -c 'echo ENV1=$ENV1'").String()
if err != nil {
t.Fatal(err)
}
want = "ENV1=\n"
if got != want {
t.Errorf("want %q, got %q", want, got)
}
}

func TestWithEnv_SetsGivenVariablesForSubsequentExec(t *testing.T) {
t.Parallel()
env := []string{"ENV1=test1", "ENV2=test2"}
got, err := script.NewPipe().WithEnv(env).Exec("sh -c 'echo ENV1=$ENV1 ENV2=$ENV2'").String()
if err != nil {
t.Fatal(err)
}
want := "ENV1=test1 ENV2=test2\n"
if got != want {
t.Errorf("want %q, got %q", want, got)
}
}

func TestErrorReturnsErrorSetByPreviousPipeStage(t *testing.T) {
t.Parallel()
p := script.File("testdata/nonexistent.txt")
Expand Down Expand Up @@ -1850,6 +1884,135 @@ func TestReadReturnsErrorGivenReadErrorOnPipe(t *testing.T) {
}
}

func TestWait_ReturnsErrorPresentOnPipe(t *testing.T) {
t.Parallel()
p := script.Echo("a\nb\nc\n").ExecForEach("{{invalid template syntax}}")
if p.Wait() == nil {
t.Error("want error, got nil")
}
}

func TestWait_DoesNotReturnErrorForValidExecution(t *testing.T) {
t.Parallel()
p := script.Echo("a\nb\nc\n").ExecForEach("echo \"{{.}}\"")
if err := p.Wait(); err != nil {
t.Fatal(err)
}
}

var base64Cases = []struct {
name string
decoded string
encoded string
}{
{
name: "empty string",
decoded: "",
encoded: "",
},
{
name: "single line string",
decoded: "hello world",
encoded: "aGVsbG8gd29ybGQ=",
},
{
name: "multi line string",
decoded: "hello\nthere\nworld\n",
encoded: "aGVsbG8KdGhlcmUKd29ybGQK",
},
}

func TestEncodeBase64_CorrectlyEncodes(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
got, err := script.Echo(tc.decoded).EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != tc.encoded {
t.Logf("input %q incorrectly encoded:", tc.decoded)
t.Error(cmp.Diff(tc.encoded, got))
}
})
}
}

func TestDecodeBase64_CorrectlyDecodes(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
got, err := script.Echo(tc.encoded).DecodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != tc.decoded {
t.Logf("input %q incorrectly decoded:", tc.encoded)
t.Error(cmp.Diff(tc.decoded, got))
}
})
}
}

func TestEncodeBase64_FollowedByDecodeRecoversOriginal(t *testing.T) {
t.Parallel()
for _, tc := range base64Cases {
t.Run(tc.name, func(t *testing.T) {
decoded, err := script.Echo(tc.decoded).EncodeBase64().DecodeBase64().String()
if err != nil {
t.Fatal(err)
}
if decoded != tc.decoded {
t.Error("encode-decode round trip failed:", cmp.Diff(tc.decoded, decoded))
}
encoded, err := script.Echo(tc.encoded).DecodeBase64().EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if encoded != tc.encoded {
t.Error("decode-encode round trip failed:", cmp.Diff(tc.encoded, encoded))
}
})
}
}

func TestDecodeBase64_CorrectlyDecodesInputToBytes(t *testing.T) {
t.Parallel()
input := "CAAAEA=="
got, err := script.Echo(input).DecodeBase64().Bytes()
if err != nil {
t.Fatal(err)
}
want := []byte{8, 0, 0, 16}
if !bytes.Equal(want, got) {
t.Logf("input %#v incorrectly decoded:", input)
t.Error(cmp.Diff(want, got))
}
}

func TestEncodeBase64_CorrectlyEncodesInputBytes(t *testing.T) {
t.Parallel()
input := []byte{8, 0, 0, 16}
reader := bytes.NewReader(input)
want := "CAAAEA=="
got, err := script.NewPipe().WithReader(reader).EncodeBase64().String()
if err != nil {
t.Fatal(err)
}
if got != want {
t.Logf("input %#v incorrectly encoded:", input)
t.Error(cmp.Diff(want, got))
}
}

func TestWithStdErr_IsConcurrencySafeAfterExec(t *testing.T) {
t.Parallel()
err := script.Exec("echo").WithStderr(nil).Wait()
if err != nil {
t.Fatal(err)
}
}

func ExampleArgs() {
script.Args().Stdout()
// prints command-line arguments
Expand Down Expand Up @@ -1969,6 +2132,12 @@ func ExamplePipe_CountLines() {
// 3
}

func ExamplePipe_DecodeBase64() {
script.Echo("SGVsbG8sIHdvcmxkIQ==").DecodeBase64().Stdout()
// Output:
// Hello, world!
}

func ExamplePipe_Do() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -2004,6 +2173,12 @@ func ExamplePipe_Echo() {
// Hello, world!
}

func ExamplePipe_EncodeBase64() {
script.Echo("Hello, world!").EncodeBase64().Stdout()
// Output:
// SGVsbG8sIHdvcmxkIQ==
}

func ExamplePipe_ExitStatus() {
p := script.Exec("echo")
fmt.Println(p.ExitStatus())
Expand Down

0 comments on commit 721eff6

Please sign in to comment.