mirror of
https://github.com/rclone/rclone.git
synced 2025-04-16 16:18:52 +08:00
Merge 7b3cf413986c15db004e697e7e40314a8478f9b2 into 0b9671313b14ffe839ecbd7dd2ae5ac7f6f05db8
This commit is contained in:
commit
171e11db09
@ -2550,6 +2550,70 @@ The `--client-key` flag is required too when using this.
|
||||
This loads the PEM encoded client side private key used for mutual TLS
|
||||
authentication. Used in conjunction with `--client-cert`.
|
||||
|
||||
### --wincrypt string
|
||||
|
||||
Available on Windows only.
|
||||
|
||||
Load certificate/key pairs from Windows certificate store using
|
||||
Cryptography API: Next Generation (CNG) for
|
||||
[mutual TLS authentication](https://en.wikipedia.org/wiki/Mutual_authentication),
|
||||
mainly designed for Smartcard scenario.
|
||||
|
||||
Certificates expired or not yet valid will be filtered out.
|
||||
|
||||
The input string is case-insensitive in most cases except said otherwise.
|
||||
Spaces at the beginning and end of the string will be discarded.
|
||||
The string must be one of the following format:
|
||||
|
||||
- `all`
|
||||
|
||||
Means load all possible certificate/key pair from user personal
|
||||
certificate store.
|
||||
|
||||
- `select`
|
||||
|
||||
Means open a certificate selection menu for user to choose one of
|
||||
certificates desired.
|
||||
|
||||
- `hash:`*`hashstr1`*`,`*`hashstr2`*`,...`
|
||||
|
||||
Means choose certificates using the specified hex hash strings. May not
|
||||
be a empty hex hash string. Multiple hash strings are sperated with `,`
|
||||
character. Hash strings referring to the same certificates are ignored.
|
||||
Spaces at the beginning and end of a single hash string are discarded.
|
||||
Hash strings with spaces in the middle of them will not be accepted.
|
||||
Supported hash algorithms are MD5, SHA-1, SHA-256, SHA-384 and SHA-512.
|
||||
Any hash string specified other than the supported one will not be accepted.
|
||||
Hash string matches no certificate is accepted. If any unaccepted string
|
||||
found, no certificate/key pair will be loaded and result in error.
|
||||
|
||||
- *OpenSSL style certificate subject string*
|
||||
|
||||
Match certificate by specified subject string in OpenSSL style. Attribute
|
||||
values are case sensitive. Special characters `/+\=` among the string can
|
||||
be escaped by `\` character. Spaces in the subject string is keeped as is.
|
||||
Attributes with multiple value can be specified by sperating them
|
||||
with `+` character. Since spaces in begin and end of the input string are
|
||||
discarded, if the last value of the last attribute value has spaces in the
|
||||
end, you should add postfix `/` at the end of the input string. Attribute
|
||||
name can not be empty but values can be. The matching rules described as
|
||||
follows:
|
||||
|
||||
- Unknown attribute is ignored.
|
||||
- Empty attribute means the attribute always matches regardless of whether
|
||||
it's present and the value in the certificate.
|
||||
- For non-empty single value attribute, it matches if the attribute value
|
||||
of the certificate contains the value specified in the subject string.
|
||||
- For non-empty attribute with multiple values, it matches if all values
|
||||
of the attribute specified in the subject string can be found and correspond
|
||||
to one of the values in the corresponding certificate fields. For example,
|
||||
let there be a certificate with subject `/OU=O1+O2` and the specified subject
|
||||
string is `/OU=O1+O1` it will not match as one of the values of organization
|
||||
unit specified (the second "O1") does not correspond to one of organization
|
||||
unit values in the certificate subject.
|
||||
- If a attribute is specified multiple times, if any of its value
|
||||
(or values, if it can have multiple values) matches, the attribute matches.
|
||||
|
||||
### --no-check-certificate=true/false ###
|
||||
|
||||
`--no-check-certificate` controls whether a client verifies the
|
||||
|
@ -570,6 +570,10 @@ uploads or downloads to limit the number of total connections.
|
||||
`, "|", "`"),
|
||||
Default: 0,
|
||||
Advanced: true,
|
||||
}, {
|
||||
Name: "wincrypt",
|
||||
Help: "Load certificate/key pairs from Windows certificate store for mutual TLS authentication",
|
||||
Groups: "Networking",
|
||||
}}
|
||||
|
||||
// ConfigInfo is filesystem config options
|
||||
@ -681,6 +685,7 @@ type ConfigInfo struct {
|
||||
PartialSuffix string `config:"partial_suffix"`
|
||||
MetadataMapper SpaceSepList `config:"metadata_mapper"`
|
||||
MaxConnections int `config:"max_connections"`
|
||||
WinCrypt string `config:"wincrypt"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/fs/accounting"
|
||||
"github.com/rclone/rclone/lib/structs"
|
||||
"github.com/rclone/rclone/lib/wincrypt"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
@ -68,6 +69,17 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
|
||||
}
|
||||
|
||||
// Load client certs
|
||||
if ci.WinCrypt != "" {
|
||||
crypts, err := wincrypt.LoadWinCryptCerts(ci.WinCrypt)
|
||||
if err != nil || crypts == nil {
|
||||
fs.Panicf(nil, "Failed to load WinCrypt certificate: %v", err)
|
||||
} else {
|
||||
for _, crypt := range crypts {
|
||||
t.TLSClientConfig.Certificates = append(t.TLSClientConfig.Certificates, crypt.TLSCertificate())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ci.ClientCert != "" || ci.ClientKey != "" {
|
||||
if ci.ClientCert == "" || ci.ClientKey == "" {
|
||||
fs.Fatalf(nil, "Both --client-cert and --client-key must be set")
|
||||
@ -83,7 +95,7 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
|
||||
fs.Fatalf(nil, "Failed to parse the certificate")
|
||||
}
|
||||
}
|
||||
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
t.TLSClientConfig.Certificates = append(t.TLSClientConfig.Certificates, cert)
|
||||
}
|
||||
|
||||
// Load CA certs
|
||||
@ -297,20 +309,22 @@ func (t *Transport) reloadCertificates() {
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
|
||||
if err != nil {
|
||||
fs.Fatalf(nil, "Failed to load --client-cert/--client-key pair: %v", err)
|
||||
}
|
||||
// Check if we need to parse the certificate again, we need it
|
||||
// for checking the expiration date
|
||||
if cert.Leaf == nil {
|
||||
// Leaf is always the first certificate
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if t.clientCert != "" && t.clientKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
|
||||
if err != nil {
|
||||
fs.Fatalf(nil, "Failed to parse the certificate")
|
||||
fs.Fatalf(nil, "Failed to load --client-cert/--client-key pair: %v", err)
|
||||
}
|
||||
// Check if we need to parse the certificate again, we need it
|
||||
// for checking the expiration date
|
||||
if cert.Leaf == nil {
|
||||
// Leaf is always the first certificate
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
fs.Fatalf(nil, "Failed to parse the certificate")
|
||||
}
|
||||
}
|
||||
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
|
21
lib/wincrypt/export_test.go
Normal file
21
lib/wincrypt/export_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package wincrypt
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
NCrypt = ncrypt
|
||||
Crypt32 = crypt32
|
||||
)
|
||||
|
||||
func NCryptFreeObject(obj syscall.Handle) error {
|
||||
return ncryptFreeObject(obj)
|
||||
}
|
||||
|
||||
func WrapError(prefix string, err error, args ...any) error {
|
||||
return wrapError(prefix, err, args...)
|
||||
}
|
35
lib/wincrypt/rdnattrtype_string.go
Normal file
35
lib/wincrypt/rdnattrtype_string.go
Normal file
@ -0,0 +1,35 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
// Code generated by "stringer -type=rdnAttrType"; DO NOT EDIT.
|
||||
|
||||
package wincrypt
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[COMMONNAME-0]
|
||||
_ = x[SERIALNUMBER-1]
|
||||
_ = x[COUNTRYNAME-2]
|
||||
_ = x[LOCALITYNAME-3]
|
||||
_ = x[STATEORPROVINCENAME-4]
|
||||
_ = x[STREETADDRESS-5]
|
||||
_ = x[ORGANIZATIONNAME-6]
|
||||
_ = x[ORGANIZATIONALUNITNAME-7]
|
||||
_ = x[POSTALCODE-8]
|
||||
_ = x[ATTRUNSPECIFIED-9]
|
||||
}
|
||||
|
||||
const _rdnAttrType_name = "COMMONNAMESERIALNUMBERCOUNTRYNAMELOCALITYNAMESTATEORPROVINCENAMESTREETADDRESSORGANIZATIONNAMEORGANIZATIONALUNITNAMEPOSTALCODEATTRUNSPECIFIED"
|
||||
|
||||
var _rdnAttrType_index = [...]uint8{0, 10, 22, 33, 45, 64, 77, 93, 115, 125, 140}
|
||||
|
||||
func (i rdnAttrType) String() string {
|
||||
if i < 0 || i >= rdnAttrType(len(_rdnAttrType_index)-1) {
|
||||
return "rdnAttrType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _rdnAttrType_name[_rdnAttrType_index[i]:_rdnAttrType_index[i+1]]
|
||||
}
|
676
lib/wincrypt/wincrypt_http_test.go
Normal file
676
lib/wincrypt/wincrypt_http_test.go
Normal file
@ -0,0 +1,676 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package wincrypt_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/fs/fshttp"
|
||||
"github.com/rclone/rclone/lib/atexit"
|
||||
"github.com/rclone/rclone/lib/wincrypt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//nolint:revive // Windows constants
|
||||
const (
|
||||
CERT_NCRYPT_KEY_HANDLE_PROP_ID uint32 = 78
|
||||
CERT_KEY_PROV_INFO_PROP_ID uint32 = 2
|
||||
CERT_CLOSE_STORE_FORCE_FLAG uint32 = 1
|
||||
NCRYPT_OVERWRITE_KEY_FLAG uint32 = 0x00000080
|
||||
BCRYPT_RSAPRIVATE_MAGIC uint32 = 0x32415352
|
||||
BCRYPT_ECDSA_PRIVATE_P256_MAGIC uint32 = 0x32534345
|
||||
BCRYPT_ECDSA_PRIVATE_P384_MAGIC uint32 = 0x34534345
|
||||
BCRYPT_ECDSA_PRIVATE_P521_MAGIC uint32 = 0x36534345
|
||||
NCRYPTBUFFER_PKCS_KEY_NAME uint32 = 45
|
||||
BCRYPTBUFFER_VERSION uint32 = 0
|
||||
NCRYPT_SILENT_FLAG uint32 = windows.CRYPT_SILENT
|
||||
)
|
||||
|
||||
//nolint:revive // Windows constants
|
||||
var (
|
||||
MS_KEY_STORAGE_PROVIDER = [...]uint16{'M', 'i', 'c', 'r', 'o', 's', 'o', 'f', 't', ' ', 'S', 'o', 'f', 't', 'w', 'a', 'r', 'e', ' ', 'K', 'e', 'y', ' ', 'S', 't', 'o', 'r', 'a', 'g', 'e', ' ', 'P', 'r', 'o', 'v', 'i', 'd', 'e', 'r', 0}
|
||||
BCRYPT_RSAPRIVATE_BLOB = [...]uint16{'R', 'S', 'A', 'P', 'R', 'I', 'V', 'A', 'T', 'E', 'B', 'L', 'O', 'B', 0}
|
||||
BCRYPT_ECCPRIVATE_BLOB = [...]uint16{'E', 'C', 'C', 'P', 'R', 'I', 'V', 'A', 'T', 'E', 'B', 'L', 'O', 'B', 0}
|
||||
)
|
||||
|
||||
var (
|
||||
procCertAddEncodedCertificateToStore = wincrypt.Crypt32.NewProc("CertAddEncodedCertificateToStore")
|
||||
procCertDeleteCertificateFromStore = wincrypt.Crypt32.NewProc("CertDeleteCertificateFromStore")
|
||||
procCertSetCertificateContextProperty = wincrypt.Crypt32.NewProc("CertSetCertificateContextProperty")
|
||||
procNCryptOpenStorageProvider = wincrypt.NCrypt.NewProc("NCryptOpenStorageProvider")
|
||||
procNCryptImportKey = wincrypt.NCrypt.NewProc("NCryptImportKey")
|
||||
procNCryptDeleteKey = wincrypt.NCrypt.NewProc("NCryptDeleteKey")
|
||||
|
||||
RNG = rand.New(rand.NewSource(time.Now().Unix()))
|
||||
|
||||
TestHashers = []crypto.Hash{
|
||||
crypto.MD5,
|
||||
crypto.SHA1,
|
||||
crypto.SHA256,
|
||||
crypto.SHA384,
|
||||
crypto.SHA512,
|
||||
}
|
||||
)
|
||||
|
||||
//nolint:revive // Windows structures
|
||||
type (
|
||||
NCRYPT_KEY_HANDLE uintptr
|
||||
NCRYPT_PROV_HANDLE uintptr
|
||||
|
||||
BCryptBuffer struct {
|
||||
cbBuffer uint32
|
||||
bufferType uint32
|
||||
pvBuffer uintptr
|
||||
}
|
||||
|
||||
BCryptBufferDesc struct {
|
||||
ulVersion uint32
|
||||
cBuffers uint32
|
||||
pBuffers uintptr
|
||||
}
|
||||
|
||||
CRYPT_KEY_PROV_INFO struct {
|
||||
pwszContainerName *uint16
|
||||
pwszProvName *uint16
|
||||
dwProvType uint32
|
||||
dwFlags uint32
|
||||
cProvParam uint32
|
||||
rgProvParam uintptr
|
||||
dwKeySpec uint32
|
||||
}
|
||||
)
|
||||
|
||||
type keyPair struct {
|
||||
cert *x509.Certificate
|
||||
key crypto.Signer
|
||||
certCtx *syscall.CertContext
|
||||
privHandle NCRYPT_KEY_HANDLE
|
||||
}
|
||||
|
||||
type testHash struct {
|
||||
algo crypto.Hash
|
||||
hash string
|
||||
}
|
||||
|
||||
type testData struct {
|
||||
name string
|
||||
isECDSA bool
|
||||
rsaKeySize uint32 // if RSA
|
||||
curve elliptic.Curve // if ECDSA
|
||||
hashes []testHash
|
||||
}
|
||||
|
||||
type testCert struct {
|
||||
data *testData
|
||||
pair *keyPair
|
||||
}
|
||||
|
||||
func castTo[T any](ptr *T) uintptr {
|
||||
return uintptr(unsafe.Pointer(ptr))
|
||||
}
|
||||
|
||||
func ncryptOpenStorageProvider(handle *NCRYPT_PROV_HANDLE, provider []uint16, flags uint32) (err error) {
|
||||
errno, _, _ := procNCryptOpenStorageProvider.Call(castTo(handle), castTo(&provider[0]), uintptr(flags))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ncryptImportKey(prov NCRYPT_PROV_HANDLE, importKey NCRYPT_KEY_HANDLE, blobType []uint16, params uintptr, handle *NCRYPT_KEY_HANDLE, data []byte, flags uint32) (err error) {
|
||||
errno, _, _ := procNCryptImportKey.Call(uintptr(prov), uintptr(importKey), castTo(&blobType[0]), params, castTo(handle), castTo(&data[0]), uintptr(len(data)), uintptr(flags))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ncryptDeleteKey(key NCRYPT_KEY_HANDLE, flags uint32) (err error) {
|
||||
errno, _, _ := procNCryptDeleteKey.Call(uintptr(key), uintptr(flags))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getBCryptBlobType(priv crypto.Signer) ([]uint16, error) {
|
||||
|
||||
if _, ok := priv.(*rsa.PrivateKey); ok {
|
||||
return BCRYPT_RSAPRIVATE_BLOB[:], nil
|
||||
}
|
||||
if eckey, ok := priv.(*ecdsa.PrivateKey); ok {
|
||||
switch eckey.Curve {
|
||||
case elliptic.P256():
|
||||
fallthrough
|
||||
case elliptic.P384():
|
||||
fallthrough
|
||||
case elliptic.P521():
|
||||
return BCRYPT_ECCPRIVATE_BLOB[:], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("unsupported private key type")
|
||||
}
|
||||
|
||||
func certAddEncodedCertificateToStore(store syscall.Handle, encodingType uint32, encoded []byte, addDisposition uint32) (ctx *syscall.CertContext, err error) {
|
||||
result, _, err := procCertAddEncodedCertificateToStore.Call(uintptr(store), uintptr(encodingType), castTo(&encoded[0]), uintptr(len(encoded)), uintptr(addDisposition), castTo(&ctx))
|
||||
if result != 0 {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func certDeleteCertificateFromStore(ctx *syscall.CertContext) (err error) {
|
||||
result, _, err := procCertDeleteCertificateFromStore.Call(castTo(ctx))
|
||||
if result != 0 {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func certSetCertificateContextProperty(ctx *syscall.CertContext, propID uint32, flags uint32, data uintptr) (err error) {
|
||||
result, _, err := procCertSetCertificateContextProperty.Call(castTo(ctx), uintptr(propID), uintptr(flags), data)
|
||||
if result != 0 {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getX509CertHash(hasher hash.Hash, cert *x509.Certificate) string {
|
||||
hasher.Reset()
|
||||
hasher.Write(cert.Raw)
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
func makeCert(caCert *x509.Certificate, caPriv crypto.Signer, priv crypto.Signer, tmpl *x509.Certificate, validity time.Duration) (*x509.Certificate, error) {
|
||||
notBefore := time.Now().UTC()
|
||||
notAfter := notBefore.UTC().Add(validity)
|
||||
clientTmpl := x509.Certificate{
|
||||
SerialNumber: tmpl.SerialNumber,
|
||||
Subject: tmpl.Subject,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
IPAddresses: tmpl.IPAddresses,
|
||||
DNSNames: tmpl.DNSNames,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(RNG, &clientTmpl, caCert, priv.Public(), caPriv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x509.ParseCertificate(derBytes)
|
||||
}
|
||||
|
||||
func encodeBNBE(bn *big.Int, size int) []byte {
|
||||
data := make([]byte, size)
|
||||
return bn.FillBytes(data)
|
||||
}
|
||||
|
||||
func toBCryptRSABlob(key *rsa.PrivateKey) ([]byte, error) {
|
||||
bnPubExp := big.NewInt(int64(key.PublicKey.E))
|
||||
pubExp := encodeBNBE(bnPubExp, len(bnPubExp.Bytes()))
|
||||
cbN := (key.N.BitLen() + 7) / 8
|
||||
N := encodeBNBE(key.N, cbN)
|
||||
Prime1 := encodeBNBE(key.Primes[0], len(key.Primes[0].Bytes()))
|
||||
Prime2 := encodeBNBE(key.Primes[1], len(key.Primes[1].Bytes()))
|
||||
// See BCRYPT_RSAKEY_BLOB structure and remarks in MS documentation
|
||||
buf := new(bytes.Buffer)
|
||||
err := binary.Write(buf, binary.NativeEndian, BCRYPT_RSAPRIVATE_MAGIC)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(key.N.BitLen()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(len(pubExp)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(cbN))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(len(Prime1)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(len(Prime2)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Endianness does not apply to []byte
|
||||
err = binary.Write(buf, binary.NativeEndian, pubExp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, N)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, Prime1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, Prime2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func toBCryptECCBlob(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
// See BCRYPT_ECCKEY_BLOB structure and remarks in MS documentation
|
||||
buf := new(bytes.Buffer)
|
||||
var err error
|
||||
switch key.Curve {
|
||||
case elliptic.P256():
|
||||
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P256_MAGIC)
|
||||
case elliptic.P384():
|
||||
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P384_MAGIC)
|
||||
case elliptic.P521():
|
||||
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P521_MAGIC)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported ECDSA curve \"%s\"", key.Params().Name)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cbKey := (key.Params().BitSize + 7) / 8
|
||||
|
||||
err = binary.Write(buf, binary.NativeEndian, uint32(cbKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ditto
|
||||
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.X, cbKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.Y, cbKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.D, cbKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func encodeBCryptPrivateKey(key crypto.Signer) ([]byte, error) {
|
||||
if k, ok := key.(*rsa.PrivateKey); ok {
|
||||
return toBCryptRSABlob(k)
|
||||
}
|
||||
if k, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
return toBCryptECCBlob(k)
|
||||
}
|
||||
return nil, errors.New("unsupported key type")
|
||||
}
|
||||
|
||||
var certSerial int64
|
||||
|
||||
func makeCACert(commonName string, serialNumber int64) (*x509.Certificate, crypto.Signer, error) {
|
||||
caPriv, err := rsa.GenerateKey(RNG, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate CA pricate key: %v", err)
|
||||
}
|
||||
caTmp := x509.Certificate{
|
||||
SerialNumber: big.NewInt(serialNumber),
|
||||
Subject: pkix.Name{
|
||||
CommonName: commonName,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(10, 0, 0),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
caBytes, err := x509.CreateCertificate(RNG, &caTmp, &caTmp, &caPriv.PublicKey, caPriv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create and sign CA certificate: %v", err)
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to parse created CA certificate: %v", err)
|
||||
}
|
||||
return caCert, caPriv, nil
|
||||
}
|
||||
|
||||
func makeTestCerts(t *testing.T, data []*testCert, prov NCRYPT_PROV_HANDLE, store syscall.Handle) (*x509.Certificate, crypto.Signer) {
|
||||
caCert, caPriv, err := makeCACert("WinCryptTestClientCA", 0xCA1)
|
||||
require.NoError(t, err)
|
||||
for _, item := range data {
|
||||
certSerial += 1
|
||||
var privKey crypto.Signer
|
||||
if !item.data.isECDSA {
|
||||
privKey, err = rsa.GenerateKey(RNG, int(item.data.rsaKeySize))
|
||||
} else {
|
||||
privKey, err = ecdsa.GenerateKey(item.data.curve, RNG)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
blobType, err := getBCryptBlobType(privKey)
|
||||
require.NoError(t, err)
|
||||
serial := big.NewInt(certSerial)
|
||||
clientCertTmpl := x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: item.data.name + "Cert",
|
||||
SerialNumber: fmt.Sprintf("%X", serial),
|
||||
},
|
||||
SerialNumber: serial,
|
||||
}
|
||||
clientCert, err := makeCert(caCert, caPriv, privKey, &clientCertTmpl, 60*time.Second)
|
||||
require.NoError(t, err)
|
||||
cngPrivBytes, err := encodeBCryptPrivateKey(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare key container name
|
||||
hasher := crypto.SHA1.New()
|
||||
hasher.Reset()
|
||||
hasher.Write(cngPrivBytes)
|
||||
name, err := windows.UTF16FromString(strings.ToUpper(hex.EncodeToString(hasher.Sum(nil))))
|
||||
require.NoError(t, err)
|
||||
|
||||
var keyHandle NCRYPT_KEY_HANDLE
|
||||
buffer := &BCryptBuffer{
|
||||
cbBuffer: uint32(len(name) * 2),
|
||||
bufferType: NCRYPTBUFFER_PKCS_KEY_NAME,
|
||||
pvBuffer: castTo(&name[0]),
|
||||
}
|
||||
bufferDesc := &BCryptBufferDesc{
|
||||
ulVersion: BCRYPTBUFFER_VERSION,
|
||||
cBuffers: 1,
|
||||
pBuffers: castTo(buffer),
|
||||
}
|
||||
err = ncryptImportKey(prov, NCRYPT_KEY_HANDLE(0), blobType, castTo(bufferDesc), &keyHandle, cngPrivBytes, NCRYPT_SILENT_FLAG|NCRYPT_OVERWRITE_KEY_FLAG)
|
||||
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to import private key", err))
|
||||
ctx, err := certAddEncodedCertificateToStore(store, windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, clientCert.Raw, windows.CERT_STORE_ADD_REPLACE_EXISTING)
|
||||
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to add certificate to store", err))
|
||||
keyInfo := CRYPT_KEY_PROV_INFO{
|
||||
pwszContainerName: &name[0],
|
||||
pwszProvName: &MS_KEY_STORAGE_PROVIDER[0],
|
||||
dwProvType: 0,
|
||||
dwFlags: NCRYPT_SILENT_FLAG,
|
||||
cProvParam: 0,
|
||||
rgProvParam: 0,
|
||||
dwKeySpec: 0,
|
||||
}
|
||||
err = certSetCertificateContextProperty(ctx, CERT_KEY_PROV_INFO_PROP_ID, 0, castTo(&keyInfo))
|
||||
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to associate certficate with private key info", err))
|
||||
err = certSetCertificateContextProperty(ctx, CERT_NCRYPT_KEY_HANDLE_PROP_ID, 0, castTo(&keyHandle))
|
||||
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to associate certficate with private key", err))
|
||||
item.pair = &keyPair{
|
||||
cert: clientCert,
|
||||
key: privKey,
|
||||
certCtx: ctx,
|
||||
privHandle: keyHandle,
|
||||
}
|
||||
for _, hasher := range TestHashers {
|
||||
item.data.hashes = append(item.data.hashes, testHash{algo: hasher, hash: getX509CertHash(hasher.New(), clientCert)})
|
||||
}
|
||||
}
|
||||
|
||||
return caCert, caPriv
|
||||
}
|
||||
|
||||
func TestWinCryptCertificate(t *testing.T) {
|
||||
store, err := syscall.CertOpenSystemStore(syscall.Handle(0), &wincrypt.USER_STORE_PERSONAL[0])
|
||||
require.NoError(t, err)
|
||||
var prov NCRYPT_PROV_HANDLE
|
||||
err = ncryptOpenStorageProvider(&prov, MS_KEY_STORAGE_PROVIDER[:], 0)
|
||||
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to open microsoft software key storage provider", err))
|
||||
serverCA, serverCAPriv, err := makeCACert("WinCryptTestServerCA", 0xCA2)
|
||||
require.NoError(t, err)
|
||||
serverPriv, err := rsa.GenerateKey(RNG, 2048)
|
||||
require.NoError(t, err)
|
||||
tmpl := x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
SerialNumber: "ABC",
|
||||
},
|
||||
SerialNumber: big.NewInt(0xABC),
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
serverCert, err := makeCert(serverCA, serverCAPriv, serverPriv, &tmpl, time.Hour*24*365)
|
||||
require.NoError(t, err)
|
||||
var ts *httptest.Server
|
||||
certs := []*testCert{
|
||||
{
|
||||
data: &testData{
|
||||
name: "RSA2048TEST",
|
||||
isECDSA: false,
|
||||
rsaKeySize: 2048,
|
||||
},
|
||||
}, {
|
||||
data: &testData{
|
||||
name: "RSA4096TEST",
|
||||
isECDSA: false,
|
||||
rsaKeySize: 4096,
|
||||
},
|
||||
}, {
|
||||
data: &testData{
|
||||
name: "P256TEST",
|
||||
isECDSA: true,
|
||||
curve: elliptic.P256(),
|
||||
},
|
||||
}, {
|
||||
data: &testData{
|
||||
name: "P384TEST",
|
||||
isECDSA: true,
|
||||
curve: elliptic.P384(),
|
||||
},
|
||||
}, {
|
||||
data: &testData{
|
||||
name: "P521TEST",
|
||||
isECDSA: true,
|
||||
curve: elliptic.P521(),
|
||||
},
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
// Cleanup
|
||||
if ts != nil {
|
||||
ts.Close()
|
||||
}
|
||||
atexit.Run()
|
||||
for _, cert := range certs {
|
||||
if cert.pair != nil {
|
||||
if cert.pair.privHandle != 0 {
|
||||
err := ncryptDeleteKey(cert.pair.privHandle, NCRYPT_SILENT_FLAG)
|
||||
if err != nil {
|
||||
_ = wincrypt.NCryptFreeObject(syscall.Handle(cert.pair.privHandle))
|
||||
}
|
||||
}
|
||||
if cert.pair.certCtx != nil {
|
||||
_ = certDeleteCertificateFromStore(cert.pair.certCtx)
|
||||
}
|
||||
}
|
||||
}
|
||||
if prov != 0 {
|
||||
_ = wincrypt.NCryptFreeObject(syscall.Handle(prov))
|
||||
}
|
||||
if store != 0 {
|
||||
_ = syscall.CertCloseStore(store, CERT_CLOSE_STORE_FORCE_FLAG)
|
||||
}
|
||||
})
|
||||
caCert, _ := makeTestCerts(t, certs, prov, store)
|
||||
|
||||
clientCAs := x509.NewCertPool()
|
||||
clientCAs.AddCert(caCert)
|
||||
ts = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
peerCerts := r.TLS.PeerCertificates
|
||||
require.Greater(t, len(peerCerts), 0, "No certificate received")
|
||||
t.Logf("Received %d certificates", len(peerCerts))
|
||||
for _, receivedCert := range peerCerts {
|
||||
var testcert *testCert
|
||||
for _, cert := range certs {
|
||||
if receivedCert.SerialNumber.Cmp(cert.pair.cert.SerialNumber) == 0 {
|
||||
testcert = cert
|
||||
for _, hash := range testcert.data.hashes {
|
||||
assert.Equalf(t, hash.hash, getX509CertHash(hash.algo.New(), receivedCert), "%s hash of test certificate \"%s\" does not match expected one", hash.algo, cert.data.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.NotNilf(t, testcert, "Recevied certificate with subject \"%s\" serialNumber \"%X\" does not belong to test certificates", receivedCert.Subject, receivedCert.SerialNumber)
|
||||
if testcert != nil {
|
||||
t.Logf("Received certificate subject \"%s\", serialNumber \"%X\"", receivedCert.Subject, receivedCert.SerialNumber)
|
||||
}
|
||||
}
|
||||
|
||||
chains, err := peerCerts[0].Verify(x509.VerifyOptions{Roots: clientCAs})
|
||||
assert.NoErrorf(t, err, "Failed to verify certificate: %v", err)
|
||||
if err == nil {
|
||||
t.Logf("Verified, constructed %d chains", len(chains))
|
||||
for i, chain := range chains {
|
||||
t.Logf("Chain %d:", i)
|
||||
for j, cert := range chain {
|
||||
t.Logf("Certificate %d: Subject: \"%s\", SerialNumber: \"%X\"", j, cert.Subject, cert.SerialNumber)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write some test data to fulfill the request
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = fmt.Fprintln(w, "test data")
|
||||
}))
|
||||
serverRoot := x509.NewCertPool()
|
||||
serverRoot.AddCert(serverCA)
|
||||
ts.TLS = &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: clientCAs,
|
||||
RootCAs: serverRoot,
|
||||
MaxVersion: tls.VersionTLS12, // TLS V1.3 forces RSA PSS signing
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{{
|
||||
Certificate: [][]byte{serverCert.Raw},
|
||||
PrivateKey: serverPriv,
|
||||
SupportedSignatureAlgorithms: []tls.SignatureScheme{
|
||||
tls.PKCS1WithSHA1,
|
||||
tls.PKCS1WithSHA256,
|
||||
tls.PKCS1WithSHA384,
|
||||
tls.PKCS1WithSHA512,
|
||||
tls.PSSWithSHA256,
|
||||
tls.PSSWithSHA384,
|
||||
tls.PSSWithSHA512,
|
||||
},
|
||||
}},
|
||||
}
|
||||
ts.StartTLS()
|
||||
|
||||
serverCAs := x509.NewCertPool()
|
||||
serverCAs.AddCert(serverCA)
|
||||
|
||||
ctx := context.TODO()
|
||||
ci := fs.GetConfig(ctx)
|
||||
ci.WinCrypt = "all"
|
||||
client := fshttp.NewClient(ctx)
|
||||
tt := client.Transport.(*fshttp.Transport)
|
||||
tt.TLSClientConfig.RootCAs = serverCAs
|
||||
_, err = client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var builder strings.Builder
|
||||
for i, cert := range certs {
|
||||
ctx := context.TODO()
|
||||
ci := fs.GetConfig(ctx)
|
||||
ci.WinCrypt = fmt.Sprintf("/CN=%s/serialNumber=%d", cert.data.name+"Cert", i+1)
|
||||
client := fshttp.NewClient(ctx)
|
||||
tt := client.Transport.(*fshttp.Transport)
|
||||
tt.TLSClientConfig.RootCAs = serverCAs
|
||||
_, err = client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var loadedCerts1 []*x509.Certificate
|
||||
for _, c := range tt.TLSClientConfig.Certificates {
|
||||
loadedCerts1 = append(loadedCerts1, c.Leaf)
|
||||
}
|
||||
|
||||
builder.Reset()
|
||||
builder.WriteString("hash: ")
|
||||
for i, hash := range cert.data.hashes {
|
||||
builder.WriteString(hash.hash)
|
||||
if i != len(cert.data.hashes)-1 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
}
|
||||
hashQuery := builder.String()
|
||||
|
||||
// Test RSA PSS and ECDSA signing
|
||||
ci.WinCrypt = hashQuery
|
||||
client = fshttp.NewClient(ctx)
|
||||
tt = client.Transport.(*fshttp.Transport)
|
||||
tt.TLSClientConfig.RootCAs = serverCAs
|
||||
_, err = client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var loadedCerts2 []*x509.Certificate
|
||||
for _, c := range tt.TLSClientConfig.Certificates {
|
||||
loadedCerts2 = append(loadedCerts2, c.Leaf)
|
||||
}
|
||||
assert.Equal(t, loadedCerts1, loadedCerts2)
|
||||
|
||||
if !cert.data.isECDSA {
|
||||
ci.WinCrypt = ""
|
||||
client = fshttp.NewClient(ctx)
|
||||
tt = client.Transport.(*fshttp.Transport)
|
||||
tt.TLSClientConfig.RootCAs = serverCAs
|
||||
|
||||
crypts, err := wincrypt.LoadWinCryptCerts(hashQuery)
|
||||
assert.NoError(t, err)
|
||||
var loadedCerts []*x509.Certificate
|
||||
for _, crypt := range crypts {
|
||||
tlsCert := crypt.TLSCertificate()
|
||||
|
||||
// Test RSA PKCS v1.5 signing
|
||||
tlsCert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
|
||||
tls.PKCS1WithSHA1,
|
||||
tls.PKCS1WithSHA256,
|
||||
tls.PKCS1WithSHA384,
|
||||
tls.PKCS1WithSHA512,
|
||||
}
|
||||
|
||||
loadedCerts = append(loadedCerts, tlsCert.Leaf)
|
||||
tt.TLSClientConfig.Certificates = append(tt.TLSClientConfig.Certificates, tlsCert)
|
||||
}
|
||||
assert.ElementsMatch(t, loadedCerts1, loadedCerts)
|
||||
_, err = client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
27
lib/wincrypt/wincrypt_other.go
Normal file
27
lib/wincrypt/wincrypt_other.go
Normal file
@ -0,0 +1,27 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
// Code generated by "dummy"; DO NOT EDIT. // Add autogenerated comment to make it ignored by linter
|
||||
|
||||
// Dummy implementation to make it build on non-Windows platform
|
||||
package wincrypt
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type WINCRYPT struct {
|
||||
}
|
||||
|
||||
func (*WINCRYPT) TLSCertificate() tls.Certificate {
|
||||
return tls.Certificate{}
|
||||
}
|
||||
|
||||
func LoadWinCryptCerts(string) ([]*WINCRYPT, error) {
|
||||
return nil, errors.New("Cryptography API: Next Generation (CNG) is only available on Windows")
|
||||
}
|
||||
|
||||
func (*WINCRYPT) Close() error {
|
||||
return nil
|
||||
}
|
112
lib/wincrypt/wincrypt_test.go
Normal file
112
lib/wincrypt/wincrypt_test.go
Normal file
@ -0,0 +1,112 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package wincrypt
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
reference = []*subjectAttr{
|
||||
{
|
||||
attrType: COMMONNAME,
|
||||
values: [][]string{{"Test CN1"}, {"Test CN2"}},
|
||||
},
|
||||
{
|
||||
attrType: SERIALNUMBER,
|
||||
values: [][]string{{"Test serial"}},
|
||||
},
|
||||
{
|
||||
attrType: COUNTRYNAME,
|
||||
values: [][]string{{"Test C1", "Test C2"}},
|
||||
},
|
||||
{
|
||||
attrType: LOCALITYNAME,
|
||||
values: [][]string{{"Test L1", "Test L2"}},
|
||||
},
|
||||
{
|
||||
attrType: STATEORPROVINCENAME,
|
||||
values: [][]string{{"Test ST1", "Test ST2"}},
|
||||
},
|
||||
{
|
||||
attrType: STREETADDRESS,
|
||||
values: [][]string{{"Test street1", "Test street2"}},
|
||||
},
|
||||
{
|
||||
attrType: ORGANIZATIONNAME,
|
||||
values: [][]string{{"Test O1", "Test O2"}},
|
||||
},
|
||||
{
|
||||
attrType: ORGANIZATIONALUNITNAME,
|
||||
values: [][]string{{"Test OU1", "Test OU2"}},
|
||||
},
|
||||
{
|
||||
attrType: POSTALCODE,
|
||||
values: [][]string{{"Test postalcode1", "Test postalcode2"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestParseSubject(t *testing.T) {
|
||||
subject := "/CN=Test CN1/OU=Test OU1+Test OU2/O=Test O1+Test O2/L=Test L1+Test L2/C=Test C1+Test C2/ST=Test ST1+Test ST2/street=Test street1+Test street2/serialNumber=Test serial/postalCode=Test postalcode1+Test postalcode2/CN=Test CN2"
|
||||
subj1, err := parseOpenSSLSubject(subject)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, reference, subj1)
|
||||
subject += "/"
|
||||
subj2, err := parseOpenSSLSubject(subject)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, subj1, subj2)
|
||||
subject += "EMAIL=Test Email"
|
||||
subj2, err = parseOpenSSLSubject(subject)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, reference, subj2)
|
||||
_, err = parseOpenSSLSubject("/CN=CN=")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCertSubjectMatch(t *testing.T) {
|
||||
attrs, err := parseOpenSSLSubject("/CN=TestCN/OU=TestOU1+TestOU1")
|
||||
require.NoError(t, err)
|
||||
cert := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "xTestCN1",
|
||||
OrganizationalUnit: []string{"xTestOU1", "xTestOU2"},
|
||||
},
|
||||
}
|
||||
assert.False(t, isCertAttributesMatch(cert, attrs))
|
||||
cert.Subject.OrganizationalUnit[1] = "xTestOU1"
|
||||
assert.True(t, isCertAttributesMatch(cert, attrs))
|
||||
attrs, err = parseOpenSSLSubject("/CN=NoMatch")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isCertAttributesMatch(cert, attrs))
|
||||
attrs, err = parseOpenSSLSubject("/CN=NoMatch/CN=")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isCertAttributesMatch(cert, attrs))
|
||||
attrs, err = parseOpenSSLSubject("/CN=NoMatch/CN=Test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isCertAttributesMatch(cert, attrs))
|
||||
cert = &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: reference[COMMONNAME].values[0][0],
|
||||
SerialNumber: reference[SERIALNUMBER].values[0][0],
|
||||
Country: reference[COUNTRYNAME].values[0],
|
||||
Locality: reference[LOCALITYNAME].values[0],
|
||||
Province: reference[STATEORPROVINCENAME].values[0],
|
||||
StreetAddress: reference[STREETADDRESS].values[0],
|
||||
Organization: reference[ORGANIZATIONNAME].values[0],
|
||||
OrganizationalUnit: reference[ORGANIZATIONALUNITNAME].values[0],
|
||||
PostalCode: reference[POSTALCODE].values[0],
|
||||
},
|
||||
}
|
||||
assert.True(t, isCertAttributesMatch(cert, reference))
|
||||
subject := "/CN=Test CN2/OU=Test OU1+Test OU2/O=Test O1+Test O2/L=Test L1+Test L2/C=Test C1+Test C2/ST=Test ST1+Test ST2/street=Test street1+Test street2/serialNumber=Test serial/postalCode=Test postalcode1+Test postalcode2"
|
||||
attrs, err = parseOpenSSLSubject(subject)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isCertAttributesMatch(cert, attrs))
|
||||
}
|
844
lib/wincrypt/wincrypt_windows.go
Normal file
844
lib/wincrypt/wincrypt_windows.go
Normal file
@ -0,0 +1,844 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
//go:generate stringer -type=rdnAttrType
|
||||
|
||||
// Package wincrypt implements loading certificate/key pairs from Windows certificate store for Mutual TLS authentication
|
||||
package wincrypt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"maps"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/lib/atexit"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//nolint:revive // Windows constants
|
||||
var (
|
||||
NCRYPT_KEY_USAGE_PROPERTY = [...]uint16{'K', 'e', 'y', ' ', 'U', 's', 'a', 'g', 'e', 0}
|
||||
NCRYPT_ALGORITHM_GROUP_PROPERTY = [...]uint16{'A', 'l', 'g', 'o', 'r', 'i', 't', 'h', 'm', ' ', 'G', 'r', 'o', 'u', 'p', 0}
|
||||
NCRYPT_ECDH_ALGORITHM_GROUP = [...]uint16{'E', 'C', 'D', 'H', 0}
|
||||
NCRYPT_RSA_ALGORITHM_GROUP = [...]uint16{'R', 'S', 'A', 0}
|
||||
NCRYPT_ECDSA_ALGORITHM_GROUP = [...]uint16{'E', 'C', 'D', 'S', 'A', 0}
|
||||
BCRYPT_SHA1_ALGORITHM = [...]uint16{'S', 'H', 'A', '1', 0}
|
||||
BCRYPT_SHA256_ALGORITHM = [...]uint16{'S', 'H', 'A', '2', '5', '6', 0}
|
||||
BCRYPT_SHA384_ALGORITHM = [...]uint16{'S', 'H', 'A', '3', '8', '4', 0}
|
||||
BCRYPT_SHA512_ALGORITHM = [...]uint16{'S', 'H', 'A', '5', '1', '2', 0}
|
||||
USER_STORE_PERSONAL = [...]uint16{'M', 'Y', 0}
|
||||
)
|
||||
|
||||
var (
|
||||
ncrypt = syscall.NewLazyDLL("ncrypt.dll")
|
||||
crypt32 = syscall.NewLazyDLL("crypt32.dll")
|
||||
cryptui = syscall.NewLazyDLL("cryptui.dll")
|
||||
procCryptUIDlgSelectCertificateFromStore = cryptui.NewProc("CryptUIDlgSelectCertificateFromStore")
|
||||
procCryptAcquireCertificatePrivateKey = crypt32.NewProc("CryptAcquireCertificatePrivateKey")
|
||||
procCertEnumCertificatesInStore = crypt32.NewProc("CertEnumCertificatesInStore")
|
||||
procCertDuplicateCertificateContext = crypt32.NewProc("CertDuplicateCertificateContext")
|
||||
procNCryptGetProperty = ncrypt.NewProc("NCryptGetProperty")
|
||||
procNCryptSignHash = ncrypt.NewProc("NCryptSignHash")
|
||||
procNCryptFreeObject = ncrypt.NewProc("NCryptFreeObject")
|
||||
|
||||
attrAltNames = map[rdnAttrType]string{
|
||||
COMMONNAME: "CN",
|
||||
COUNTRYNAME: "C",
|
||||
LOCALITYNAME: "L",
|
||||
STATEORPROVINCENAME: "ST",
|
||||
STREETADDRESS: "STREET",
|
||||
ORGANIZATIONNAME: "O",
|
||||
ORGANIZATIONALUNITNAME: "OU",
|
||||
}
|
||||
)
|
||||
|
||||
//nolint:revive // Windows constants
|
||||
const (
|
||||
NCRYPT_ALLOW_ALL_USAGES uint32 = 0x00FFFFFF
|
||||
NCRYPT_ALLOW_DECRYPT_FLAG uint32 = 0x00000001
|
||||
NCRYPT_ALLOW_SIGNING_FLAG uint32 = 0x00000002
|
||||
NCRYPT_ALLOW_KEY_AGREEMENT_FLAG uint32 = 0x00000004
|
||||
BCRYPT_PAD_PKCS1 uint32 = 0x00000002
|
||||
BCRYPT_PAD_PSS uint32 = 0x00000008
|
||||
)
|
||||
|
||||
type keyType int
|
||||
type rdnAttrType int
|
||||
|
||||
//nolint:revive // Constants in uppercase used for matching
|
||||
const (
|
||||
COMMONNAME rdnAttrType = iota
|
||||
SERIALNUMBER
|
||||
COUNTRYNAME
|
||||
LOCALITYNAME
|
||||
STATEORPROVINCENAME
|
||||
STREETADDRESS
|
||||
ORGANIZATIONNAME
|
||||
ORGANIZATIONALUNITNAME
|
||||
POSTALCODE
|
||||
ATTRUNSPECIFIED
|
||||
)
|
||||
|
||||
type subjectAttr struct {
|
||||
attrType rdnAttrType
|
||||
values [][]string
|
||||
}
|
||||
|
||||
const (
|
||||
keyTypeUnspecified keyType = iota
|
||||
keyTypeRSA
|
||||
keyTypeECDSA
|
||||
)
|
||||
|
||||
//nolint:revive // Windows structures
|
||||
type (
|
||||
BCRYPT_PKCS1_PADDING_INFO struct {
|
||||
pszAlgId *uint16
|
||||
}
|
||||
|
||||
BCRYPT_PSS_PADDING_INFO struct {
|
||||
pszAlgId *uint16
|
||||
cbSalt uint32
|
||||
}
|
||||
)
|
||||
|
||||
// WINCRYPT is a structure encapsulates CNG key handle
|
||||
type WINCRYPT struct {
|
||||
cert tls.Certificate
|
||||
priv syscall.Handle
|
||||
keyType keyType
|
||||
shouldFree bool
|
||||
}
|
||||
|
||||
func wrapError(prefix string, err error, args ...any) error {
|
||||
prefix = fmt.Sprintf(prefix, args...)
|
||||
if errno, ok := err.(syscall.Errno); ok {
|
||||
return fmt.Errorf("%s, errorCode=\"0x%08X\", errorMessage=\"%w\"", prefix, uint32(errno), err)
|
||||
}
|
||||
return fmt.Errorf("%s: %w", prefix, err)
|
||||
}
|
||||
|
||||
func isKeySuitableForSigning(prov syscall.Handle, keyType keyType) (bool, error) {
|
||||
var keyUsage uint32
|
||||
var keyUsageSize uint32 = 4
|
||||
err := ncryptGetProperty(prov, NCRYPT_KEY_USAGE_PROPERTY[:], castTo(&keyUsage), keyUsageSize, &keyUsageSize, 0)
|
||||
if err != nil {
|
||||
return false, wrapError("failed to get Key Usage information", err)
|
||||
}
|
||||
switch {
|
||||
case keyUsage == NCRYPT_ALLOW_ALL_USAGES:
|
||||
fs.Debug(nil, "All usage flag set")
|
||||
return true, nil
|
||||
case keyUsage&NCRYPT_ALLOW_SIGNING_FLAG != 0:
|
||||
fs.Debug(nil, "Signing usage flag set")
|
||||
return true, nil
|
||||
case (keyUsage&NCRYPT_ALLOW_KEY_AGREEMENT_FLAG != 0) && (keyType == keyTypeECDSA):
|
||||
fs.Debug(nil, "Allow ECC key with key agreement usage flag set")
|
||||
return true, nil
|
||||
default:
|
||||
return false, errors.New("provided key is not suitable for signing purpose")
|
||||
}
|
||||
}
|
||||
|
||||
func freeCertificates(certs []*syscall.CertContext) {
|
||||
if len(certs) > 0 {
|
||||
for _, cert := range certs {
|
||||
_ = syscall.CertFreeCertificateContext(cert)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func castFrom[T any](ptr uintptr) *T {
|
||||
// Add another indirect level to please go vet
|
||||
return (*T)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr)))
|
||||
}
|
||||
|
||||
func castTo[T any](ptr *T) uintptr {
|
||||
return uintptr(unsafe.Pointer(ptr))
|
||||
}
|
||||
|
||||
func enumerateCertificates(store syscall.Handle) ([]*syscall.CertContext, error) {
|
||||
currentCtx, _, err := procCertEnumCertificatesInStore.Call(uintptr(store), 0)
|
||||
if err == syscall.Errno(windows.CRYPT_E_NOT_FOUND) {
|
||||
return nil, nil
|
||||
} else if err != windows.ERROR_SUCCESS {
|
||||
return nil, err
|
||||
}
|
||||
var certs []*syscall.CertContext
|
||||
for {
|
||||
copyCtx, _, err := procCertDuplicateCertificateContext.Call(currentCtx)
|
||||
if err != windows.ERROR_SUCCESS {
|
||||
certs = append(certs, castFrom[syscall.CertContext](currentCtx))
|
||||
freeCertificates(certs)
|
||||
return nil, err
|
||||
}
|
||||
certs = append(certs, castFrom[syscall.CertContext](copyCtx))
|
||||
currentCtx, _, err = procCertEnumCertificatesInStore.Call(uintptr(store), currentCtx)
|
||||
if err == syscall.Errno(windows.CRYPT_E_NOT_FOUND) {
|
||||
return certs, nil
|
||||
} else if err != windows.ERROR_SUCCESS {
|
||||
freeCertificates(certs)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unprintable string, used as map key only
|
||||
func calculateCertHash(hasher hash.Hash, cert *x509.Certificate) string {
|
||||
hasher.Reset()
|
||||
hasher.Write(cert.Raw)
|
||||
return string(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
func parseCertContext(ctx *syscall.CertContext) (*x509.Certificate, error) {
|
||||
encoded := unsafe.Slice(ctx.EncodedCert, ctx.Length)
|
||||
buf := make([]byte, ctx.Length)
|
||||
copy(buf, encoded)
|
||||
return x509.ParseCertificate(buf)
|
||||
}
|
||||
|
||||
func getRDNAttrType(name string) rdnAttrType {
|
||||
str := strings.ToUpper(name)
|
||||
for attrType := rdnAttrType(0); attrType < ATTRUNSPECIFIED; attrType++ {
|
||||
if attrType.String() == str {
|
||||
return attrType
|
||||
} else if altName, ok := attrAltNames[attrType]; ok && altName == str {
|
||||
return attrType
|
||||
}
|
||||
}
|
||||
return ATTRUNSPECIFIED
|
||||
}
|
||||
|
||||
func parseOpenSSLSubject(subjectStr string) ([]*subjectAttr, error) {
|
||||
subject := []rune(subjectStr)
|
||||
if len(subject) == 0 {
|
||||
return nil, errors.New("no subject string specified")
|
||||
}
|
||||
if subject[0] != '/' {
|
||||
return nil, fmt.Errorf("invalid first charater of subject string, expects '/', got '%c'", subject[0])
|
||||
}
|
||||
attrs := make(map[rdnAttrType]*subjectAttr, int(ATTRUNSPECIFIED))
|
||||
i := 1
|
||||
var str strings.Builder
|
||||
noTypeError := errors.New("no RDN type string specified")
|
||||
parseName := func() (string, error) {
|
||||
if i >= len(subject) {
|
||||
return "", noTypeError
|
||||
}
|
||||
str.Reset()
|
||||
err:
|
||||
for ; i < len(subject); i++ {
|
||||
switch subject[i] {
|
||||
case '/':
|
||||
if str.Len() == 0 {
|
||||
return "", noTypeError
|
||||
}
|
||||
break err
|
||||
case '=':
|
||||
if str.Len() == 0 {
|
||||
return "", noTypeError
|
||||
}
|
||||
i += 1
|
||||
return str.String(), nil
|
||||
case '\\':
|
||||
if i+1 >= len(subject) {
|
||||
return "", fmt.Errorf("unexpected EOF during parsing escape character in RDN type string at index %d", i)
|
||||
}
|
||||
i += 1
|
||||
fallthrough
|
||||
default:
|
||||
str.WriteRune(subject[i])
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing '=' after RDN type string \"%s\" in subject name string at index %d", str.String(), i)
|
||||
}
|
||||
parseValue := func() ([]string, error) {
|
||||
var values []string
|
||||
// attributes without value are allowed
|
||||
if i >= len(subject) {
|
||||
return values, nil
|
||||
}
|
||||
str.Reset()
|
||||
out:
|
||||
for ; i < len(subject); i++ {
|
||||
switch subject[i] {
|
||||
case '=':
|
||||
return nil, fmt.Errorf("unexpected character '=' found at index %d. If it's intended, please escape it with '\\'", i)
|
||||
case '/':
|
||||
i += 1
|
||||
break out
|
||||
case '+':
|
||||
if str.Len() != 0 {
|
||||
values = append(values, str.String())
|
||||
str.Reset()
|
||||
}
|
||||
case '\\':
|
||||
if i+1 >= len(subject) {
|
||||
return nil, fmt.Errorf("unexpected EOF during parsing escape character in value at index %d", i)
|
||||
}
|
||||
i += 1
|
||||
fallthrough
|
||||
default:
|
||||
str.WriteRune(subject[i])
|
||||
}
|
||||
}
|
||||
if str.Len() != 0 {
|
||||
values = append(values, str.String())
|
||||
str.Reset()
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
for i < len(subject) {
|
||||
name, err := parseName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values, err := parseValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attrType := getRDNAttrType(name)
|
||||
if attrType == ATTRUNSPECIFIED {
|
||||
fs.Errorf(nil, "Unknown RDN attribute type \"%s\", ignoring", name)
|
||||
continue
|
||||
}
|
||||
if attrType == COMMONNAME || attrType == SERIALNUMBER {
|
||||
if len(values) > 1 {
|
||||
return nil, fmt.Errorf("attribute type \"%s\" should be single value", attrType.String())
|
||||
}
|
||||
}
|
||||
if attr, ok := attrs[attrType]; ok {
|
||||
attr.values = append(attr.values, values)
|
||||
} else {
|
||||
attrs[attrType] = &subjectAttr{attrType: attrType, values: [][]string{values}}
|
||||
}
|
||||
}
|
||||
return slices.Collect(maps.Values(attrs)), nil
|
||||
}
|
||||
|
||||
type certPair struct {
|
||||
cert *x509.Certificate
|
||||
ctx *syscall.CertContext
|
||||
}
|
||||
|
||||
func matchCertificatesByHash(availablePairs []*certPair, hashes string) ([]*certPair, error) {
|
||||
if len(availablePairs) == 0 {
|
||||
return availablePairs, nil
|
||||
}
|
||||
hashStrs := strings.Split(hashes, ",")
|
||||
|
||||
hashTypes := make(map[crypto.Hash]bool)
|
||||
|
||||
validPairs := make(map[*certPair]bool, len(availablePairs))
|
||||
|
||||
for i := 0; i < len(hashStrs); i++ {
|
||||
hashStrs[i] = strings.TrimSpace(hashStrs[i])
|
||||
hashStr := hashStrs[i]
|
||||
if hashStr == "" {
|
||||
return nil, fmt.Errorf("invalid hash string format: \"%s\"", hashes)
|
||||
}
|
||||
switch len(hashStr) {
|
||||
case 32:
|
||||
hashTypes[crypto.MD5] = true
|
||||
case 40:
|
||||
hashTypes[crypto.SHA1] = true
|
||||
case 64:
|
||||
hashTypes[crypto.SHA256] = true
|
||||
case 96:
|
||||
hashTypes[crypto.SHA384] = true
|
||||
case 128:
|
||||
hashTypes[crypto.SHA512] = true
|
||||
default:
|
||||
return nil, fmt.Errorf("hash string \"%s\" is not a MD5/SHA-1/SHA-256/SHA-384/SHA-512 hex string", hashStr)
|
||||
}
|
||||
}
|
||||
certHashTable := make(map[string][]*certPair)
|
||||
var hashers []hash.Hash
|
||||
for k := range hashTypes {
|
||||
hashers = append(hashers, k.New())
|
||||
}
|
||||
for _, pair := range availablePairs {
|
||||
for _, hasher := range hashers {
|
||||
hashBytes := calculateCertHash(hasher, pair.cert)
|
||||
entries, ok := certHashTable[hashBytes]
|
||||
if !ok {
|
||||
certHashTable[hashBytes] = []*certPair{pair}
|
||||
} else {
|
||||
// Duplicate hash found
|
||||
entries = append(entries, pair)
|
||||
certHashTable[hashBytes] = entries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, hashStr := range hashStrs {
|
||||
|
||||
hash, err := hex.DecodeString(hashStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, ok := certHashTable[string(hash)]
|
||||
if !ok {
|
||||
fs.Errorf(nil, "No certificate has hash \"%s\"", hashStr)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if _, ok := validPairs[entry]; !ok {
|
||||
validPairs[entry] = true
|
||||
} else {
|
||||
fs.LogPrintf(fs.LogLevelWarning, nil, "Found duplicate certificate with hash \"%s\", ignoring", hashStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return slices.Collect(maps.Keys(validPairs)), nil
|
||||
}
|
||||
|
||||
func isCertAttributesMatch(cert *x509.Certificate, attrs []*subjectAttr) bool {
|
||||
valid := true
|
||||
for _, attr := range attrs {
|
||||
|
||||
// For attribute specified multiple times, as long as one of them match, the attribute matches
|
||||
attrMatch := false
|
||||
for _, attrValues := range attr.values {
|
||||
// empty attributes always match
|
||||
if len(attrValues) == 0 {
|
||||
attrMatch = true
|
||||
break
|
||||
}
|
||||
|
||||
var targetStrs []string
|
||||
switch attr.attrType {
|
||||
case COMMONNAME:
|
||||
targetStrs = []string{cert.Subject.CommonName}
|
||||
case SERIALNUMBER:
|
||||
targetStrs = []string{cert.Subject.SerialNumber}
|
||||
case COUNTRYNAME:
|
||||
targetStrs = cert.Subject.Country
|
||||
case LOCALITYNAME:
|
||||
targetStrs = cert.Subject.Locality
|
||||
case STATEORPROVINCENAME:
|
||||
targetStrs = cert.Subject.Province
|
||||
case STREETADDRESS:
|
||||
targetStrs = cert.Subject.StreetAddress
|
||||
case ORGANIZATIONNAME:
|
||||
targetStrs = cert.Subject.Organization
|
||||
case ORGANIZATIONALUNITNAME:
|
||||
targetStrs = cert.Subject.OrganizationalUnit
|
||||
case POSTALCODE:
|
||||
targetStrs = cert.Subject.PostalCode
|
||||
}
|
||||
|
||||
// always won't match
|
||||
if len(attrValues) > len(targetStrs) || len(targetStrs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
attrValueMatch := true
|
||||
matched := make(map[int]bool, len(targetStrs))
|
||||
|
||||
for _, value := range attrValues {
|
||||
match := false
|
||||
for i, target := range targetStrs {
|
||||
if _, ok := matched[i]; ok {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(target, value) {
|
||||
matched[i] = true
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// if any attribute value does not match, the whole match fails
|
||||
if !match {
|
||||
attrValueMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if attrValueMatch {
|
||||
attrMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !attrMatch {
|
||||
valid = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
func matchCertificatesBySubject(availablePairs []*certPair, subject string) ([]*certPair, error) {
|
||||
if len(availablePairs) == 0 {
|
||||
return availablePairs, nil
|
||||
}
|
||||
attrs, err := parseOpenSSLSubject(subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validPairs []*certPair
|
||||
for _, pair := range availablePairs {
|
||||
|
||||
if isCertAttributesMatch(pair.cert, attrs) {
|
||||
validPairs = append(validPairs, pair)
|
||||
}
|
||||
}
|
||||
return validPairs, nil
|
||||
}
|
||||
|
||||
func filterInvalidCerts(pairs []*certPair) (result []*certPair) {
|
||||
now := time.Now().UTC()
|
||||
for _, pair := range pairs {
|
||||
if pair.cert.NotAfter.UTC().After(now) && pair.cert.NotBefore.UTC().Before(now) {
|
||||
result = append(result, pair)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
LoadWinCryptCerts loads certificate/key pairs from Windows certificate store matche specified criteria.
|
||||
Please check documentation for the detailed format of "criteria" string.
|
||||
*/
|
||||
func LoadWinCryptCerts(criteria string) (crypts []*WINCRYPT, err error) {
|
||||
criteria = strings.TrimSpace(criteria)
|
||||
store, err := syscall.CertOpenSystemStore(syscall.Handle(0), &USER_STORE_PERSONAL[0])
|
||||
if err != nil {
|
||||
return nil, wrapError("failed to open user personal certificate store", err)
|
||||
}
|
||||
defer func() { _ = syscall.CertCloseStore(store, 0) }()
|
||||
|
||||
var availableCerts []*certPair
|
||||
availableCertCtxs := []*syscall.CertContext{}
|
||||
defer func() { freeCertificates(availableCertCtxs) }()
|
||||
matchFunc := func(pairs []*certPair, _ string) ([]*certPair, error) {
|
||||
return pairs, nil
|
||||
}
|
||||
if strings.EqualFold(criteria, "select") {
|
||||
ctx, err := selectCertificateFromUserStore(store)
|
||||
if err != nil {
|
||||
return nil, wrapError("error occurred trying to get selected certificate handle", err)
|
||||
}
|
||||
// User canceled selection
|
||||
if ctx == nil {
|
||||
return nil, errors.New("no WinCrypt certificate was chosen")
|
||||
}
|
||||
availableCertCtxs = append(availableCertCtxs, ctx)
|
||||
} else {
|
||||
certCtxs, err := enumerateCertificates(store)
|
||||
if err != nil {
|
||||
return nil, wrapError("failed to enumerate certificate in user personal certificate store", err)
|
||||
}
|
||||
if len(certCtxs) == 0 {
|
||||
return nil, errors.New("no certificate found in user personal certificate store")
|
||||
}
|
||||
availableCertCtxs = append(availableCertCtxs, certCtxs...)
|
||||
|
||||
if !strings.EqualFold(criteria, "all") {
|
||||
if len(criteria) >= 5 && strings.EqualFold(criteria[0:5], "hash:") {
|
||||
if len(criteria) < 6 {
|
||||
return nil, errors.New("no hash specified")
|
||||
}
|
||||
criteria = criteria[5:]
|
||||
matchFunc = matchCertificatesByHash
|
||||
} else {
|
||||
matchFunc = matchCertificatesBySubject
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ctx := range availableCertCtxs {
|
||||
cert, err := parseCertContext(ctx)
|
||||
if err != nil {
|
||||
fs.Errorf(nil, "failed to parse the certificate, skiping: %s", err)
|
||||
continue
|
||||
}
|
||||
availableCerts = append(availableCerts, &certPair{cert, ctx})
|
||||
}
|
||||
|
||||
availableCerts, err = matchFunc(availableCerts, criteria)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
availableCerts = filterInvalidCerts(availableCerts)
|
||||
if len(availableCerts) == 0 {
|
||||
return nil, fmt.Errorf("no certificate found matches the specified criteria \"%s\"", criteria)
|
||||
}
|
||||
for _, cert := range availableCerts {
|
||||
crypt := new(WINCRYPT)
|
||||
defer func() {
|
||||
if crypt != nil {
|
||||
_ = crypt.Close()
|
||||
}
|
||||
}()
|
||||
fs.Debugf(nil, "Certificate Subject: %s, SerialNumber: %X", cert.cert.Subject.String(), cert.cert.SerialNumber)
|
||||
var keyFlags uint32
|
||||
err = cryptAcquireCertificatePrivateKey(cert.ctx, windows.CRYPT_ACQUIRE_COMPARE_KEY_FLAG|windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, &(crypt.priv), &keyFlags, &crypt.shouldFree)
|
||||
if err != nil {
|
||||
if errors.Unwrap(err) != nil {
|
||||
err = errors.Unwrap(err)
|
||||
}
|
||||
if err == syscall.Errno(windows.CRYPT_E_NO_KEY_PROPERTY) {
|
||||
fs.LogPrintf(fs.LogLevelNotice, nil, "The certificate has no associated private key, skiping")
|
||||
} else if err == syscall.Errno(windows.NTE_BAD_PUBLIC_KEY) {
|
||||
fs.Errorf(nil, "The public key of this certificate does not match the private key, skiping")
|
||||
} else {
|
||||
fs.Errorf(nil, "%s", wrapError("Unknown error occurred during private key handle acquisition, skiping", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if keyFlags != windows.CERT_NCRYPT_KEY_SPEC {
|
||||
fs.LogPrintf(fs.LogLevelNotice, nil, "The certificate has no associated NCrypt key handle, skiping")
|
||||
continue
|
||||
}
|
||||
crypt.keyType, err = ncryptGetPrivateKeyType(crypt.priv)
|
||||
if err != nil {
|
||||
fs.Errorf(nil, "Failed to determine private key type, skiping: %s", err)
|
||||
continue
|
||||
}
|
||||
v, err := isKeySuitableForSigning(crypt.priv, crypt.keyType)
|
||||
if err != nil {
|
||||
fs.Errorf(nil, "Failed to determine key pair capability, skiping: %s", err)
|
||||
continue
|
||||
} else if !v {
|
||||
fs.Error(nil, "The key pair is not suitable for signing purpose, skiping")
|
||||
continue
|
||||
}
|
||||
crypt.cert = tls.Certificate{
|
||||
PrivateKey: crypt,
|
||||
Leaf: cert.cert,
|
||||
Certificate: [][]byte{cert.cert.Raw},
|
||||
}
|
||||
if crypt.keyType == keyTypeECDSA {
|
||||
crypt.cert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
|
||||
tls.ECDSAWithSHA1,
|
||||
tls.ECDSAWithP256AndSHA256,
|
||||
tls.ECDSAWithP384AndSHA384,
|
||||
tls.ECDSAWithP521AndSHA512,
|
||||
}
|
||||
} else {
|
||||
crypt.cert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
|
||||
tls.PKCS1WithSHA1,
|
||||
tls.PKCS1WithSHA256,
|
||||
tls.PKCS1WithSHA384,
|
||||
tls.PKCS1WithSHA512,
|
||||
tls.PSSWithSHA256,
|
||||
tls.PSSWithSHA384,
|
||||
tls.PSSWithSHA512,
|
||||
}
|
||||
}
|
||||
crypts = append(crypts, crypt)
|
||||
crypt = nil
|
||||
}
|
||||
if len(crypts) == 0 {
|
||||
return nil, errors.New("no valid certificate/private key pair available")
|
||||
}
|
||||
atexit.Register(func() {
|
||||
for _, crypt := range crypts {
|
||||
if crypt != nil {
|
||||
_ = crypt.Close()
|
||||
}
|
||||
}
|
||||
})
|
||||
return crypts, nil
|
||||
}
|
||||
|
||||
// Public returns public key of loaded certificate
|
||||
func (w *WINCRYPT) Public() crypto.PublicKey {
|
||||
return w.cert.Leaf.PublicKey
|
||||
}
|
||||
|
||||
// TLSCertificate returns client TLS certificate with CNG private key
|
||||
func (w *WINCRYPT) TLSCertificate() tls.Certificate {
|
||||
return w.cert
|
||||
}
|
||||
|
||||
func goHashToNCryptHash(h crypto.Hash) (alg *uint16, err error) {
|
||||
err = nil
|
||||
switch h {
|
||||
case crypto.SHA1:
|
||||
alg = &BCRYPT_SHA1_ALGORITHM[0]
|
||||
case crypto.SHA256:
|
||||
alg = &BCRYPT_SHA256_ALGORITHM[0]
|
||||
case crypto.SHA384:
|
||||
alg = &BCRYPT_SHA384_ALGORITHM[0]
|
||||
case crypto.SHA512:
|
||||
alg = &BCRYPT_SHA512_ALGORITHM[0]
|
||||
default:
|
||||
err = fmt.Errorf("no suitable hash algorithm identifier found for %v", h)
|
||||
}
|
||||
return alg, err
|
||||
}
|
||||
|
||||
// Sign signs "digest" by CNG private key represented by CNG key handle
|
||||
func (w *WINCRYPT) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
||||
var paddingInfo uintptr
|
||||
var signFlags uint32
|
||||
signingType := "ECDSA"
|
||||
if w.keyType == keyTypeRSA {
|
||||
pss, ok := opts.(*rsa.PSSOptions)
|
||||
|
||||
alg, err := goHashToNCryptHash(opts.HashFunc())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
signingType = "RSA PSS"
|
||||
pssPad := BCRYPT_PSS_PADDING_INFO{}
|
||||
pssPad.pszAlgId = alg
|
||||
var cbSalt uint32
|
||||
if pss.SaltLength == rsa.PSSSaltLengthEqualsHash || pss.SaltLength == rsa.PSSSaltLengthAuto {
|
||||
cbSalt = uint32(opts.HashFunc().Size())
|
||||
} else {
|
||||
cbSalt = uint32(pss.SaltLength)
|
||||
}
|
||||
pssPad.cbSalt = cbSalt
|
||||
signFlags = BCRYPT_PAD_PSS
|
||||
paddingInfo = castTo(&pssPad)
|
||||
} else {
|
||||
signingType = "RSA PKCS1"
|
||||
pkcs1Pad := BCRYPT_PKCS1_PADDING_INFO{}
|
||||
pkcs1Pad.pszAlgId = alg
|
||||
signFlags = BCRYPT_PAD_PKCS1
|
||||
paddingInfo = castTo(&pkcs1Pad)
|
||||
}
|
||||
} else if w.keyType != keyTypeECDSA {
|
||||
return nil, errors.New("unsupported private key type")
|
||||
}
|
||||
signature, err = ncryptSignHash(w.priv, paddingInfo, digest, uint32(len(digest)), signFlags)
|
||||
if err == syscall.Errno(windows.SCARD_W_CANCELLED_BY_USER) {
|
||||
return nil, errors.New("signing operation canceled by user")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, wrapError(signingType+" signing failed", err)
|
||||
}
|
||||
if w.keyType == keyTypeECDSA {
|
||||
signature, err = ecdsaConvertIEEEP1363ToASN1(signature)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
fs.Debugf(nil, "%s signed successfully", signingType)
|
||||
return
|
||||
}
|
||||
|
||||
// Close frees CNG key handle
|
||||
func (w *WINCRYPT) Close() error {
|
||||
if w.shouldFree && w.cert.PrivateKey != nil {
|
||||
w.cert.PrivateKey = nil
|
||||
_ = ncryptFreeObject(w.priv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectCertificateFromUserStore(store syscall.Handle) (ctx *syscall.CertContext, err error) {
|
||||
pCtx, _, err := procCryptUIDlgSelectCertificateFromStore.Call((uintptr)(store), 0x0, 0x0, 0x0, 0x0, 0x0, 0x0)
|
||||
ctx = castFrom[syscall.CertContext](pCtx)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ncryptGetProperty(prov syscall.Handle, prop []uint16, pbOut uintptr, cbOut uint32, pcbOut *uint32, dwFlags uint32) (err error) {
|
||||
errno, _, _ := procNCryptGetProperty.Call(uintptr(prov), castTo(&prop[0]), pbOut, uintptr(cbOut), castTo(pcbOut), uintptr(dwFlags))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func cryptAcquireCertificatePrivateKey(ctx *syscall.CertContext, dwFlags uint32, prov *syscall.Handle, pdwKeySpec *uint32, shouldFree *bool) (err error) {
|
||||
var callerFree uint32
|
||||
result, _, err := procCryptAcquireCertificatePrivateKey.Call(castTo(ctx), uintptr(dwFlags), 0x0, castTo(prov), castTo(pdwKeySpec), castTo(&callerFree))
|
||||
if result == 0 {
|
||||
return err
|
||||
}
|
||||
err = nil
|
||||
*shouldFree = (callerFree == 1)
|
||||
return
|
||||
}
|
||||
|
||||
func ncryptFreeObject(obj syscall.Handle) (err error) {
|
||||
errno, _, _ := procNCryptFreeObject.Call(uintptr(obj))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ncryptSignHash(prov syscall.Handle, paddingInfo uintptr, hash []byte, cbHash uint32, dwFlags uint32) (signature []byte, err error) {
|
||||
var sigSize uint32
|
||||
errno, _, _ := procNCryptSignHash.Call(uintptr(prov), 0x0, castTo(&hash[0]), uintptr(cbHash), 0x0, 0x0, castTo(&sigSize), uintptr(dwFlags))
|
||||
err = syscall.Errno(errno)
|
||||
if err != windows.ERROR_SUCCESS {
|
||||
return nil, wrapError("failed to get signature size", err)
|
||||
}
|
||||
signature = make([]byte, sigSize)
|
||||
errno, _, _ = procNCryptSignHash.Call(uintptr(prov), paddingInfo, castTo(&hash[0]), uintptr(cbHash), castTo(&signature[0]), uintptr(sigSize), castTo(&sigSize), uintptr(dwFlags))
|
||||
err = syscall.Errno(errno)
|
||||
if err == windows.ERROR_SUCCESS {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ecdsaConvertIEEEP1363ToASN1(src []byte) ([]byte, error) {
|
||||
// R and S
|
||||
var sigs = [2]*big.Int{new(big.Int), new(big.Int)}
|
||||
sigs[0].SetBytes(src[:len(src)/2])
|
||||
sigs[1].SetBytes(src[len(src)/2:])
|
||||
return asn1.Marshal(sigs[:])
|
||||
}
|
||||
|
||||
func ncryptGetPrivateKeyType(prov syscall.Handle) (ktype keyType, err error) {
|
||||
var propertySize uint32
|
||||
ktype = keyTypeUnspecified
|
||||
err = ncryptGetProperty(prov, NCRYPT_ALGORITHM_GROUP_PROPERTY[:], 0x0, 0x0, &propertySize, 0)
|
||||
if err != nil {
|
||||
return ktype, wrapError("failed to query algorithm group size", err)
|
||||
}
|
||||
if propertySize == 0 || propertySize&0x1 == 0x1 {
|
||||
return ktype, fmt.Errorf("invalid property size: %d", propertySize)
|
||||
}
|
||||
var alg = make([]uint16, propertySize/2)
|
||||
err = ncryptGetProperty(prov, NCRYPT_ALGORITHM_GROUP_PROPERTY[:], castTo(&alg[0]), propertySize, &propertySize, 0)
|
||||
if err != nil {
|
||||
return ktype, wrapError("failed to query algorithm group", err)
|
||||
}
|
||||
str := windows.UTF16PtrToString(&alg[0])
|
||||
fs.Debugf(nil, "Algorithm Group: %s", str)
|
||||
if reflect.DeepEqual(alg, NCRYPT_ECDH_ALGORITHM_GROUP[:]) || reflect.DeepEqual(alg, NCRYPT_ECDSA_ALGORITHM_GROUP[:]) {
|
||||
ktype = keyTypeECDSA
|
||||
} else if reflect.DeepEqual(alg, NCRYPT_RSA_ALGORITHM_GROUP[:]) {
|
||||
ktype = keyTypeRSA
|
||||
} else {
|
||||
return ktype, fmt.Errorf("unsupported private key algorithm group: %v", str)
|
||||
}
|
||||
return ktype, nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user