diff --git a/internal/batcher/buffer_batcher.go b/internal/batcher/buffer_batcher.go new file mode 100644 index 0000000..bd0a471 --- /dev/null +++ b/internal/batcher/buffer_batcher.go @@ -0,0 +1,164 @@ +package batcher + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "time" +) + +const ( + // DefaultMaxBytes is 500KB; this is typically 250 to 500 lines. + DefaultMaxBytes = 500000 + // DefaultInterval will flush every 10 seconds. + DefaultInterval = 10 * time.Second +) + +var ( + // ErrLineTooLong is the error returned when reading a line that exceeds MaxLineLength. + ErrLineTooLong = errors.New("batcher: line too long") +) + +// BufferBatcher batches line protocol for sends to output. +type BufferBatcher struct { + MaxFlushBytes int // MaxFlushBytes is the maximum number of bytes to buffer before flushing + MaxFlushInterval time.Duration // MaxFlushInterval is the maximum amount of time to wait before flushing + MaxLineLength int // MaxLineLength specifies the maximum length of a single line +} + +// WriteBatches reads batches from r, passing them on to an arbitrary writeFn. +func (b *BufferBatcher) WriteBatches(ctx context.Context, r io.Reader, writeFn func(batch []byte) error) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + lines := make(chan []byte) + + errC := make(chan error, 2) + go b.write(ctx, writeFn, lines, errC) + go b.read(ctx, r, lines, errC) + + // we loop twice to check if both read and batcher have an error. if read exits + // cleanly, then we still want to wait for batcher. + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errC: + // only if there is any error, exit immediately. + if err != nil { + return err + } + } + } + return nil +} + +// read will close the line channel when there is no more data, or an error occurs. +// it is possible for an io.Reader to block forever; Write's context can be +// used to cancel, but, it's possible there will be dangling read go routines. +func (b *BufferBatcher) read(ctx context.Context, r io.Reader, lines chan<- []byte, errC chan<- error) { + defer close(lines) + scanner := bufio.NewScanner(r) + scanner.Split(ScanLines) + + maxLineLength := bufio.MaxScanTokenSize + if b.MaxLineLength > 0 { + maxLineLength = b.MaxLineLength + } + scanner.Buffer(nil, maxLineLength) + + for scanner.Scan() { + // exit early if the context is done + select { + // NOTE: We purposefully don't use scanner.Bytes() here because it returns a slice + // pointing to an array which is reused / overwritten on every call to Scan(). + case lines <- []byte(scanner.Text()): + case <-ctx.Done(): + errC <- ctx.Err() + return + } + } + err := scanner.Err() + if errors.Is(err, bufio.ErrTooLong) { + err = ErrLineTooLong + } + errC <- err +} + +// finishes when the lines channel is closed or context is done. +// if an error occurs while writing data to the write service, the error is sent in the +// errC channel and the function returns. +func (b *BufferBatcher) write(ctx context.Context, writeFn func(batch []byte) error, lines <-chan []byte, errC chan<- error) { + flushInterval := b.MaxFlushInterval + if flushInterval == 0 { + flushInterval = DefaultInterval + } + + maxBytes := b.MaxFlushBytes + if maxBytes == 0 { + maxBytes = DefaultMaxBytes + } + + timer := time.NewTimer(flushInterval) + defer func() { _ = timer.Stop() }() + + buf := make([]byte, 0, maxBytes) + + var line []byte + var more = true + // if read closes the channel normally, exit the loop + for more { + select { + case line, more = <-lines: + if more && string(line) != "\n" { + buf = append(buf, line...) + } + // batcher if we exceed the max lines OR read routine has finished + if len(buf) >= maxBytes || (!more && len(buf) > 0) { + timer.Reset(flushInterval) + if err := writeFn(buf); err != nil { + errC <- err + return + } + buf = buf[:0] + } + case <-timer.C: + if len(buf) > 0 { + timer.Reset(flushInterval) + if err := writeFn(buf); err != nil { + errC <- err + return + } + buf = buf[:0] + } + case <-ctx.Done(): + errC <- ctx.Err() + return + } + } + + errC <- nil +} + +// ScanLines is used in bufio.Scanner.Split to split lines of line protocol. +func ScanLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := bytes.IndexByte(data, '\n'); i >= 0 { + // We have a full newline-terminated line. + return i + 1, data[0 : i+1], nil + + } + + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} diff --git a/internal/batcher/buffer_batcher_internal_test.go b/internal/batcher/buffer_batcher_internal_test.go new file mode 100644 index 0000000..8f601b7 --- /dev/null +++ b/internal/batcher/buffer_batcher_internal_test.go @@ -0,0 +1,261 @@ +package batcher + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// errorReader mocks io.Reader but returns an error. +type errorReader struct{} + +func (r *errorReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("error") +} + +func TestBatcher_read(t *testing.T) { + type args struct { + cancel bool + r io.Reader + max int + } + tests := []struct { + name string + args args + want []string + expErr error + }{ + { + name: "reading two lines produces 2 lines", + args: args{ + r: strings.NewReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2"), + }, + want: []string{"m1,t1=v1 f1=1\n", "m2,t2=v2 f2=2"}, + }, + { + name: "canceling returns no lines", + args: args{ + cancel: true, + r: strings.NewReader("m1,t1=v1 f1=1"), + }, + want: nil, + expErr: context.Canceled, + }, + { + name: "error from reader returns error", + args: args{ + r: &errorReader{}, + }, + want: nil, + expErr: fmt.Errorf("error"), + }, + { + name: "error when input exceeds max line length", + args: args{ + r: strings.NewReader("m1,t1=v1 f1=1"), + max: 5, + }, + want: nil, + expErr: ErrLineTooLong, + }, + { + name: "lines greater than MaxScanTokenSize are allowed", + args: args{ + r: strings.NewReader(strings.Repeat("a", bufio.MaxScanTokenSize+1)), + max: bufio.MaxScanTokenSize + 2, + }, + want: []string{strings.Repeat("a", bufio.MaxScanTokenSize+1)}, + }, + { + name: "lines greater than MaxScanTokenSize by default are not allowed", + args: args{ + r: strings.NewReader(strings.Repeat("a", bufio.MaxScanTokenSize+1)), + }, + want: nil, + expErr: ErrLineTooLong, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + var cancel context.CancelFunc + if tt.args.cancel { + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + b := &BufferBatcher{MaxLineLength: tt.args.max} + var got []string + + lines := make(chan []byte) + errC := make(chan error, 1) + + go b.read(ctx, tt.args.r, lines, errC) + + if cancel == nil { + for line := range lines { + got = append(got, string(line)) + } + } + + err := <-errC + assert.Equal(t, err, tt.expErr) + assert.Equal(t, got, tt.want) + }) + } +} + +func TestBatcher_write(t *testing.T) { + type fields struct { + MaxFlushBytes int + MaxFlushInterval time.Duration + } + type args struct { + cancel bool + writeError bool + line string + lines chan []byte + errC chan error + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + wantNoCall bool + }{ + { + name: "sending a single line will send a line to the service", + fields: fields{ + MaxFlushBytes: 1, + }, + args: args{ + line: "m1,t1=v1 f1=1", + lines: make(chan []byte), + errC: make(chan error), + }, + want: "m1,t1=v1 f1=1", + }, + { + name: "sending a single line and waiting for a timeout will send a line to the service", + fields: fields{ + MaxFlushInterval: time.Millisecond, + }, + args: args{ + line: "m1,t1=v1 f1=1", + lines: make(chan []byte), + errC: make(chan error), + }, + want: "m1,t1=v1 f1=1", + }, + { + name: "batcher service returning error stops the batcher after timeout", + fields: fields{ + MaxFlushInterval: time.Millisecond, + }, + args: args{ + writeError: true, + line: "m1,t1=v1 f1=1", + lines: make(chan []byte), + errC: make(chan error), + }, + wantErr: true, + }, + { + name: "canceling will batcher no data to service", + fields: fields{ + MaxFlushBytes: 1, + }, + args: args{ + cancel: true, + line: "m1,t1=v1 f1=1", + lines: make(chan []byte, 1), + errC: make(chan error, 1), + }, + wantErr: true, + wantNoCall: true, + }, + { + name: "batcher service returning error stops the batcher", + fields: fields{ + MaxFlushBytes: 1, + }, + args: args{ + writeError: true, + line: "m1,t1=v1 f1=1", + lines: make(chan []byte), + errC: make(chan error), + }, + wantErr: true, + }, + { + name: "blank line is not sent to service", + fields: fields{ + MaxFlushBytes: 1, + }, + args: args{ + line: "\n", + lines: make(chan []byte), + errC: make(chan error), + }, + wantNoCall: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + var cancel context.CancelFunc + if tt.args.cancel { + ctx, cancel = context.WithCancel(ctx) + } + + // mocking the batcher service here to either return an error + // or get back all the bytes from the reader. + writeCalled := false + b := &BufferBatcher{ + MaxFlushBytes: tt.fields.MaxFlushBytes, + MaxFlushInterval: tt.fields.MaxFlushInterval, + } + var got string + writeFn := func(batch []byte) error { + writeCalled = true + if tt.wantErr { + return errors.New("I broke") + } + got = string(batch) + return nil + } + + go b.write(ctx, writeFn, tt.args.lines, tt.args.errC) + + if cancel != nil { + cancel() + time.Sleep(500 * time.Millisecond) + } + + tt.args.lines <- []byte(tt.args.line) + // if the max flush interval is not zero, we are testing to see + // if the data is flushed via the timer rather than forced by + // closing the channel. + if tt.fields.MaxFlushInterval != 0 { + time.Sleep(tt.fields.MaxFlushInterval * 100) + } + close(tt.args.lines) + + err := <-tt.args.errC + require.Equal(t, tt.wantErr, err != nil) + + require.Equal(t, tt.wantNoCall, !writeCalled) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/batcher/buffer_batcher_test.go b/internal/batcher/buffer_batcher_test.go new file mode 100644 index 0000000..2758654 --- /dev/null +++ b/internal/batcher/buffer_batcher_test.go @@ -0,0 +1,192 @@ +package batcher_test + +import ( + "bufio" + "context" + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/influxdata/influx-cli/v2/internal/batcher" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanLines(t *testing.T) { + tests := []struct { + name string + input string + want []string + wantErr bool + }{ + { + name: "3 lines produced including their newlines", + input: "m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3", + want: []string{"m1,t1=v1 f1=1\n", "m2,t2=v2 f2=2\n", "m3,t3=v3 f3=3"}, + }, + { + name: "single line without newline", + input: "m1,t1=v1 f1=1", + want: []string{"m1,t1=v1 f1=1"}, + }, + { + name: "single line with newline", + input: "m1,t1=v1 f1=1\n", + want: []string{"m1,t1=v1 f1=1\n"}, + }, + { + name: "no lines", + input: "", + want: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scanner := bufio.NewScanner(strings.NewReader(tt.input)) + scanner.Split(batcher.ScanLines) + got := []string{} + for scanner.Scan() { + got = append(got, scanner.Text()) + } + err := scanner.Err() + + if (err != nil) != tt.wantErr { + t.Errorf("ScanLines() error = %v, wantErr %v", err, tt.wantErr) + return + } + + assert.Equal(t, tt.want, got) + }) + } +} + +// errorReader mocks io.Reader but returns an error. +type errorReader struct{} + +func (r *errorReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("error") +} + +func TestBatcher_WriteTo(t *testing.T) { + createReader := func(data string) func() io.Reader { + if data == "error" { + return func() io.Reader { + return &errorReader{} + } + } + return func() io.Reader { + return strings.NewReader(data) + } + } + + type fields struct { + MaxFlushBytes int + MaxFlushInterval time.Duration + } + type args struct { + r func() io.Reader + } + tests := []struct { + name string + fields fields + args args + want string + wantFlushes int + wantErr bool + }{ + { + name: "a line of line protocol is sent to the service", + fields: fields{ + MaxFlushBytes: 1, + }, + args: args{ + r: createReader("m1,t1=v1 f1=1"), + }, + want: "m1,t1=v1 f1=1", + wantFlushes: 1, + }, + { + name: "multiple lines cause multiple flushes", + fields: fields{ + MaxFlushBytes: len([]byte("m1,t1=v1 f1=1\n")), + }, + args: args{ + r: createReader("m1,t1=v1 f1=1\nm2,t2=v2 f2=2\nm3,t3=v3 f3=3"), + }, + want: "m3,t3=v3 f3=3", + wantFlushes: 3, + }, + { + name: "errors during read return error", + fields: fields{}, + args: args{ + r: createReader("error"), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &batcher.BufferBatcher{ + MaxFlushBytes: tt.fields.MaxFlushBytes, + MaxFlushInterval: tt.fields.MaxFlushInterval, + } + + // mocking the batcher service here to either return an error + // or get back all the bytes from the reader. + var ( + got string + gotFlushes int + ) + err := b.WriteBatches(context.Background(), tt.args.r(), func(batch []byte) error { + got = string(batch) + gotFlushes++ + return nil + }) + require.Equal(t, tt.wantErr, err != nil) + require.Equal(t, tt.wantFlushes, gotFlushes) + require.Equal(t, tt.want, got) + }) + // test the same data, but now with WriteBatches function + t.Run("WriteTo_"+tt.name, func(t *testing.T) { + b := &batcher.BufferBatcher{ + MaxFlushBytes: tt.fields.MaxFlushBytes, + MaxFlushInterval: tt.fields.MaxFlushInterval, + } + + // mocking the batcher service here to either return an error + // or get back all the bytes from the reader. + var ( + got string + gotFlushes int + ) + err := b.WriteBatches(context.Background(), tt.args.r(), func(batch []byte) error { + got = string(batch) + gotFlushes++ + return nil + }) + require.Equal(t, tt.wantErr, err != nil) + require.Equal(t, tt.wantFlushes, gotFlushes) + require.Equal(t, tt.want, got) + }) + } +} + +func TestBatcher_WriteTimeout(t *testing.T) { + b := &batcher.BufferBatcher{} + + // this mimics a reader like stdin that may never return data. + r, _ := io.Pipe() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + + var got string + require.Equal(t, context.DeadlineExceeded, b.WriteBatches(ctx, r, func(batch []byte) error { + got = string(batch) + return nil + })) + require.Empty(t, got, "BufferBatcher.Write() with timeout received data") +}