diff --git a/README.md b/README.md index fc61117..9896f66 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# go-pipe [![GoDoc](https://pkg.go.dev/badge/github.com/github/docs)](https://pkg.go.dev/github.com/github/go-pipe) +# go-pipe [![GoDoc](https://pkg.go.dev/badge/github.com/github/docs)](https://pkg.go.dev/github.com/github/go-pipe/v2) A package used to easily build command pipelines in your Go applications # Important @@ -6,4 +6,4 @@ We have not thoroughly tested this package on OSs other than Linux, especially W # Links -* [Docs](https://pkg.go.dev/github.com/github/go-pipe) +* [Docs](https://pkg.go.dev/github.com/github/go-pipe/v2) diff --git a/go.mod b/go.mod index 6f69110..041809e 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/github/go-pipe +module github.com/github/go-pipe/v2 go 1.24.0 diff --git a/internal/ptree/ptree_test.go b/internal/ptree/ptree_test.go index 9c6d4e4..5c014c2 100644 --- a/internal/ptree/ptree_test.go +++ b/internal/ptree/ptree_test.go @@ -9,7 +9,7 @@ import ( "strconv" "testing" - "github.com/github/go-pipe/internal/ptree" + "github.com/github/go-pipe/v2/internal/ptree" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pipe/command.go b/pipe/command.go index 2113902..0daa91e 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -20,9 +20,13 @@ var errProcessInfoMissing = errors.New("cmd.Process is nil") // commandStage is a pipeline `Stage` based on running an external // command and piping the data through its stdin and stdout. type commandStage struct { - name string - stdin io.Closer - cmd *exec.Cmd + name string + cmd *exec.Cmd + + // lateClosers is a list of things that have to be closed once the + // command has finished. + lateClosers []io.Closer + done chan struct{} wg errgroup.Group stderr bytes.Buffer @@ -32,6 +36,10 @@ type commandStage struct { ctxErr atomic.Value } +var ( + _ Stage = (*commandStage)(nil) +) + // Command returns a pipeline `Stage` based on the specified external // `command`, run with the given command-line `args`. Its stdin and // stdout are handled as usual, and its stderr is collected and @@ -61,33 +69,100 @@ func (s *commandStage) Name() string { return s.name } +func (s *commandStage) Preferences() StagePreferences { + return StagePreferences{ + StdinPreference: IOPreferenceFile, + StdoutPreference: IOPreferenceFile, + } +} + func (s *commandStage) Start( - ctx context.Context, env Env, stdin io.ReadCloser, -) (io.ReadCloser, error) { + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, _ StartOptions, +) error { if s.cmd.Dir == "" { s.cmd.Dir = env.Dir } s.setupEnv(ctx, env) + // Things that have to be closed as soon as the command has + // started: + var earlyClosers []io.Closer + + // See the type comment for `Stage` and the long comment in + // `Pipeline.WithStdin()` for the explanation of this unwrapping + // and closing behavior. + if stdin != nil { - // See the long comment in `Pipeline.Start()` for the - // explanation of this special case. switch stdin := stdin.(type) { - case nopCloser: - s.cmd.Stdin = stdin.Reader - case nopCloserWriterTo: - s.cmd.Stdin = stdin.Reader + case readerNopCloser: + // In this case, we shouldn't close it. But unwrap it for + // efficiency's sake: + s.cmd.Stdin = UnwrapReader(stdin) + case *os.File: + // In this case, we can close stdin as soon as the command + // has started: + s.cmd.Stdin = stdin + earlyClosers = append(earlyClosers, stdin) default: + // In this case, we need to close `stdin`, but we should + // only do so after the command has finished: s.cmd.Stdin = stdin + s.lateClosers = append(s.lateClosers, stdin) } - // Also keep a copy so that we can close it when the command exits: - s.stdin = stdin } - stdout, err := s.cmd.StdoutPipe() - if err != nil { - return nil, err + if stdout != nil { + // See the long comment in `Pipeline.Start()` for the + // explanation of this special case. + switch stdout := stdout.(type) { + case writerNopCloser: + // We shouldn't close the wrapped writer. Unwrap it; if + // it's an `*os.File`, exec.Cmd can pass the fd directly + // to the child. Otherwise route the copy through our own + // pipe so we can use a pooled buffer. + writer := UnwrapWriter(stdout) + if f, ok := writer.(*os.File); ok { + s.cmd.Stdout = f + } else { + ec, err := s.setupPooledStdout(writer) + if err != nil { + return err + } + earlyClosers = append(earlyClosers, ec) + } + case *os.File: + // In this case, we can close stdout as soon as the command + // has started: + s.cmd.Stdout = stdout + earlyClosers = append(earlyClosers, stdout) + default: + // In this case, we need to close `stdout`, but we should + // only do so after the command has finished. We also + // route the copy through our own pipe so we can use a + // pooled buffer rather than letting exec.Cmd allocate a + // fresh 32KB buffer for its internal io.Copy. + ec, err := s.setupPooledStdout(stdout) + if err != nil { + return err + } + earlyClosers = append(earlyClosers, ec) + s.lateClosers = append(s.lateClosers, stdout) + } + } + + closeEarlyClosers := func() { + for _, closer := range earlyClosers { + _ = closer.Close() + } + } + + // On error, Close any pipes we created and wait for the goroutines to + // exit before propagating the error. + cleanupOnStartFailure := func() { + closeEarlyClosers() + _ = s.wg.Wait() + _ = s.closeLateClosers() } // If the caller hasn't arranged otherwise, read the command's @@ -99,7 +174,8 @@ func (s *commandStage) Start( // can be sure. p, err := s.cmd.StderrPipe() if err != nil { - return nil, err + cleanupOnStartFailure() + return err } s.wg.Go(func() error { _, err := io.Copy(&s.stderr, p) @@ -116,9 +192,12 @@ func (s *commandStage) Start( s.runInOwnProcessGroup() if err := s.cmd.Start(); err != nil { - return nil, err + cleanupOnStartFailure() + return err } + closeEarlyClosers() + // Arrange for the process to be killed (gently) if the context // expires before the command exits normally: go func() { @@ -130,7 +209,7 @@ func (s *commandStage) Start( } }() - return stdout, nil + return nil } // setupEnv sets or modifies the environment that will be passed to @@ -219,21 +298,55 @@ func (s *commandStage) Wait() error { // Make sure that any stderr is copied before `s.cmd.Wait()` // closes the read end of the pipe: - wErr := s.wg.Wait() + wgErr := s.wg.Wait() err := s.cmd.Wait() err = s.filterCmdError(err) - if err == nil && wErr != nil { - err = wErr + if err == nil && wgErr != nil { + err = wgErr } - if s.stdin != nil { - cErr := s.stdin.Close() - if cErr != nil && err == nil { - return cErr - } + if closeErr := s.closeLateClosers(); err == nil { + err = closeErr } return err } + +func (s *commandStage) closeLateClosers() error { + var err error + for _, closer := range s.lateClosers { + if closeErr := closer.Close(); closeErr != nil && err == nil { + err = closeErr + } + } + s.lateClosers = nil + return err +} + +// setupPooledStdout creates an `os.Pipe()`, sets it as `cmd.Stdout`, +// and starts a goroutine that copies from the read end to `dst` using +// a pooled buffer (or `dst.ReadFrom` when `dst` implements it). The +// returned closer is the write end of the pipe; the caller must add +// it to `earlyClosers` so it is closed once the command has started. +// +// The buffer-pool optimization works for command stages whose stdout is +// not an `*os.File`. Without it, `exec.Cmd` would set up its own pipe +// and run `io.Copy` with a freshly allocated 32KB buffer per invocation. +func (s *commandStage) setupPooledStdout(dst io.Writer) (io.Closer, error) { + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + s.cmd.Stdout = pw + s.wg.Go(func() error { + defer pr.Close() + _, err := pooledCopy(dst, pr) + if err != nil && !errors.Is(err, os.ErrClosed) { + return err + } + return nil + }) + return pw, nil +} diff --git a/pipe/command_linux.go b/pipe/command_linux.go index 997d2cd..987bffb 100644 --- a/pipe/command_linux.go +++ b/pipe/command_linux.go @@ -5,7 +5,7 @@ package pipe import ( "context" - "github.com/github/go-pipe/internal/ptree" + "github.com/github/go-pipe/v2/internal/ptree" ) // On linux, we can limit or observe memory usage in command stages. diff --git a/pipe/command_nil_panic_test.go b/pipe/command_nil_panic_test.go index 740af73..54c1508 100644 --- a/pipe/command_nil_panic_test.go +++ b/pipe/command_nil_panic_test.go @@ -33,7 +33,7 @@ func TestKillWithFailedStart(t *testing.T) { stage := Command("/this/path/does/not/exist/invalid_command_12345") - _, err := stage.Start(ctx, Env{}, nil) + err := stage.Start(ctx, Env{}, nil, nil, StartOptions{}) if err == nil { t.Fatal("Expected start to fail, but it succeeded") } diff --git a/pipe/command_starterror_test.go b/pipe/command_starterror_test.go new file mode 100644 index 0000000..6099d1d --- /dev/null +++ b/pipe/command_starterror_test.go @@ -0,0 +1,63 @@ +package pipe_test + +import ( + "bytes" + "context" + "os/exec" + "sync/atomic" + "testing" + + "github.com/github/go-pipe/v2/pipe" +) + +// TestCommandStageStartFailureNoRace verifies that when `cmd.Start()` +// fails (e.g. command not found), the goroutine that +// `setupPooledStdout` spawned does not leak past `Pipeline.Run()`. +// `bytes.Buffer.ReadFrom` writes to the buffer's slice header via +// `grow()` before its first `Read()`, so a leaked goroutine races +// with the caller's access to the destination buffer once Run +// returns the error. Run a tight loop so `-race` is likely to catch +// any regression. +func TestCommandStageStartFailureNoRace(t *testing.T) { + for i := 0; i < 50; i++ { + var buf bytes.Buffer + p := pipe.New(pipe.WithStdout(&buf)) + p.Add(pipe.CommandStage("nope", exec.Command("this-binary-does-not-exist-xyz123"))) + if err := p.Run(context.Background()); err == nil { + t.Fatalf("expected error from non-existent command, got nil") + } + _ = buf.String() + } +} + +// trackingWriteCloser is a non-`*os.File` `io.WriteCloser` that records +// whether it has been closed. Because it isn't an `*os.File`, a command +// stage routes it through `setupPooledStdout` and closes it as a "late +// closer" (i.e. only after the command finishes / cleanup runs). +type trackingWriteCloser struct { + closed atomic.Bool +} + +func (w *trackingWriteCloser) Write(p []byte) (int, error) { return len(p), nil } + +func (w *trackingWriteCloser) Close() error { + w.closed.Store(true) + return nil +} + +// TestCommandStageStartFailureClosesLateClosers verifies that a +// `WithStdoutCloser` on the last stage is closed even when `cmd.Start()` +// fails. The closer is registered as a "late closer," which is normally +// drained by `Wait()`; since `Wait()` never runs when `Start()` fails, +// the start-failure cleanup path must close it instead. +func TestCommandStageStartFailureClosesLateClosers(t *testing.T) { + w := &trackingWriteCloser{} + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.CommandStage("nope", exec.Command("this-binary-does-not-exist-xyz123"))) + if err := p.Run(context.Background()); err == nil { + t.Fatalf("expected error from non-existent command, got nil") + } + if !w.closed.Load() { + t.Fatalf("expected late closer to be closed after Start() failure") + } +} diff --git a/pipe/command_stdout_fastpath_test.go b/pipe/command_stdout_fastpath_test.go new file mode 100644 index 0000000..42f581a --- /dev/null +++ b/pipe/command_stdout_fastpath_test.go @@ -0,0 +1,118 @@ +package pipe + +import ( + "context" + "io" + "os" + "os/exec" + "testing" +) + +// TestCommandStageStdoutFastPath asserts that when a commandStage's stdout is +// an `*os.File`, the file is set as `cmd.Stdout` so that `exec.Cmd` dup's the +// fd into the child process directly. This is one of the optimizations enabled +// by the Stage interface redesign in #21: the subprocess writes straight to +// the caller's destination fd with no Go-side copy stage in between, and the +// subprocess can detect when that fd is closed. +func TestCommandStageStdoutFastPath(t *testing.T) { + cases := []struct { + name string + wrap func(*os.File) io.WriteCloser + }{ + { + name: "raw *os.File via WithStdoutCloser", + wrap: func(f *os.File) io.WriteCloser { return f }, + }, + { + name: "writerNopCloser{*os.File} via WithStdout", + wrap: func(f *os.File) io.WriteCloser { return writerNopCloser{f} }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f, err := os.CreateTemp(t.TempDir(), "stdout") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = f.Close() }) + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + if err := s.Start(ctx, Env{}, nil, tc.wrap(f), StartOptions{}); err != nil { + t.Fatalf("Start: %v", err) + } + t.Cleanup(func() { _ = s.Wait() }) + + gotFile, ok := s.cmd.Stdout.(*os.File) + if !ok { + t.Fatalf("expected cmd.Stdout to be *os.File, got %T", s.cmd.Stdout) + } + if gotFile != f { + t.Errorf("expected cmd.Stdout to be the user-provided *os.File "+ + "(fd %d), got a different *os.File (fd %d). The fd-pass "+ + "fast path is broken; sendfile/zero-copy will not apply.", + f.Fd(), gotFile.Fd()) + } + }) + } +} + +// TestCommandStageStdoutFastPathThroughPipeline is the same assertion +// but driven end-to-end through `Pipeline.Start()`, so it also +// exercises the `Pipeline.stdout` plumbing that hands the writer to +// the last stage. +func TestCommandStageStdoutFastPathThroughPipeline(t *testing.T) { + cases := []struct { + name string + option func(*os.File) Option + }{ + { + name: "WithStdoutCloser(*os.File)", + option: func(f *os.File) Option { return WithStdoutCloser(f) }, + }, + { + name: "WithStdout(*os.File)", + option: func(f *os.File) Option { return WithStdout(f) }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f, err := os.CreateTemp(t.TempDir(), "stdout") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = f.Close() }) + + cmd := exec.Command("true") + s := CommandStage("true", cmd).(*commandStage) + + p := New(tc.option(f)) + p.Add(s) + if err := p.Start(ctx); err != nil { + t.Fatalf("Start: %v", err) + } + stdoutAfterStart := s.cmd.Stdout + t.Cleanup(func() { _ = p.Wait() }) + + gotFile, ok := stdoutAfterStart.(*os.File) + if !ok { + t.Fatalf("expected cmd.Stdout to be *os.File, got %T", stdoutAfterStart) + } + if gotFile != f { + t.Errorf("expected cmd.Stdout to be the user-provided *os.File "+ + "(fd %d), got a different *os.File (fd %d). The fd-pass "+ + "fast path is broken; sendfile/zero-copy will not apply.", + f.Fd(), gotFile.Fd()) + } + }) + } +} diff --git a/pipe/command_test.go b/pipe/command_test.go index ca5a8c0..531f11f 100644 --- a/pipe/command_test.go +++ b/pipe/command_test.go @@ -78,7 +78,8 @@ func TestCopyEnvWithOverride(t *testing.T) { for _, ex := range examples { t.Run(ex.label, func(t *testing.T) { assert.ElementsMatch(t, ex.expectedResult, - copyEnvWithOverrides(ex.env, ex.overrides)) + copyEnvWithOverrides(ex.env, ex.overrides), + ) }) } } diff --git a/pipe/copy_pool.go b/pipe/copy_pool.go new file mode 100644 index 0000000..69f8a21 --- /dev/null +++ b/pipe/copy_pool.go @@ -0,0 +1,40 @@ +package pipe + +import ( + "io" + "sync" +) + +// copyBufPool reuses 32KB buffers across `io.CopyBuffer` calls, +// avoiding a fresh heap allocation per copy. This matters in +// high-throughput pipelines where many command stages run +// concurrently and stdout is not an `*os.File` that can be passed +// directly through `exec.Cmd`. +var copyBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 32*1024) + return &b + }, +} + +// readerOnly wraps an `io.Reader`, hiding any other interfaces (such +// as `io.WriterTo`) so that `io.CopyBuffer` is forced to use the +// provided buffer. Without this, `*os.File`'s `WriterTo` (added in +// Go 1.26) causes `CopyBuffer` to call `File.WriteTo`, which can +// fall back to `io.Copy` with a fresh allocation, bypassing the pool +// entirely. +type readerOnly struct{ io.Reader } + +// pooledCopy copies from `src` to `dst`. If `dst` implements +// `io.ReaderFrom` (e.g. `*net.TCPConn`, `*os.File`), it delegates to +// `ReadFrom` so platform fast paths like splice can be used. +// Otherwise it falls back to `io.CopyBuffer` with a pooled 32KB +// buffer. +func pooledCopy(dst io.Writer, src io.Reader) (int64, error) { + if rf, ok := dst.(io.ReaderFrom); ok { + return rf.ReadFrom(src) + } + bp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bp) + return io.CopyBuffer(dst, readerOnly{src}, *bp) +} diff --git a/pipe/export_test.go b/pipe/export_test.go new file mode 100644 index 0000000..92862cc --- /dev/null +++ b/pipe/export_test.go @@ -0,0 +1,4 @@ +package pipe + +// This file exports a function to be used only for testing. +var UnwrapNopCloser = unwrapNopCloser diff --git a/pipe/function.go b/pipe/function.go index e8d9522..00ca595 100644 --- a/pipe/function.go +++ b/pipe/function.go @@ -4,12 +4,13 @@ import ( "context" "fmt" "io" + "strings" ) // StageFunc is a function that can be used to power a `goStage`. It // should read its input from `stdin` and write its output to // `stdout`. `stdin` and `stdout` will be closed automatically (if -// necessary) once the function returns. +// non-nil) once the function returns. // // Neither `stdin` nor `stdout` are necessarily buffered. If the // `StageFunc` requires buffering, it needs to arrange that itself. @@ -32,29 +33,51 @@ func Function(name string, f StageFunc) Stage { // goStage is a `Stage` that does its work by running an arbitrary // `stageFunc` in a goroutine. type goStage struct { - name string - f StageFunc - done chan struct{} - err error - panicHandler StagePanicHandler + name string + f StageFunc + done chan struct{} + err error } +var _ Stage = (*goStage)(nil) + func (s *goStage) Name() string { return s.name } -func (s *goStage) SetPanicHandler(ph StagePanicHandler) { - s.panicHandler = ph +func (s *goStage) Preferences() StagePreferences { + return StagePreferences{ + StdinPreference: IOPreferenceUndefined, + StdoutPreference: IOPreferenceUndefined, + } } -func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - r, w := io.Pipe() +func (s *goStage) Start( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, opts StartOptions, +) error { + r := UnwrapReader(stdin) + if r == nil { + // treat nil as empty input. + r = strings.NewReader("") + } + + w := UnwrapWriter(stdout) + if w == nil { + // treat nil output as /dev/null + w = io.Discard + } go func() { defer func() { - // Cleanup resources on exit - if err := w.Close(); err != nil && s.err == nil { - s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err) + if opts.PanicHandler != nil { + if p := recover(); p != nil { + s.err = opts.PanicHandler(p) + } + } + if stdout != nil { + if err := stdout.Close(); err != nil && s.err == nil { + s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err) + } } if stdin != nil { if err := stdin.Close(); err != nil && s.err == nil { @@ -63,26 +86,13 @@ func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.R } close(s.done) }() - - defer s.recoverPanic() - - s.err = s.f(ctx, env, stdin, w) + s.err = s.f(ctx, env, r, w) }() - return r, nil + return nil } func (s *goStage) Wait() error { <-s.done return s.err } - -func (s *goStage) recoverPanic() { - if s.panicHandler == nil { - return - } - - if p := recover(); p != nil { - s.err = s.panicHandler(p) - } -} diff --git a/pipe/function_panic_test.go b/pipe/function_panic_test.go new file mode 100644 index 0000000..143abf3 --- /dev/null +++ b/pipe/function_panic_test.go @@ -0,0 +1,68 @@ +package pipe_test + +import ( + "context" + "io" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/github/go-pipe/v2/pipe" +) + +const panicChildEnv = "GO_PIPE_FUNCTION_PANIC_CHILD" +const panicSentinel = "function-panic-sentinel" + +// TestFunctionPanicWithoutHandlerPropagates verifies that when a +// `Function` stage panics and no panic handler is installed, the panic +// propagates (crashing the process) rather than being silently +// swallowed and reported as a successful run. Because a propagating +// panic would crash the test binary itself, the actual pipeline is run +// in a re-exec'd subprocess and this test asserts on its outcome. +func TestFunctionPanicWithoutHandlerPropagates(t *testing.T) { + if os.Getenv(panicChildEnv) == "1" { + runPanicChild() + return + } + + cmd := exec.Command(os.Args[0], "-test.run=^TestFunctionPanicWithoutHandlerPropagates$", "-test.v") //nolint:gosec // re-exec of this test binary with constant arguments. + cmd.Env = append(os.Environ(), panicChildEnv+"=1") + out, err := cmd.CombinedOutput() + output := string(out) + + if err == nil { + t.Fatalf("expected subprocess to crash from a propagated panic, but it exited 0\noutput:\n%s", output) + } + if strings.Contains(output, "SURVIVED") { + t.Fatalf("panic was swallowed: Run returned instead of propagating\noutput:\n%s", output) + } + if !strings.Contains(output, "panic:") || !strings.Contains(output, panicSentinel) { + t.Fatalf("expected a propagated panic mentioning %q, got:\n%s", panicSentinel, output) + } +} + +// runPanicChild runs a pipeline whose only stage is a `Function` that panics, +// with no panic handler configured. The panic, being unhandled, should crash +// the process before the sleep elapses; if it is swallowed (the regression), +// Run returns and we print SURVIVED so the parent can detect the failure. +func runPanicChild() { + p := pipe.New(pipe.WithStdout(io.Discard)) + p.Add(pipe.Function("boom", func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + panic(panicSentinel) + })) + + err := p.Run(context.Background()) + + // reaching this point at all indicates the panic was swallowed. + time.Sleep(2 * time.Second) + os.Stdout.WriteString("SURVIVED: Run returned err=") + if err != nil { + os.Stdout.WriteString(err.Error()) + } else { + os.Stdout.WriteString("") + } + os.Stdout.WriteString("\n") + os.Exit(0) +} diff --git a/pipe/iocopier.go b/pipe/iocopier.go deleted file mode 100644 index 78a9143..0000000 --- a/pipe/iocopier.go +++ /dev/null @@ -1,62 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "io" - "os" -) - -// ioCopier is a stage that copies its stdin to a specified -// `io.Writer`. It generates no stdout itself. -type ioCopier struct { - w io.WriteCloser - done chan struct{} - err error -} - -func newIOCopier(w io.WriteCloser) *ioCopier { - return &ioCopier{ - w: w, - done: make(chan struct{}), - } -} - -func (s *ioCopier) Name() string { - return "ioCopier" -} - -// This method always returns `nil, nil`. -func (s *ioCopier) Start(_ context.Context, _ Env, r io.ReadCloser) (io.ReadCloser, error) { - go func() { - _, err := io.Copy(s.w, r) - // We don't consider `ErrClosed` an error (FIXME: is this - // correct?): - if err != nil && !errors.Is(err, os.ErrClosed) { - s.err = err - } - if err := r.Close(); err != nil && s.err == nil { - s.err = err - } - if err := s.w.Close(); err != nil && s.err == nil { - s.err = err - } - close(s.done) - }() - - // FIXME: if `s.w.Write()` is blocking (e.g., because there is a - // downstream process that is not reading from the other side), - // there's no way to terminate the copy when the context expires. - // This is not too bad, because the `io.Copy()` call will exit by - // itself when its input is closed. - // - // We could, however, be smarter about exiting more quickly if the - // context expires but `s.w.Write()` is not blocking. - - return nil, nil -} - -func (s *ioCopier) Wait() error { - <-s.done - return s.err -} diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go index 8e91dc1..df362c6 100644 --- a/pipe/memorylimit.go +++ b/pipe/memorylimit.go @@ -11,12 +11,12 @@ import ( const memoryPollInterval = time.Second -// ErrMemoryLimitExceeded is the error that will be used to kill a process, if -// necessary, from MemoryLimit. +// ErrMemoryLimitExceeded is the error that will be used to kill a +// process, if necessary, from MemoryLimit. var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") -// LimitableStage is the superset of Stage that must be implemented by stages -// passed to MemoryLimit and MemoryObserver. +// LimitableStage is the superset of `Stage` that must be implemented +// by stages passed to MemoryLimit and MemoryObserver. type LimitableStage interface { Stage @@ -26,6 +26,11 @@ type LimitableStage interface { // MemoryLimit watches the memory usage of the stage and stops it if it // exceeds the given limit. +// +// If the event handler panics while reporting the over-limit event, the +// stage is still killed. A panic in any other event-handler call (an +// RSS-read error) is recovered via StartOptions.PanicHandler and the +// stage keeps running unmonitored; see StartOptions.PanicHandler. func MemoryLimit(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { limitableStage, ok := stage.(LimitableStage) @@ -73,29 +78,35 @@ func killAtLimit(byteLimit uint64, eventHandler func(e *Event)) memoryWatchFunc if rss < byteLimit { continue } - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - stage.Kill(ErrMemoryLimitExceeded) + func() { + // Guarantee the over-limit stage is killed even if + // the user's event handler panics. + defer stage.Kill(ErrMemoryLimitExceeded) + eventHandler(&Event{ + Command: stage.Name(), + Msg: "stage exceeded allowed memory use", + Err: fmt.Errorf("stage exceeded allowed memory use"), + Context: map[string]interface{}{ + "limit": byteLimit, + "used": rss, + }, + }) + }() return } } } } -// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver into a single -// stage that uses one goroutine instead of two. It watches the memory usage of -// the stage, kills the process if it exceeds byteLimit, and logs peak memory -// usage when the stage exits. +// MemoryLimitWithObserver combines MemoryLimit and MemoryObserver in +// one goroutine. It watches the memory usage of the stage, stops it +// if it exceeds the given limit, and logs the peak memory usage when +// the stage exits. // -// Use this instead of MemoryLimit(MemoryObserver(stage, h), limit, h) to save -// one goroutine per pipeline stage. +// Its event-handler panic behavior matches MemoryLimit: the over-limit +// kill always happens, while a panic in the RSS-error or peak-usage +// handler is recovered via StartOptions.PanicHandler and the stage keeps +// running unmonitored. See StartOptions.PanicHandler. func MemoryLimitWithObserver(stage Stage, byteLimit uint64, eventHandler func(e *Event)) Stage { limitableStage, ok := stage.(LimitableStage) if !ok { @@ -168,16 +179,20 @@ func killAtLimitAndObserve(byteLimit uint64, eventHandler func(e *Event)) memory } if rss >= byteLimit { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]interface{}{ - "limit": byteLimit, - "used": rss, - }, - }) - stage.Kill(ErrMemoryLimitExceeded) + func() { + // Guarantee the over-limit stage is killed even if + // the user's event handler panics. + defer stage.Kill(ErrMemoryLimitExceeded) + eventHandler(&Event{ + Command: stage.Name(), + Msg: "stage exceeded allowed memory use", + Err: fmt.Errorf("stage exceeded allowed memory use"), + Context: map[string]interface{}{ + "limit": byteLimit, + "used": rss, + }, + }) + }() killed = true } } @@ -261,6 +276,7 @@ type memoryWatchStage struct { watch memoryWatchFunc cancel context.CancelFunc wg sync.WaitGroup + watchErr error } type memoryWatchFunc func(context.Context, LimitableStage) @@ -271,27 +287,50 @@ func (m *memoryWatchStage) Name() string { return m.stage.Name() + m.nameSuffix } -func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) { - io, err := m.stage.Start(ctx, env, stdin) - if err != nil { - return nil, err +func (m *memoryWatchStage) Preferences() StagePreferences { + return m.stage.Preferences() +} + +func (m *memoryWatchStage) Start( + ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, opts StartOptions, +) error { + if err := m.stage.Start(ctx, env, stdin, stdout, opts); err != nil { + return err } + m.monitor(ctx, opts.PanicHandler) + + return nil +} + +// monitor starts up a goroutine that monitors the memory of `m`. If +// panicHandler is set, any panic that escapes the user-supplied event handler +// (via m.watch) is recovered. +func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicHandler) { ctx, cancel := context.WithCancel(ctx) m.cancel = cancel m.wg.Add(1) go func() { + defer m.wg.Done() + defer func() { + if p := recover(); p != nil { + if panicHandler == nil { + panic(p) + } + m.watchErr = panicHandler(p) + } + }() m.watch(ctx, m.stage) - m.wg.Done() }() - - return io, nil } func (m *memoryWatchStage) Wait() error { err := m.stage.Wait() m.stopWatching() + if err == nil { + err = m.watchErr // non-nil if panicHandler() returned anything + } return err } diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go new file mode 100644 index 0000000..da9977e --- /dev/null +++ b/pipe/memorylimit_panic_test.go @@ -0,0 +1,172 @@ +package pipe + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "strings" + "testing" + "time" +) + +const memWatchPanicSentinel = "memwatch-panic-sentinel" +const memWatchPanicChildEnv = "GO_PIPE_MEMWATCH_PANIC_CHILD" + +// fakeLimitableStage is a minimal LimitableStage whose Wait returns +// immediately, letting a memoryWatchStage test exercise its watch +// goroutine in isolation. +type fakeLimitableStage struct{} + +func (fakeLimitableStage) Name() string { return "fake" } +func (fakeLimitableStage) Preferences() StagePreferences { return StagePreferences{} } +func (fakeLimitableStage) Start( + context.Context, Env, io.ReadCloser, io.WriteCloser, StartOptions, +) error { + return nil +} +func (fakeLimitableStage) Wait() error { return nil } +func (fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { return 0, nil } +func (fakeLimitableStage) Kill(error) {} + +func panickingWatchStage() *memoryWatchStage { + return &memoryWatchStage{ + stage: fakeLimitableStage{}, + watch: func(context.Context, LimitableStage) { panic(memWatchPanicSentinel) }, + } +} + +// TestMemoryWatchStagePanicWithHandlerSurfaced verifies that a panic +// escaping the memory-watch goroutine (where the user-supplied event +// handler runs) is recovered via the configured panic handler and +// surfaced as the stage's Wait error. +func TestMemoryWatchStagePanicWithHandlerSurfaced(t *testing.T) { + ms := panickingWatchStage() + opts := StartOptions{ + PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, + } + + if err := ms.Start(context.Background(), Env{}, nil, nil, opts); err != nil { + t.Fatalf("Start returned unexpected error: %v", err) + } + + err := ms.Wait() + if err == nil { + t.Fatal("expected Wait to surface the recovered panic, got nil") + } + if !strings.Contains(err.Error(), memWatchPanicSentinel) { + t.Fatalf("expected error to mention %q, got: %v", memWatchPanicSentinel, err) + } +} + +// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that when +// the memory-watch goroutine panics and no panic handler is installed, +// the panic propagates (crashing the process) rather than being +// silently swallowed. Because that would crash the test binary, the +// scenario runs in a re-exec'd subprocess. +func TestMemoryWatchStagePanicWithoutHandlerPropagates(t *testing.T) { + if os.Getenv(memWatchPanicChildEnv) == "1" { + runMemWatchPanicChild() + return + } + + cmd := exec.Command(os.Args[0], "-test.run=^TestMemoryWatchStagePanicWithoutHandlerPropagates$", "-test.v") //nolint:gosec // re-exec of this test binary with constant arguments. + cmd.Env = append(os.Environ(), memWatchPanicChildEnv+"=1") + out, err := cmd.CombinedOutput() + output := string(out) + + if err == nil { + t.Fatalf("expected subprocess to crash from a propagated panic, but it exited 0\noutput:\n%s", output) + } + if strings.Contains(output, "SURVIVED") { + t.Fatalf("panic was swallowed: Wait returned instead of propagating\noutput:\n%s", output) + } + if !strings.Contains(output, "panic:") || !strings.Contains(output, memWatchPanicSentinel) { + t.Fatalf("expected a propagated panic mentioning %q, got:\n%s", memWatchPanicSentinel, output) + } +} + +func runMemWatchPanicChild() { + ms := panickingWatchStage() + + if err := ms.Start(context.Background(), Env{}, nil, nil, StartOptions{}); err != nil { + os.Stdout.WriteString("SURVIVED: Start returned err=" + err.Error() + "\n") + os.Exit(0) + } + + _ = ms.Wait() + + // Reaching this point at all indicates the panic was swallowed. + time.Sleep(2 * time.Second) + os.Stdout.WriteString("SURVIVED: Wait returned\n") + os.Exit(0) +} + +// killTrackingStage is a LimitableStage that reports an over-limit RSS +// and blocks in Wait until it is killed, recording that the kill +// happened. It lets a test assert that the memory limit is enforced. +type killTrackingStage struct { + killed chan struct{} + done chan struct{} +} + +func newKillTrackingStage() *killTrackingStage { + return &killTrackingStage{ + killed: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (*killTrackingStage) Name() string { return "kill-tracking" } +func (*killTrackingStage) Preferences() StagePreferences { return StagePreferences{} } +func (*killTrackingStage) Start( + context.Context, Env, io.ReadCloser, io.WriteCloser, StartOptions, +) error { + return nil +} +func (s *killTrackingStage) Wait() error { <-s.done; return ErrMemoryLimitExceeded } +func (*killTrackingStage) GetRSSAnon(context.Context) (uint64, error) { + return 1 << 30, nil +} + +func (s *killTrackingStage) Kill(error) { + select { + case <-s.killed: + // already killed + default: + close(s.killed) + close(s.done) + } +} + +// TestMemoryLimitKillsEvenIfEventHandlerPanics verifies that an over-limit +// stage is still killed (the limit enforced) even when the user's event +// handler panics and that panic is recovered by the configured handler. +// Without the kill being guaranteed during unwinding, the runaway stage +// would never be killed and Wait would hang. +func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { + stage := newKillTrackingStage() + ms := &memoryWatchStage{ + stage: stage, + watch: killAtLimit(1, func(*Event) { panic(memWatchPanicSentinel) }), + } + opts := StartOptions{ + PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, + } + + if err := ms.Start(context.Background(), Env{}, nil, nil, opts); err != nil { + t.Fatalf("Start returned unexpected error: %v", err) + } + + select { + case <-stage.killed: + // expected: the limit was enforced despite the handler panic. + case <-time.After(5 * time.Second): + t.Fatal("over-limit stage was not killed after the event handler panicked") + } + + if err := ms.Wait(); err != ErrMemoryLimitExceeded { + t.Fatalf("Wait = %v, want %v", err, ErrMemoryLimitExceeded) + } +} diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go index 37d3edd..582421d 100644 --- a/pipe/memorylimit_test.go +++ b/pipe/memorylimit_test.go @@ -8,10 +8,11 @@ import ( "log" "os" "strings" + "syscall" "testing" "time" - "github.com/github/go-pipe/pipe" + "github.com/github/go-pipe/v2/pipe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -112,15 +113,6 @@ func TestMemoryLimitTreeMem(t *testing.T) { require.ErrorContains(t, err, "memory limit exceeded") } -type closeWrapper struct { - io.Writer - close func() error -} - -func (w closeWrapper) Close() error { - return w.close() -} - func TestMemoryLimitWithObserverSimple(t *testing.T) { t.Parallel() msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) @@ -137,6 +129,14 @@ func TestMemoryLimitWithObserverTreeMem(t *testing.T) { require.ErrorContains(t, err, "memory limit exceeded") } +func TestMemoryLimitWithObserverLogsPeakOnKill(t *testing.T) { + t.Parallel() + msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) + assert.Contains(t, msg, "exceeded allowed memory") + assert.Contains(t, msg, "peak memory usage") + require.ErrorContains(t, err, "memory limit exceeded") +} + func TestMemoryLimitWithObserverBelowLimit(t *testing.T) { t.Parallel() rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("less")) @@ -149,16 +149,13 @@ func TestMemoryLimitWithObserverBelowLimitTreeMem(t *testing.T) { require.Greater(t, rss, 400_000_000) } -func TestMemoryLimitWithObserverLogsPeakOnKill(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - // Verify both limit-exceeded AND peak memory are logged (matching - // the behavior of MemoryLimit(MemoryObserver(...))) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "peak memory usage") - require.ErrorContains(t, err, "memory limit exceeded") -} - +// testMemoryLimitWithObserverBelowLimit exercises the observer half of +// `MemoryLimitWithObserver` when the memory limit is never hit: with a +// 100GiB limit, less should never be killed, but the wrapper should +// still poll RSS and emit a "peak memory usage" event when the stage +// exits normally. Mirrors `testMemoryObserver` in structure — we hold +// stdin open across at least one poll interval so RSS samples are +// guaranteed to be taken before the stage is allowed to exit. func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Stage) int { ctx := context.Background() @@ -168,9 +165,8 @@ func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Sta require.NoError(t, err) buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) + logger := log.New(buf, "testMemoryLimitWithObserverBelowLimit", log.Ldate|log.Ltime) - // Use a high limit so it won't be hit — we want to verify the observer part p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) p.Add(pipe.MemoryLimitWithObserver(stage, 100*1024*1024*1024, LogEventHandler(logger))) require.NoError(t, p.Start(ctx)) @@ -182,95 +178,84 @@ func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Sta require.Equal(t, len(bytes), n) } + // Wrapper polls once per second; sleep long enough to guarantee at + // least one sample is taken before the stage is allowed to exit. time.Sleep(2 * time.Second) require.NoError(t, stdinWriter.Close()) require.NoError(t, p.Wait()) - // Verify that peak memory usage was logged (the observer part) output := buf.String() assert.Contains(t, output, "peak memory usage") return maxBytes(output) } -func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { +func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { ctx := context.Background() - stdinReader, stdinWriter := io.Pipe() - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) require.NoError(t, err) - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) + logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger))) + p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) + p.Add( + pipe.Function( + "write-to-less", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + // Write some nonsense data to less. + var bytes [1_000_000]byte + for i := 0; i < mbs; i++ { + _, err := stdout.Write(bytes[:]) + if err != nil { + assert.ErrorIs(t, err, syscall.EPIPE) + return nil + } + } + + return nil + }, + ), + pipe.MemoryLimit(stage, limit, LogEventHandler(logger)), + ) require.NoError(t, p.Start(ctx)) - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) err = p.Wait() return buf.String(), err } -func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { +func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { ctx := context.Background() - stdinReader, stdinWriter := io.Pipe() - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) require.NoError(t, err) - // io.Pipe doesn't know if anything is listening on the other end, so once - // our process is expectedly killed then we'll end up blocked trying to - // write to it. To workaround this, make sure we close the pipe reader when - // we've detected that the process has exited (i.e. when stdout has been - // closed). This will cause our write to immediately fail with this error. - closedErr := fmt.Errorf("stdout was closed") - stdout := closeWrapper{ - Writer: devNull, - close: func() error { - require.NoError(t, stdinReader.CloseWithError(closedErr)) - return nil - }, - } - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) + logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdoutCloser(stdout)) - p.Add(pipe.MemoryLimit(stage, limit, LogEventHandler(logger))) + p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) + p.Add( + pipe.Function( + "write-to-less", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + var bytes [1_000_000]byte + for i := 0; i < mbs; i++ { + _, err := stdout.Write(bytes[:]) + if err != nil { + assert.ErrorIs(t, err, syscall.EPIPE) + return nil + } + } + return nil + }, + ), + pipe.MemoryLimitWithObserver(stage, limit, LogEventHandler(logger)), + ) require.NoError(t, p.Start(ctx)) - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdinWriter.Write(bytes[:]) - if err != nil { - require.ErrorIs(t, err, closedErr) - } - } - - require.NoError(t, stdinWriter.Close()) err = p.Wait() return buf.String(), err diff --git a/pipe/nop_closer.go b/pipe/nop_closer.go index d435d0a..739497f 100644 --- a/pipe/nop_closer.go +++ b/pipe/nop_closer.go @@ -6,29 +6,73 @@ package pipe import "io" -// newNopCloser returns a ReadCloser with a no-op Close method wrapping -// the provided io.Reader r. -// If r implements io.WriterTo, the returned io.ReadCloser will implement io.WriterTo -// by forwarding calls to r. -func newNopCloser(r io.Reader) io.ReadCloser { - if _, ok := r.(io.WriterTo); ok { - return nopCloserWriterTo{r} - } - return nopCloser{r} +// newReaderNopCloser returns a ReadCloser with a no-op Close method, wrapping +// the provided io.Reader `r`. The wrapper deliberately hides `r`'s concrete +// type; use [UnwrapReader] to recover the underlying reader before use in +// situations where eg. use of io.WriterTo is important for performance. +func newReaderNopCloser(r io.Reader) io.ReadCloser { + return readerNopCloser{r} } -type nopCloser struct { +// readerNopCloser is a ReadCloser that wraps a provided `io.Reader`, +// but whose `Close()` method does nothing. It should be unwrapped (via +// [UnwrapReader]) before use. +type readerNopCloser struct { io.Reader } -func (nopCloser) Close() error { return nil } +func (readerNopCloser) Close() error { + return nil +} -type nopCloserWriterTo struct { - io.Reader +// writerNopCloser is a WriteCloser that wraps a provided `io.Writer`, but +// whose `Close()` method does nothing. It should be unwrapped (via +// [UnwrapWriter]) before use where fast-path interfaces such as +// `io.ReaderFrom` are relevant. +type writerNopCloser struct { + io.Writer } -func (nopCloserWriterTo) Close() error { return nil } +func (w writerNopCloser) Close() error { + return nil +} -func (c nopCloserWriterTo) WriteTo(w io.Writer) (n int64, err error) { - return c.Reader.(io.WriterTo).WriteTo(w) +// UnwrapReader returns the underlying [io.Reader] that go-pipe wrapped around +// the `stdin` it passes to a [Stage]'s `Start` method. [Stage] implementations +// should call this before reading from `stdin` so it sees the caller's +// concrete reader type, because that allows for use of fast-path interfaces +// (e.g. `io.WriterTo`) and identity (e.g. `*os.File`, for direct fd passing). +// +// If `r` is not a go-pipe wrapper (including nil), it is returned unchanged. +func UnwrapReader(r io.Reader) io.Reader { + if w, ok := r.(readerNopCloser); ok { + return w.Reader + } + return r +} + +// UnwrapWriter returns the underlying [io.Writer] that go-pipe wrapped around +// the `stdout` it passes to a [Stage]'s `Start` method. [Stage] +// implementations should call this before writing to `stdout` (see above). +// +// If `w` is not a go-pipe wrapper (including nil), it is returned unchanged. +func UnwrapWriter(w io.Writer) io.Writer { + if n, ok := w.(writerNopCloser); ok { + return n.Writer + } + return w +} + +// unwrapNopCloser unwraps the object if it is some kind of nop +// closer, and returns the underlying object. This function is used +// only for testing. +func unwrapNopCloser(obj any) (any, bool) { + switch obj := obj.(type) { + case readerNopCloser: + return obj.Reader, true + case writerNopCloser: + return obj.Writer, true + default: + return nil, false + } } diff --git a/pipe/nop_closer_test.go b/pipe/nop_closer_test.go new file mode 100644 index 0000000..13f9935 --- /dev/null +++ b/pipe/nop_closer_test.go @@ -0,0 +1,70 @@ +package pipe + +import ( + "bytes" + "context" + "io" + "testing" +) + +func TestUnwrapReader(t *testing.T) { + src := bytes.NewReader([]byte("payload")) + + if got := UnwrapReader(newReaderNopCloser(src)); got != io.Reader(src) { + t.Errorf("UnwrapReader(wrapped) = %T %p, want %p", got, got, src) + } + + // A non-wrapped reader passes through unchanged. + if got := UnwrapReader(src); got != io.Reader(src) { + t.Errorf("UnwrapReader(plain) = %T %p, want %p", got, got, src) + } + + if got := UnwrapReader(nil); got != nil { + t.Errorf("UnwrapReader(nil) = %v, want nil", got) + } +} + +func TestUnwrapWriter(t *testing.T) { + dst := &bytes.Buffer{} + + if got := UnwrapWriter(writerNopCloser{dst}); got != io.Writer(dst) { + t.Errorf("UnwrapWriter(wrapped) = %T %p, want %p", got, got, dst) + } + + // A non-wrapped writer passes through unchanged. + if got := UnwrapWriter(dst); got != io.Writer(dst) { + t.Errorf("UnwrapWriter(plain) = %T %p, want %p", got, got, dst) + } + + if got := UnwrapWriter(nil); got != nil { + t.Errorf("UnwrapWriter(nil) = %v, want nil", got) + } +} + +// TestGoStageUnwrapsWriterToStdin verifies that a Function stage +// receives its stdin already unwrapped to the caller's concrete type, +// so fast-path interfaces such as io.WriterTo survive. This guards +// against the regression where goStage only unwrapped one of the +// internal nop-closer wrapper types. +func TestGoStageUnwrapsWriterToStdin(t *testing.T) { + src := bytes.NewReader([]byte("hello")) + + var got io.Reader + p := New(WithStdin(src), WithStdout(io.Discard)) + p.Add(Function("capture", func(_ context.Context, _ Env, stdin io.Reader, _ io.Writer) error { + got = stdin + _, err := io.Copy(io.Discard, stdin) + return err + })) + + if err := p.Run(context.Background()); err != nil { + t.Fatalf("Run: %v", err) + } + + if got != io.Reader(src) { + t.Fatalf("StageFunc stdin = %T %p, want unwrapped *bytes.Reader %p", got, got, src) + } + if _, ok := got.(io.WriterTo); !ok { + t.Fatalf("unwrapped stdin %T does not expose io.WriterTo fast path", got) + } +} diff --git a/pipe/panic.go b/pipe/panic.go deleted file mode 100644 index e0ca600..0000000 --- a/pipe/panic.go +++ /dev/null @@ -1,12 +0,0 @@ -package pipe - -// StagePanicHandlerAware is an interface that Stages can implement to receive -// a panic handler from the pipeline. This is particularly useful for stages -// that execute work in a separate goroutine and need to manage panics occurring -// within that goroutine. -type StagePanicHandlerAware interface { - SetPanicHandler(StagePanicHandler) -} - -// StagePanicHandler is a function that handles panics in a pipeline's stages. -type StagePanicHandler func(p any) error diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go new file mode 100644 index 0000000..4ac206c --- /dev/null +++ b/pipe/pipe_matching_test.go @@ -0,0 +1,380 @@ +package pipe_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + + "github.com/github/go-pipe/v2/pipe" + "github.com/stretchr/testify/assert" +) + +// Tests that `Pipeline.Start()` uses the correct types of pipes in +// various situations. +// +// The type of pipe to use depends on both the source and the consumer +// of the data, including the overall pipeline's stdin and stdout. So +// there are a lot of possibilities to consider. + +// Additional values used for the expected types of stdin/stdout: +const ( + IOPreferenceUndefinedNopCloser pipe.IOPreference = iota + 100 + IOPreferenceFileNopCloser + + // expectNil is a test-only expectation token meaning that the + // stage should be passed a `nil` stdin / stdout (which happens at + // the beginning / end of a pipeline when no overall stdin / stdout + // is configured). It is not a real `IOPreference`. + expectNil +) + +func file(t *testing.T) *os.File { + f, err := os.Open(os.DevNull) + assert.NoError(t, err) + return f +} + +func readCloser() io.ReadCloser { + r, w := io.Pipe() + w.Close() + return r +} + +func writeCloser() io.WriteCloser { + r, w := io.Pipe() + r.Close() + return w +} + +func newPipeSniffingStage( + stdinPreference, stdinExpectation pipe.IOPreference, + stdoutPreference, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return &pipeSniffingStage{ + prefs: pipe.StagePreferences{ + StdinPreference: stdinPreference, + StdoutPreference: stdoutPreference, + }, + expect: pipe.StagePreferences{ + StdinPreference: stdinExpectation, + StdoutPreference: stdoutExpectation, + }, + } +} + +func newPipeSniffingFunc( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.IOPreferenceUndefined, stdinExpectation, + pipe.IOPreferenceUndefined, stdoutExpectation, + ) +} + +func newPipeSniffingCmd( + stdinExpectation, stdoutExpectation pipe.IOPreference, +) *pipeSniffingStage { + return newPipeSniffingStage( + pipe.IOPreferenceFile, stdinExpectation, + pipe.IOPreferenceFile, stdoutExpectation, + ) +} + +type pipeSniffingStage struct { + prefs pipe.StagePreferences + expect pipe.StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + +func (*pipeSniffingStage) Name() string { + return "pipe-sniffer" +} + +func (s *pipeSniffingStage) Preferences() pipe.StagePreferences { + return s.prefs +} + +func (s *pipeSniffingStage) Start( + _ context.Context, _ pipe.Env, stdin io.ReadCloser, stdout io.WriteCloser, _ pipe.StartOptions, +) error { + s.stdin = stdin + if stdin != nil { + _ = stdin.Close() + } + s.stdout = stdout + if stdout != nil { + _ = stdout.Close() + } + return nil +} + +func (s *pipeSniffingStage) check(t *testing.T, i int) { + t.Helper() + + checkStdinExpectation(t, i, s.expect.StdinPreference, s.stdin) + checkStdoutExpectation(t, i, s.expect.StdoutPreference, s.stdout) +} + +func (s *pipeSniffingStage) Wait() error { + return nil +} + +var _ pipe.Stage = (*pipeSniffingStage)(nil) + +func ioTypeString(f any) string { + if f == nil { + return "nil" + } + if f, ok := pipe.UnwrapNopCloser(f); ok { + return fmt.Sprintf("nopCloser(%s)", ioTypeString(f)) + } + switch f := f.(type) { + case *os.File: + return "*os.File" + case io.Reader: + return "other" + case io.Writer: + return "other" + default: + return fmt.Sprintf("%T", f) + } +} + +func prefString(pref pipe.IOPreference) string { + switch pref { + case pipe.IOPreferenceUndefined: + return "other" + case pipe.IOPreferenceFile: + return "*os.File" + case expectNil: + return "nil" + case IOPreferenceUndefinedNopCloser: + return "nopCloser(other)" + case IOPreferenceFileNopCloser: + return "nopCloser(*os.File)" + default: + panic(fmt.Sprintf("invalid IOPreference: %d", pref)) + } +} + +func checkStdinExpectation(t *testing.T, i int, pref pipe.IOPreference, stdin io.ReadCloser) { + t.Helper() + + ioType := ioTypeString(stdin) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdin: expected %s, got %s (%T)", i, expType, ioType, stdin, + ) +} + +type WriterNopCloser interface { + NopCloserWriter() io.Writer +} + +func checkStdoutExpectation(t *testing.T, i int, pref pipe.IOPreference, stdout io.WriteCloser) { + t.Helper() + + ioType := ioTypeString(stdout) + expType := prefString(pref) + assert.Equalf( + t, expType, ioType, + "stage %d stdout: expected %s, got %s (%T)", i, expType, ioType, stdout, + ) +} + +type checker interface { + check(t *testing.T, i int) +} + +func TestPipeTypes(t *testing.T) { + ctx := context.Background() + + t.Parallel() + + for _, tc := range []struct { + name string + opts []pipe.Option + stages []pipe.Stage + stdin io.Reader + stdout io.Writer + }{ + { + name: "func", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, expectNil), + }, + }, + { + name: "func-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceFileNopCloser, expectNil), + }, + }, + { + name: "func-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "func-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, pipe.IOPreferenceFile), + }, + }, + { + name: "func-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, expectNil), + }, + }, + { + name: "cmd-file-stdin", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceFileNopCloser, expectNil), + }, + }, + { + name: "cmd-file-stdout", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd-file-stdout-closer", + opts: []pipe.Option{ + pipe.WithStdoutCloser(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, pipe.IOPreferenceFile), + }, + }, + { + name: "cmd-file-stdin-other-stdout-closer-other", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func-func", + opts: []pipe.Option{ + pipe.WithStdin(file(t)), + pipe.WithStdoutCloser(writeCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(IOPreferenceFileNopCloser, pipe.IOPreferenceUndefined), + newPipeSniffingFunc(pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined), + }, + }, + { + name: "func-cmd", + opts: []pipe.Option{ + pipe.WithStdout(file(t)), + }, + stages: []pipe.Stage{ + newPipeSniffingFunc(expectNil, pipe.IOPreferenceFile), + newPipeSniffingCmd(pipe.IOPreferenceFile, IOPreferenceFileNopCloser), + }, + }, + { + name: "cmd-func", + opts: []pipe.Option{ + pipe.WithStdin(readCloser()), + }, + stages: []pipe.Stage{ + newPipeSniffingCmd(IOPreferenceUndefinedNopCloser, pipe.IOPreferenceFile), + newPipeSniffingFunc(pipe.IOPreferenceFile, expectNil), + }, + }, + { + name: "cmd-cmd", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingCmd(expectNil, pipe.IOPreferenceFile), + newPipeSniffingCmd(pipe.IOPreferenceFile, expectNil), + }, + }, + { + name: "hybrid1", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.IOPreferenceUndefined, expectNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, expectNil, + ), + }, + }, + { + name: "hybrid2", + opts: []pipe.Option{}, + stages: []pipe.Stage{ + newPipeSniffingStage( + pipe.IOPreferenceUndefined, expectNil, + pipe.IOPreferenceUndefined, pipe.IOPreferenceFile, + ), + newPipeSniffingStage( + pipe.IOPreferenceFile, pipe.IOPreferenceFile, + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + ), + newPipeSniffingStage( + pipe.IOPreferenceUndefined, pipe.IOPreferenceUndefined, + pipe.IOPreferenceUndefined, expectNil, + ), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := pipe.New(tc.opts...) + p.Add(tc.stages...) + assert.NoError(t, p.Run(ctx)) + for i, s := range tc.stages { + s.(checker).check(t, i) + } + }) + } +} diff --git a/pipe/pipeline.go b/pipe/pipeline.go index 8bc3a37..2ec44b5 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "os" "sync/atomic" ) @@ -54,7 +55,7 @@ type ContextValuesFunc func(context.Context) []EnvVar type Pipeline struct { env Env - stdin io.Reader + stdin io.ReadCloser stdout io.WriteCloser stages []Stage cancel func() @@ -70,14 +71,6 @@ type Pipeline struct { var emptyEventHandler = func(_ *Event) {} -type nopWriteCloser struct { - io.Writer -} - -func (w nopWriteCloser) Close() error { - return nil -} - type NewPipeFn func(opts ...Option) *Pipeline // NewPipeline returns a Pipeline struct with all of the `options` @@ -107,14 +100,58 @@ func WithDir(dir string) Option { // WithStdin assigns stdin to the first command in the pipeline. func WithStdin(stdin io.Reader) Option { return func(p *Pipeline) { - p.stdin = stdin + // We don't want the first stage to close `stdin`, and it is + // not even necessarily an `io.ReadCloser`. So wrap it in a + // fake `io.ReadCloser` whose `Close()` method doesn't do + // anything. + // + // We could use `io.NopCloser()` for this purpose, but that + // would have a subtle problem. If the first stage is a + // `Command`, then it wants to set the `exec.Cmd`'s `Stdin` to + // an `io.Reader` corresponding to `p.stdin`. If `Cmd.Stdin` + // is an `*os.File`, then `exec.Cmd` will pass the file + // descriptor to the subcommand directly; there is no need to + // create a pipe and copy the data into the input side of the + // pipe. But if `p.stdin` is not an `*os.File`, then this + // optimization is prevented. And even worse, it also has the + // side effect that the goroutine that copies from `Cmd.Stdin` + // into the pipe doesn't terminate until that fd is closed by + // the writing side. + // + // That isn't always what we want. Consider, for example, the + // following snippet, where the subcommand's stdin is set to + // the stdin of the enclosing Go program, but wrapped with + // `io.NopCloser`: + // + // cmd := exec.Command("ls") + // cmd.Stdin = io.NopCloser(os.Stdin) + // cmd.Stdout = os.Stdout + // cmd.Stderr = os.Stderr + // cmd.Run() + // + // In this case, we don't want the Go program to wait for + // `os.Stdin` to close (because `ls` isn't even trying to read + // from its stdin). But it does: `exec.Cmd` doesn't recognize + // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and + // copies the data itself, and this goroutine doesn't + // terminate until `cmd.Stdin` (i.e., the Go program's own + // stdin) is closed. But if, for example, the Go program is + // run from an interactive shell session, that might never + // happen, in which case the program will fail to terminate, + // even after `ls` exits. + // + // So instead, in this special case, we wrap `stdin` in our + // own `nopCloser`, which behaves like `io.NopCloser`, except + // that `pipe.CommandStage` knows how to unwrap it before + // passing it to `exec.Cmd`. + p.stdin = newReaderNopCloser(stdin) } } // WithStdout assigns stdout to the last command in the pipeline. func WithStdout(stdout io.Writer) Option { return func(p *Pipeline) { - p.stdout = nopWriteCloser{stdout} + p.stdout = writerNopCloser{stdout} } } @@ -184,11 +221,6 @@ func WithEventHandler(handler func(e *Event)) Option { // WithStagePanicHandler sets a panic handler for the stages within a pipeline. // When a pipeline stage panics, the provided handler will be invoked, allowing // the client to handle the panic in whatever way they see fit. -// -// Note: -// - Only the Function stage supports this functionality. -// - The client is responsible for deciding whether to recover from the panic or panicking again. -// - If a panic handler is not set, the panic will be propagated normally. func WithStagePanicHandler(ph StagePanicHandler) Option { return func(p *Pipeline) { p.panicHandler = ph @@ -220,6 +252,12 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } +type stageStarter struct { + prefs StagePreferences + stdin io.ReadCloser + stdout io.WriteCloser +} + // Start starts the commands in the pipeline. If `Start()` exits // without an error, `Wait()` must also be called, to allow all // resources to be freed. @@ -231,93 +269,116 @@ func (p *Pipeline) Start(ctx context.Context) error { atomic.StoreUint32(&p.started, 1) ctx, p.cancel = context.WithCancel(ctx) - var nextStdin io.ReadCloser - if p.stdin != nil { - // We don't want the first stage to actually close this, and - // `p.stdin` is not even necessarily an `io.ReadCloser`. So - // wrap it in a fake `io.ReadCloser` whose `Close()` method - // doesn't do anything. - // - // We could use `io.NopCloser()` for this purpose, but it has - // a subtle problem. If the first stage is a `Command`, then - // it wants to set the `exec.Cmd`'s `Stdin` to an `io.Reader` - // corresponding to `p.stdin`. If `Cmd.Stdin` is an - // `*os.File`, then the file descriptor can be passed to the - // subcommand directly; there is no need for this process to - // create a pipe and copy the data into the input side of the - // pipe. But if `p.stdin` is not an `*os.File`, then this - // optimization is prevented. And even worse, it also has the - // side effect that the goroutine that copies from `Cmd.Stdin` - // into the pipe doesn't terminate until that fd is closed by - // the writing side. - // - // That isn't always what we want. Consider, for example, the - // following snippet, where the subcommand's stdin is set to - // the stdin of the enclosing Go program, but wrapped with - // `io.NopCloser`: - // - // cmd := exec.Command("ls") - // cmd.Stdin = io.NopCloser(os.Stdin) - // cmd.Stdout = os.Stdout - // cmd.Stderr = os.Stderr - // cmd.Run() - // - // In this case, we don't want the Go program to wait for - // `os.Stdin` to close (because `ls` isn't even trying to read - // from its stdin). But it does: `exec.Cmd` doesn't recognize - // that `Cmd.Stdin` is an `*os.File`, so it sets up a pipe and - // copies the data itself, and this goroutine doesn't - // terminate until `cmd.Stdin` (i.e., the Go program's own - // stdin) is closed. But if, for example, the Go program is - // run from an interactive shell session, that might never - // happen, in which case the program will fail to terminate, - // even after `ls` exits. - // - // So instead, in this special case, we wrap `p.stdin` in our - // own `nopCloser`, which behaves like `io.NopCloser`, except - // that `pipe.CommandStage` knows how to unwrap it before - // passing it to `exec.Cmd`. - nextStdin = newNopCloser(p.stdin) + if len(p.stages) == 0 { + if p.stdout == nil { + // No stages and no destination: there is nothing to do + // and nowhere to put `p.stdin` even if it was set. + return nil + } + // No stages but a destination was configured: synthesize an + // identity-copy stage so that `WithStdin()` is drained into + // `WithStdout()`/`WithStdoutCloser()` and the destination + // closer (if any) is invoked. + p.stages = append(p.stages, Function( + "identity", + func(_ context.Context, _ Env, stdin io.Reader, stdout io.Writer) error { + if stdin == nil { + return nil + } + _, err := io.Copy(stdout, stdin) + return err + }, + )) } + // We need to decide how to start the stages, especially what + // pipes to use to connect adjacent stages (`os.Pipe()` vs. + // `io.Pipe()`) based on the two stages' preferences. + stageStarters := make([]stageStarter, len(p.stages), len(p.stages)+1) + + // Collect information about each stage's type and preferences: for i, s := range p.stages { - if phs, ok := s.(StagePanicHandlerAware); ok && p.panicHandler != nil { - phs.SetPanicHandler(p.panicHandler) + stageStarters[i].prefs = s.Preferences() + } + + if p.stdin != nil { + // Arrange for the input of the 0th stage to come from + // `p.stdin`: + stageStarters[0].stdin = p.stdin + } + + if p.stdout != nil { + i := len(p.stages) - 1 + ss := &stageStarters[i] + ss.stdout = p.stdout + } + + // Clean up any processes and pipes that have been created. `i` is + // the index of the stage that failed to start (whose output pipe + // has already been cleaned up if necessary). + abort := func(i int, err error) error { + // Close the pipe that the previous stage was writing to. + // That should cause it to exit even if it's not minding + // its context. + if stageStarters[i].stdin != nil { + _ = stageStarters[i].stdin.Close() } - var err error - stdout, err := s.Start(ctx, p.env, nextStdin) - if err != nil { - // Close the pipe that the previous stage was writing to. - // That should cause it to exit even if it's not minding - // its context. - if nextStdin != nil { - _ = nextStdin.Close() - } + // Kill and wait for any stages that have been started + // already to finish: + p.cancel() + for _, s := range p.stages[:i] { + _ = s.Wait() + } + p.eventHandler(&Event{ + Command: p.stages[i].Name(), + Msg: "failed to start pipeline stage", + Err: err, + }) + return fmt.Errorf( + "starting pipeline stage %q: %w", p.stages[i].Name(), err, + ) + } - // Kill and wait for any stages that have been started - // already to finish: - p.cancel() - for _, s := range p.stages[:i] { - _ = s.Wait() + // Loop over all but the last stage, starting them. By the time we + // get to a stage, its stdin will have already been determined, + // but we still need to figure out its stdout and set the stdin + // that will be used for the subsequent stage. + for i, s := range p.stages[:len(p.stages)-1] { + ss := &stageStarters[i] + nextSS := &stageStarters[i+1] + + // We need to generate a pipe pair for this stage to use + // to communicate with its successor: + if ss.prefs.StdoutPreference == IOPreferenceFile || + nextSS.prefs.StdinPreference == IOPreferenceFile { + // Use an OS-level pipe for the communication: + var err error + nextSS.stdin, ss.stdout, err = os.Pipe() + if err != nil { + return abort(i, err) } - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, - }) - return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err) + } else { + nextSS.stdin, ss.stdout = io.Pipe() + } + if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, StartOptions{PanicHandler: p.panicHandler}); err != nil { + nextSS.stdin.Close() + ss.stdout.Close() + return abort(i, err) } - nextStdin = stdout } - // If the pipeline was configured with a `stdout`, add a synthetic - // stage to copy the last stage's stdout to that writer: - if p.stdout != nil { - c := newIOCopier(p.stdout) - p.stages = append(p.stages, c) - // `ioCopier.Start()` never fails: - _, _ = c.Start(ctx, p.env, nextStdin) + // The last stage needs special handling, because its stdout + // doesn't need to flow into another stage (it's already set in + // `ss.stdout` if it's needed). + { + i := len(p.stages) - 1 + s := p.stages[i] + ss := &stageStarters[i] + + if err := s.Start(ctx, p.env, ss.stdin, ss.stdout, StartOptions{PanicHandler: p.panicHandler}); err != nil { + return abort(i, err) + } } return nil @@ -325,7 +386,7 @@ func (p *Pipeline) Start(ctx context.Context) error { func (p *Pipeline) Output(ctx context.Context) ([]byte, error) { var buf bytes.Buffer - p.stdout = nopWriteCloser{&buf} + p.stdout = writerNopCloser{&buf} err := p.Run(ctx) return buf.Bytes(), err } diff --git a/pipe/pipeline_test.go b/pipe/pipeline_test.go index bebc931..f8712b4 100644 --- a/pipe/pipeline_test.go +++ b/pipe/pipeline_test.go @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/github/go-pipe/pipe" + "github.com/github/go-pipe/v2/pipe" ) // Check whether this package's test suite leaks any goroutines: @@ -26,15 +26,66 @@ func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } -func TestPipelineFirstStageFailsToStart(t *testing.T) { +func TestPipelineEmpty(t *testing.T) { + t.Parallel() + p := pipe.New() + assert.NoError(t, p.Run(context.Background())) +} + +func TestPipelineEmptyWithStdinAndStdout(t *testing.T) { t.Parallel() ctx := context.Background() + stdout := &bytes.Buffer{} + p := pipe.New( + pipe.WithStdin(strings.NewReader("hello world\n")), + pipe.WithStdout(stdout), + ) + if assert.NoError(t, p.Run(ctx)) { + assert.Equal(t, "hello world\n", stdout.String()) + } +} - dir := t.TempDir() +func TestPipelineEmptyOutput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New(pipe.WithStdin(strings.NewReader("hello world\n"))) + out, err := p.Output(ctx) + if assert.NoError(t, err) { + assert.Equal(t, "hello world\n", string(out)) + } +} + +func TestPipelineEmptyWithStdoutCloser(t *testing.T) { + t.Parallel() + ctx := context.Background() + stdout := &closeTrackingWriter{} + p := pipe.New( + pipe.WithStdin(strings.NewReader("hello world\n")), + pipe.WithStdoutCloser(stdout), + ) + if assert.NoError(t, p.Run(ctx)) { + assert.Equal(t, "hello world\n", stdout.buf.String()) + assert.True(t, stdout.closed, "WithStdoutCloser destination should be closed") + } +} + +type closeTrackingWriter struct { + buf bytes.Buffer + closed bool +} +func (w *closeTrackingWriter) Write(p []byte) (int, error) { return w.buf.Write(p) } +func (w *closeTrackingWriter) Close() error { + w.closed = true + return nil +} + +func TestPipelineFirstStageFailsToStart(t *testing.T) { + t.Parallel() + ctx := context.Background() startErr := errors.New("foo") - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( ErrorStartingStage{startErr}, ErrorStartingStage{errors.New("this error should never happen")}, @@ -45,12 +96,9 @@ func TestPipelineFirstStageFailsToStart(t *testing.T) { func TestPipelineSecondStageFailsToStart(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - startErr := errors.New("foo") - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( seqFunction(20000), ErrorStartingStage{startErr}, @@ -61,10 +109,7 @@ func TestPipelineSecondStageFailsToStart(t *testing.T) { func TestPipelineSingleCommandOutput(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -75,19 +120,16 @@ func TestPipelineSingleCommandOutput(t *testing.T) { func TestPipelineSingleCommandWithStdout(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("echo", "hello world")) if assert.NoError(t, p.Run(ctx)) { assert.Equal(t, "hello world\n", stdout.String()) } } -func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { +func TestPipelineStdinOSPipeThatIsNeverClosed(t *testing.T) { t.Parallel() // Make sure that the subprocess terminates on its own, as opposed @@ -105,7 +147,10 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { var stdout bytes.Buffer - p := pipe.New(pipe.WithStdin(r), pipe.WithStdout(&stdout)) + p := pipe.New( + pipe.WithStdin(r), + pipe.WithStdout(&stdout), + ) // Note that this command doesn't read from its stdin, so it will // terminate regardless of whether `w` gets closed: p.Add(pipe.Command("true")) @@ -115,7 +160,7 @@ func TestPipelineStdinFileThatIsNeverClosed(t *testing.T) { assert.NoError(t, p.Run(ctx)) } -func TestPipelineStdinThatIsNeverClosed(t *testing.T) { +func TestPipelineIOPipeStdinThatIsNeverClosed(t *testing.T) { t.Skip("test not run because it currently deadlocks") t.Parallel() @@ -158,10 +203,7 @@ func TestPipelineStdinThatIsNeverClosed(t *testing.T) { func TestNontrivialPipeline(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("sed", "s/hello/goodbye/"), @@ -172,7 +214,33 @@ func TestNontrivialPipeline(t *testing.T) { } } -func TestPipelineReadFromSlowly(t *testing.T) { +func TestOSPipePipelineReadFromSlowly(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + r, w, err := os.Pipe() + require.NoError(t, err) + + var buf []byte + readErr := make(chan error, 1) + + go func() { + time.Sleep(200 * time.Millisecond) + var err error + buf, err = io.ReadAll(r) + readErr <- err + }() + + p := pipe.New(pipe.WithStdoutCloser(w)) + p.Add(pipe.Command("echo", "hello world")) + assert.NoError(t, p.Run(ctx)) + + assert.NoError(t, <-readErr) + assert.Equal(t, "hello world\n", string(buf)) +} + +func TestIOPipePipelineReadFromSlowly(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -210,9 +278,6 @@ func TestPipelineReadFromSlowly2(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - - dir := t.TempDir() - r, w := io.Pipe() var buf []byte @@ -236,7 +301,7 @@ func TestPipelineReadFromSlowly2(t *testing.T) { } }() - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(w)) + p := pipe.New(pipe.WithStdout(w)) p.Add(pipe.Command("seq", "100")) assert.NoError(t, p.Run(ctx)) @@ -252,10 +317,7 @@ func TestPipelineReadFromSlowly2(t *testing.T) { func TestPipelineTwoCommandsPiping(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Command("echo", "hello world")) assert.Panics(t, func() { p.Add(pipe.Command("")) }) out, err := p.Output(ctx) @@ -282,10 +344,7 @@ func TestPipelineDir(t *testing.T) { func TestPipelineExit(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("false"), pipe.Command("true"), @@ -316,11 +375,10 @@ func TestPipelineInterrupted(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) @@ -339,11 +397,10 @@ func TestPipelineCanceled(t *testing.T) { } t.Parallel() - dir := t.TempDir() stdout := &bytes.Buffer{} - p := pipe.New(pipe.WithDir(dir), pipe.WithStdout(stdout)) + p := pipe.New(pipe.WithStdout(stdout)) p.Add(pipe.Command("sleep", "10")) ctx, cancel := context.WithCancel(context.Background()) @@ -367,9 +424,8 @@ func TestLittleEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("sh", "-c", "sleep 1; echo foo"), pipe.Command("true"), @@ -391,9 +447,8 @@ func TestBigEPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("seq", "100000"), pipe.Command("true"), @@ -415,9 +470,8 @@ func TestIgnoredSIGPIPE(t *testing.T) { } t.Parallel() - dir := t.TempDir() - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE), pipe.Command("echo", "foo"), @@ -433,11 +487,8 @@ func TestIgnoredSIGPIPE(t *testing.T) { func TestFunction(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - t.Run("successful function", func(t *testing.T) { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Print("hello world"), pipe.Function( @@ -463,10 +514,8 @@ func TestFunction(t *testing.T) { t.Run("panic with handler", func(t *testing.T) { p := pipe.New( - pipe.WithDir(dir), pipe.WithStagePanicHandler(func(p any) error { - err := fmt.Errorf("panic handled: %v", p) - return err + return fmt.Errorf("panic handled: %v", p) }), ) p.Add( @@ -483,15 +532,36 @@ func TestFunction(t *testing.T) { assert.ErrorContains(t, err, "panic handled") assert.Empty(t, out) }) + + t.Run("panic with handler through IgnoreError", func(t *testing.T) { + p := pipe.New( + pipe.WithStagePanicHandler(func(p any) error { + return fmt.Errorf("panic handled: %v", p) + }), + ) + p.Add( + pipe.Print("hello world"), + pipe.IgnoreError( + pipe.Function( + "farewell", + func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { + panic("this is a panic") + }, + ), + func(_ error) bool { return false }, + ), + ) + + out, err := p.Output(ctx) + assert.ErrorContains(t, err, "panic handled") + assert.Empty(t, out) + }) } func TestPipelineWithFunction(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "-n", "hello world"), pipe.Function( @@ -524,10 +594,23 @@ func (s ErrorStartingStage) Name() string { return "errorStartingStage" } +func (s ErrorStartingStage) Preferences() pipe.StagePreferences { + return pipe.StagePreferences{ + StdinPreference: pipe.IOPreferenceUndefined, + StdoutPreference: pipe.IOPreferenceUndefined, + } +} + func (s ErrorStartingStage) Start( - _ context.Context, _ pipe.Env, _ io.ReadCloser, -) (io.ReadCloser, error) { - return io.NopCloser(&bytes.Buffer{}), s.err + _ context.Context, _ pipe.Env, stdin io.ReadCloser, stdout io.WriteCloser, _ pipe.StartOptions, +) error { + if stdin != nil { + _ = stdin.Close() + } + if stdout != nil { + _ = stdout.Close() + } + return s.err } func (s ErrorStartingStage) Wait() error { @@ -552,10 +635,7 @@ func seqFunction(n int) pipe.Stage { func TestPipelineWithLinewiseFunction(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() // Print the numbers from 1 to 20 (generated from scratch): p.Add( seqFunction(20), @@ -694,10 +774,7 @@ func TestScannerFinishEarly(t *testing.T) { func TestPrintln(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Println("Look Ma, no hands!")) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -708,10 +785,7 @@ func TestPrintln(t *testing.T) { func TestPrintf(t *testing.T) { t.Parallel() ctx := context.Background() - - dir := t.TempDir() - - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add(pipe.Printf("Strangely recursive: %T", p)) out, err := p.Output(ctx) if assert.NoError(t, err) { @@ -719,6 +793,26 @@ func TestPrintf(t *testing.T) { } } +func TestPrintlnNoOutput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() + p.Add(pipe.Println("Look Ma, no output!")) + assert.NoError(t, p.Run(ctx)) +} + +func TestFunctionNoInput(t *testing.T) { + t.Parallel() + ctx := context.Background() + p := pipe.New() + p.Add(pipe.Function("read-all", func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error { + n, err := io.Copy(io.Discard, stdin) + assert.Equal(t, int64(0), n) + return err + })) + assert.NoError(t, p.Run(ctx)) +} + func TestErrors(t *testing.T) { t.Parallel() ctx := context.Background() @@ -903,11 +997,8 @@ func TestErrors(t *testing.T) { func BenchmarkSingleProgram(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("true"), ) @@ -917,11 +1008,8 @@ func BenchmarkSingleProgram(b *testing.B) { func BenchmarkTenPrograms(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Command("cat"), @@ -943,16 +1031,13 @@ func BenchmarkTenPrograms(b *testing.B) { func BenchmarkTenFunctions(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Println("hello world"), pipe.Function("copy1", cp), @@ -974,16 +1059,13 @@ func BenchmarkTenFunctions(b *testing.B) { func BenchmarkTenMixedStages(b *testing.B) { ctx := context.Background() - - dir := b.TempDir() - cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { _, err := io.Copy(stdout, stdin) return err } for i := 0; i < b.N; i++ { - p := pipe.New(pipe.WithDir(dir)) + p := pipe.New() p.Add( pipe.Command("echo", "hello world"), pipe.Function("copy1", cp), @@ -1003,6 +1085,97 @@ func BenchmarkTenMixedStages(b *testing.B) { } } +func BenchmarkMoreDataUnbuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + for i := 1; i <= 100000; i++ { + fmt.Fprintln(stdout, i) + } + return nil + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(_ context.Context, _ pipe.Env, _ []byte, _ *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 100000, count) + } + } +} + +func BenchmarkMoreDataBuffered(b *testing.B) { + ctx := context.Background() + + cp := func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error { + _, err := io.Copy(stdout, stdin) + return err + } + + for i := 0; i < b.N; i++ { + count := 0 + p := pipe.New() + p.Add( + pipe.Function( + "seq", + func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { + out := bufio.NewWriter(stdout) + for i := 1; i <= 1000000; i++ { + fmt.Fprintln(out, i) + } + return out.Flush() + }, + ), + pipe.Command("cat"), + pipe.Function("copy2", cp), + pipe.Command("cat"), + pipe.Function("copy3", cp), + pipe.Command("cat"), + pipe.Function("copy4", cp), + pipe.Command("cat"), + pipe.Function("copy5", cp), + pipe.Command("cat"), + pipe.LinewiseFunction( + "count", + func(_ context.Context, _ pipe.Env, _ []byte, _ *bufio.Writer) error { + count++ + return nil + }, + ), + ) + err := p.Run(ctx) + if assert.NoError(b, err) { + assert.EqualValues(b, 1000000, count) + } + } +} + func genErr(err error) pipe.StageFunc { return func(_ context.Context, _ pipe.Env, _ io.Reader, _ io.Writer) error { return err diff --git a/pipe/scanner.go b/pipe/scanner.go index b56b58c..5ec16e8 100644 --- a/pipe/scanner.go +++ b/pipe/scanner.go @@ -56,11 +56,7 @@ func ScannerFunction( return err } } - if err := scanner.Err(); err != nil { - return err - } - - return nil + return scanner.Err() // `p.AddFunction()` arranges for `stdout` to be closed. }, ) diff --git a/pipe/stage.go b/pipe/stage.go index f3d74d9..611c9dd 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -5,30 +5,152 @@ import ( "io" ) -// Stage is an element of a `Pipeline`. +// Stage is an element of a `Pipeline`. It reads from standard input +// and writes to standard output. +// +// Who closes stdin and stdout? +// +// A `Stage` as a whole needs to be responsible for closing its end of +// stdin and stdout (assuming that `Start()` returns successfully). +// Its doing so tells the previous/next stage that it is done +// reading/writing data, which can affect their behavior. Therefore, +// it should close each one as soon as it is done with it. If the +// caller wants to suppress the closing of stdin/stdout, it can always +// wrap the corresponding argument in a "nopCloser". +// +// How this should be done depends on whether stdin/stdout are of type +// `*os.File`. +// +// If a stage is an external command, then the subprocess ultimately +// needs its own copies of `*os.File` file descriptors for its stdin +// and stdout. The external command will "always" [1] close those when +// it exits. +// +// If the stage is an external command and one of the arguments is an +// `*os.File`, then it can set the corresponding field of `exec.Cmd` +// to that argument directly. This has the result that `exec.Cmd` +// duplicates that file descriptor and passes the dup to the +// subprocess. Therefore, the stage must close its copy of that +// argument as soon as the external command has started, because the +// external command will keep its own copy open as long as necessary +// (and no longer!). It should use roughly the following sequence: +// +// cmd.Stdin = f // Similarly for stdout +// cmd.Start(…) +// f.Close() // close our copy +// cmd.Wait() +// +// If the stage is an external command and one of its arguments is not +// an `*os.File`, then `exec.Cmd` will take care of creating an +// `os.Pipe()`, copying from the provided argument in/out of the pipe, +// and eventually closing both ends of the pipe. The stage must close +// the argument itself, but only _after_ the external command has +// finished, like so: +// +// cmd.Stdin = r // Similarly for stdout +// cmd.Start(…) +// cmd.Wait() +// r.Close() +// +// If the stage is a Go function, then it holds the only copy of +// stdin/stdout, so it must wait until the function is done before +// closing them (regardless of their underlying type, like so: +// +// go func() { +// f(…, stdin, stdout) +// stdin.Close() +// stdout.Close() +// }() +// +// From the point of view of the pipeline as a whole, if stdin is +// provided by the user (`WithStdin()`), then we don't want to close +// it at all, whether it's an `*os.File` or not. For this reason, +// stdin has to be wrapped using a `readerNopCloser` before being +// passed into the first stage. For efficiency reasons, the first +// stage should ideally unwrap its stdin argument (using +// [UnwrapReader]) before actually using it. If the wrapped value is +// an `*os.File` and the stage is a command stage, then unwrapping is +// also important to get the right semantics. +// +// For stdout, it depends on whether the user supplied it using +// `WithStdout()` or `WithStdoutCloser()`. If the former, then the +// considerations are the same as for stdin. +// +// [1] It's theoretically possible for a command to pass the open file +// descriptor to another, longer-lived process, in which case the +// file descriptor wouldn't necessarily get closed when the +// command finishes. But that's ill-behaved in a command that is +// being used in a pipeline, so we'll ignore that possibility. + type Stage interface { // Name returns the name of the stage. Name() string + // Preferences() returns this stage's preferences regarding how it + // should be run. + Preferences() StagePreferences + // Start starts the stage in the background, in the environment - // described by `env`, and using `stdin` as input. (`stdin` should - // be set to `nil` if the stage is to receive no input, which - // might be the case for the first stage in a pipeline.) It - // returns an `io.ReadCloser` from which the stage's output can be - // read (or `nil` if it generates no output, which should only be - // the case for the last stage in a pipeline). It is the stages' - // responsibility to close `stdin` (if it is not nil) when it has - // read all of the input that it needs, and to close the write end - // of its output reader when it is done, as that is generally how - // the subsequent stage knows that it has received all of its - // input and can finish its work, too. + // described by `env`, using `stdin` to provide its input and + // `stdout` to collect its output. (`stdin`/`stdout` might be set + // to `nil` if the stage is to receive no input, which might be + // the case for the first/last stage in a pipeline.) See the + // `Stage` type comment for more information about responsibility + // for closing stdin and stdout. // // If `Start()` returns without an error, `Wait()` must also be // called, to allow all resources to be freed. - Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) + Start(ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser, opts StartOptions) error // Wait waits for the stage to be done, either because it has // finished or because it has been killed due to the expiration of // the context passed to `Start()`. Wait() error } + +// StartOptions carries run-scoped options passed to `Stage.Start`. +// It is a struct (rather than positional parameters) so that future +// options can be added without breaking the `Stage` interface. +type StartOptions struct { + // PanicHandler, if non-nil, is invoked to recover a panic that escapes + // user code that a stage runs in a library-spawned goroutine (a + // Function stage's StageFunc, or a memory-limit stage's event + // handler), converting it into an error. Stage types that don't run + // user code in a library-spawned goroutine ignore it. + PanicHandler StagePanicHandler +} + +// StagePanicHandler is a function that handles panics in a pipeline's stages. +type StagePanicHandler func(p any) error + +// StagePreferences is the way that a `Stage` indicates its +// preferences about how it is run. This is used within +// `pipe.Pipeline` to decide when to use `os.Pipe()` vs. `io.Pipe()` +// for creating the pipes between stages. +type StagePreferences struct { + StdinPreference IOPreference + StdoutPreference IOPreference +} + +// IOPreference describes what type of stdin / stdout a stage would +// prefer. +// +// External commands prefer `*os.File`s (such as those produced by +// `os.Pipe()`) as their stdin and stdout, because those can be passed +// directly by the external process without any extra copying and also +// simplify the semantics around process termination. Go function +// stages are typically happy with any `io.ReadCloser` (such as one +// produced by `io.Pipe()`), which can be more efficient because +// traffic through an `io.Pipe()` happens entirely in userspace. +type IOPreference int + +const ( + // IOPreferenceUndefined indicates that the stage doesn't care + // what form the specified stdin / stdout takes (i.e., any old + // `io.ReadCloser` / `io.WriteCloser` is just fine). + IOPreferenceUndefined IOPreference = iota + + // IOPreferenceFile indicates that the stage would prefer for the + // specified stdin / stdout to be an `*os.File`, to avoid copying. + IOPreferenceFile +)