297 lines
8.8 KiB
Go
297 lines
8.8 KiB
Go
// Copyright 2022 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package util_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/tidb/util"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestInvalidTLS(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
|
|
caPath := filepath.Join(tempDir, "ca.pem")
|
|
_, err := util.NewTLS(caPath, "", "", "localhost", nil)
|
|
require.Regexp(t, "could not read ca certificate:.*", err.Error())
|
|
|
|
err = os.WriteFile(caPath, []byte("invalid ca content"), 0644)
|
|
require.NoError(t, err)
|
|
_, err = util.NewTLS(caPath, "", "", "localhost", nil)
|
|
require.Regexp(t, "failed to append ca certs", err.Error())
|
|
|
|
certPath := filepath.Join(tempDir, "test.pem")
|
|
keyPath := filepath.Join(tempDir, "test.key")
|
|
_, err = util.NewTLS(caPath, certPath, keyPath, "localhost", nil)
|
|
require.Regexp(t, "could not load client key pair: open.*", err.Error())
|
|
|
|
err = os.WriteFile(certPath, []byte("invalid cert content"), 0644)
|
|
require.NoError(t, err)
|
|
err = os.WriteFile(keyPath, []byte("invalid key content"), 0600)
|
|
require.NoError(t, err)
|
|
_, err = util.NewTLS(caPath, certPath, keyPath, "localhost", nil)
|
|
require.Regexp(t, "could not load client key pair: tls.*", err.Error())
|
|
}
|
|
|
|
func TestVerifyCommonNameAndRotate(t *testing.T) {
|
|
caData, certs, keys := generateCerts(t, []string{"server", "client1", "client2"})
|
|
serverCert, serverKey := certs[0], keys[0]
|
|
client1Cert, client1Key := certs[1], keys[1]
|
|
client2Cert, client2Key := certs[2], keys[2]
|
|
|
|
// only allow client1 to visit
|
|
serverTLS, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData),
|
|
util.WithCertAndKeyContent(serverCert, serverKey),
|
|
util.WithVerifyCommonName([]string{"client1"}),
|
|
)
|
|
require.NoError(t, err)
|
|
port := 9292
|
|
url := fmt.Sprintf("https://127.0.0.1:%d", port)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
server := runServer(ctx, serverTLS, port, t)
|
|
defer func() {
|
|
cancel()
|
|
server.Close()
|
|
}()
|
|
|
|
clientTLS1, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData),
|
|
util.WithCertAndKeyContent(client1Cert, client1Key),
|
|
)
|
|
require.NoError(t, err)
|
|
resp, err := util.ClientWithTLS(clientTLS1).Get(url)
|
|
require.NoError(t, err)
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "This an example server", string(body))
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
// client2 can't visit server
|
|
dir := t.TempDir()
|
|
certPath := filepath.Join(dir, "client.pem")
|
|
keyPath := filepath.Join(dir, "client.key")
|
|
err = os.WriteFile(certPath, client2Cert, 0600)
|
|
require.NoError(t, err)
|
|
err = os.WriteFile(keyPath, client2Key, 0600)
|
|
require.NoError(t, err)
|
|
|
|
clientTLS2, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData),
|
|
util.WithCertAndKeyPath(certPath, keyPath),
|
|
)
|
|
require.NoError(t, err)
|
|
client2 := util.ClientWithTLS(clientTLS2)
|
|
resp, err = client2.Get(url)
|
|
require.ErrorContains(t, err, "tls: bad certificate")
|
|
if resp != nil {
|
|
require.NoError(t, resp.Body.Close())
|
|
}
|
|
|
|
// test certificate rotation
|
|
err = os.WriteFile(certPath, client1Cert, 0600)
|
|
require.NoError(t, err)
|
|
err = os.WriteFile(keyPath, client1Key, 0600)
|
|
require.NoError(t, err)
|
|
|
|
resp, err = client2.Get(url)
|
|
require.NoError(t, err)
|
|
body, err = io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "This an example server", string(body))
|
|
require.NoError(t, resp.Body.Close())
|
|
}
|
|
|
|
func TestCA(t *testing.T) {
|
|
caData, certs, keys := generateCerts(t, []string{"server", "client"})
|
|
serverCert, serverKey := certs[0], keys[0]
|
|
clientCert, clientKey := certs[1], keys[1]
|
|
|
|
caData2, _, _ := generateCerts(t, nil)
|
|
|
|
serverTLS, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData),
|
|
util.WithCertAndKeyContent(serverCert, serverKey),
|
|
)
|
|
require.NoError(t, err)
|
|
port := 9293
|
|
url := fmt.Sprintf("https://127.0.0.1:%d", port)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
server := runServer(ctx, serverTLS, port, t)
|
|
defer func() {
|
|
cancel()
|
|
server.Close()
|
|
}()
|
|
|
|
// test only CA
|
|
clientTLS1, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData),
|
|
)
|
|
require.NoError(t, err)
|
|
resp, err := util.ClientWithTLS(clientTLS1).Get(url)
|
|
require.NoError(t, err)
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "This an example server", string(body))
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
// test without CA
|
|
clientTLS2, err := util.NewTLSConfig(
|
|
util.WithCertAndKeyContent(clientCert, clientKey),
|
|
)
|
|
require.NoError(t, err)
|
|
// inject CA to imitate our generated CA is a trusted CA
|
|
clientTLS2.RootCAs = clientTLS1.RootCAs
|
|
resp, err = util.ClientWithTLS(clientTLS2).Get(url)
|
|
require.NoError(t, err)
|
|
body, err = io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "This an example server", string(body))
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
// test wrong CA should fail
|
|
clientTLS3, err := util.NewTLSConfig(
|
|
util.WithCAContent(caData2),
|
|
)
|
|
require.NoError(t, err)
|
|
// inject CA to imitate our generated CA is a trusted CA
|
|
certPool := x509.NewCertPool()
|
|
ok := certPool.AppendCertsFromPEM(caData)
|
|
require.True(t, ok)
|
|
ok = certPool.AppendCertsFromPEM(caData2)
|
|
require.True(t, ok)
|
|
clientTLS3.RootCAs = certPool
|
|
resp, err = util.ClientWithTLS(clientTLS3).Get(url)
|
|
require.ErrorContains(t, err, "different CA is used")
|
|
if resp != nil {
|
|
require.NoError(t, resp.Body.Close())
|
|
}
|
|
}
|
|
|
|
func handler(w http.ResponseWriter, req *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.Write([]byte("This an example server"))
|
|
}
|
|
|
|
func runServer(ctx context.Context, tlsCfg *tls.Config, port int, t *testing.T) *http.Server {
|
|
http.HandleFunc("/", handler)
|
|
server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: nil}
|
|
|
|
conn, err := net.Listen("tcp", server.Addr)
|
|
if err != nil {
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
tlsListener := tls.NewListener(conn, tlsCfg)
|
|
go server.Serve(tlsListener)
|
|
return server
|
|
}
|
|
|
|
// generateCerts returns the PEM contents of a CA certificate and some certificates and private keys per Common Name in
|
|
// commonNames.
|
|
// thanks to https://shaneutt.com/blog/golang-ca-and-signed-cert-go/.
|
|
func generateCerts(t *testing.T, commonNames []string) (caCert []byte, certs [][]byte, keys [][]byte) {
|
|
caPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
ca := &x509.Certificate{
|
|
SerialNumber: big.NewInt(2019),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"test"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
IsCA: true,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
caPEM := new(bytes.Buffer)
|
|
err = pem.Encode(caPEM, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caBytes,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
caPrivKeyPEM := new(bytes.Buffer)
|
|
err = pem.Encode(caPrivKeyPEM, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
caBytes = caPEM.Bytes()
|
|
|
|
for _, cn := range commonNames {
|
|
cert := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1658),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"test"},
|
|
CommonName: cn,
|
|
},
|
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
}
|
|
|
|
certPrivKey, err2 := rsa.GenerateKey(rand.Reader, 4096)
|
|
require.NoError(t, err2)
|
|
|
|
certBytes, err2 := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
|
|
require.NoError(t, err2)
|
|
|
|
certPEM := new(bytes.Buffer)
|
|
err2 = pem.Encode(certPEM, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certBytes,
|
|
})
|
|
require.NoError(t, err2)
|
|
|
|
certPrivKeyPEM := new(bytes.Buffer)
|
|
err2 = pem.Encode(certPrivKeyPEM, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
|
})
|
|
require.NoError(t, err2)
|
|
certs = append(certs, certPEM.Bytes())
|
|
keys = append(keys, certPrivKeyPEM.Bytes())
|
|
}
|
|
|
|
return caBytes, certs, keys
|
|
}
|