reverseproxy: allow specifying ip version for dynamic a upstream (#5401)

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
Emily Lange
2023-02-27 18:23:09 +01:00
committed by GitHub
parent 096971e313
commit 941eae5f61
3 changed files with 64 additions and 5 deletions

View File

@ -213,6 +213,11 @@ func (sl srvLookup) isFresh() bool {
return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
}
type ipVersions struct {
IPv4 *bool `json:"ipv4,omitempty"`
IPv6 *bool `json:"ipv6,omitempty"`
}
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
@ -240,6 +245,11 @@ type AUpstreams struct {
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
// The IP versions to resolve for. By default, both
// "ipv4" and "ipv6" will be enabled, which
// correspond to A and AAAA records respectively.
Versions *ipVersions `json:"versions,omitempty"`
resolver *net.Resolver
}
@ -286,7 +296,29 @@ func (au *AUpstreams) Provision(_ caddy.Context) error {
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
auStr := repl.ReplaceAll(au.String(), "")
resolveIpv4 := au.Versions.IPv4 == nil || *au.Versions.IPv4
resolveIpv6 := au.Versions.IPv6 == nil || *au.Versions.IPv6
// Map ipVersion early, so we can use it as part of the cache-key.
// This should be fairly inexpensive and comes and the upside of
// allowing the same dynamic upstream (name + port combination)
// to be used multiple times with different ip versions.
//
// It also forced a cache-miss if a previously cached dynamic
// upstream changes its ip version, e.g. after a config reload,
// while keeping the cache-invalidation as simple as it currently is.
var ipVersion string
switch {
case resolveIpv4 && !resolveIpv6:
ipVersion = "ip4"
case !resolveIpv4 && resolveIpv6:
ipVersion = "ip6"
default:
ipVersion = "ip"
}
auStr := repl.ReplaceAll(au.String()+ipVersion, "")
// first, use a cheap read-lock to return a cached result quickly
aAaaaMu.RLock()
@ -311,7 +343,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
ips, err := au.resolver.LookupIPAddr(r.Context(), name)
ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name)
if err != nil {
return nil, err
}