Files
tidb/vendor/github.com/blacktear23/go-proxyprotocol/proxy_protocol.go
2018-01-08 11:00:06 +08:00

367 lines
9.8 KiB
Go

package proxyprotocol
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"strings"
"sync/atomic"
"time"
"errors"
)
// Ref: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt .
const (
proxyProtocolV1MaxHeaderLen = 108
unknownProtocol = 0
proxyProtocolV1 = 1
proxyProtocolV2 = 2
v2CmdPos = 12
v2FamlyPos = 13
v2LenPos = 14
v2AddrsPos = 16
)
var (
ErrProxyProtocolV1HeaderInvalid = errors.New("PROXY Protocol v1 header is invalid")
ErrProxyProtocolV2HeaderInvalid = errors.New("PROXY Protocol v2 header is invalid")
ErrProxyAddressNotAllowed = errors.New("Proxy address is not allowed")
ErrHeaderReadTimeout = errors.New("Header read timeout")
proxyProtocolV2Sig = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
_ net.Conn = &proxyProtocolConn{}
_ net.Listener = &proxyProtocolListener{}
)
type connErr struct {
conn net.Conn
err error
}
type proxyProtocolListener struct {
listener net.Listener
allowAll bool
allowedNets []*net.IPNet
headerReadTimeout int // Unit is second
acceptQueue chan *connErr
runningFlag int32
}
func IsProxyProtocolError(err error) bool {
return err == ErrProxyProtocolV1HeaderInvalid ||
err == ErrProxyProtocolV2HeaderInvalid ||
err == ErrProxyAddressNotAllowed ||
err == ErrHeaderReadTimeout
}
// Create new PROXY protocol listener
// * listener is basic listener for TCP
// * allowedIPs is protocol allowed addresses or CIDRs split by `,` if use '*' means allow any address
// * headerReadTimeout is timeout for PROXY protocol header read
func NewListener(listener net.Listener, allowedIPs string, headerReadTimeout int) (net.Listener, error) {
ppl, err := newListener(listener, allowedIPs, headerReadTimeout)
if err == nil {
go ppl.acceptLoop()
}
return ppl, err
}
func newListener(listener net.Listener, allowedIPs string, headerReadTimeout int) (*proxyProtocolListener, error) {
allowAll := false
allowedNets := []*net.IPNet{}
if allowedIPs == "*" {
allowAll = true
} else {
for _, aip := range strings.Split(allowedIPs, ",") {
saip := strings.TrimSpace(aip)
_, ipnet, err := net.ParseCIDR(saip)
if err == nil {
allowedNets = append(allowedNets, ipnet)
continue
}
psaip := fmt.Sprintf("%s/32", saip)
_, ipnet, err = net.ParseCIDR(psaip)
if err != nil {
return nil, err
}
allowedNets = append(allowedNets, ipnet)
}
}
return &proxyProtocolListener{
listener: listener,
allowAll: allowAll,
allowedNets: allowedNets,
headerReadTimeout: headerReadTimeout,
acceptQueue: make(chan *connErr, 1),
runningFlag: 1,
}, nil
}
// Check remote address is allowed
func (l *proxyProtocolListener) checkAllowed(raddr net.Addr) bool {
if l.allowAll {
return true
}
taddr, ok := raddr.(*net.TCPAddr)
if !ok {
return false
}
cip := taddr.IP
for _, ipnet := range l.allowedNets {
if ipnet.Contains(cip) {
return true
}
}
return false
}
// Create proxyProtocolConn instance
func (l *proxyProtocolListener) createProxyProtocolConn(conn net.Conn) (*proxyProtocolConn, error) {
ppconn := &proxyProtocolConn{
Conn: conn,
headerReadTimeout: l.headerReadTimeout,
}
err := ppconn.readClientAddrBehindProxy(conn.RemoteAddr())
if err != nil {
ppconn.Close()
return nil, err
}
return ppconn, nil
}
func (l *proxyProtocolListener) running() bool {
r := atomic.LoadInt32(&l.runningFlag)
return r == 1
}
func (l *proxyProtocolListener) acceptLoop() {
for l.running() {
conn, err := l.listener.Accept()
if err != nil {
l.acceptQueue <- &connErr{conn, err}
} else if !l.checkAllowed(conn.RemoteAddr()) && l.running() {
conn.Close()
l.acceptQueue <- &connErr{nil, ErrProxyAddressNotAllowed}
} else {
go l.wrapConn(conn)
}
}
}
func (l *proxyProtocolListener) wrapConn(conn net.Conn) {
wconn, err := l.createProxyProtocolConn(conn)
if l.running() {
l.acceptQueue <- &connErr{wconn, err}
} else {
wconn.Close()
}
}
// Accept new connection
// You should check error instead of panic it.
// As PROXY protocol SPEC wrote, if invalid PROXY protocol header
// received or connection's address not allowed, Accept function
// will return an error and close this connection.
func (l *proxyProtocolListener) Accept() (net.Conn, error) {
ce := <-l.acceptQueue
if opErr, ok := ce.err.(*net.OpError); ok {
if opErr.Err.Error() == "use of closed network connection" {
close(l.acceptQueue)
}
}
return ce.conn, ce.err
}
// Close listener
func (l *proxyProtocolListener) Close() error {
atomic.SwapInt32(&l.runningFlag, 0)
err := l.listener.Close()
return err
}
// Get listener's address
func (l *proxyProtocolListener) Addr() net.Addr {
return l.listener.Addr()
}
type proxyProtocolConn struct {
net.Conn
headerReadTimeout int
clientIP net.Addr
exceedBuffer []byte
exceedBufferStart int
exceedBufferLen int
exceedBufferReaded bool
}
func (c *proxyProtocolConn) readClientAddrBehindProxy(connRemoteAddr net.Addr) error {
return c.parseHeader(connRemoteAddr)
}
func (c *proxyProtocolConn) parseHeader(connRemoteAddr net.Addr) error {
ver, buffer, err := c.readHeader()
if err != nil {
return err
}
switch ver {
case proxyProtocolV1:
raddr, v1err := c.extractClientIPV1(buffer, connRemoteAddr)
if v1err != nil {
return v1err
}
c.clientIP = raddr
return nil
case proxyProtocolV2:
raddr, v2err := c.extraceClientIPV2(buffer, connRemoteAddr)
if v2err != nil {
return v2err
}
c.clientIP = raddr
return nil
default:
panic("Should not come here")
}
}
func (c *proxyProtocolConn) extractClientIPV1(buffer []byte, connRemoteAddr net.Addr) (net.Addr, error) {
header := string(buffer)
parts := strings.Split(header, " ")
if len(parts) != 6 {
if len(parts) > 1 && parts[1] == "UNKNOWN\r\n" {
return connRemoteAddr, nil
}
return nil, ErrProxyProtocolV1HeaderInvalid
}
clientIPStr := parts[2]
clientPortStr := parts[4]
iptype := parts[1]
switch iptype {
case "TCP4":
addrStr := fmt.Sprintf("%s:%s", clientIPStr, clientPortStr)
return net.ResolveTCPAddr("tcp4", addrStr)
case "TCP6":
addrStr := fmt.Sprintf("[%s]:%s", clientIPStr, clientPortStr)
return net.ResolveTCPAddr("tcp6", addrStr)
case "UNKNOWN":
return connRemoteAddr, nil
default:
return nil, ErrProxyProtocolV1HeaderInvalid
}
}
func (c *proxyProtocolConn) extraceClientIPV2(buffer []byte, connRemoteAddr net.Addr) (net.Addr, error) {
verCmd := buffer[v2CmdPos]
famly := buffer[v2FamlyPos]
switch verCmd & 0x0F {
case 0x01: /* PROXY command */
switch famly {
case 0x11: /* TCPv4 */
srcAddrV4 := net.IP(buffer[v2AddrsPos : v2AddrsPos+4])
srcPortV4 := binary.BigEndian.Uint16(buffer[v2AddrsPos+8 : v2AddrsPos+10])
return &net.TCPAddr{
IP: srcAddrV4,
Port: int(srcPortV4),
}, nil
case 0x21: /* TCPv6 */
srcAddrV6 := net.IP(buffer[v2AddrsPos : v2AddrsPos+16])
srcPortV6 := binary.BigEndian.Uint16(buffer[v2AddrsPos+32 : v2AddrsPos+34])
return &net.TCPAddr{
IP: srcAddrV6,
Port: int(srcPortV6),
}, nil
default:
// unsupported protocol, keep local connection address
return connRemoteAddr, nil
}
case 0x00: /* LOCAL command */
// keep local connection address for LOCAL
return connRemoteAddr, nil
default:
// not a supported command
return nil, ErrProxyProtocolV2HeaderInvalid
}
}
// Get client address
func (c *proxyProtocolConn) RemoteAddr() net.Addr {
return c.clientIP
}
// Read received data
func (c *proxyProtocolConn) Read(buffer []byte) (int, error) {
if c.exceedBufferReaded {
return c.Conn.Read(buffer)
}
if c.exceedBufferLen == 0 || c.exceedBufferStart >= c.exceedBufferLen {
c.exceedBufferReaded = true
return c.Conn.Read(buffer)
}
buflen := len(buffer)
nExceedRead := c.exceedBufferLen - c.exceedBufferStart
// buffer length is less or equals than exceedBuffer length
if nExceedRead >= buflen {
copy(buffer[0:], c.exceedBuffer[c.exceedBufferStart:c.exceedBufferStart+buflen])
c.exceedBufferStart += buflen
return buflen, nil
}
// buffer length is great than exceedBuffer length
copy(buffer[0:nExceedRead], c.exceedBuffer[c.exceedBufferStart:])
n, err := c.Conn.Read(buffer[nExceedRead-1:])
if err == nil {
// If read is success set buffer start to buffer length
// If fail make rest buffer can be read in next time
c.exceedBufferStart = c.exceedBufferLen
return n + nExceedRead - 1, nil
}
return 0, err
}
func (c *proxyProtocolConn) readHeader() (int, []byte, error) {
buf := make([]byte, proxyProtocolV1MaxHeaderLen)
// This mean all header data should be read in headerReadTimeout seconds.
c.Conn.SetReadDeadline(time.Now().Add(time.Duration(c.headerReadTimeout) * time.Second))
// When function return clean read deadline.
defer func() {
c.Conn.SetReadDeadline(time.Time{})
}()
n, err := c.Conn.Read(buf)
if err != nil {
return unknownProtocol, nil, ErrHeaderReadTimeout
}
if n >= 16 {
if bytes.Equal(buf[0:12], proxyProtocolV2Sig) && (buf[v2CmdPos]&0xF0) == 0x20 {
endPos := 16 + int(binary.BigEndian.Uint16(buf[v2LenPos:v2LenPos+2]))
if n < endPos {
return unknownProtocol, nil, ErrProxyProtocolV2HeaderInvalid
}
if n > endPos {
c.exceedBuffer = buf[endPos:]
c.exceedBufferLen = n - endPos + 1
}
return proxyProtocolV2, buf[0:endPos], nil
}
}
if n >= 5 {
if string(buf[0:5]) != "PROXY" {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
pos := bytes.IndexByte(buf, byte(10))
if pos == -1 {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
if buf[pos-1] != byte(13) {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
endPos := pos
if n > endPos {
c.exceedBuffer = buf[endPos+1:]
c.exceedBufferLen = n - endPos
}
return proxyProtocolV1, buf[0 : endPos+1], nil
}
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}