Merge 7b3cf413986c15db004e697e7e40314a8478f9b2 into 0b9671313b14ffe839ecbd7dd2ae5ac7f6f05db8

This commit is contained in:
MkfsSion 2025-04-11 19:29:46 +05:30 committed by GitHub
commit 171e11db09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1810 additions and 12 deletions

View File

@ -2550,6 +2550,70 @@ The `--client-key` flag is required too when using this.
This loads the PEM encoded client side private key used for mutual TLS
authentication. Used in conjunction with `--client-cert`.
### --wincrypt string
Available on Windows only.
Load certificate/key pairs from Windows certificate store using
Cryptography API: Next Generation (CNG) for
[mutual TLS authentication](https://en.wikipedia.org/wiki/Mutual_authentication),
mainly designed for Smartcard scenario.
Certificates expired or not yet valid will be filtered out.
The input string is case-insensitive in most cases except said otherwise.
Spaces at the beginning and end of the string will be discarded.
The string must be one of the following format:
- `all`
Means load all possible certificate/key pair from user personal
certificate store.
- `select`
Means open a certificate selection menu for user to choose one of
certificates desired.
- `hash:`*`hashstr1`*`,`*`hashstr2`*`,...`
Means choose certificates using the specified hex hash strings. May not
be a empty hex hash string. Multiple hash strings are sperated with `,`
character. Hash strings referring to the same certificates are ignored.
Spaces at the beginning and end of a single hash string are discarded.
Hash strings with spaces in the middle of them will not be accepted.
Supported hash algorithms are MD5, SHA-1, SHA-256, SHA-384 and SHA-512.
Any hash string specified other than the supported one will not be accepted.
Hash string matches no certificate is accepted. If any unaccepted string
found, no certificate/key pair will be loaded and result in error.
- *OpenSSL style certificate subject string*
Match certificate by specified subject string in OpenSSL style. Attribute
values are case sensitive. Special characters `/+\=` among the string can
be escaped by `\` character. Spaces in the subject string is keeped as is.
Attributes with multiple value can be specified by sperating them
with `+` character. Since spaces in begin and end of the input string are
discarded, if the last value of the last attribute value has spaces in the
end, you should add postfix `/` at the end of the input string. Attribute
name can not be empty but values can be. The matching rules described as
follows:
- Unknown attribute is ignored.
- Empty attribute means the attribute always matches regardless of whether
it's present and the value in the certificate.
- For non-empty single value attribute, it matches if the attribute value
of the certificate contains the value specified in the subject string.
- For non-empty attribute with multiple values, it matches if all values
of the attribute specified in the subject string can be found and correspond
to one of the values in the corresponding certificate fields. For example,
let there be a certificate with subject `/OU=O1+O2` and the specified subject
string is `/OU=O1+O1` it will not match as one of the values of organization
unit specified (the second "O1") does not correspond to one of organization
unit values in the certificate subject.
- If a attribute is specified multiple times, if any of its value
(or values, if it can have multiple values) matches, the attribute matches.
### --no-check-certificate=true/false ###
`--no-check-certificate` controls whether a client verifies the

View File

@ -570,6 +570,10 @@ uploads or downloads to limit the number of total connections.
`, "|", "`"),
Default: 0,
Advanced: true,
}, {
Name: "wincrypt",
Help: "Load certificate/key pairs from Windows certificate store for mutual TLS authentication",
Groups: "Networking",
}}
// ConfigInfo is filesystem config options
@ -681,6 +685,7 @@ type ConfigInfo struct {
PartialSuffix string `config:"partial_suffix"`
MetadataMapper SpaceSepList `config:"metadata_mapper"`
MaxConnections int `config:"max_connections"`
WinCrypt string `config:"wincrypt"`
}
func init() {

View File

@ -17,6 +17,7 @@ import (
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/accounting"
"github.com/rclone/rclone/lib/structs"
"github.com/rclone/rclone/lib/wincrypt"
"golang.org/x/net/publicsuffix"
)
@ -68,6 +69,17 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
}
// Load client certs
if ci.WinCrypt != "" {
crypts, err := wincrypt.LoadWinCryptCerts(ci.WinCrypt)
if err != nil || crypts == nil {
fs.Panicf(nil, "Failed to load WinCrypt certificate: %v", err)
} else {
for _, crypt := range crypts {
t.TLSClientConfig.Certificates = append(t.TLSClientConfig.Certificates, crypt.TLSCertificate())
}
}
}
if ci.ClientCert != "" || ci.ClientKey != "" {
if ci.ClientCert == "" || ci.ClientKey == "" {
fs.Fatalf(nil, "Both --client-cert and --client-key must be set")
@ -83,7 +95,7 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
fs.Fatalf(nil, "Failed to parse the certificate")
}
}
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
t.TLSClientConfig.Certificates = append(t.TLSClientConfig.Certificates, cert)
}
// Load CA certs
@ -297,20 +309,22 @@ func (t *Transport) reloadCertificates() {
return
}
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
if err != nil {
fs.Fatalf(nil, "Failed to load --client-cert/--client-key pair: %v", err)
}
// Check if we need to parse the certificate again, we need it
// for checking the expiration date
if cert.Leaf == nil {
// Leaf is always the first certificate
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if t.clientCert != "" && t.clientKey != "" {
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
if err != nil {
fs.Fatalf(nil, "Failed to parse the certificate")
fs.Fatalf(nil, "Failed to load --client-cert/--client-key pair: %v", err)
}
// Check if we need to parse the certificate again, we need it
// for checking the expiration date
if cert.Leaf == nil {
// Leaf is always the first certificate
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
fs.Fatalf(nil, "Failed to parse the certificate")
}
}
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
}
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
}
// RoundTrip implements the RoundTripper interface.

View File

@ -0,0 +1,21 @@
//go:build windows
// +build windows
package wincrypt
import (
"syscall"
)
var (
NCrypt = ncrypt
Crypt32 = crypt32
)
func NCryptFreeObject(obj syscall.Handle) error {
return ncryptFreeObject(obj)
}
func WrapError(prefix string, err error, args ...any) error {
return wrapError(prefix, err, args...)
}

View File

