210 lines
5.6 KiB
Go
210 lines
5.6 KiB
Go
// Copyright 2020 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package encrypt
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"math"
|
|
"math/big"
|
|
)
|
|
|
|
var errInvalidBlockSize = errors.New("invalid encrypt block size")
|
|
|
|
// defaultEncryptBlockSize indicates the default encrypt block size in bytes
|
|
const defaultEncryptBlockSize = 1024
|
|
|
|
// CtrCipher encrypting data using AES in counter mode
|
|
type CtrCipher struct {
|
|
block cipher.Block
|
|
nonce uint64
|
|
// encryptBlockSize indicates the encrypt block size in bytes.
|
|
encryptBlockSize int64
|
|
// aesBlockCount indicates the total aes blocks in one encrypt block
|
|
aesBlockCount int64
|
|
}
|
|
|
|
// NewCtrCipher return a CtrCipher using the default encrypt block size
|
|
func NewCtrCipher() (ctr *CtrCipher, err error) {
|
|
return NewCtrCipherWithBlockSize(defaultEncryptBlockSize)
|
|
}
|
|
|
|
// NewCtrCipherWithBlockSize return a CtrCipher with the encrypt block size
|
|
func NewCtrCipherWithBlockSize(encryptBlockSize int64) (ctr *CtrCipher, err error) {
|
|
key := make([]byte, aes.BlockSize)
|
|
_, err = rand.Read(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if encryptBlockSize%aes.BlockSize != 0 {
|
|
return nil, errInvalidBlockSize
|
|
}
|
|
ctr = new(CtrCipher)
|
|
ctr.block = block
|
|
nonce, err := rand.Int(rand.Reader, big.NewInt(int64(math.MaxInt64)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ctr.nonce = nonce.Uint64()
|
|
ctr.encryptBlockSize = encryptBlockSize
|
|
ctr.aesBlockCount = encryptBlockSize / aes.BlockSize
|
|
return
|
|
}
|
|
|
|
// stream returns a cipher.Stream be use to encrypts/decrypts
|
|
func (ctr *CtrCipher) stream(counter uint64) cipher.Stream {
|
|
counterBuf := make([]byte, aes.BlockSize)
|
|
binary.BigEndian.PutUint64(counterBuf, ctr.nonce)
|
|
binary.BigEndian.PutUint64(counterBuf[8:], counter)
|
|
return cipher.NewCTR(ctr.block, counterBuf)
|
|
}
|
|
|
|
// Writer implements an io.WriteCloser, it encrypt data using AES before writing to the underlying object.
|
|
type Writer struct {
|
|
err error
|
|
w io.WriteCloser
|
|
cipherStream cipher.Stream
|
|
buf []byte
|
|
flushedUserDataCnt int64
|
|
n int
|
|
}
|
|
|
|
// NewWriter returns a new Writer which encrypt data using AES before writing to the underlying object.
|
|
func NewWriter(w io.WriteCloser, ctrCipher *CtrCipher) *Writer {
|
|
writer := &Writer{w: w}
|
|
writer.buf = make([]byte, ctrCipher.encryptBlockSize)
|
|
writer.cipherStream = ctrCipher.stream(0)
|
|
return writer
|
|
}
|
|
|
|
// AvailableSize returns how many bytes are unused in the buffer.
|
|
func (w *Writer) AvailableSize() int { return len(w.buf) - w.n }
|
|
|
|
// Write implements the io.Writer interface.
|
|
func (w *Writer) Write(p []byte) (n int, err error) {
|
|
if w.err != nil {
|
|
return n, w.err
|
|
}
|
|
for len(p) > w.AvailableSize() && w.err == nil {
|
|
copiedNum := copy(w.buf[w.n:], p)
|
|
w.n += copiedNum
|
|
err = w.Flush()
|
|
if err != nil {
|
|
return
|
|
}
|
|
n += copiedNum
|
|
p = p[copiedNum:]
|
|
}
|
|
copiedNum := copy(w.buf[w.n:], p)
|
|
w.n += copiedNum
|
|
n += copiedNum
|
|
return
|
|
}
|
|
|
|
// Buffered returns the number of bytes that have been written into the current buffer.
|
|
func (w *Writer) Buffered() int { return w.n }
|
|
|
|
// Flush writes all the buffered data to the underlying object.
|
|
func (w *Writer) Flush() error {
|
|
if w.err != nil {
|
|
return w.err
|
|
}
|
|
if w.n == 0 {
|
|
return nil
|
|
}
|
|
w.cipherStream.XORKeyStream(w.buf[:w.n], w.buf[:w.n])
|
|
n, err := w.w.Write(w.buf[:w.n])
|
|
w.flushedUserDataCnt += int64(n)
|
|
if n < w.n && err == nil {
|
|
err = io.ErrShortWrite
|
|
}
|
|
if err != nil {
|
|
w.err = err
|
|
return err
|
|
}
|
|
w.n = 0
|
|
return nil
|
|
}
|
|
|
|
// GetCache returns the byte slice that holds the data not flushed to disk.
|
|
func (w *Writer) GetCache() []byte {
|
|
return w.buf[:w.n]
|
|
}
|
|
|
|
// GetCacheDataOffset return the user data offset in cache.
|
|
func (w *Writer) GetCacheDataOffset() int64 {
|
|
return w.flushedUserDataCnt
|
|
}
|
|
|
|
// Close implements the io.Closer interface.
|
|
func (w *Writer) Close() (err error) {
|
|
err = w.Flush()
|
|
if err != nil {
|
|
return
|
|
}
|
|
return w.w.Close()
|
|
}
|
|
|
|
// Reader implements an io.ReadAt, reading from the input source after decrypting.
|
|
type Reader struct {
|
|
r io.ReaderAt
|
|
cipher *CtrCipher
|
|
}
|
|
|
|
// NewReader returns a new Reader which can read from the input source after decrypting.
|
|
func NewReader(r io.ReaderAt, ctrCipher *CtrCipher) *Reader {
|
|
reader := &Reader{r: r, cipher: ctrCipher}
|
|
return reader
|
|
}
|
|
|
|
// ReadAt implements the io.ReadAt interface.
|
|
func (r *Reader) ReadAt(p []byte, off int64) (n int, err error) {
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
offset := off % r.cipher.encryptBlockSize
|
|
counter := (off / r.cipher.encryptBlockSize) * r.cipher.aesBlockCount
|
|
cursor := off - offset
|
|
|
|
buf := make([]byte, r.cipher.encryptBlockSize)
|
|
var readNum int
|
|
cipherStream := r.cipher.stream(uint64(counter))
|
|
for len(p) > 0 && err == nil {
|
|
readNum, err = r.r.ReadAt(buf, cursor)
|
|
if err != nil {
|
|
if readNum == 0 || err != io.EOF {
|
|
return n, err
|
|
}
|
|
err = nil
|
|
// continue if n > 0 and r.err is io.EOF
|
|
}
|
|
cursor += int64(readNum)
|
|
cipherStream.XORKeyStream(buf[:readNum], buf[:readNum])
|
|
copiedNum := copy(p, buf[offset:readNum])
|
|
n += copiedNum
|
|
p = p[copiedNum:]
|
|
offset = 0
|
|
}
|
|
return n, err
|
|
}
|