core: propagate ECH keys to the QUIC listener (#7670)

This commit is contained in:
Zen Dodd 2026-04-24 05:33:41 +10:00 committed by GitHub
parent 441d5eb062
commit 41aee97386
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 72 additions and 1 deletions

View file

@ -462,7 +462,10 @@ func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config
sqs := newSharedQUICState(tlsConf) sqs := newSharedQUICState(tlsConf)
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
//nolint:gosec //nolint:gosec
quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient} quicTlsConfig := &tls.Config{
GetConfigForClient: sqs.getConfigForClient,
GetEncryptedClientHelloKeys: sqs.getEncryptedClientHelloKeys,
}
// Require clients to verify their source address when we're handling more than 1000 handshakes per second. // Require clients to verify their source address when we're handling more than 1000 handshakes per second.
// TODO: make tunable? // TODO: make tunable?
limiter := rate.NewLimiter(1000, 1000) limiter := rate.NewLimiter(1000, 1000)
@ -540,6 +543,16 @@ func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Co
return sqs.activeTlsConf.GetConfigForClient(ch) return sqs.activeTlsConf.GetConfigForClient(ch)
} }
// getEncryptedClientHelloKeys is used as tls.Config's GetEncryptedClientHelloKeys field.
func (sqs *sharedQUICState) getEncryptedClientHelloKeys(ch *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
sqs.rmu.RLock()
defer sqs.rmu.RUnlock()
if sqs.activeTlsConf.GetEncryptedClientHelloKeys == nil {
return nil, nil
}
return sqs.activeTlsConf.GetEncryptedClientHelloKeys(ch)
}
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc // addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
// so that when cancelled, the active tls.Config will change // so that when cancelled, the active tls.Config will change
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) { func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) {

View file

@ -15,6 +15,7 @@
package caddy package caddy
import ( import (
"crypto/tls"
"reflect" "reflect"
"testing" "testing"
@ -175,6 +176,63 @@ func TestJoinNetworkAddress(t *testing.T) {
} }
} }
func TestSharedQUICStateGetEncryptedClientHelloKeys(t *testing.T) {
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
initialKeys := []tls.EncryptedClientHelloKey{{Config: []byte("initial"), PrivateKey: []byte("initial-key")}}
updatedKeys := []tls.EncryptedClientHelloKey{{Config: []byte("updated"), PrivateKey: []byte("updated-key")}}
initialConfig := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return initialKeys, nil
},
}
sqs := newSharedQUICState(initialConfig)
keys, err := sqs.getEncryptedClientHelloKeys(hello)
if err != nil {
t.Fatalf("getting initial ECH keys: %v", err)
}
if !reflect.DeepEqual(keys, initialKeys) {
t.Fatalf("unexpected initial ECH keys: got %#v, want %#v", keys, initialKeys)
}
updatedConfig := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return updatedKeys, nil
},
}
_, cancel := sqs.addState(updatedConfig)
sqs.rmu.Lock()
sqs.activeTlsConf = updatedConfig
sqs.rmu.Unlock()
keys, err = sqs.getEncryptedClientHelloKeys(hello)
if err != nil {
t.Fatalf("getting updated ECH keys: %v", err)
}
if !reflect.DeepEqual(keys, updatedKeys) {
t.Fatalf("unexpected updated ECH keys: got %#v, want %#v", keys, updatedKeys)
}
cancel(nil)
keys, err = sqs.getEncryptedClientHelloKeys(hello)
if err != nil {
t.Fatalf("getting restored ECH keys: %v", err)
}
if !reflect.DeepEqual(keys, initialKeys) {
t.Fatalf("unexpected restored ECH keys: got %#v, want %#v", keys, initialKeys)
}
}
func TestParseNetworkAddress(t *testing.T) { func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
input string input string