@ -0,0 +1,35 @@
//go:build windows
// +build windows
// Code generated by "stringer -type=rdnAttrType"; DO NOT EDIT.
package wincrypt
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[COMMONNAME-0]
_ = x[SERIALNUMBER-1]
_ = x[COUNTRYNAME-2]
_ = x[LOCALITYNAME-3]
_ = x[STATEORPROVINCENAME-4]
_ = x[STREETADDRESS-5]
_ = x[ORGANIZATIONNAME-6]
_ = x[ORGANIZATIONALUNITNAME-7]
_ = x[POSTALCODE-8]
_ = x[ATTRUNSPECIFIED-9]
}
const _rdnAttrType_name = "COMMONNAMESERIALNUMBERCOUNTRYNAMELOCALITYNAMESTATEORPROVINCENAMESTREETADDRESSORGANIZATIONNAMEORGANIZATIONALUNITNAMEPOSTALCODEATTRUNSPECIFIED"
var _rdnAttrType_index = [...]uint8{0, 10, 22, 33, 45, 64, 77, 93, 115, 125, 140}
func (i rdnAttrType) String() string {
if i < 0 || i >= rdnAttrType(len(_rdnAttrType_index)-1) {
return "rdnAttrType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _rdnAttrType_name[_rdnAttrType_index[i]:_rdnAttrType_index[i+1]]
}

View File

@ -0,0 +1,676 @@
//go:build windows
// +build windows
package wincrypt_test
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"hash"
"math/big"
"math/rand"
"net"
"net/http"
"net/http/httptest"
"strings"
"syscall"
"testing"
"time"
"unsafe"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/fshttp"
"github.com/rclone/rclone/lib/atexit"
"github.com/rclone/rclone/lib/wincrypt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
)
//nolint:revive // Windows constants
const (
CERT_NCRYPT_KEY_HANDLE_PROP_ID uint32 = 78
CERT_KEY_PROV_INFO_PROP_ID uint32 = 2
CERT_CLOSE_STORE_FORCE_FLAG uint32 = 1
NCRYPT_OVERWRITE_KEY_FLAG uint32 = 0x00000080
BCRYPT_RSAPRIVATE_MAGIC uint32 = 0x32415352
BCRYPT_ECDSA_PRIVATE_P256_MAGIC uint32 = 0x32534345
BCRYPT_ECDSA_PRIVATE_P384_MAGIC uint32 = 0x34534345
BCRYPT_ECDSA_PRIVATE_P521_MAGIC uint32 = 0x36534345
NCRYPTBUFFER_PKCS_KEY_NAME uint32 = 45
BCRYPTBUFFER_VERSION uint32 = 0
NCRYPT_SILENT_FLAG uint32 = windows.CRYPT_SILENT
)
//nolint:revive // Windows constants
var (
MS_KEY_STORAGE_PROVIDER = [...]uint16{'M', 'i', 'c', 'r', 'o', 's', 'o', 'f', 't', ' ', 'S', 'o', 'f', 't', 'w', 'a', 'r', 'e', ' ', 'K', 'e', 'y', ' ', 'S', 't', 'o', 'r', 'a', 'g', 'e', ' ', 'P', 'r', 'o', 'v', 'i', 'd', 'e', 'r', 0}
BCRYPT_RSAPRIVATE_BLOB = [...]uint16{'R', 'S', 'A', 'P', 'R', 'I', 'V', 'A', 'T', 'E', 'B', 'L', 'O', 'B', 0}
BCRYPT_ECCPRIVATE_BLOB = [...]uint16{'E', 'C', 'C', 'P', 'R', 'I', 'V', 'A', 'T', 'E', 'B', 'L', 'O', 'B', 0}
)
var (
procCertAddEncodedCertificateToStore = wincrypt.Crypt32.NewProc("CertAddEncodedCertificateToStore")
procCertDeleteCertificateFromStore = wincrypt.Crypt32.NewProc("CertDeleteCertificateFromStore")
procCertSetCertificateContextProperty = wincrypt.Crypt32.NewProc("CertSetCertificateContextProperty")
procNCryptOpenStorageProvider = wincrypt.NCrypt.NewProc("NCryptOpenStorageProvider")
procNCryptImportKey = wincrypt.NCrypt.NewProc("NCryptImportKey")
procNCryptDeleteKey = wincrypt.NCrypt.NewProc("NCryptDeleteKey")
RNG = rand.New(rand.NewSource(time.Now().Unix()))
TestHashers = []crypto.Hash{
crypto.MD5,
crypto.SHA1,
crypto.SHA256,
crypto.SHA384,
crypto.SHA512,
}
)
//nolint:revive // Windows structures
type (
NCRYPT_KEY_HANDLE uintptr
NCRYPT_PROV_HANDLE uintptr
BCryptBuffer struct {
cbBuffer uint32
bufferType uint32
pvBuffer uintptr
}
BCryptBufferDesc struct {
ulVersion uint32
cBuffers uint32
pBuffers uintptr
}
CRYPT_KEY_PROV_INFO struct {
pwszContainerName *uint16
pwszProvName *uint16
dwProvType uint32
dwFlags uint32
cProvParam uint32
rgProvParam uintptr
dwKeySpec uint32
}
)
type keyPair struct {
cert *x509.Certificate
key crypto.Signer
certCtx *syscall.CertContext
privHandle NCRYPT_KEY_HANDLE
}
type testHash struct {
algo crypto.Hash
hash string
}
type testData struct {
name string
isECDSA bool
rsaKeySize uint32 // if RSA
curve elliptic.Curve // if ECDSA
hashes []testHash
}
type testCert struct {
data *testData
pair *keyPair
}
func castTo[T any](ptr *T) uintptr {
return uintptr(unsafe.Pointer(ptr))
}
func ncryptOpenStorageProvider(handle *NCRYPT_PROV_HANDLE, provider []uint16, flags uint32) (err error) {
errno, _, _ := procNCryptOpenStorageProvider.Call(castTo(handle), castTo(&provider[0]), uintptr(flags))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return err
}
func ncryptImportKey(prov NCRYPT_PROV_HANDLE, importKey NCRYPT_KEY_HANDLE, blobType []uint16, params uintptr, handle *NCRYPT_KEY_HANDLE, data []byte, flags uint32) (err error) {
errno, _, _ := procNCryptImportKey.Call(uintptr(prov), uintptr(importKey), castTo(&blobType[0]), params, castTo(handle), castTo(&data[0]), uintptr(len(data)), uintptr(flags))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return err
}
func ncryptDeleteKey(key NCRYPT_KEY_HANDLE, flags uint32) (err error) {
errno, _, _ := procNCryptDeleteKey.Call(uintptr(key), uintptr(flags))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return err
}
func getBCryptBlobType(priv crypto.Signer) ([]uint16, error) {
if _, ok := priv.(*rsa.PrivateKey); ok {
return BCRYPT_RSAPRIVATE_BLOB[:], nil
}
if eckey, ok := priv.(*ecdsa.PrivateKey); ok {
switch eckey.Curve {
case elliptic.P256():
fallthrough
case elliptic.P384():
fallthrough
case elliptic.P521():
return BCRYPT_ECCPRIVATE_BLOB[:], nil
}
}
return nil, errors.New("unsupported private key type")
}
func certAddEncodedCertificateToStore(store syscall.Handle, encodingType uint32, encoded []byte, addDisposition uint32) (ctx *syscall.CertContext, err error) {
result, _, err := procCertAddEncodedCertificateToStore.Call(uintptr(store), uintptr(encodingType), castTo(&encoded[0]), uintptr(len(encoded)), uintptr(addDisposition), castTo(&ctx))
if result != 0 {
err = nil
}
return
}
func certDeleteCertificateFromStore(ctx *syscall.CertContext) (err error) {
result, _, err := procCertDeleteCertificateFromStore.Call(castTo(ctx))
if result != 0 {
err = nil
}
return
}
func certSetCertificateContextProperty(ctx *syscall.CertContext, propID uint32, flags uint32, data uintptr) (err error) {
result, _, err := procCertSetCertificateContextProperty.Call(castTo(ctx), uintptr(propID), uintptr(flags), data)
if result != 0 {
err = nil
}
return
}
func getX509CertHash(hasher hash.Hash, cert *x509.Certificate) string {
hasher.Reset()
hasher.Write(cert.Raw)
return hex.EncodeToString(hasher.Sum(nil))
}
func makeCert(caCert *x509.Certificate, caPriv crypto.Signer, priv crypto.Signer, tmpl *x509.Certificate, validity time.Duration) (*x509.Certificate, error) {
notBefore := time.Now().UTC()
notAfter := notBefore.UTC().Add(validity)
clientTmpl := x509.Certificate{
SerialNumber: tmpl.SerialNumber,
Subject: tmpl.Subject,
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
IPAddresses: tmpl.IPAddresses,
DNSNames: tmpl.DNSNames,
}
derBytes, err := x509.CreateCertificate(RNG, &clientTmpl, caCert, priv.Public(), caPriv)
if err != nil {
return nil, err
}
return x509.ParseCertificate(derBytes)
}
func encodeBNBE(bn *big.Int, size int) []byte {
data := make([]byte, size)
return bn.FillBytes(data)
}
func toBCryptRSABlob(key *rsa.PrivateKey) ([]byte, error) {
bnPubExp := big.NewInt(int64(key.PublicKey.E))
pubExp := encodeBNBE(bnPubExp, len(bnPubExp.Bytes()))
cbN := (key.N.BitLen() + 7) / 8
N := encodeBNBE(key.N, cbN)
Prime1 := encodeBNBE(key.Primes[0], len(key.Primes[0].Bytes()))
Prime2 := encodeBNBE(key.Primes[1], len(key.Primes[1].Bytes()))
// See BCRYPT_RSAKEY_BLOB structure and remarks in MS documentation
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.NativeEndian, BCRYPT_RSAPRIVATE_MAGIC)
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, uint32(key.N.BitLen()))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, uint32(len(pubExp)))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, uint32(cbN))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, uint32(len(Prime1)))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, uint32(len(Prime2)))
if err != nil {
return nil, err
}
// Endianness does not apply to []byte
err = binary.Write(buf, binary.NativeEndian, pubExp)
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, N)
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, Prime1)
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, Prime2)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func toBCryptECCBlob(key *ecdsa.PrivateKey) ([]byte, error) {
// See BCRYPT_ECCKEY_BLOB structure and remarks in MS documentation
buf := new(bytes.Buffer)
var err error
switch key.Curve {
case elliptic.P256():
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P256_MAGIC)
case elliptic.P384():
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P384_MAGIC)
case elliptic.P521():
err = binary.Write(buf, binary.NativeEndian, BCRYPT_ECDSA_PRIVATE_P521_MAGIC)
default:
err = fmt.Errorf("unsupported ECDSA curve \"%s\"", key.Params().Name)
}
if err != nil {
return nil, err
}
cbKey := (key.Params().BitSize + 7) / 8
err = binary.Write(buf, binary.NativeEndian, uint32(cbKey))
if err != nil {
return nil, err
}
// Ditto
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.X, cbKey))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.Y, cbKey))
if err != nil {
return nil, err
}
err = binary.Write(buf, binary.NativeEndian, encodeBNBE(key.D, cbKey))
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func encodeBCryptPrivateKey(key crypto.Signer) ([]byte, error) {
if k, ok := key.(*rsa.PrivateKey); ok {
return toBCryptRSABlob(k)
}
if k, ok := key.(*ecdsa.PrivateKey); ok {
return toBCryptECCBlob(k)
}
return nil, errors.New("unsupported key type")
}
var certSerial int64
func makeCACert(commonName string, serialNumber int64) (*x509.Certificate, crypto.Signer, error) {
caPriv, err := rsa.GenerateKey(RNG, 2048)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate CA pricate key: %v", err)
}
caTmp := x509.Certificate{
SerialNumber: big.NewInt(serialNumber),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(RNG, &caTmp, &caTmp, &caPriv.PublicKey, caPriv)
if err != nil {
return nil, nil, fmt.Errorf("failed to create and sign CA certificate: %v", err)
}
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse created CA certificate: %v", err)
}
return caCert, caPriv, nil
}
func makeTestCerts(t *testing.T, data []*testCert, prov NCRYPT_PROV_HANDLE, store syscall.Handle) (*x509.Certificate, crypto.Signer) {
caCert, caPriv, err := makeCACert("WinCryptTestClientCA", 0xCA1)
require.NoError(t, err)
for _, item := range data {
certSerial += 1
var privKey crypto.Signer
if !item.data.isECDSA {
privKey, err = rsa.GenerateKey(RNG, int(item.data.rsaKeySize))
} else {
privKey, err = ecdsa.GenerateKey(item.data.curve, RNG)
}
require.NoError(t, err)
blobType, err := getBCryptBlobType(privKey)
require.NoError(t, err)
serial := big.NewInt(certSerial)
clientCertTmpl := x509.Certificate{
Subject: pkix.Name{
CommonName: item.data.name + "Cert",
SerialNumber: fmt.Sprintf("%X", serial),
},
SerialNumber: serial,
}
clientCert, err := makeCert(caCert, caPriv, privKey, &clientCertTmpl, 60*time.Second)
require.NoError(t, err)
cngPrivBytes, err := encodeBCryptPrivateKey(privKey)
require.NoError(t, err)
// Prepare key container name
hasher := crypto.SHA1.New()
hasher.Reset()
hasher.Write(cngPrivBytes)
name, err := windows.UTF16FromString(strings.ToUpper(hex.EncodeToString(hasher.Sum(nil))))
require.NoError(t, err)
var keyHandle NCRYPT_KEY_HANDLE
buffer := &BCryptBuffer{
cbBuffer: uint32(len(name) * 2),
bufferType: NCRYPTBUFFER_PKCS_KEY_NAME,
pvBuffer: castTo(&name[0]),
}
bufferDesc := &BCryptBufferDesc{
ulVersion: BCRYPTBUFFER_VERSION,
cBuffers: 1,
pBuffers: castTo(buffer),
}
err = ncryptImportKey(prov, NCRYPT_KEY_HANDLE(0), blobType, castTo(bufferDesc), &keyHandle, cngPrivBytes, NCRYPT_SILENT_FLAG|NCRYPT_OVERWRITE_KEY_FLAG)
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to import private key", err))
ctx, err := certAddEncodedCertificateToStore(store, windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, clientCert.Raw, windows.CERT_STORE_ADD_REPLACE_EXISTING)
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to add certificate to store", err))
keyInfo := CRYPT_KEY_PROV_INFO{
pwszContainerName: &name[0],
pwszProvName: &MS_KEY_STORAGE_PROVIDER[0],
dwProvType: 0,
dwFlags: NCRYPT_SILENT_FLAG,
cProvParam: 0,
rgProvParam: 0,
dwKeySpec: 0,
}
err = certSetCertificateContextProperty(ctx, CERT_KEY_PROV_INFO_PROP_ID, 0, castTo(&keyInfo))
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to associate certficate with private key info", err))
err = certSetCertificateContextProperty(ctx, CERT_NCRYPT_KEY_HANDLE_PROP_ID, 0, castTo(&keyHandle))
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to associate certficate with private key", err))
item.pair = &keyPair{
cert: clientCert,
key: privKey,
certCtx: ctx,
privHandle: keyHandle,
}
for _, hasher := range TestHashers {
item.data.hashes = append(item.data.hashes, testHash{algo: hasher, hash: getX509CertHash(hasher.New(), clientCert)})
}
}
return caCert, caPriv
}
func TestWinCryptCertificate(t *testing.T) {
store, err := syscall.CertOpenSystemStore(syscall.Handle(0), &wincrypt.USER_STORE_PERSONAL[0])
require.NoError(t, err)
var prov NCRYPT_PROV_HANDLE
err = ncryptOpenStorageProvider(&prov, MS_KEY_STORAGE_PROVIDER[:], 0)
require.NoErrorf(t, err, "%w", wincrypt.WrapError("Failed to open microsoft software key storage provider", err))
serverCA, serverCAPriv, err := makeCACert("WinCryptTestServerCA", 0xCA2)
require.NoError(t, err)
serverPriv, err := rsa.GenerateKey(RNG, 2048)
require.NoError(t, err)
tmpl := x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
SerialNumber: "ABC",
},
SerialNumber: big.NewInt(0xABC),
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
serverCert, err := makeCert(serverCA, serverCAPriv, serverPriv, &tmpl, time.Hour*24*365)
require.NoError(t, err)
var ts *httptest.Server
certs := []*testCert{
{
data: &testData{
name: "RSA2048TEST",
isECDSA: false,
rsaKeySize: 2048,
},
}, {
data: &testData{
name: "RSA4096TEST",
isECDSA: false,
rsaKeySize: 4096,
},
}, {
data: &testData{
name: "P256TEST",
isECDSA: true,
curve: elliptic.P256(),
},
}, {
data: &testData{
name: "P384TEST",
isECDSA: true,
curve: elliptic.P384(),
},
}, {
data: &testData{
name: "P521TEST",
isECDSA: true,
curve: elliptic.P521(),
},
},
}
t.Cleanup(func() {
// Cleanup
if ts != nil {
ts.Close()
}
atexit.Run()
for _, cert := range certs {
if cert.pair != nil {
if cert.pair.privHandle != 0 {
err := ncryptDeleteKey(cert.pair.privHandle, NCRYPT_SILENT_FLAG)
if err != nil {
_ = wincrypt.NCryptFreeObject(syscall.Handle(cert.pair.privHandle))
}
}
if cert.pair.certCtx != nil {
_ = certDeleteCertificateFromStore(cert.pair.certCtx)
}
}
}
if prov != 0 {
_ = wincrypt.NCryptFreeObject(syscall.Handle(prov))
}
if store != 0 {
_ = syscall.CertCloseStore(store, CERT_CLOSE_STORE_FORCE_FLAG)
}
})
caCert, _ := makeTestCerts(t, certs, prov, store)
clientCAs := x509.NewCertPool()
clientCAs.AddCert(caCert)
ts = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
peerCerts := r.TLS.PeerCertificates
require.Greater(t, len(peerCerts), 0, "No certificate received")
t.Logf("Received %d certificates", len(peerCerts))
for _, receivedCert := range peerCerts {
var testcert *testCert
for _, cert := range certs {
if receivedCert.SerialNumber.Cmp(cert.pair.cert.SerialNumber) == 0 {
testcert = cert
for _, hash := range testcert.data.hashes {
assert.Equalf(t, hash.hash, getX509CertHash(hash.algo.New(), receivedCert), "%s hash of test certificate \"%s\" does not match expected one", hash.algo, cert.data.name)
}
}
}
assert.NotNilf(t, testcert, "Recevied certificate with subject \"%s\" serialNumber \"%X\" does not belong to test certificates", receivedCert.Subject, receivedCert.SerialNumber)
if testcert != nil {
t.Logf("Received certificate subject \"%s\", serialNumber \"%X\"", receivedCert.Subject, receivedCert.SerialNumber)
}
}
chains, err := peerCerts[0].Verify(x509.VerifyOptions{Roots: clientCAs})
assert.NoErrorf(t, err, "Failed to verify certificate: %v", err)
if err == nil {
t.Logf("Verified, constructed %d chains", len(chains))
for i, chain := range chains {
t.Logf("Chain %d:", i)
for j, cert := range chain {
t.Logf("Certificate %d: Subject: \"%s\", SerialNumber: \"%X\"", j, cert.Subject, cert.SerialNumber)
}
}
}
// Write some test data to fulfill the request
w.Header().Set("Content-Type", "text/plain")
_, _ = fmt.Fprintln(w, "test data")
}))
serverRoot := x509.NewCertPool()
serverRoot.AddCert(serverCA)
ts.TLS = &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCAs,
RootCAs: serverRoot,
MaxVersion: tls.VersionTLS12, // TLS V1.3 forces RSA PSS signing
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{{
Certificate: [][]byte{serverCert.Raw},
PrivateKey: serverPriv,
SupportedSignatureAlgorithms: []tls.SignatureScheme{
tls.PKCS1WithSHA1,
tls.PKCS1WithSHA256,
tls.PKCS1WithSHA384,
tls.PKCS1WithSHA512,
tls.PSSWithSHA256,
tls.PSSWithSHA384,
tls.PSSWithSHA512,
},
}},
}
ts.StartTLS()
serverCAs := x509.NewCertPool()
serverCAs.AddCert(serverCA)
ctx := context.TODO()
ci := fs.GetConfig(ctx)
ci.WinCrypt = "all"
client := fshttp.NewClient(ctx)
tt := client.Transport.(*fshttp.Transport)
tt.TLSClientConfig.RootCAs = serverCAs
_, err = client.Get(ts.URL)
assert.NoError(t, err)
var builder strings.Builder
for i, cert := range certs {
ctx := context.TODO()
ci := fs.GetConfig(ctx)
ci.WinCrypt = fmt.Sprintf("/CN=%s/serialNumber=%d", cert.data.name+"Cert", i+1)
client := fshttp.NewClient(ctx)
tt := client.Transport.(*fshttp.Transport)
tt.TLSClientConfig.RootCAs = serverCAs
_, err = client.Get(ts.URL)
assert.NoError(t, err)
var loadedCerts1 []*x509.Certificate
for _, c := range tt.TLSClientConfig.Certificates {
loadedCerts1 = append(loadedCerts1, c.Leaf)
}
builder.Reset()
builder.WriteString("hash: ")
for i, hash := range cert.data.hashes {
builder.WriteString(hash.hash)
if i != len(cert.data.hashes)-1 {
builder.WriteString(", ")
}
}
hashQuery := builder.String()
// Test RSA PSS and ECDSA signing
ci.WinCrypt = hashQuery
client = fshttp.NewClient(ctx)
tt = client.Transport.(*fshttp.Transport)
tt.TLSClientConfig.RootCAs = serverCAs
_, err = client.Get(ts.URL)
assert.NoError(t, err)
var loadedCerts2 []*x509.Certificate
for _, c := range tt.TLSClientConfig.Certificates {
loadedCerts2 = append(loadedCerts2, c.Leaf)
}
assert.Equal(t, loadedCerts1, loadedCerts2)
if !cert.data.isECDSA {
ci.WinCrypt = ""
client = fshttp.NewClient(ctx)
tt = client.Transport.(*fshttp.Transport)
tt.TLSClientConfig.RootCAs = serverCAs
crypts, err := wincrypt.LoadWinCryptCerts(hashQuery)
assert.NoError(t, err)
var loadedCerts []*x509.Certificate
for _, crypt := range crypts {
tlsCert := crypt.TLSCertificate()
// Test RSA PKCS v1.5 signing
tlsCert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
tls.PKCS1WithSHA1,
tls.PKCS1WithSHA256,
tls.PKCS1WithSHA384,
tls.PKCS1WithSHA512,
}
loadedCerts = append(loadedCerts, tlsCert.Leaf)
tt.TLSClientConfig.Certificates = append(tt.TLSClientConfig.Certificates, tlsCert)
}
assert.ElementsMatch(t, loadedCerts1, loadedCerts)
_, err = client.Get(ts.URL)
assert.NoError(t, err)
}
}
}

