diff --git a/caddyhttp/httpserver/https.go b/caddyhttp/httpserver/https.go index 929897145..60b0c2429 100644 --- a/caddyhttp/httpserver/https.go +++ b/caddyhttp/httpserver/https.go @@ -79,7 +79,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error { cfg.TLS.Enabled = true cfg.Addr.Scheme = "https" if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) { - _, err := caddytls.CacheManagedCertificate(cfg.Addr.Host, cfg.TLS) + _, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host) if err != nil { return err } diff --git a/caddyhttp/httpserver/mitm.go b/caddyhttp/httpserver/mitm.go index acb23244f..d83fa71cb 100644 --- a/caddyhttp/httpserver/mitm.go +++ b/caddyhttp/httpserver/mitm.go @@ -35,6 +35,11 @@ type tlsHandler struct { // Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17): // https://jhalderm.com/pub/papers/interception-ndss17.pdf func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.listener == nil { + h.next.ServeHTTP(w, r) + return + } + h.listener.helloInfosMu.RLock() info := h.listener.helloInfos[r.RemoteAddr] h.listener.helloInfosMu.RUnlock() @@ -78,63 +83,62 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.next.ServeHTTP(w, r) } +// clientHelloConn reads the ClientHello +// and stores it in the attached listener. type clientHelloConn struct { net.Conn - readHello bool listener *tlsHelloListener + readHello bool // whether ClientHello has been read + buf *bytes.Buffer } +// Read reads from c.Conn (by letting the standard library +// do the reading off the wire), with the exception of +// getting a copy of the ClientHello so it can parse it. func (c *clientHelloConn) Read(b []byte) (n int, err error) { - if !c.readHello { - // Read the header bytes. - hdr := make([]byte, 5) - n, err := io.ReadFull(c.Conn, hdr) - if err != nil { - return n, err - } - - // Get the length of the ClientHello message and read it as well. - length := uint16(hdr[3])<<8 | uint16(hdr[4]) - hello := make([]byte, int(length)) - n, err = io.ReadFull(c.Conn, hello) - if err != nil { - return n, err - } - - // Parse the ClientHello and store it in the map. - rawParsed := parseRawClientHello(hello) - c.listener.helloInfosMu.Lock() - c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed - c.listener.helloInfosMu.Unlock() - - // Since we buffered the header and ClientHello, pretend we were - // never here by lining up the buffered values to be read with a - // custom connection type, followed by the rest of the actual - // underlying connection. - mr := io.MultiReader(bytes.NewReader(hdr), bytes.NewReader(hello), c.Conn) - mc := multiConn{Conn: c.Conn, reader: mr} - - c.Conn = mc - - c.readHello = true + // if we've already read the ClientHello, pass thru + if c.readHello { + return c.Conn.Read(b) } - return c.Conn.Read(b) -} -// multiConn is a net.Conn that reads from the -// given reader instead of the wire directly. This -// is useful when some of the connection has already -// been read (like the TLS Client Hello) and the -// reader is a io.MultiReader that starts with -// the contents of the buffer. -type multiConn struct { - net.Conn - reader io.Reader -} + // we let the standard lib read off the wire for us, and + // tee that into our buffer so we can read the ClientHello + tee := io.TeeReader(c.Conn, c.buf) + n, err = tee.Read(b) + if err != nil { + return + } + if c.buf.Len() < 5 { + return // need to read more bytes for header + } -// Read reads from mc.reader. -func (mc multiConn) Read(b []byte) (n int, err error) { - return mc.reader.Read(b) + // read the header bytes + hdr := make([]byte, 5) + _, err = io.ReadFull(c.buf, hdr) + if err != nil { + return // this would be highly unusual and sad + } + + // get length of the ClientHello message and read it + length := int(uint16(hdr[3])<<8 | uint16(hdr[4])) + if c.buf.Len() < length { + return // need to read more bytes + } + hello := make([]byte, length) + _, err = io.ReadFull(c.buf, hello) + if err != nil { + return + } + c.buf = nil // buffer no longer needed + + // parse the ClientHello and store it in the map + rawParsed := parseRawClientHello(hello) + c.listener.helloInfosMu.Lock() + c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed + c.listener.helloInfosMu.Unlock() + + c.readHello = true + return } // parseRawClientHello parses data which contains the raw @@ -279,7 +283,7 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) { if err != nil { return nil, err } - helloConn := &clientHelloConn{Conn: conn, listener: l} + helloConn := &clientHelloConn{Conn: conn, listener: l, buf: new(bytes.Buffer)} return tls.Server(helloConn, l.config), nil } diff --git a/caddyhttp/httpserver/mitm_test.go b/caddyhttp/httpserver/mitm_test.go index e5c75af8a..721926972 100644 --- a/caddyhttp/httpserver/mitm_test.go +++ b/caddyhttp/httpserver/mitm_test.go @@ -84,7 +84,7 @@ func TestHeuristicFunctions(t *testing.T) { // clientHello pairs a User-Agent string to its ClientHello message. type clientHello struct { userAgent string - helloHex string + helloHex string // do NOT include the header, just the ClientHello message } // clientHellos groups samples of true (real) ClientHellos by the @@ -158,7 +158,12 @@ func TestHeuristicFunctions(t *testing.T) { }, { // IE 11 on Windows 7, this connection was intercepted by Blue Coat - helloHex: "010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100", + helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`, + }, + { + // Firefox 51.0.1 being intercepted by burp 1.7.17 + userAgent: "(TODO)", + helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`, }, }, } diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index 2006ecfcb..538c2fe15 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -31,40 +31,47 @@ type Server struct { connTimeout time.Duration // max time to wait for a connection before force stop tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine vhosts *vhostTrie - tlsConfig caddytls.ConfigGroup } // ensure it satisfies the interface var _ caddy.GracefulServer = new(Server) +var defaultALPN = []string{"h2", "http/1.1"} + +// makeTLSConfig extracts TLS settings from each site config to +// build a tls.Config usable in Caddy HTTP servers. The returned +// config will be nil if TLS is disabled for these sites. +func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) { + var tlsConfigs []*caddytls.Config + for i := range group { + if HTTP2 && len(group[i].TLS.ALPN) == 0 { + // if no application-level protocol was configured up to now, + // default to HTTP/2, then HTTP/1.1 if necessary + group[i].TLS.ALPN = defaultALPN + } + tlsConfigs = append(tlsConfigs, group[i].TLS) + } + return caddytls.MakeTLSConfig(tlsConfigs) +} + // NewServer creates a new Server instance that will listen on addr // and will serve the sites configured in group. func NewServer(addr string, group []*SiteConfig) (*Server, error) { s := &Server{ - Server: makeHTTPServer(addr, group), + Server: makeHTTPServerWithTimeouts(addr, group), vhosts: newVHostTrie(), sites: group, connTimeout: GracefulTimeout, } - s.Server.Handler = s // this is weird, but whatever - tlsh := &tlsHandler{next: s.Server.Handler} - s.Server.ConnState = func(c net.Conn, cs http.ConnState) { - // when a connection closes or is hijacked, delete its entry - // in the map, because we are done with it. - if tlsh.listener != nil { - if cs == http.StateHijacked || cs == http.StateClosed { - tlsh.listener.helloInfosMu.Lock() - delete(tlsh.listener.helloInfos, c.RemoteAddr().String()) - tlsh.listener.helloInfosMu.Unlock() - } - } - } - // Disable HTTP/2 if desired - if !HTTP2 { - s.Server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + // extract TLS settings from each site config to build + // a tls.Config, which will not be nil if TLS is enabled + tlsConfig, err := makeTLSConfig(group) + if err != nil { + return nil, err } + s.Server.TLSConfig = tlsConfig // Enable QUIC if desired if QUIC { @@ -72,41 +79,36 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) } - // Set up TLS configuration - tlsConfigs := make(caddytls.ConfigGroup) - var allConfigs []*caddytls.Config - - for _, site := range group { - - if err := site.TLS.Build(tlsConfigs); err != nil { - return nil, err - } - - tlsConfigs[site.TLS.Hostname] = site.TLS - allConfigs = append(allConfigs, site.TLS) - } - - // Check if configs are valid - if err := caddytls.CheckConfigs(allConfigs); err != nil { - return nil, err - } - - s.tlsConfig = tlsConfigs - - if caddytls.HasTLSEnabled(allConfigs) { - s.Server.TLSConfig = &tls.Config{ - GetConfigForClient: s.tlsConfig.GetConfigForClient, - GetCertificate: s.tlsConfig.GetCertificate, - } - } - - // As of Go 1.7, HTTP/2 is enabled only if NextProtos includes the string "h2" - if HTTP2 && s.Server.TLSConfig != nil && len(s.Server.TLSConfig.NextProtos) == 0 { - s.Server.TLSConfig.NextProtos = []string{"h2"} - } - + // if TLS is enabled, make sure we prepare the Server accordingly if s.Server.TLSConfig != nil { - s.Server.Handler = tlsh + // wrap the HTTP handler with a handler that does MITM detection + tlsh := &tlsHandler{next: s.Server.Handler} + s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion + + // when Serve() creates the TLS listener later, that listener should + // be adding a reference the ClientHello info to a map; this callback + // will be sure to clear out that entry when the connection closes. + s.Server.ConnState = func(c net.Conn, cs http.ConnState) { + // when a connection closes or is hijacked, delete its entry + // in the map, because we are done with it. + if tlsh.listener != nil { + if cs == http.StateHijacked || cs == http.StateClosed { + tlsh.listener.helloInfosMu.Lock() + delete(tlsh.listener.helloInfos, c.RemoteAddr().String()) + tlsh.listener.helloInfosMu.Unlock() + } + } + } + + // As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only + // if TLSConfig.NextProtos includes the string "h2" + if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 { + // some experimenting shows that this NextProtos must have at least + // one value that overlaps with the NextProtos of any other tls.Config + // that is returned from GetConfigForClient; if there is no overlap, + // the connection will fail (as of Go 1.8, Feb. 2017). + s.Server.TLSConfig.NextProtos = defaultALPN + } } // Compile custom middleware for every site (enables virtual hosting) @@ -122,6 +124,61 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { return s, nil } +// makeHTTPServerWithTimeouts makes an http.Server from the group of +// configs in a way that configures timeouts (or, if not set, it uses +// the default timeouts) by combining the configuration of each +// SiteConfig in the group. (Timeouts are important for mitigating +// slowloris attacks.) +func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server { + // find the minimum duration configured for each timeout + var min Timeouts + for _, cfg := range group { + if cfg.Timeouts.ReadTimeoutSet && + (!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) { + min.ReadTimeoutSet = true + min.ReadTimeout = cfg.Timeouts.ReadTimeout + } + if cfg.Timeouts.ReadHeaderTimeoutSet && + (!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) { + min.ReadHeaderTimeoutSet = true + min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout + } + if cfg.Timeouts.WriteTimeoutSet && + (!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) { + min.WriteTimeoutSet = true + min.WriteTimeout = cfg.Timeouts.WriteTimeout + } + if cfg.Timeouts.IdleTimeoutSet && + (!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) { + min.IdleTimeoutSet = true + min.IdleTimeout = cfg.Timeouts.IdleTimeout + } + } + + // for the values that were not set, use defaults + if !min.ReadTimeoutSet { + min.ReadTimeout = defaultTimeouts.ReadTimeout + } + if !min.ReadHeaderTimeoutSet { + min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout + } + if !min.WriteTimeoutSet { + min.WriteTimeout = defaultTimeouts.WriteTimeout + } + if !min.IdleTimeoutSet { + min.IdleTimeout = defaultTimeouts.IdleTimeout + } + + // set the final values on the server and return it + return &http.Server{ + Addr: addr, + ReadTimeout: min.ReadTimeout, + ReadHeaderTimeout: min.ReadHeaderTimeout, + WriteTimeout: min.WriteTimeout, + IdleTimeout: min.IdleTimeout, + } +} + func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.quicServer.SetQuicHeaders(w.Header()) @@ -390,62 +447,6 @@ var defaultTimeouts = Timeouts{ IdleTimeout: 2 * time.Minute, } -// makeHTTPServer makes an http.Server from the group of configs -// in a way that configures timeouts (or, if not set, it uses the -// default timeouts) and other http.Server properties by combining -// the configuration of each SiteConfig in the group. (Timeouts -// are important for mitigating slowloris attacks.) -func makeHTTPServer(addr string, group []*SiteConfig) *http.Server { - s := &http.Server{Addr: addr} - - // find the minimum duration configured for each timeout - var min Timeouts - for _, cfg := range group { - if cfg.Timeouts.ReadTimeoutSet && - (!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) { - min.ReadTimeoutSet = true - min.ReadTimeout = cfg.Timeouts.ReadTimeout - } - if cfg.Timeouts.ReadHeaderTimeoutSet && - (!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) { - min.ReadHeaderTimeoutSet = true - min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout - } - if cfg.Timeouts.WriteTimeoutSet && - (!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) { - min.WriteTimeoutSet = true - min.WriteTimeout = cfg.Timeouts.WriteTimeout - } - if cfg.Timeouts.IdleTimeoutSet && - (!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) { - min.IdleTimeoutSet = true - min.IdleTimeout = cfg.Timeouts.IdleTimeout - } - } - - // for the values that were not set, use defaults - if !min.ReadTimeoutSet { - min.ReadTimeout = defaultTimeouts.ReadTimeout - } - if !min.ReadHeaderTimeoutSet { - min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout - } - if !min.WriteTimeoutSet { - min.WriteTimeout = defaultTimeouts.WriteTimeout - } - if !min.IdleTimeoutSet { - min.IdleTimeout = defaultTimeouts.IdleTimeout - } - - // set the final values on the server - s.ReadTimeout = min.ReadTimeout - s.ReadHeaderTimeout = min.ReadHeaderTimeout - s.WriteTimeout = min.WriteTimeout - s.IdleTimeout = min.IdleTimeout - - return s -} - // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually diff --git a/caddyhttp/httpserver/server_test.go b/caddyhttp/httpserver/server_test.go index dc926c596..69d2f7453 100644 --- a/caddyhttp/httpserver/server_test.go +++ b/caddyhttp/httpserver/server_test.go @@ -92,7 +92,7 @@ func TestMakeHTTPServer(t *testing.T) { }, }, } { - actual := makeHTTPServer("127.0.0.1:9005", tc.group) + actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group) if got, want := actual.Addr, "127.0.0.1:9005"; got != want { t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got) diff --git a/caddytls/certificates.go b/caddytls/certificates.go index bc46dd078..d3aee8f0d 100644 --- a/caddytls/certificates.go +++ b/caddytls/certificates.go @@ -89,8 +89,8 @@ func getCertificate(name string) (cert Certificate, matched, defaulted bool) { // cache, flagging it as Managed and, if onDemand is true, as "OnDemand" // (meaning that it was obtained or loaded during a TLS handshake). // -// This function is safe for concurrent use. -func CacheManagedCertificate(domain string, cfg *Config) (Certificate, error) { +// This method is safe for concurrent use. +func (cfg *Config) CacheManagedCertificate(domain string) (Certificate, error) { storage, err := cfg.StorageFor(cfg.CAUrl) if err != nil { return Certificate{}, err diff --git a/caddytls/config.go b/caddytls/config.go index 840cfebdf..d6d779997 100644 --- a/caddytls/config.go +++ b/caddytls/config.go @@ -109,11 +109,11 @@ type Config struct { // Add the must staple TLS extension to the CSR generated by lego/acme MustStaple bool - // Disables HTTP2 completely - DisableHTTP2 bool + // The list of protocols to choose from for Application Layer + // Protocol Negotiation (ALPN). + ALPN []string - // Holds final tls.Config - tlsConfig *tls.Config + tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig() } // OnDemandState contains some state relevant for providing @@ -223,33 +223,20 @@ func (c *Config) StorageFor(caURL string) (Storage, error) { return s, nil } -func (cfg *Config) Build(group ConfigGroup) error { - config, err := cfg.build() - - if err != nil { - return err - } - - if config != nil { - cfg.tlsConfig = config - cfg.tlsConfig.GetCertificate = group.GetCertificate - } - - return nil - -} - -func (cfg *Config) build() (*tls.Config, error) { - config := new(tls.Config) - +// buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config +// and stores it in cfg so it can be used in servers. If TLS is disabled, +// no tls.Config is created. +func (cfg *Config) buildStandardTLSConfig() error { if !cfg.Enabled { - return nil, nil + return nil } + config := new(tls.Config) + ciphersAdded := make(map[uint16]struct{}) curvesAdded := make(map[tls.CurveID]struct{}) - // Add cipher suites + // add cipher suites for _, ciph := range cfg.Ciphers { if _, ok := ciphersAdded[ciph]; !ok { ciphersAdded[ciph] = struct{}{} @@ -259,7 +246,7 @@ func (cfg *Config) build() (*tls.Config, error) { config.PreferServerCipherSuites = cfg.PreferServerCipherSuites - // Union curves + // add curve preferences for _, curv := range cfg.CurvePreferences { if _, ok := curvesAdded[curv]; !ok { curvesAdded[curv] = struct{}{} @@ -270,8 +257,10 @@ func (cfg *Config) build() (*tls.Config, error) { config.MinVersion = cfg.ProtocolMinVersion config.MaxVersion = cfg.ProtocolMaxVersion config.ClientAuth = cfg.ClientAuth + config.NextProtos = cfg.ALPN + config.GetCertificate = cfg.GetCertificate - // Set up client authentication if enabled + // set up client authentication if enabled if config.ClientAuth != tls.NoClientCert { pool := x509.NewCertPool() clientCertsAdded := make(map[string]struct{}) @@ -286,45 +275,51 @@ func (cfg *Config) build() (*tls.Config, error) { // Any client with a certificate from this CA will be allowed to connect caCrt, err := ioutil.ReadFile(caFile) if err != nil { - return nil, err + return err } if !pool.AppendCertsFromPEM(caCrt) { - return nil, fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) + return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) } } config.ClientCAs = pool } - // Default cipher suites + // default cipher suites if len(config.CipherSuites) == 0 { config.CipherSuites = defaultCiphers } - // For security, ensure TLS_FALLBACK_SCSV is always included first + // for security, ensure TLS_FALLBACK_SCSV is always included first if len(config.CipherSuites) == 0 || config.CipherSuites[0] != tls.TLS_FALLBACK_SCSV { config.CipherSuites = append([]uint16{tls.TLS_FALLBACK_SCSV}, config.CipherSuites...) } - if cfg.DisableHTTP2 { - config.NextProtos = []string{} - } else { - config.NextProtos = []string{"h2"} - } + // store the resulting new tls.Config + cfg.tlsConfig = config - return config, nil + return nil } -// CheckConfigs checks if multiple TLS configs does not collide with each other -func CheckConfigs(configs []*Config) error { +// MakeTLSConfig makes a tls.Config from configs. The returned +// tls.Config is programmed to load the matching caddytls.Config +// based on the hostname in SNI, but that's all. +func MakeTLSConfig(configs []*Config) (*tls.Config, error) { if len(configs) == 0 { - return nil + return nil, nil } - for i, cfg := range configs { + configMap := make(configGroup) - // Can't serve TLS and not-TLS on same port + for i, cfg := range configs { + if cfg == nil { + // avoid nil pointer dereference below this loop + configs[i] = new(Config) + continue + } + + // can't serve TLS and non-TLS on same port if i > 0 && cfg.Enabled != configs[i-1].Enabled { thisConfProto, lastConfProto := "not TLS", "not TLS" if cfg.Enabled { @@ -333,26 +328,33 @@ func CheckConfigs(configs []*Config) error { if configs[i-1].Enabled { lastConfProto = "TLS" } - return fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener", + return nil, fmt.Errorf("cannot multiplex %s (%s) and %s (%s) on same listener", configs[i-1].Hostname, lastConfProto, cfg.Hostname, thisConfProto) } - if !cfg.Enabled { - continue + // convert each caddytls.Config into a tls.Config + if err := cfg.buildStandardTLSConfig(); err != nil { + return nil, err } + + // Key this config by its hostname (overwriting + // configs with the same hostname pattern); during + // TLS handshakes, configs are loaded based on + // the hostname pattern, according to client's SNI. + configMap[cfg.Hostname] = cfg } - return nil -} - -func HasTLSEnabled(configs []*Config) bool { - for _, config := range configs { - if config.Enabled { - return true - } + // Is TLS disabled? By now, we know that all + // configs agree whether it is or not, so we + // can just look at the first one. If so, + // we're done here. + if len(configs) == 0 || !configs[0].Enabled { + return nil, nil } - return false + return &tls.Config{ + GetConfigForClient: configMap.GetConfigForClient, + }, nil } // ConfigGetter gets a Config keyed by key. diff --git a/caddytls/config_test.go b/caddytls/config_test.go index 6440de43a..5fccb1bad 100644 --- a/caddytls/config_test.go +++ b/caddytls/config_test.go @@ -8,50 +8,50 @@ import ( "testing" ) -func TestMakeTLSConfigProtocolVersions(t *testing.T) { +func TestConvertTLSConfigProtocolVersions(t *testing.T) { // same min and max protocol versions - config := Config{ + config := &Config{ Enabled: true, ProtocolMinVersion: tls.VersionTLS12, ProtocolMaxVersion: tls.VersionTLS12, } - result, err := config.build() + err := config.buildStandardTLSConfig() if err != nil { t.Fatalf("Did not expect an error, but got %v", err) } - if got, want := result.MinVersion, uint16(tls.VersionTLS12); got != want { + if got, want := config.tlsConfig.MinVersion, uint16(tls.VersionTLS12); got != want { t.Errorf("Expected min version to be %x, got %x", want, got) } - if got, want := result.MaxVersion, uint16(tls.VersionTLS12); got != want { + if got, want := config.tlsConfig.MaxVersion, uint16(tls.VersionTLS12); got != want { t.Errorf("Expected max version to be %x, got %x", want, got) } } -func TestMakeTLSConfigPreferServerCipherSuites(t *testing.T) { +func TestConvertTLSConfigPreferServerCipherSuites(t *testing.T) { // prefer server cipher suites config := Config{Enabled: true, PreferServerCipherSuites: true} - result, err := config.build() + err := config.buildStandardTLSConfig() if err != nil { t.Fatalf("Did not expect an error, but got %v", err) } - if got, want := result.PreferServerCipherSuites, true; got != want { + if got, want := config.tlsConfig.PreferServerCipherSuites, true; got != want { t.Errorf("Expected PreferServerCipherSuites==%v but got %v", want, got) } } -func TestMakeTLSConfigTLSEnabledDisabled(t *testing.T) { +func TestMakeTLSConfigTLSEnabledDisabledError(t *testing.T) { // verify handling when Enabled is true and false configs := []*Config{ {Enabled: true}, {Enabled: false}, } - err := CheckConfigs(configs) + _, err := MakeTLSConfig(configs) if err == nil { t.Fatalf("Expected an error, but got %v", err) } } -func TestMakeTLSConfigCipherSuites(t *testing.T) { +func TestConvertTLSConfigCipherSuites(t *testing.T) { // ensure cipher suites are unioned and // that TLS_FALLBACK_SCSV is prepended configs := []*Config{ @@ -67,10 +67,13 @@ func TestMakeTLSConfigCipherSuites(t *testing.T) { } for i, config := range configs { - cfg, _ := config.build() - - if !reflect.DeepEqual(cfg.CipherSuites, expectedCiphers[i]) { - t.Errorf("Expected ciphers %v but got %v", expectedCiphers[i], cfg.CipherSuites) + err := config.buildStandardTLSConfig() + if err != nil { + t.Errorf("Test %d: Expected no error, got: %v", i, err) + } + if !reflect.DeepEqual(config.tlsConfig.CipherSuites, expectedCiphers[i]) { + t.Errorf("Test %d: Expected ciphers %v but got %v", + i, expectedCiphers[i], config.tlsConfig.CipherSuites) } } diff --git a/caddytls/handshake.go b/caddytls/handshake.go index eaf6422ff..52d0a0a7d 100644 --- a/caddytls/handshake.go +++ b/caddytls/handshake.go @@ -13,18 +13,19 @@ import ( // configGroup is a type that keys configs by their hostname // (hostnames can have wildcard characters; use the getConfig -// method to get a config by matching its hostname). Its -// GetCertificate function can be used with tls.Config. -type ConfigGroup map[string]*Config +// method to get a config by matching its hostname). +type configGroup map[string]*Config // getConfig gets the config by the first key match for name. // In other words, "sub.foo.bar" will get the config for "*.foo.bar" -// if that is the closest match. This function MAY return nil -// if no match is found. +// if that is the closest match. If no match is found, the first +// (random) config will be loaded, which will defer any TLS alerts +// to the certificate validation (this may or may not be ideal; +// let's talk about it if this becomes problematic). // // This function follows nearly the same logic to lookup // a hostname as the getCertificate function uses. -func (cg ConfigGroup) getConfig(name string) *Config { +func (cg configGroup) getConfig(name string) *Config { name = strings.ToLower(name) // exact match? great, let's use it @@ -42,14 +43,36 @@ func (cg ConfigGroup) getConfig(name string) *Config { } } - // as last resort, try a config that serves all names + // as a fallback, try a config that serves all names if config, ok := cg[""]; ok { return config } + // as a last resort, use a random config + // (even if the config isn't for that hostname, + // it should help us serve clients without SNI + // or at least defer TLS alerts to the cert) + for _, config := range cg { + return config + } + return nil } +// GetConfigForClient gets a TLS configuration satisfying clientHello. +// In getting the configuration, it abides the rules and settings +// defined in the Config that matches clientHello.ServerName. If no +// tls.Config is set on the matching Config, a nil value is returned. +// +// This method is safe for use as a tls.Config.GetConfigForClient callback. +func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { + config := cg.getConfig(clientHello.ServerName) + if config != nil { + return config.tlsConfig, nil + } + return nil, nil +} + // GetCertificate gets a certificate to satisfy clientHello. In getting // the certificate, it abides the rules and settings defined in the // Config that matches clientHello.ServerName. It first checks the in- @@ -58,27 +81,11 @@ func (cg ConfigGroup) getConfig(name string) *Config { // via ACME. // // This method is safe for use as a tls.Config.GetCertificate callback. -func (cg ConfigGroup) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := cg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true) +func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true) return &cert.Certificate, err } -// GetConfigForClient gets a TLS configuration satisfying clientHello. In getting -// the configuration, it abides the rules and settings defined in the -// Config that matches clientHello.ServerName. -// -// This method is safe for use as a tls.Config.GetConfigForClient callback. -func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { - - config := cg.getConfig(clientHello.ServerName) - - if config != nil { - return config.tlsConfig, nil - } - - return nil, nil -} - // getCertDuringHandshake will get a certificate for name. It first tries // the in-memory cache. If no certificate for name is in the cache, the // config most closely corresponding to name will be loaded. If that config @@ -90,21 +97,20 @@ func (cg ConfigGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls // certificate is available. // // This function is safe for concurrent use. -func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { +func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { // First check our in-memory cache to see if we've already loaded it cert, matched, defaulted := getCertificate(name) if matched { return cert, nil } - // Get the relevant TLS config for this name. If OnDemand is enabled, - // then we might be able to load or obtain a needed certificate. - cfg := cg.getConfig(name) - if cfg != nil && cfg.OnDemand && loadIfNecessary { + // If OnDemand is enabled, then we might be able to load or + // obtain a needed certificate + if cfg.OnDemand && loadIfNecessary { // Then check to see if we have one on disk - loadedCert, err := CacheManagedCertificate(name, cfg) + loadedCert, err := cfg.CacheManagedCertificate(name) if err == nil { - loadedCert, err = cg.handshakeMaintenance(name, loadedCert) + loadedCert, err = cfg.handshakeMaintenance(name, loadedCert) if err != nil { log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) } @@ -116,7 +122,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai name = strings.ToLower(name) // Make sure aren't over any applicable limits - err := cg.checkLimitsForObtainingNewCerts(name, cfg) + err := cfg.checkLimitsForObtainingNewCerts(name) if err != nil { return Certificate{}, err } @@ -127,7 +133,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai } // Obtain certificate from the CA - return cg.obtainOnDemandCertificate(name, cfg) + return cfg.obtainOnDemandCertificate(name) } } @@ -143,7 +149,7 @@ func (cg ConfigGroup) getCertDuringHandshake(name string, loadIfNecessary, obtai // now according to mitigating factors we keep track of and preferences the // user has set. If a non-nil error is returned, do not issue a new certificate // for name. -func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) error { +func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error { // User can set hard limit for number of certs for the process to issue if cfg.OnDemandState.MaxObtain > 0 && atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain { @@ -167,7 +173,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since) } - // 👍Good to go + // Good to go 👍 return nil } @@ -176,7 +182,7 @@ func (cg ConfigGroup) checkLimitsForObtainingNewCerts(name string, cfg *Config) // name, it will wait and use what the other goroutine obtained. // // This function is safe for use by multiple concurrent goroutines. -func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certificate, error) { +func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) { // We must protect this process from happening concurrently, so synchronize. obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] @@ -185,7 +191,7 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi // wait for it to finish obtaining the cert and then we'll use it. obtainCertWaitChansMu.Unlock() <-wait - return cg.getCertDuringHandshake(name, true, false) + return cfg.getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and obtain the cert. @@ -228,19 +234,19 @@ func (cg ConfigGroup) obtainOnDemandCertificate(name string, cfg *Config) (Certi lastIssueTimeMu.Unlock() // certificate is already on disk; now just start over to load it and serve it - return cg.getCertDuringHandshake(name, true, false) + return cfg.getCertDuringHandshake(name, true, false) } // handshakeMaintenance performs a check on cert for expiration and OCSP // validity. // // This function is safe for use by multiple concurrent goroutines. -func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { +func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) { // Check cert expiration timeLeft := cert.NotAfter.Sub(time.Now().UTC()) if timeLeft < RenewDurationBefore { log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - return cg.renewDynamicCertificate(name, cert.Config) + return cfg.renewDynamicCertificate(name) } // Check OCSP staple validity @@ -268,7 +274,7 @@ func (cg ConfigGroup) handshakeMaintenance(name string, cert Certificate) (Certi // usable. name should already be lower-cased before calling this function. // // This function is safe for use by multiple concurrent goroutines. -func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certificate, error) { +func (cfg *Config) renewDynamicCertificate(name string) (Certificate, error) { obtainCertWaitChansMu.Lock() wait, ok := obtainCertWaitChans[name] if ok { @@ -276,7 +282,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi // wait for it to finish, then we'll use the new one. obtainCertWaitChansMu.Unlock() <-wait - return cg.getCertDuringHandshake(name, true, false) + return cfg.getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and renew the cert @@ -300,7 +306,7 @@ func (cg ConfigGroup) renewDynamicCertificate(name string, cfg *Config) (Certifi return Certificate{}, err } - return cg.getCertDuringHandshake(name, true, false) + return cfg.getCertDuringHandshake(name, true, false) } // obtainCertWaitChans is used to coordinate obtaining certs for each hostname. diff --git a/caddytls/handshake_test.go b/caddytls/handshake_test.go index b7e35b3a9..eca8dfcb3 100644 --- a/caddytls/handshake_test.go +++ b/caddytls/handshake_test.go @@ -9,7 +9,7 @@ import ( func TestGetCertificate(t *testing.T) { defer func() { certCache = make(map[string]Certificate) }() - cg := make(ConfigGroup) + cfg := new(Config) hello := &tls.ClientHelloInfo{ServerName: "example.com"} helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} @@ -17,10 +17,10 @@ func TestGetCertificate(t *testing.T) { helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} // When cache is empty - if cert, err := cg.GetCertificate(hello); err == nil { + if cert, err := cfg.GetCertificate(hello); err == nil { t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) } - if cert, err := cg.GetCertificate(helloNoSNI); err == nil { + if cert, err := cfg.GetCertificate(helloNoSNI); err == nil { t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) } @@ -28,12 +28,12 @@ func TestGetCertificate(t *testing.T) { defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} certCache[""] = defaultCert certCache["example.com"] = defaultCert - if cert, err := cg.GetCertificate(hello); err != nil { + if cert, err := cfg.GetCertificate(hello); err != nil { t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) } - if cert, err := cg.GetCertificate(helloNoSNI); err != nil { + if cert, err := cfg.GetCertificate(helloNoSNI); err != nil { t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) @@ -41,14 +41,14 @@ func TestGetCertificate(t *testing.T) { // When retrieving wildcard certificate certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} - if cert, err := cg.GetCertificate(helloSub); err != nil { + if cert, err := cfg.GetCertificate(helloSub); err != nil { t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) } else if cert.Leaf.DNSNames[0] != "*.example.com" { t.Errorf("Got wrong certificate, expected wildcard: %v", cert) } // When no certificate matches, the default is returned - if cert, err := cg.GetCertificate(helloNoMatch); err != nil { + if cert, err := cfg.GetCertificate(helloNoMatch); err != nil { t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) } else if cert.Leaf.DNSNames[0] != "example.com" { t.Errorf("Expected default cert with no matches, got: %v", cert) diff --git a/caddytls/maintain.go b/caddytls/maintain.go index 7095d3e93..21b61e13a 100644 --- a/caddytls/maintain.go +++ b/caddytls/maintain.go @@ -152,7 +152,7 @@ func RenewManagedCertificates(allowPrompts bool) (err error) { delete(certCache, "") certCacheMu.Unlock() } - _, err := CacheManagedCertificate(cert.Names[0], cert.Config) + _, err := cert.Config.CacheManagedCertificate(cert.Names[0]) if err != nil { if allowPrompts { return err // operator is present, so report error immediately diff --git a/caddytls/setup.go b/caddytls/setup.go index 3fb9d02b1..11668b8a3 100644 --- a/caddytls/setup.go +++ b/caddytls/setup.go @@ -164,21 +164,15 @@ func setupTLS(c *caddy.Controller) error { return c.Errf("Unsupported Storage provider '%s'", args[0]) } config.StorageProvider = args[0] - - case "http2": + case "alpn": args := c.RemainingArgs() - if len(args) != 1 { + if len(args) == 0 { return c.ArgErr() } - - switch args[0] { - case "off": - config.DisableHTTP2 = true - default: - c.ArgErr() + for _, arg := range args { + config.ALPN = append(config.ALPN, arg) } - - case "muststaple": + case "must_staple": config.MustStaple = true default: return c.Errf("Unknown keyword '%s'", c.Val()) diff --git a/caddytls/setup_test.go b/caddytls/setup_test.go index ce0c9adb5..ed18bb0de 100644 --- a/caddytls/setup_test.go +++ b/caddytls/setup_test.go @@ -91,8 +91,8 @@ func TestSetupParseBasic(t *testing.T) { t.Error("Expected PreferServerCipherSuites = true, but was false") } - if cfg.DisableHTTP2 { - t.Error("Expected HTTP2 to be enabled by default") + if len(cfg.ALPN) != 0 { + t.Error("Expected ALPN empty by default") } // Ensure curve count is correct @@ -121,8 +121,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) { params := `tls ` + certFile + ` ` + keyFile + ` { protocols tls1.0 tls1.2 ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 - muststaple - http2 off + must_staple + alpn http/1.1 }` cfg := new(Config) RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg }) @@ -149,8 +149,8 @@ func TestSetupParseWithOptionalParams(t *testing.T) { t.Error("Expected must staple to be true") } - if !cfg.DisableHTTP2 { - t.Error("Expected HTTP2 to be disabled") + if len(cfg.ALPN) != 1 || cfg.ALPN[0] != "http/1.1" { + t.Errorf("Expected ALPN to contain only 'http/1.1' but got: %v", cfg.ALPN) } }