fix: detect non-interactive stdio and use "normal" IO operations (#204)

This commit is contained in:
Daniel Moran 2021-07-21 17:03:41 -04:00 committed by GitHub
parent a111d83b5a
commit bfd929f444
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 503 additions and 129 deletions

View File

@ -98,7 +98,7 @@ func (c Client) Apply(ctx context.Context, params *Params) error {
return fmt.Errorf("failed to check template impact: %w", err)
}
if c.StdIO.InputIsInteractive() && (len(res.Summary.MissingEnvRefs) > 0 || len(res.Summary.MissingSecrets) > 0) {
if c.StdIO.IsInteractive() && (len(res.Summary.MissingEnvRefs) > 0 || len(res.Summary.MissingSecrets) > 0) {
for _, e := range res.Summary.MissingEnvRefs {
val, err := c.StdIO.GetStringInput(fmt.Sprintf("Please provide environment reference value for key %s", e), "")
if err != nil {
@ -129,7 +129,7 @@ func (c Client) Apply(ctx context.Context, params *Params) error {
}
}
if c.StdIO.InputIsInteractive() && !params.Force {
if c.StdIO.IsInteractive() && !params.Force {
if confirmed := c.StdIO.GetConfirm("Confirm application of the above resources"); !confirmed {
return errors.New("aborted application of template")
}

View File

@ -120,18 +120,18 @@ func (mr *MockStdIOMockRecorder) GetStringInput(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStringInput", reflect.TypeOf((*MockStdIO)(nil).GetStringInput), arg0, arg1)
}
// InputIsInteractive mocks base method.
func (m *MockStdIO) InputIsInteractive() bool {
// IsInteractive mocks base method.
func (m *MockStdIO) IsInteractive() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InputIsInteractive")
ret := m.ctrl.Call(m, "IsInteractive")
ret0, _ := ret[0].(bool)
return ret0
}
// InputIsInteractive indicates an expected call of InputIsInteractive.
func (mr *MockStdIOMockRecorder) InputIsInteractive() *gomock.Call {
// IsInteractive indicates an expected call of IsInteractive.
func (mr *MockStdIOMockRecorder) IsInteractive() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InputIsInteractive", reflect.TypeOf((*MockStdIO)(nil).InputIsInteractive))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsInteractive", reflect.TypeOf((*MockStdIO)(nil).IsInteractive))
}
// Write mocks base method.

View File

@ -1,118 +0,0 @@
package stdio
import (
"errors"
"io"
"os"
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/mattn/go-isatty"
)
// terminalStdio interacts with the user via an interactive terminal.
type terminalStdio struct {
Stdin terminal.FileReader
Stdout terminal.FileWriter
Stderr io.Writer
}
// TerminalStdio interacts with users over stdin/stdout/stderr.
var TerminalStdio StdIO = &terminalStdio{
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
}
// Write prints some bytes to stdout.
func (t *terminalStdio) Write(p []byte) (int, error) {
return t.Stdout.Write(p)
}
// WriteErr prints some bytes to stderr.
func (t *terminalStdio) WriteErr(p []byte) (int, error) {
return t.Stderr.Write(p)
}
type bannerTemplateData struct {
Message string
}
var bannerTemplate = `{{color "cyan+hb"}}> {{ .Message }}{{color "reset"}}
`
// Banner displays informational text to the user.
func (t *terminalStdio) Banner(message string) error {
r := survey.Renderer{}
r.WithStdio(terminal.Stdio{In: t.Stdin, Out: t.Stdout, Err: t.Stderr})
return r.Render(bannerTemplate, &bannerTemplateData{Message: message})
}
// Error displays an error message to the user.
func (t *terminalStdio) Error(message string) error {
r := survey.Renderer{}
r.WithStdio(terminal.Stdio{In: t.Stdin, Out: t.Stdout, Err: t.Stderr})
cfg := survey.PromptConfig{Icons: survey.IconSet{Error: survey.Icon{Text: "X", Format: "red"}}}
return r.Error(&cfg, errors.New(message))
}
func (t *terminalStdio) InputIsInteractive() bool {
return isatty.IsTerminal(t.Stdin.Fd()) || isatty.IsCygwinTerminal(t.Stdin.Fd())
}
// GetStringInput prompts the user for arbitrary input.
func (t *terminalStdio) GetStringInput(prompt, defaultValue string) (input string, err error) {
question := survey.Input{
Message: prompt,
Default: defaultValue,
}
err = survey.AskOne(&question, &input,
survey.WithStdio(t.Stdin, t.Stdout, t.Stderr),
survey.WithValidator(survey.Required))
return
}
// GetSecret prompts the user for a secret.
func (t *terminalStdio) GetSecret(prompt string, minLen int) (password string, err error) {
question := survey.Password{Message: prompt}
opts := []survey.AskOpt{survey.WithStdio(t.Stdin, t.Stdout, t.Stderr)}
if minLen > 0 {
opts = append(opts, survey.WithValidator(survey.MinLength(minLen)))
}
err = survey.AskOne(&question, &password, opts...)
question.NewCursor().HorizontalAbsolute(0)
return
}
// GetPassword prompts the user for a secret twice, and inputs must match.
// Uses stdio.MinPasswordLen as the minimum input length
func (t *terminalStdio) GetPassword(prompt string) (string, error) {
for {
pass1, err := t.GetSecret(prompt, MinPasswordLen)
if err != nil {
return "", err
}
// Don't bother with the length check the 2nd time, since we check equality to pass1.
pass2, err := t.GetSecret(prompt+" again", 0)
if err != nil {
return "", err
}
if pass1 == pass2 {
return pass1, nil
}
if err := t.Error("Passwords do not match"); err != nil {
return "", err
}
}
}
// GetConfirm asks the user for a y/n answer to a prompt.
func (t *terminalStdio) GetConfirm(prompt string) (answer bool) {
question := survey.Confirm{
Message: prompt,
}
if err := survey.AskOne(&question, &answer, survey.WithStdio(t.Stdin, t.Stdout, t.Stderr)); err != nil {
answer = false
}
return
}

100
pkg/stdio/interactive.go Normal file
View File

@ -0,0 +1,100 @@
package stdio
import (
"errors"
"io"
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
)
// interactiveStdio interacts with the user via an interactive terminal.
type interactiveStdio struct {
in terminal.FileReader
out terminal.FileWriter
err io.Writer
}
func (t *interactiveStdio) Write(p []byte) (int, error) {
return t.out.Write(p)
}
func (t *interactiveStdio) WriteErr(p []byte) (int, error) {
return t.err.Write(p)
}
type bannerTemplateData struct {
Message string
}
var bannerTemplate = `{{color "cyan+hb"}}> {{ .Message }}{{color "reset"}}
`
func (t *interactiveStdio) Banner(message string) error {
r := survey.Renderer{}
r.WithStdio(terminal.Stdio{In: t.in, Out: t.out, Err: t.err})
return r.Render(bannerTemplate, &bannerTemplateData{Message: message})
}
func (t *interactiveStdio) Error(message string) error {
r := survey.Renderer{}
r.WithStdio(terminal.Stdio{In: t.in, Out: t.out, Err: t.err})
cfg := survey.PromptConfig{Icons: survey.IconSet{Error: survey.Icon{Text: "X", Format: "red"}}}
return r.Error(&cfg, errors.New(message))
}
func (t *interactiveStdio) IsInteractive() bool {
return true
}
func (t *interactiveStdio) GetStringInput(prompt, defaultValue string) (input string, err error) {
question := survey.Input{
Message: prompt,
Default: defaultValue,
}
err = survey.AskOne(&question, &input,
survey.WithStdio(t.in, t.out, t.err),
survey.WithValidator(survey.Required))
return
}
func (t *interactiveStdio) GetSecret(prompt string, minLen int) (password string, err error) {
question := survey.Password{Message: prompt}
opts := []survey.AskOpt{survey.WithStdio(t.in, t.out, t.err)}
if minLen > 0 {
opts = append(opts, survey.WithValidator(survey.MinLength(minLen)))
}
err = survey.AskOne(&question, &password, opts...)
question.NewCursor().HorizontalAbsolute(0)
return
}
func (t *interactiveStdio) GetPassword(prompt string) (string, error) {
for {
pass1, err := t.GetSecret(prompt, MinPasswordLen)
if err != nil {
return "", err
}
// Don't bother with the length check the 2nd time, since we check equality to pass1.
pass2, err := t.GetSecret(prompt+" again", 0)
if err != nil {
return "", err
}
if pass1 == pass2 {
return pass1, nil
}
if err := t.Error("Passwords do not match"); err != nil {
return "", err
}
}
}
func (t *interactiveStdio) GetConfirm(prompt string) (answer bool) {
question := survey.Confirm{
Message: prompt,
}
if err := survey.AskOne(&question, &answer, survey.WithStdio(t.in, t.out, t.err)); err != nil {
answer = false
}
return
}

View File

@ -0,0 +1,73 @@
package stdio
import (
"bufio"
"fmt"
"io"
"strings"
)
// noninteractiveStdio interacts with stdin/stdout/stderr as files, with no user interaction.
type noninteractiveStdio struct {
in *bufio.Scanner
out io.Writer
err io.Writer
}
func (t *noninteractiveStdio) Write(p []byte) (int, error) {
return t.out.Write(p)
}
func (t *noninteractiveStdio) WriteErr(p []byte) (int, error) {
return t.err.Write(p)
}
func (t *noninteractiveStdio) Banner(message string) error {
_, err := fmt.Fprintf(t.out, "> %s\n", message)
return err
}
func (t *noninteractiveStdio) Error(message string) error {
_, err := fmt.Fprintf(t.err, "X %s\n", message)
return err
}
func (t *noninteractiveStdio) IsInteractive() bool {
return false
}
func (t *noninteractiveStdio) GetStringInput(prompt, defaultValue string) (string, error) {
if inLine := t.readStdinLine(); inLine != "" {
return inLine, nil
}
if defaultValue != "" {
return defaultValue, nil
}
return "", fmt.Errorf("couldn't get input for prompt %q: no data on stdin", prompt)
}
func (t *noninteractiveStdio) GetSecret(prompt string, minLen int) (string, error) {
inLine := t.readStdinLine()
if len(inLine) >= minLen {
return inLine, nil
} else if minLen > 0 {
return "", fmt.Errorf("value for prompt %q is too short: min length is %d", prompt, minLen)
}
return "", nil
}
func (t *noninteractiveStdio) GetPassword(prompt string) (string, error) {
return t.GetSecret(prompt, MinPasswordLen)
}
func (t *noninteractiveStdio) GetConfirm(prompt string) (answer bool) {
return strings.HasPrefix(t.readStdinLine(), "y")
}
// readStdinLine returns the first line of text on stdin, or empty string if stdin is at EOF.
func (t *noninteractiveStdio) readStdinLine() string {
if !t.in.Scan() {
return ""
}
return t.in.Text()
}

View File

@ -1,17 +1,52 @@
package stdio
import "io"
import (
"bufio"
"io"
"os"
"github.com/AlecAivazis/survey/v2/terminal"
"github.com/mattn/go-isatty"
)
const MinPasswordLen = 8
type StdIO interface {
io.Writer
// Write prints some bytes to stdout.
Write(p []byte) (n int, err error)
// WriteErr prints some bytes to stderr.
WriteErr(p []byte) (n int, err error)
// Banner displays informational text to the user.
Banner(message string) error
// Error displays an error message to the user.
Error(message string) error
InputIsInteractive() bool
// IsInteractive signals whether interactive I/O is supported.
IsInteractive() bool
// GetStringInput prompts the user for arbitrary input.
GetStringInput(prompt, defaultValue string) (string, error)
// GetSecret prompts the user for a secret.
GetSecret(prompt string, minLen int) (string, error)
// GetPassword prompts the user for a secret twice, and inputs must match.
// Uses stdio.MinPasswordLen as the minimum input length
GetPassword(prompt string) (string, error)
// GetConfirm asks the user for a y/n answer to a prompt.
GetConfirm(prompt string) bool
}
func newTerminalStdio(in terminal.FileReader, out terminal.FileWriter, err io.Writer) StdIO {
interactiveIn := isatty.IsTerminal(in.Fd()) || isatty.IsCygwinTerminal(in.Fd())
interactiveOut := isatty.IsTerminal(out.Fd()) || isatty.IsCygwinTerminal(out.Fd())
if interactiveIn && interactiveOut {
return &interactiveStdio{in: in, out: out, err: err}
}
return &noninteractiveStdio{
in: bufio.NewScanner(in),
out: out,
err: err,
}
}
// TerminalStdio interacts with users over stdin/stdout/stderr.
var TerminalStdio = newTerminalStdio(os.Stdin, os.Stdout, os.Stderr)

View File

@ -0,0 +1,284 @@
package stdio
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestTerminalStdIO_Banner(t *testing.T) {
t.Parallel()
tmp, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer os.RemoveAll(tmp)
stdin, err := os.Create(filepath.Join(tmp, "stdin"))
require.NoError(t, err)
defer stdin.Close()
stdout, err := os.Create(filepath.Join(tmp, "stdout"))
require.NoError(t, err)
defer stdout.Close()
stderr, err := os.Create(filepath.Join(tmp, "stderr"))
require.NoError(t, err)
defer stderr.Close()
io := newTerminalStdio(stdin, stdout, stderr)
require.NoError(t, io.Banner("Hello world!"))
outBytes, err := os.ReadFile(stdout.Name())
require.NoError(t, err)
require.Equal(t, "> Hello world!\n", string(outBytes))
}
func TestTerminalStdIO_Error(t *testing.T) {
t.Parallel()
tmp, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer os.RemoveAll(tmp)
stdin, err := os.Create(filepath.Join(tmp, "stdin"))
require.NoError(t, err)
defer stdin.Close()
stdout, err := os.Create(filepath.Join(tmp, "stdout"))
require.NoError(t, err)
defer stdout.Close()
stderr, err := os.Create(filepath.Join(tmp, "stderr"))
require.NoError(t, err)
defer stderr.Close()
io := newTerminalStdio(stdin, stdout, stderr)
require.NoError(t, io.Error("Oh no"))
errBytes, err := os.ReadFile(stderr.Name())
require.NoError(t, err)
require.Equal(t, "X Oh no\n", string(errBytes))
}
func TestTerminalStdIO_GetStringInput(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
defVal string
lines []string
}{
{
name: "empty, no default",
},
{
name: "empty with default",
defVal: "foo",
},
{
name: "one line",
lines: []string{"foo"},
},
{
name: "multi-line",
lines: []string{"foo", "bar"},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tmp, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer os.RemoveAll(tmp)
stdin, err := os.Create(filepath.Join(tmp, "stdin"))
require.NoError(t, err)
defer stdin.Close()
for _, l := range tc.lines {
_, err := stdin.WriteString(fmt.Sprintln(l))
require.NoError(t, err)
}
_, err = stdin.Seek(0, 0)
require.NoError(t, err)
stdout, err := os.Create(filepath.Join(tmp, "stdout"))
require.NoError(t, err)
defer stdout.Close()
stderr, err := os.Create(filepath.Join(tmp, "stderr"))
require.NoError(t, err)
defer stderr.Close()
io := newTerminalStdio(stdin, stdout, stderr)
if len(tc.lines) == 0 {
val, err := io.GetStringInput("my prompt", tc.defVal)
if tc.defVal != "" {
require.NoError(t, err)
require.Equal(t, tc.defVal, val)
} else {
require.Error(t, err)
}
return
}
for _, l := range tc.lines {
val, err := io.GetStringInput("my prompt", tc.defVal)
require.NoError(t, err)
require.Equal(t, l, val)
}
})
}
}
func TestTerminalStdIO_GetSecret(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
minLength int
lines []string
}{
{
name: "empty, no min length",
},
{
name: "empty with min length",
minLength: 1,
},
{
name: "non-empty, too short",
minLength: 3,
lines: []string{"oh"},
},
{
name: "one line",
lines: []string{"foo"},
},
{
name: "multi-line",
minLength: 3,
lines: []string{"foo", "bar"},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tmp, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer os.RemoveAll(tmp)
stdin, err := os.Create(filepath.Join(tmp, "stdin"))
require.NoError(t, err)
defer stdin.Close()
for _, l := range tc.lines {
_, err := stdin.WriteString(fmt.Sprintln(l))
require.NoError(t, err)
}
_, err = stdin.Seek(0, 0)
require.NoError(t, err)
stdout, err := os.Create(filepath.Join(tmp, "stdout"))
require.NoError(t, err)
defer stdout.Close()
stderr, err := os.Create(filepath.Join(tmp, "stderr"))
require.NoError(t, err)
defer stderr.Close()
io := newTerminalStdio(stdin, stdout, stderr)
if len(tc.lines) == 0 {
val, err := io.GetSecret("my prompt", tc.minLength)
if tc.minLength == 0 {
require.NoError(t, err)
require.Empty(t, val)
} else {
require.Error(t, err)
}
return
}
for _, l := range tc.lines {
val, err := io.GetSecret("my prompt", tc.minLength)
if len(l) < tc.minLength {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, l, val)
}
}
})
}
}
func TestTerminalStdIO_GetConfirm(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
answer string
expected bool
}{
{
name: "empty answer",
expected: false,
},
{
name: "short affirmative",
answer: "y",
expected: true,
},
{
name: "short negative",
answer: "n",
},
{
name: "long affirmative",
answer: "yes",
expected: true,
},
{
name: "long negative",
answer: "no",
expected: false,
},
{
name: "nonsense answer",
answer: "I dunno",
expected: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tmp, err := os.MkdirTemp("", "")
require.NoError(t, err)
defer os.RemoveAll(tmp)
stdin, err := os.Create(filepath.Join(tmp, "stdin"))
require.NoError(t, err)
defer stdin.Close()
_, err = stdin.WriteString(fmt.Sprintln(tc.answer))
require.NoError(t, err)
_, err = stdin.Seek(0, 0)
require.NoError(t, err)
stdout, err := os.Create(filepath.Join(tmp, "stdout"))
require.NoError(t, err)
defer stdout.Close()
stderr, err := os.Create(filepath.Join(tmp, "stderr"))
require.NoError(t, err)
defer stderr.Close()
io := newTerminalStdio(stdin, stdout, stderr)
confirmed := io.GetConfirm("?")
require.Equal(t, tc.expected, confirmed)
})
}
}