View File

@ -0,0 +1,27 @@
//go:build !windows
// +build !windows
// Code generated by "dummy"; DO NOT EDIT. // Add autogenerated comment to make it ignored by linter
// Dummy implementation to make it build on non-Windows platform
package wincrypt
import (
"crypto/tls"
"errors"
)
type WINCRYPT struct {
}
func (*WINCRYPT) TLSCertificate() tls.Certificate {
return tls.Certificate{}
}
func LoadWinCryptCerts(string) ([]*WINCRYPT, error) {
return nil, errors.New("Cryptography API: Next Generation (CNG) is only available on Windows")
}
func (*WINCRYPT) Close() error {
return nil
}

View File

@ -0,0 +1,112 @@
//go:build windows
// +build windows
package wincrypt
import (
"crypto/x509"
"crypto/x509/pkix"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
reference = []*subjectAttr{
{
attrType: COMMONNAME,
values: [][]string{{"Test CN1"}, {"Test CN2"}},
},
{
attrType: SERIALNUMBER,
values: [][]string{{"Test serial"}},
},
{
attrType: COUNTRYNAME,
values: [][]string{{"Test C1", "Test C2"}},
},
{
attrType: LOCALITYNAME,
values: [][]string{{"Test L1", "Test L2"}},
},
{
attrType: STATEORPROVINCENAME,
values: [][]string{{"Test ST1", "Test ST2"}},
},
{
attrType: STREETADDRESS,
values: [][]string{{"Test street1", "Test street2"}},
},
{
attrType: ORGANIZATIONNAME,
values: [][]string{{"Test O1", "Test O2"}},
},
{
attrType: ORGANIZATIONALUNITNAME,
values: [][]string{{"Test OU1", "Test OU2"}},
},
{
attrType: POSTALCODE,
values: [][]string{{"Test postalcode1", "Test postalcode2"}},
},
}
)
func TestParseSubject(t *testing.T) {
subject := "/CN=Test CN1/OU=Test OU1+Test OU2/O=Test O1+Test O2/L=Test L1+Test L2/C=Test C1+Test C2/ST=Test ST1+Test ST2/street=Test street1+Test street2/serialNumber=Test serial/postalCode=Test postalcode1+Test postalcode2/CN=Test CN2"
subj1, err := parseOpenSSLSubject(subject)
require.NoError(t, err)
assert.ElementsMatch(t, reference, subj1)
subject += "/"
subj2, err := parseOpenSSLSubject(subject)
require.NoError(t, err)
assert.ElementsMatch(t, subj1, subj2)
subject += "EMAIL=Test Email"
subj2, err = parseOpenSSLSubject(subject)
require.NoError(t, err)
assert.ElementsMatch(t, reference, subj2)
_, err = parseOpenSSLSubject("/CN=CN=")
assert.Error(t, err)
}
func TestCertSubjectMatch(t *testing.T) {
attrs, err := parseOpenSSLSubject("/CN=TestCN/OU=TestOU1+TestOU1")
require.NoError(t, err)
cert := &x509.Certificate{
Subject: pkix.Name{
CommonName: "xTestCN1",
OrganizationalUnit: []string{"xTestOU1", "xTestOU2"},
},
}
assert.False(t, isCertAttributesMatch(cert, attrs))
cert.Subject.OrganizationalUnit[1] = "xTestOU1"
assert.True(t, isCertAttributesMatch(cert, attrs))
attrs, err = parseOpenSSLSubject("/CN=NoMatch")
require.NoError(t, err)
assert.False(t, isCertAttributesMatch(cert, attrs))
attrs, err = parseOpenSSLSubject("/CN=NoMatch/CN=")
require.NoError(t, err)
assert.True(t, isCertAttributesMatch(cert, attrs))
attrs, err = parseOpenSSLSubject("/CN=NoMatch/CN=Test")
require.NoError(t, err)
assert.True(t, isCertAttributesMatch(cert, attrs))
cert = &x509.Certificate{
Subject: pkix.Name{
CommonName: reference[COMMONNAME].values[0][0],
SerialNumber: reference[SERIALNUMBER].values[0][0],
Country: reference[COUNTRYNAME].values[0],
Locality: reference[LOCALITYNAME].values[0],
Province: reference[STATEORPROVINCENAME].values[0],
StreetAddress: reference[STREETADDRESS].values[0],
Organization: reference[ORGANIZATIONNAME].values[0],
OrganizationalUnit: reference[ORGANIZATIONALUNITNAME].values[0],
PostalCode: reference[POSTALCODE].values[0],
},
}
assert.True(t, isCertAttributesMatch(cert, reference))
subject := "/CN=Test CN2/OU=Test OU1+Test OU2/O=Test O1+Test O2/L=Test L1+Test L2/C=Test C1+Test C2/ST=Test ST1+Test ST2/street=Test street1+Test street2/serialNumber=Test serial/postalCode=Test postalcode1+Test postalcode2"
attrs, err = parseOpenSSLSubject(subject)
require.NoError(t, err)
assert.False(t, isCertAttributesMatch(cert, attrs))
}

