From 7d6ea73c335c6361e7fd47f30d56bf0404842b44 Mon Sep 17 00:00:00 2001 From: Daniel Moran Date: Mon, 26 Apr 2021 09:51:15 -0400 Subject: [PATCH] feat: use signal-wrapped contexts in CLI commands (#38) --- cmd/influx/main.go | 11 +++++- pkg/signals/context.go | 31 +++++++++++++++ pkg/signals/context_test.go | 79 +++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 pkg/signals/context.go create mode 100644 pkg/signals/context_test.go diff --git a/cmd/influx/main.go b/cmd/influx/main.go index 290753d..76b223b 100644 --- a/cmd/influx/main.go +++ b/cmd/influx/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "fmt" "net/http" @@ -13,6 +14,7 @@ import ( "github.com/influxdata/influx-cli/v2/internal/api" "github.com/influxdata/influx-cli/v2/internal/config" "github.com/influxdata/influx-cli/v2/internal/stdio" + "github.com/influxdata/influx-cli/v2/pkg/signals" "github.com/urfave/cli/v2" ) @@ -101,6 +103,11 @@ func newApiClient(ctx *cli.Context, cli *internal.CLI, injectToken bool) (*api.A return api.NewAPIClient(apiConfig), nil } +// standardCtx returns a context that will cancel on SIGINT and SIGTERM. +func standardCtx(ctx *cli.Context) context.Context { + return signals.WithStandardSignals(ctx.Context) +} + func main() { if len(date) == 0 { date = time.Now().UTC().Format(time.RFC3339) @@ -182,7 +189,7 @@ func main() { if err != nil { return err } - return cli.Ping(ctx.Context, client.HealthApi) + return cli.Ping(standardCtx(ctx), client.HealthApi) }, }, { @@ -253,7 +260,7 @@ func main() { if err != nil { return err } - return cli.Setup(ctx.Context, client.SetupApi, &internal.SetupParams{ + return cli.Setup(standardCtx(ctx), client.SetupApi, &internal.SetupParams{ Username: ctx.String("username"), Password: ctx.String("password"), AuthToken: ctx.String(tokenFlag), diff --git a/pkg/signals/context.go b/pkg/signals/context.go new file mode 100644 index 0000000..7541ac0 --- /dev/null +++ b/pkg/signals/context.go @@ -0,0 +1,31 @@ +package signals + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// WithSignals returns a context that is canceled with any signal in sigs. +func WithSignals(ctx context.Context, sigs ...os.Signal) context.Context { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, sigs...) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + select { + case <-ctx.Done(): + return + case <-sigCh: + return + } + }() + return ctx +} + +// WithStandardSignals cancels the context on os.Interrupt, syscall.SIGTERM. +func WithStandardSignals(ctx context.Context) context.Context { + return WithSignals(ctx, os.Interrupt, syscall.SIGTERM) +} diff --git a/pkg/signals/context_test.go b/pkg/signals/context_test.go new file mode 100644 index 0000000..222e398 --- /dev/null +++ b/pkg/signals/context_test.go @@ -0,0 +1,79 @@ +package signals + +import ( + "context" + "fmt" + "os" + "syscall" + "testing" + "time" +) + +func ExampleWithSignals() { + ctx := WithSignals(context.Background(), syscall.SIGUSR1) + go func() { + time.Sleep(500 * time.Millisecond) // after some time SIGUSR1 is sent + // mimicking a signal from the outside + syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) + }() + + <-ctx.Done() + fmt.Println("finished") + // Output: + // finished +} + +func Example_withUnregisteredSignals() { + dctx, cancel := context.WithTimeout(context.TODO(), time.Millisecond*100) + defer cancel() + + ctx := WithSignals(dctx, syscall.SIGUSR1) + go func() { + time.Sleep(10 * time.Millisecond) // after some time SIGUSR2 is sent + // mimicking a signal from the outside, WithSignals will not handle it + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + }() + + <-ctx.Done() + fmt.Println("finished") + // Output: + // finished +} + +func TestWithSignals(t *testing.T) { + tests := []struct { + name string + ctx context.Context + sigs []os.Signal + wantSignal bool + }{ + { + name: "sending signal SIGUSR2 should exit context.", + ctx: context.Background(), + sigs: []os.Signal{syscall.SIGUSR2}, + wantSignal: true, + }, + { + name: "sending signal SIGUSR2 should NOT exit context.", + ctx: context.Background(), + sigs: []os.Signal{syscall.SIGUSR1}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := WithSignals(tt.ctx, tt.sigs...) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-ctx.Done(): + if !tt.wantSignal { + t.Errorf("unexpected exit with signal") + } + case <-timer.C: + if tt.wantSignal { + t.Errorf("expected to exit with signal but did not") + } + } + }) + } +}