From bfd929f4449eba51ee17d026ac707ae641ee9f4b Mon Sep 17 00:00:00 2001 From: Daniel Moran Date: Wed, 21 Jul 2021 17:03:41 -0400 Subject: [PATCH] fix: detect non-interactive stdio and use "normal" IO operations (#204) --- clients/apply/apply.go | 4 +- internal/mock/stdio.gen.go | 12 +- pkg/stdio/console.go | 118 ------------- pkg/stdio/interactive.go | 100 +++++++++++ pkg/stdio/noninteractive.go | 73 ++++++++ pkg/stdio/stdio.go | 41 ++++- pkg/stdio/stdio_internal_test.go | 284 +++++++++++++++++++++++++++++++ 7 files changed, 503 insertions(+), 129 deletions(-) delete mode 100644 pkg/stdio/console.go create mode 100644 pkg/stdio/interactive.go create mode 100644 pkg/stdio/noninteractive.go create mode 100644 pkg/stdio/stdio_internal_test.go diff --git a/clients/apply/apply.go b/clients/apply/apply.go index d5676bf..cf15181 100644 --- a/clients/apply/apply.go +++ b/clients/apply/apply.go @@ -98,7 +98,7 @@ func (c Client) Apply(ctx context.Context, params *Params) error { return fmt.Errorf("failed to check template impact: %w", err) } - if c.StdIO.InputIsInteractive() && (len(res.Summary.MissingEnvRefs) > 0 || len(res.Summary.MissingSecrets) > 0) { + if c.StdIO.IsInteractive() && (len(res.Summary.MissingEnvRefs) > 0 || len(res.Summary.MissingSecrets) > 0) { for _, e := range res.Summary.MissingEnvRefs { val, err := c.StdIO.GetStringInput(fmt.Sprintf("Please provide environment reference value for key %s", e), "") if err != nil { @@ -129,7 +129,7 @@ func (c Client) Apply(ctx context.Context, params *Params) error { } } - if c.StdIO.InputIsInteractive() && !params.Force { + if c.StdIO.IsInteractive() && !params.Force { if confirmed := c.StdIO.GetConfirm("Confirm application of the above resources"); !confirmed { return errors.New("aborted application of template") } diff --git a/internal/mock/stdio.gen.go b/internal/mock/stdio.gen.go index 9b825c9..46f9507 100644 --- a/internal/mock/stdio.gen.go +++ b/internal/mock/stdio.gen.go @@ -120,18 +120,18 @@ func (mr *MockStdIOMockRecorder) GetStringInput(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStringInput", reflect.TypeOf((*MockStdIO)(nil).GetStringInput), arg0, arg1) } -// InputIsInteractive mocks base method. -func (m *MockStdIO) InputIsInteractive() bool { +// IsInteractive mocks base method. +func (m *MockStdIO) IsInteractive() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InputIsInteractive") + ret := m.ctrl.Call(m, "IsInteractive") ret0, _ := ret[0].(bool) return ret0 } -// InputIsInteractive indicates an expected call of InputIsInteractive. -func (mr *MockStdIOMockRecorder) InputIsInteractive() *gomock.Call { +// IsInteractive indicates an expected call of IsInteractive. +func (mr *MockStdIOMockRecorder) IsInteractive() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InputIsInteractive", reflect.TypeOf((*MockStdIO)(nil).InputIsInteractive)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsInteractive", reflect.TypeOf((*MockStdIO)(nil).IsInteractive)) } // Write mocks base method. diff --git a/pkg/stdio/console.go b/pkg/stdio/console.go deleted file mode 100644 index b7120ee..0000000 --- a/pkg/stdio/console.go +++ /dev/null @@ -1,118 +0,0 @@ -package stdio - -import ( - "errors" - "io" - "os" - - "github.com/AlecAivazis/survey/v2" - "github.com/AlecAivazis/survey/v2/terminal" - "github.com/mattn/go-isatty" -) - -// terminalStdio interacts with the user via an interactive terminal. -type terminalStdio struct { - Stdin terminal.FileReader - Stdout terminal.FileWriter - Stderr io.Writer -} - -// TerminalStdio interacts with users over stdin/stdout/stderr. -var TerminalStdio StdIO = &terminalStdio{ - Stdin: os.Stdin, - Stdout: os.Stdout, - Stderr: os.Stderr, -} - -// Write prints some bytes to stdout. -func (t *terminalStdio) Write(p []byte) (int, error) { - return t.Stdout.Write(p) -} - -// WriteErr prints some bytes to stderr. -func (t *terminalStdio) WriteErr(p []byte) (int, error) { - return t.Stderr.Write(p) -} - -type bannerTemplateData struct { - Message string -} - -var bannerTemplate = `{{color "cyan+hb"}}> {{ .Message }}{{color "reset"}} -` - -// Banner displays informational text to the user. -func (t *terminalStdio) Banner(message string) error { - r := survey.Renderer{} - r.WithStdio(terminal.Stdio{In: t.Stdin, Out: t.Stdout, Err: t.Stderr}) - return r.Render(bannerTemplate, &bannerTemplateData{Message: message}) -} - -// Error displays an error message to the user. -func (t *terminalStdio) Error(message string) error { - r := survey.Renderer{} - r.WithStdio(terminal.Stdio{In: t.Stdin, Out: t.Stdout, Err: t.Stderr}) - cfg := survey.PromptConfig{Icons: survey.IconSet{Error: survey.Icon{Text: "X", Format: "red"}}} - return r.Error(&cfg, errors.New(message)) -} - -func (t *terminalStdio) InputIsInteractive() bool { - return isatty.IsTerminal(t.Stdin.Fd()) || isatty.IsCygwinTerminal(t.Stdin.Fd()) -} - -// GetStringInput prompts the user for arbitrary input. -func (t *terminalStdio) GetStringInput(prompt, defaultValue string) (input string, err error) { - question := survey.Input{ - Message: prompt, - Default: defaultValue, - } - err = survey.AskOne(&question, &input, - survey.WithStdio(t.Stdin, t.Stdout, t.Stderr), - survey.WithValidator(survey.Required)) - return -} - -// GetSecret prompts the user for a secret. -func (t *terminalStdio) GetSecret(prompt string, minLen int) (password string, err error) { - question := survey.Password{Message: prompt} - opts := []survey.AskOpt{survey.WithStdio(t.Stdin, t.Stdout, t.Stderr)} - if minLen > 0 { - opts = append(opts, survey.WithValidator(survey.MinLength(minLen))) - } - err = survey.AskOne(&question, &password, opts...) - question.NewCursor().HorizontalAbsolute(0) - return -} - -// GetPassword prompts the user for a secret twice, and inputs must match. -// Uses stdio.MinPasswordLen as the minimum input length -func (t *terminalStdio) GetPassword(prompt string) (string, error) { - for { - pass1, err := t.GetSecret(prompt, MinPasswordLen) - if err != nil { - return "", err - } - // Don't bother with the length check the 2nd time, since we check equality to pass1. - pass2, err := t.GetSecret(prompt+" again", 0) - if err != nil { - return "", err - } - if pass1 == pass2 { - return pass1, nil - } - if err := t.Error("Passwords do not match"); err != nil { - return "", err - } - } -} - -// GetConfirm asks the user for a y/n answer to a prompt. -func (t *terminalStdio) GetConfirm(prompt string) (answer bool) { - question := survey.Confirm{ - Message: prompt, - } - if err := survey.AskOne(&question, &answer, survey.WithStdio(t.Stdin, t.Stdout, t.Stderr)); err != nil { - answer = false - } - return -} diff --git a/pkg/stdio/interactive.go b/pkg/stdio/interactive.go new file mode 100644 index 0000000..76172fe --- /dev/null +++ b/pkg/stdio/interactive.go @@ -0,0 +1,100 @@ +package stdio + +import ( + "errors" + "io" + + "github.com/AlecAivazis/survey/v2" + "github.com/AlecAivazis/survey/v2/terminal" +) + +// interactiveStdio interacts with the user via an interactive terminal. +type interactiveStdio struct { + in terminal.FileReader + out terminal.FileWriter + err io.Writer +} + +func (t *interactiveStdio) Write(p []byte) (int, error) { + return t.out.Write(p) +} + +func (t *interactiveStdio) WriteErr(p []byte) (int, error) { + return t.err.Write(p) +} + +type bannerTemplateData struct { + Message string +} + +var bannerTemplate = `{{color "cyan+hb"}}> {{ .Message }}{{color "reset"}} +` + +func (t *interactiveStdio) Banner(message string) error { + r := survey.Renderer{} + r.WithStdio(terminal.Stdio{In: t.in, Out: t.out, Err: t.err}) + return r.Render(bannerTemplate, &bannerTemplateData{Message: message}) +} + +func (t *interactiveStdio) Error(message string) error { + r := survey.Renderer{} + r.WithStdio(terminal.Stdio{In: t.in, Out: t.out, Err: t.err}) + cfg := survey.PromptConfig{Icons: survey.IconSet{Error: survey.Icon{Text: "X", Format: "red"}}} + return r.Error(&cfg, errors.New(message)) +} + +func (t *interactiveStdio) IsInteractive() bool { + return true +} + +func (t *interactiveStdio) GetStringInput(prompt, defaultValue string) (input string, err error) { + question := survey.Input{ + Message: prompt, + Default: defaultValue, + } + err = survey.AskOne(&question, &input, + survey.WithStdio(t.in, t.out, t.err), + survey.WithValidator(survey.Required)) + return +} + +func (t *interactiveStdio) GetSecret(prompt string, minLen int) (password string, err error) { + question := survey.Password{Message: prompt} + opts := []survey.AskOpt{survey.WithStdio(t.in, t.out, t.err)} + if minLen > 0 { + opts = append(opts, survey.WithValidator(survey.MinLength(minLen))) + } + err = survey.AskOne(&question, &password, opts...) + question.NewCursor().HorizontalAbsolute(0) + return +} + +func (t *interactiveStdio) GetPassword(prompt string) (string, error) { + for { + pass1, err := t.GetSecret(prompt, MinPasswordLen) + if err != nil { + return "", err + } + // Don't bother with the length check the 2nd time, since we check equality to pass1. + pass2, err := t.GetSecret(prompt+" again", 0) + if err != nil { + return "", err + } + if pass1 == pass2 { + return pass1, nil + } + if err := t.Error("Passwords do not match"); err != nil { + return "", err + } + } +} + +func (t *interactiveStdio) GetConfirm(prompt string) (answer bool) { + question := survey.Confirm{ + Message: prompt, + } + if err := survey.AskOne(&question, &answer, survey.WithStdio(t.in, t.out, t.err)); err != nil { + answer = false + } + return +} diff --git a/pkg/stdio/noninteractive.go b/pkg/stdio/noninteractive.go new file mode 100644 index 0000000..e2e7d9c --- /dev/null +++ b/pkg/stdio/noninteractive.go @@ -0,0 +1,73 @@ +package stdio + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +// noninteractiveStdio interacts with stdin/stdout/stderr as files, with no user interaction. +type noninteractiveStdio struct { + in *bufio.Scanner + out io.Writer + err io.Writer +} + +func (t *noninteractiveStdio) Write(p []byte) (int, error) { + return t.out.Write(p) +} + +func (t *noninteractiveStdio) WriteErr(p []byte) (int, error) { + return t.err.Write(p) +} + +func (t *noninteractiveStdio) Banner(message string) error { + _, err := fmt.Fprintf(t.out, "> %s\n", message) + return err +} + +func (t *noninteractiveStdio) Error(message string) error { + _, err := fmt.Fprintf(t.err, "X %s\n", message) + return err +} + +func (t *noninteractiveStdio) IsInteractive() bool { + return false +} + +func (t *noninteractiveStdio) GetStringInput(prompt, defaultValue string) (string, error) { + if inLine := t.readStdinLine(); inLine != "" { + return inLine, nil + } + if defaultValue != "" { + return defaultValue, nil + } + return "", fmt.Errorf("couldn't get input for prompt %q: no data on stdin", prompt) +} + +func (t *noninteractiveStdio) GetSecret(prompt string, minLen int) (string, error) { + inLine := t.readStdinLine() + if len(inLine) >= minLen { + return inLine, nil + } else if minLen > 0 { + return "", fmt.Errorf("value for prompt %q is too short: min length is %d", prompt, minLen) + } + return "", nil +} + +func (t *noninteractiveStdio) GetPassword(prompt string) (string, error) { + return t.GetSecret(prompt, MinPasswordLen) +} + +func (t *noninteractiveStdio) GetConfirm(prompt string) (answer bool) { + return strings.HasPrefix(t.readStdinLine(), "y") +} + +// readStdinLine returns the first line of text on stdin, or empty string if stdin is at EOF. +func (t *noninteractiveStdio) readStdinLine() string { + if !t.in.Scan() { + return "" + } + return t.in.Text() +} diff --git a/pkg/stdio/stdio.go b/pkg/stdio/stdio.go index c937d9d..801129f 100644 --- a/pkg/stdio/stdio.go +++ b/pkg/stdio/stdio.go @@ -1,17 +1,52 @@ package stdio -import "io" +import ( + "bufio" + "io" + "os" + + "github.com/AlecAivazis/survey/v2/terminal" + "github.com/mattn/go-isatty" +) const MinPasswordLen = 8 type StdIO interface { - io.Writer + // Write prints some bytes to stdout. + Write(p []byte) (n int, err error) + // WriteErr prints some bytes to stderr. WriteErr(p []byte) (n int, err error) + // Banner displays informational text to the user. Banner(message string) error + // Error displays an error message to the user. Error(message string) error - InputIsInteractive() bool + // IsInteractive signals whether interactive I/O is supported. + IsInteractive() bool + // GetStringInput prompts the user for arbitrary input. GetStringInput(prompt, defaultValue string) (string, error) + // GetSecret prompts the user for a secret. GetSecret(prompt string, minLen int) (string, error) + // GetPassword prompts the user for a secret twice, and inputs must match. + // Uses stdio.MinPasswordLen as the minimum input length GetPassword(prompt string) (string, error) + // GetConfirm asks the user for a y/n answer to a prompt. GetConfirm(prompt string) bool } + +func newTerminalStdio(in terminal.FileReader, out terminal.FileWriter, err io.Writer) StdIO { + interactiveIn := isatty.IsTerminal(in.Fd()) || isatty.IsCygwinTerminal(in.Fd()) + interactiveOut := isatty.IsTerminal(out.Fd()) || isatty.IsCygwinTerminal(out.Fd()) + + if interactiveIn && interactiveOut { + return &interactiveStdio{in: in, out: out, err: err} + } + + return &noninteractiveStdio{ + in: bufio.NewScanner(in), + out: out, + err: err, + } +} + +// TerminalStdio interacts with users over stdin/stdout/stderr. +var TerminalStdio = newTerminalStdio(os.Stdin, os.Stdout, os.Stderr) diff --git a/pkg/stdio/stdio_internal_test.go b/pkg/stdio/stdio_internal_test.go new file mode 100644 index 0000000..c84068a --- /dev/null +++ b/pkg/stdio/stdio_internal_test.go @@ -0,0 +1,284 @@ +package stdio + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTerminalStdIO_Banner(t *testing.T) { + t.Parallel() + + tmp, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(tmp) + + stdin, err := os.Create(filepath.Join(tmp, "stdin")) + require.NoError(t, err) + defer stdin.Close() + stdout, err := os.Create(filepath.Join(tmp, "stdout")) + require.NoError(t, err) + defer stdout.Close() + stderr, err := os.Create(filepath.Join(tmp, "stderr")) + require.NoError(t, err) + defer stderr.Close() + + io := newTerminalStdio(stdin, stdout, stderr) + require.NoError(t, io.Banner("Hello world!")) + outBytes, err := os.ReadFile(stdout.Name()) + require.NoError(t, err) + require.Equal(t, "> Hello world!\n", string(outBytes)) +} + +func TestTerminalStdIO_Error(t *testing.T) { + t.Parallel() + + tmp, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(tmp) + + stdin, err := os.Create(filepath.Join(tmp, "stdin")) + require.NoError(t, err) + defer stdin.Close() + stdout, err := os.Create(filepath.Join(tmp, "stdout")) + require.NoError(t, err) + defer stdout.Close() + stderr, err := os.Create(filepath.Join(tmp, "stderr")) + require.NoError(t, err) + defer stderr.Close() + + io := newTerminalStdio(stdin, stdout, stderr) + require.NoError(t, io.Error("Oh no")) + errBytes, err := os.ReadFile(stderr.Name()) + require.NoError(t, err) + require.Equal(t, "X Oh no\n", string(errBytes)) +} + +func TestTerminalStdIO_GetStringInput(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + defVal string + lines []string + }{ + { + name: "empty, no default", + }, + { + name: "empty with default", + defVal: "foo", + }, + { + name: "one line", + lines: []string{"foo"}, + }, + { + name: "multi-line", + lines: []string{"foo", "bar"}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tmp, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(tmp) + + stdin, err := os.Create(filepath.Join(tmp, "stdin")) + require.NoError(t, err) + defer stdin.Close() + for _, l := range tc.lines { + _, err := stdin.WriteString(fmt.Sprintln(l)) + require.NoError(t, err) + } + _, err = stdin.Seek(0, 0) + require.NoError(t, err) + + stdout, err := os.Create(filepath.Join(tmp, "stdout")) + require.NoError(t, err) + defer stdout.Close() + stderr, err := os.Create(filepath.Join(tmp, "stderr")) + require.NoError(t, err) + defer stderr.Close() + + io := newTerminalStdio(stdin, stdout, stderr) + + if len(tc.lines) == 0 { + val, err := io.GetStringInput("my prompt", tc.defVal) + if tc.defVal != "" { + require.NoError(t, err) + require.Equal(t, tc.defVal, val) + + } else { + require.Error(t, err) + } + return + } + + for _, l := range tc.lines { + val, err := io.GetStringInput("my prompt", tc.defVal) + require.NoError(t, err) + require.Equal(t, l, val) + } + }) + } +} + +func TestTerminalStdIO_GetSecret(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + minLength int + lines []string + }{ + { + name: "empty, no min length", + }, + { + name: "empty with min length", + minLength: 1, + }, + { + name: "non-empty, too short", + minLength: 3, + lines: []string{"oh"}, + }, + { + name: "one line", + lines: []string{"foo"}, + }, + { + name: "multi-line", + minLength: 3, + lines: []string{"foo", "bar"}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tmp, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(tmp) + + stdin, err := os.Create(filepath.Join(tmp, "stdin")) + require.NoError(t, err) + defer stdin.Close() + for _, l := range tc.lines { + _, err := stdin.WriteString(fmt.Sprintln(l)) + require.NoError(t, err) + } + _, err = stdin.Seek(0, 0) + require.NoError(t, err) + + stdout, err := os.Create(filepath.Join(tmp, "stdout")) + require.NoError(t, err) + defer stdout.Close() + stderr, err := os.Create(filepath.Join(tmp, "stderr")) + require.NoError(t, err) + defer stderr.Close() + + io := newTerminalStdio(stdin, stdout, stderr) + + if len(tc.lines) == 0 { + val, err := io.GetSecret("my prompt", tc.minLength) + if tc.minLength == 0 { + require.NoError(t, err) + require.Empty(t, val) + } else { + require.Error(t, err) + } + return + } + + for _, l := range tc.lines { + val, err := io.GetSecret("my prompt", tc.minLength) + if len(l) < tc.minLength { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, l, val) + } + } + }) + } +} + +func TestTerminalStdIO_GetConfirm(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + answer string + expected bool + }{ + { + name: "empty answer", + expected: false, + }, + { + name: "short affirmative", + answer: "y", + expected: true, + }, + { + name: "short negative", + answer: "n", + }, + { + name: "long affirmative", + answer: "yes", + expected: true, + }, + { + name: "long negative", + answer: "no", + expected: false, + }, + { + name: "nonsense answer", + answer: "I dunno", + expected: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tmp, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(tmp) + + stdin, err := os.Create(filepath.Join(tmp, "stdin")) + require.NoError(t, err) + defer stdin.Close() + _, err = stdin.WriteString(fmt.Sprintln(tc.answer)) + require.NoError(t, err) + _, err = stdin.Seek(0, 0) + require.NoError(t, err) + + stdout, err := os.Create(filepath.Join(tmp, "stdout")) + require.NoError(t, err) + defer stdout.Close() + stderr, err := os.Create(filepath.Join(tmp, "stderr")) + require.NoError(t, err) + defer stderr.Close() + + io := newTerminalStdio(stdin, stdout, stderr) + confirmed := io.GetConfirm("?") + require.Equal(t, tc.expected, confirmed) + }) + } +}