View File

@ -0,0 +1,844 @@
//go:build windows
// +build windows
//go:generate stringer -type=rdnAttrType
// Package wincrypt implements loading certificate/key pairs from Windows certificate store for Mutual TLS authentication
package wincrypt
import (
"crypto"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"maps"
"math/big"
"reflect"
"slices"
"strings"
"syscall"
"time"
"unsafe"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/lib/atexit"
"golang.org/x/sys/windows"
)
//nolint:revive // Windows constants
var (
NCRYPT_KEY_USAGE_PROPERTY = [...]uint16{'K', 'e', 'y', ' ', 'U', 's', 'a', 'g', 'e', 0}
NCRYPT_ALGORITHM_GROUP_PROPERTY = [...]uint16{'A', 'l', 'g', 'o', 'r', 'i', 't', 'h', 'm', ' ', 'G', 'r', 'o', 'u', 'p', 0}
NCRYPT_ECDH_ALGORITHM_GROUP = [...]uint16{'E', 'C', 'D', 'H', 0}
NCRYPT_RSA_ALGORITHM_GROUP = [...]uint16{'R', 'S', 'A', 0}
NCRYPT_ECDSA_ALGORITHM_GROUP = [...]uint16{'E', 'C', 'D', 'S', 'A', 0}
BCRYPT_SHA1_ALGORITHM = [...]uint16{'S', 'H', 'A', '1', 0}
BCRYPT_SHA256_ALGORITHM = [...]uint16{'S', 'H', 'A', '2', '5', '6', 0}
BCRYPT_SHA384_ALGORITHM = [...]uint16{'S', 'H', 'A', '3', '8', '4', 0}
BCRYPT_SHA512_ALGORITHM = [...]uint16{'S', 'H', 'A', '5', '1', '2', 0}
USER_STORE_PERSONAL = [...]uint16{'M', 'Y', 0}
)
var (
ncrypt = syscall.NewLazyDLL("ncrypt.dll")
crypt32 = syscall.NewLazyDLL("crypt32.dll")
cryptui = syscall.NewLazyDLL("cryptui.dll")
procCryptUIDlgSelectCertificateFromStore = cryptui.NewProc("CryptUIDlgSelectCertificateFromStore")
procCryptAcquireCertificatePrivateKey = crypt32.NewProc("CryptAcquireCertificatePrivateKey")
procCertEnumCertificatesInStore = crypt32.NewProc("CertEnumCertificatesInStore")
procCertDuplicateCertificateContext = crypt32.NewProc("CertDuplicateCertificateContext")
procNCryptGetProperty = ncrypt.NewProc("NCryptGetProperty")
procNCryptSignHash = ncrypt.NewProc("NCryptSignHash")
procNCryptFreeObject = ncrypt.NewProc("NCryptFreeObject")
attrAltNames = map[rdnAttrType]string{
COMMONNAME: "CN",
COUNTRYNAME: "C",
LOCALITYNAME: "L",
STATEORPROVINCENAME: "ST",
STREETADDRESS: "STREET",
ORGANIZATIONNAME: "O",
ORGANIZATIONALUNITNAME: "OU",
}
)
//nolint:revive // Windows constants
const (
NCRYPT_ALLOW_ALL_USAGES uint32 = 0x00FFFFFF
NCRYPT_ALLOW_DECRYPT_FLAG uint32 = 0x00000001
NCRYPT_ALLOW_SIGNING_FLAG uint32 = 0x00000002
NCRYPT_ALLOW_KEY_AGREEMENT_FLAG uint32 = 0x00000004
BCRYPT_PAD_PKCS1 uint32 = 0x00000002
BCRYPT_PAD_PSS uint32 = 0x00000008
)
type keyType int
type rdnAttrType int
//nolint:revive // Constants in uppercase used for matching
const (
COMMONNAME rdnAttrType = iota
SERIALNUMBER
COUNTRYNAME
LOCALITYNAME
STATEORPROVINCENAME
STREETADDRESS
ORGANIZATIONNAME
ORGANIZATIONALUNITNAME
POSTALCODE
ATTRUNSPECIFIED
)
type subjectAttr struct {
attrType rdnAttrType
values [][]string
}
const (
keyTypeUnspecified keyType = iota
keyTypeRSA
keyTypeECDSA
)
//nolint:revive // Windows structures
type (
BCRYPT_PKCS1_PADDING_INFO struct {
pszAlgId *uint16
}
BCRYPT_PSS_PADDING_INFO struct {
pszAlgId *uint16
cbSalt uint32
}
)
// WINCRYPT is a structure encapsulates CNG key handle
type WINCRYPT struct {
cert tls.Certificate
priv syscall.Handle
keyType keyType
shouldFree bool
}
func wrapError(prefix string, err error, args ...any) error {
prefix = fmt.Sprintf(prefix, args...)
if errno, ok := err.(syscall.Errno); ok {
return fmt.Errorf("%s, errorCode=\"0x%08X\", errorMessage=\"%w\"", prefix, uint32(errno), err)
}
return fmt.Errorf("%s: %w", prefix, err)
}
func isKeySuitableForSigning(prov syscall.Handle, keyType keyType) (bool, error) {
var keyUsage uint32
var keyUsageSize uint32 = 4
err := ncryptGetProperty(prov, NCRYPT_KEY_USAGE_PROPERTY[:], castTo(&keyUsage), keyUsageSize, &keyUsageSize, 0)
if err != nil {
return false, wrapError("failed to get Key Usage information", err)
}
switch {
case keyUsage == NCRYPT_ALLOW_ALL_USAGES:
fs.Debug(nil, "All usage flag set")
return true, nil
case keyUsage&NCRYPT_ALLOW_SIGNING_FLAG != 0:
fs.Debug(nil, "Signing usage flag set")
return true, nil
case (keyUsage&NCRYPT_ALLOW_KEY_AGREEMENT_FLAG != 0) && (keyType == keyTypeECDSA):
fs.Debug(nil, "Allow ECC key with key agreement usage flag set")
return true, nil
default:
return false, errors.New("provided key is not suitable for signing purpose")
}
}
func freeCertificates(certs []*syscall.CertContext) {
if len(certs) > 0 {
for _, cert := range certs {
_ = syscall.CertFreeCertificateContext(cert)
}
}
}
func castFrom[T any](ptr uintptr) *T {
// Add another indirect level to please go vet
return (*T)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr)))
}
func castTo[T any](ptr *T) uintptr {
return uintptr(unsafe.Pointer(ptr))
}
func enumerateCertificates(store syscall.Handle) ([]*syscall.CertContext, error) {
currentCtx, _, err := procCertEnumCertificatesInStore.Call(uintptr(store), 0)
if err == syscall.Errno(windows.CRYPT_E_NOT_FOUND) {
return nil, nil
} else if err != windows.ERROR_SUCCESS {
return nil, err
}
var certs []*syscall.CertContext
for {
copyCtx, _, err := procCertDuplicateCertificateContext.Call(currentCtx)
if err != windows.ERROR_SUCCESS {
certs = append(certs, castFrom[syscall.CertContext](currentCtx))
freeCertificates(certs)
return nil, err
}
certs = append(certs, castFrom[syscall.CertContext](copyCtx))
currentCtx, _, err = procCertEnumCertificatesInStore.Call(uintptr(store), currentCtx)
if err == syscall.Errno(windows.CRYPT_E_NOT_FOUND) {
return certs, nil
} else if err != windows.ERROR_SUCCESS {
freeCertificates(certs)
return nil, err
}
}
}
// Unprintable string, used as map key only
func calculateCertHash(hasher hash.Hash, cert *x509.Certificate) string {
hasher.Reset()
hasher.Write(cert.Raw)
return string(hasher.Sum(nil))
}
func parseCertContext(ctx *syscall.CertContext) (*x509.Certificate, error) {
encoded := unsafe.Slice(ctx.EncodedCert, ctx.Length)
buf := make([]byte, ctx.Length)
copy(buf, encoded)
return x509.ParseCertificate(buf)
}
func getRDNAttrType(name string) rdnAttrType {
str := strings.ToUpper(name)
for attrType := rdnAttrType(0); attrType < ATTRUNSPECIFIED; attrType++ {
if attrType.String() == str {
return attrType
} else if altName, ok := attrAltNames[attrType]; ok && altName == str {
return attrType
}
}
return ATTRUNSPECIFIED
}
func parseOpenSSLSubject(subjectStr string) ([]*subjectAttr, error) {
subject := []rune(subjectStr)
if len(subject) == 0 {
return nil, errors.New("no subject string specified")
}
if subject[0] != '/' {
return nil, fmt.Errorf("invalid first charater of subject string, expects '/', got '%c'", subject[0])
}
attrs := make(map[rdnAttrType]*subjectAttr, int(ATTRUNSPECIFIED))
i := 1
var str strings.Builder
noTypeError := errors.New("no RDN type string specified")
parseName := func() (string, error) {
if i >= len(subject) {
return "", noTypeError
}
str.Reset()
err:
for ; i < len(subject); i++ {
switch subject[i] {
case '/':
if str.Len() == 0 {
return "", noTypeError
}
break err
case '=':
if str.Len() == 0 {
return "", noTypeError
}
i += 1
return str.String(), nil
case '\\':
if i+1 >= len(subject) {
return "", fmt.Errorf("unexpected EOF during parsing escape character in RDN type string at index %d", i)
}
i += 1
fallthrough
default:
str.WriteRune(subject[i])
}
}
return "", fmt.Errorf("missing '=' after RDN type string \"%s\" in subject name string at index %d", str.String(), i)
}
parseValue := func() ([]string, error) {
var values []string
// attributes without value are allowed
if i >= len(subject) {
return values, nil
}
str.Reset()
out:
for ; i < len(subject); i++ {
switch subject[i] {
case '=':
return nil, fmt.Errorf("unexpected character '=' found at index %d. If it's intended, please escape it with '\\'", i)
case '/':
i += 1
break out
case '+':
if str.Len() != 0 {
values = append(values, str.String())
str.Reset()
}
case '\\':
if i+1 >= len(subject) {
return nil, fmt.Errorf("unexpected EOF during parsing escape character in value at index %d", i)
}
i += 1
fallthrough
default:
str.WriteRune(subject[i])
}
}
if str.Len() != 0 {
values = append(values, str.String())
str.Reset()
}
return values, nil
}
for i < len(subject) {
name, err := parseName()
if err != nil {
return nil, err
}
values, err := parseValue()
if err != nil {
return nil, err
}
attrType := getRDNAttrType(name)
if attrType == ATTRUNSPECIFIED {
fs.Errorf(nil, "Unknown RDN attribute type \"%s\", ignoring", name)
continue
}
if attrType == COMMONNAME || attrType == SERIALNUMBER {
if len(values) > 1 {
return nil, fmt.Errorf("attribute type \"%s\" should be single value", attrType.String())
}
}
if attr, ok := attrs[attrType]; ok {
attr.values = append(attr.values, values)
} else {
attrs[attrType] = &subjectAttr{attrType: attrType, values: [][]string{values}}
}
}
return slices.Collect(maps.Values(attrs)), nil
}
type certPair struct {
cert *x509.Certificate
ctx *syscall.CertContext
}
func matchCertificatesByHash(availablePairs []*certPair, hashes string) ([]*certPair, error) {
if len(availablePairs) == 0 {
return availablePairs, nil
}
hashStrs := strings.Split(hashes, ",")
hashTypes := make(map[crypto.Hash]bool)
validPairs := make(map[*certPair]bool, len(availablePairs))
for i := 0; i < len(hashStrs); i++ {
hashStrs[i] = strings.TrimSpace(hashStrs[i])
hashStr := hashStrs[i]
if hashStr == "" {
return nil, fmt.Errorf("invalid hash string format: \"%s\"", hashes)
}
switch len(hashStr) {
case 32:
hashTypes[crypto.MD5] = true
case 40:
hashTypes[crypto.SHA1] = true
case 64:
hashTypes[crypto.SHA256] = true
case 96:
hashTypes[crypto.SHA384] = true
case 128:
hashTypes[crypto.SHA512] = true
default:
return nil, fmt.Errorf("hash string \"%s\" is not a MD5/SHA-1/SHA-256/SHA-384/SHA-512 hex string", hashStr)
}
}
certHashTable := make(map[string][]*certPair)
var hashers []hash.Hash
for k := range hashTypes {
hashers = append(hashers, k.New())
}
for _, pair := range availablePairs {
for _, hasher := range hashers {
hashBytes := calculateCertHash(hasher, pair.cert)
entries, ok := certHashTable[hashBytes]
if !ok {
certHashTable[hashBytes] = []*certPair{pair}
} else {
// Duplicate hash found
entries = append(entries, pair)
certHashTable[hashBytes] = entries
}
}
}
for _, hashStr := range hashStrs {
hash, err := hex.DecodeString(hashStr)
if err != nil {
return nil, err
}
entries, ok := certHashTable[string(hash)]
if !ok {
fs.Errorf(nil, "No certificate has hash \"%s\"", hashStr)
continue
}
for _, entry := range entries {
if _, ok := validPairs[entry]; !ok {
validPairs[entry] = true
} else {
fs.LogPrintf(fs.LogLevelWarning, nil, "Found duplicate certificate with hash \"%s\", ignoring", hashStr)
}
}
}
return slices.Collect(maps.Keys(validPairs)), nil
}
func isCertAttributesMatch(cert *x509.Certificate, attrs []*subjectAttr) bool {
valid := true
for _, attr := range attrs {
// For attribute specified multiple times, as long as one of them match, the attribute matches
attrMatch := false
for _, attrValues := range attr.values {
// empty attributes always match
if len(attrValues) == 0 {
attrMatch = true
break
}
var targetStrs []string
switch attr.attrType {
case COMMONNAME:
targetStrs = []string{cert.Subject.CommonName}
case SERIALNUMBER:
targetStrs = []string{cert.Subject.SerialNumber}
case COUNTRYNAME:
targetStrs = cert.Subject.Country
case LOCALITYNAME:
targetStrs = cert.Subject.Locality
case STATEORPROVINCENAME:
targetStrs = cert.Subject.Province
case STREETADDRESS:
targetStrs = cert.Subject.StreetAddress
case ORGANIZATIONNAME:
targetStrs = cert.Subject.Organization
case ORGANIZATIONALUNITNAME:
targetStrs = cert.Subject.OrganizationalUnit
case POSTALCODE:
targetStrs = cert.Subject.PostalCode
}
// always won't match
if len(attrValues) > len(targetStrs) || len(targetStrs) == 0 {
continue
}
attrValueMatch := true
matched := make(map[int]bool, len(targetStrs))
for _, value := range attrValues {
match := false
for i, target := range targetStrs {
if _, ok := matched[i]; ok {
continue
}
if strings.Contains(target, value) {
matched[i] = true
match = true
break
}
}
// if any attribute value does not match, the whole match fails
if !match {
attrValueMatch = false
break
}
}
if attrValueMatch {
attrMatch = true
break
}
}
if !attrMatch {
valid = false
break
}
}
return valid
}
func matchCertificatesBySubject(availablePairs []*certPair, subject string) ([]*certPair, error) {
if len(availablePairs) == 0 {
return availablePairs, nil
}
attrs, err := parseOpenSSLSubject(subject)
if err != nil {
return nil, err
}
var validPairs []*certPair
for _, pair := range availablePairs {
if isCertAttributesMatch(pair.cert, attrs) {
validPairs = append(validPairs, pair)
}
}
return validPairs, nil
}
func filterInvalidCerts(pairs []*certPair) (result []*certPair) {
now := time.Now().UTC()
for _, pair := range pairs {
if pair.cert.NotAfter.UTC().After(now) && pair.cert.NotBefore.UTC().Before(now) {
result = append(result, pair)
}
}
return
}
/*
LoadWinCryptCerts loads certificate/key pairs from Windows certificate store matche specified criteria.
Please check documentation for the detailed format of "criteria" string.
*/
func LoadWinCryptCerts(criteria string) (crypts []*WINCRYPT, err error) {
criteria = strings.TrimSpace(criteria)
store, err := syscall.CertOpenSystemStore(syscall.Handle(0), &USER_STORE_PERSONAL[0])
if err != nil {
return nil, wrapError("failed to open user personal certificate store", err)
}
defer func() { _ = syscall.CertCloseStore(store, 0) }()
var availableCerts []*certPair
availableCertCtxs := []*syscall.CertContext{}
defer func() { freeCertificates(availableCertCtxs) }()
matchFunc := func(pairs []*certPair, _ string) ([]*certPair, error) {
return pairs, nil
}
if strings.EqualFold(criteria, "select") {
ctx, err := selectCertificateFromUserStore(store)
if err != nil {
return nil, wrapError("error occurred trying to get selected certificate handle", err)
}
// User canceled selection
if ctx == nil {
return nil, errors.New("no WinCrypt certificate was chosen")
}
availableCertCtxs = append(availableCertCtxs, ctx)
} else {
certCtxs, err := enumerateCertificates(store)
if err != nil {
return nil, wrapError("failed to enumerate certificate in user personal certificate store", err)
}
if len(certCtxs) == 0 {
return nil, errors.New("no certificate found in user personal certificate store")
}
availableCertCtxs = append(availableCertCtxs, certCtxs...)
if !strings.EqualFold(criteria, "all") {
if len(criteria) >= 5 && strings.EqualFold(criteria[0:5], "hash:") {
if len(criteria) < 6 {
return nil, errors.New("no hash specified")
}
criteria = criteria[5:]
matchFunc = matchCertificatesByHash
} else {
matchFunc = matchCertificatesBySubject
}
}
}
for _, ctx := range availableCertCtxs {
cert, err := parseCertContext(ctx)
if err != nil {
fs.Errorf(nil, "failed to parse the certificate, skiping: %s", err)
continue
}
availableCerts = append(availableCerts, &certPair{cert, ctx})
}
availableCerts, err = matchFunc(availableCerts, criteria)
if err != nil {
return nil, err
}
availableCerts = filterInvalidCerts(availableCerts)
if len(availableCerts) == 0 {
return nil, fmt.Errorf("no certificate found matches the specified criteria \"%s\"", criteria)
}
for _, cert := range availableCerts {
crypt := new(WINCRYPT)
defer func() {
if crypt != nil {
_ = crypt.Close()
}
}()
fs.Debugf(nil, "Certificate Subject: %s, SerialNumber: %X", cert.cert.Subject.String(), cert.cert.SerialNumber)
var keyFlags uint32
err = cryptAcquireCertificatePrivateKey(cert.ctx, windows.CRYPT_ACQUIRE_COMPARE_KEY_FLAG|windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, &(crypt.priv), &keyFlags, &crypt.shouldFree)
if err != nil {
if errors.Unwrap(err) != nil {
err = errors.Unwrap(err)
}
if err == syscall.Errno(windows.CRYPT_E_NO_KEY_PROPERTY) {
fs.LogPrintf(fs.LogLevelNotice, nil, "The certificate has no associated private key, skiping")
} else if err == syscall.Errno(windows.NTE_BAD_PUBLIC_KEY) {
fs.Errorf(nil, "The public key of this certificate does not match the private key, skiping")
} else {
fs.Errorf(nil, "%s", wrapError("Unknown error occurred during private key handle acquisition, skiping", err))
}
continue
}
if keyFlags != windows.CERT_NCRYPT_KEY_SPEC {
fs.LogPrintf(fs.LogLevelNotice, nil, "The certificate has no associated NCrypt key handle, skiping")
continue
}
crypt.keyType, err = ncryptGetPrivateKeyType(crypt.priv)
if err != nil {
fs.Errorf(nil, "Failed to determine private key type, skiping: %s", err)
continue
}
v, err := isKeySuitableForSigning(crypt.priv, crypt.keyType)
if err != nil {
fs.Errorf(nil, "Failed to determine key pair capability, skiping: %s", err)
continue
} else if !v {
fs.Error(nil, "The key pair is not suitable for signing purpose, skiping")
continue
}
crypt.cert = tls.Certificate{
PrivateKey: crypt,
Leaf: cert.cert,
Certificate: [][]byte{cert.cert.Raw},
}
if crypt.keyType == keyTypeECDSA {
crypt.cert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
tls.ECDSAWithSHA1,
tls.ECDSAWithP256AndSHA256,
tls.ECDSAWithP384AndSHA384,
tls.ECDSAWithP521AndSHA512,
}
} else {
crypt.cert.SupportedSignatureAlgorithms = []tls.SignatureScheme{
tls.PKCS1WithSHA1,
tls.PKCS1WithSHA256,
tls.PKCS1WithSHA384,
tls.PKCS1WithSHA512,
tls.PSSWithSHA256,
tls.PSSWithSHA384,
tls.PSSWithSHA512,
}
}
crypts = append(crypts, crypt)
crypt = nil
}
if len(crypts) == 0 {
return nil, errors.New("no valid certificate/private key pair available")
}
atexit.Register(func() {
for _, crypt := range crypts {
if crypt != nil {
_ = crypt.Close()
}
}
})
return crypts, nil
}
// Public returns public key of loaded certificate
func (w *WINCRYPT) Public() crypto.PublicKey {
return w.cert.Leaf.PublicKey
}
// TLSCertificate returns client TLS certificate with CNG private key
func (w *WINCRYPT) TLSCertificate() tls.Certificate {
return w.cert
}
func goHashToNCryptHash(h crypto.Hash) (alg *uint16, err error) {
err = nil
switch h {
case crypto.SHA1:
alg = &BCRYPT_SHA1_ALGORITHM[0]
case crypto.SHA256:
alg = &BCRYPT_SHA256_ALGORITHM[0]
case crypto.SHA384:
alg = &BCRYPT_SHA384_ALGORITHM[0]
case crypto.SHA512:
alg = &BCRYPT_SHA512_ALGORITHM[0]
default:
err = fmt.Errorf("no suitable hash algorithm identifier found for %v", h)
}
return alg, err
}
// Sign signs "digest" by CNG private key represented by CNG key handle
func (w *WINCRYPT) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
var paddingInfo uintptr
var signFlags uint32
signingType := "ECDSA"
if w.keyType == keyTypeRSA {
pss, ok := opts.(*rsa.PSSOptions)
alg, err := goHashToNCryptHash(opts.HashFunc())
if err != nil {
return nil, err
}
if ok {
signingType = "RSA PSS"
pssPad := BCRYPT_PSS_PADDING_INFO{}
pssPad.pszAlgId = alg
var cbSalt uint32
if pss.SaltLength == rsa.PSSSaltLengthEqualsHash || pss.SaltLength == rsa.PSSSaltLengthAuto {
cbSalt = uint32(opts.HashFunc().Size())
} else {
cbSalt = uint32(pss.SaltLength)
}
pssPad.cbSalt = cbSalt
signFlags = BCRYPT_PAD_PSS
paddingInfo = castTo(&pssPad)
} else {
signingType = "RSA PKCS1"
pkcs1Pad := BCRYPT_PKCS1_PADDING_INFO{}
pkcs1Pad.pszAlgId = alg
signFlags = BCRYPT_PAD_PKCS1
paddingInfo = castTo(&pkcs1Pad)
}
} else if w.keyType != keyTypeECDSA {
return nil, errors.New("unsupported private key type")
}
signature, err = ncryptSignHash(w.priv, paddingInfo, digest, uint32(len(digest)), signFlags)
if err == syscall.Errno(windows.SCARD_W_CANCELLED_BY_USER) {
return nil, errors.New("signing operation canceled by user")
}
if err != nil {
return nil, wrapError(signingType+" signing failed", err)
}
if w.keyType == keyTypeECDSA {
signature, err = ecdsaConvertIEEEP1363ToASN1(signature)
if err != nil {
return nil, err
}
}
fs.Debugf(nil, "%s signed successfully", signingType)
return
}
// Close frees CNG key handle
func (w *WINCRYPT) Close() error {
if w.shouldFree && w.cert.PrivateKey != nil {
w.cert.PrivateKey = nil
_ = ncryptFreeObject(w.priv)
}
return nil
}
func selectCertificateFromUserStore(store syscall.Handle) (ctx *syscall.CertContext, err error) {
pCtx, _, err := procCryptUIDlgSelectCertificateFromStore.Call((uintptr)(store), 0x0, 0x0, 0x0, 0x0, 0x0, 0x0)
ctx = castFrom[syscall.CertContext](pCtx)
if err == windows.ERROR_SUCCESS {
err = nil
}
return
}
func ncryptGetProperty(prov syscall.Handle, prop []uint16, pbOut uintptr, cbOut uint32, pcbOut *uint32, dwFlags uint32) (err error) {
errno, _, _ := procNCryptGetProperty.Call(uintptr(prov), castTo(&prop[0]), pbOut, uintptr(cbOut), castTo(pcbOut), uintptr(dwFlags))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return err
}
func cryptAcquireCertificatePrivateKey(ctx *syscall.CertContext, dwFlags uint32, prov *syscall.Handle, pdwKeySpec *uint32, shouldFree *bool) (err error) {
var callerFree uint32
result, _, err := procCryptAcquireCertificatePrivateKey.Call(castTo(ctx), uintptr(dwFlags), 0x0, castTo(prov), castTo(pdwKeySpec), castTo(&callerFree))
if result == 0 {
return err
}
err = nil
*shouldFree = (callerFree == 1)
return
}
func ncryptFreeObject(obj syscall.Handle) (err error) {
errno, _, _ := procNCryptFreeObject.Call(uintptr(obj))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return
}
func ncryptSignHash(prov syscall.Handle, paddingInfo uintptr, hash []byte, cbHash uint32, dwFlags uint32) (signature []byte, err error) {
var sigSize uint32
errno, _, _ := procNCryptSignHash.Call(uintptr(prov), 0x0, castTo(&hash[0]), uintptr(cbHash), 0x0, 0x0, castTo(&sigSize), uintptr(dwFlags))
err = syscall.Errno(errno)
if err != windows.ERROR_SUCCESS {
return nil, wrapError("failed to get signature size", err)
}
signature = make([]byte, sigSize)
errno, _, _ = procNCryptSignHash.Call(uintptr(prov), paddingInfo, castTo(&hash[0]), uintptr(cbHash), castTo(&signature[0]), uintptr(sigSize), castTo(&sigSize), uintptr(dwFlags))
err = syscall.Errno(errno)
if err == windows.ERROR_SUCCESS {
err = nil
}
return
}
func ecdsaConvertIEEEP1363ToASN1(src []byte) ([]byte, error) {
// R and S
var sigs = [2]*big.Int{new(big.Int), new(big.Int)}
sigs[0].SetBytes(src[:len(src)/2])
sigs[1].SetBytes(src[len(src)/2:])
return asn1.Marshal(sigs[:])
}
func ncryptGetPrivateKeyType(prov syscall.Handle) (ktype keyType, err error) {
var propertySize uint32
ktype = keyTypeUnspecified
err = ncryptGetProperty(prov, NCRYPT_ALGORITHM_GROUP_PROPERTY[:], 0x0, 0x0, &propertySize, 0)
if err != nil {
return ktype, wrapError("failed to query algorithm group size", err)
}
if propertySize == 0 || propertySize&0x1 == 0x1 {
return ktype, fmt.Errorf("invalid property size: %d", propertySize)
}
var alg = make([]uint16, propertySize/2)
err = ncryptGetProperty(prov, NCRYPT_ALGORITHM_GROUP_PROPERTY[:], castTo(&alg[0]), propertySize, &propertySize, 0)
if err != nil {
return ktype, wrapError("failed to query algorithm group", err)
}
str := windows.UTF16PtrToString(&alg[0])
fs.Debugf(nil, "Algorithm Group: %s", str)
if reflect.DeepEqual(alg, NCRYPT_ECDH_ALGORITHM_GROUP[:]) || reflect.DeepEqual(alg, NCRYPT_ECDSA_ALGORITHM_GROUP[:]) {
ktype = keyTypeECDSA
} else if reflect.DeepEqual(alg, NCRYPT_RSA_ALGORITHM_GROUP[:]) {
ktype = keyTypeRSA
} else {
return ktype, fmt.Errorf("unsupported private key algorithm group: %v", str)
}
return ktype, nil
}