From e38228b5f13b2f2225c99f329336c9d92f95c632 Mon Sep 17 00:00:00 2001 From: Francis Lavoie Date: Fri, 17 Apr 2026 11:45:16 -0400 Subject: [PATCH] reverseproxy: Test that WebSockets + unix-sockets works --- caddytest/integration/reverseproxy_test.go | 131 +++++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/caddytest/integration/reverseproxy_test.go b/caddytest/integration/reverseproxy_test.go index 6e0b3dcff..63ef8b895 100644 --- a/caddytest/integration/reverseproxy_test.go +++ b/caddytest/integration/reverseproxy_test.go @@ -1,9 +1,14 @@ package integration import ( + "bufio" + "crypto/sha1" + "encoding/base64" "fmt" + "io" "net" "net/http" + "net/textproto" "os" "runtime" "strings" @@ -562,3 +567,129 @@ func TestReverseProxyHealthCheckUnixSocketWithoutPort(t *testing.T) { tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!") } + +func TestReverseProxyWebSocketUpgradeUnixSocket(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + f, err := os.CreateTemp("", "*.sock") + if err != nil { + t.Fatalf("failed to create temporary socket file: %v", err) + } + _ = os.Remove(f.Name()) + socketName := f.Name() + + backend := http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path != "/ws" { + http.NotFound(w, req) + return + } + + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + http.Error(w, "missing websocket upgrade headers", http.StatusBadRequest) + return + } + + wsKey := req.Header.Get("Sec-WebSocket-Key") + if wsKey == "" { + http.Error(w, "missing Sec-WebSocket-Key", http.StatusBadRequest) + return + } + + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacker not supported", http.StatusInternalServerError) + return + } + + conn, brw, err := hj.Hijack() + if err != nil { + return + } + defer conn.Close() + + _, _ = brw.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + _, _ = brw.WriteString("Upgrade: websocket\r\n") + _, _ = brw.WriteString("Connection: Upgrade\r\n") + _, _ = brw.WriteString("Sec-WebSocket-Accept: " + computeWebSocketAccept(wsKey) + "\r\n") + _, _ = brw.WriteString("\r\n") + _ = brw.Flush() + }), + } + + unixListener, err := net.Listen("unix", socketName) + if err != nil { + t.Fatalf("failed to listen on unix socket: %v", err) + } + go backend.Serve(unixListener) + t.Cleanup(func() { + _ = backend.Close() + _ = unixListener.Close() + _ = os.Remove(socketName) + }) + runtime.Gosched() + + tester := caddytest.NewTester(t) + tester.InitServer(fmt.Sprintf(` + { + skip_install_trust + admin localhost:2999 + http_port 9080 + https_port 9443 + grace_period 1ns + } + http://localhost:9080 { + reverse_proxy unix/%s + } + `, socketName), "caddyfile") + + conn, err := net.Dial("tcp", "127.0.0.1:9080") + if err != nil { + t.Fatalf("failed to dial caddy listener: %v", err) + } + defer conn.Close() + + wsKey := "dGhlIHNhbXBsZSBub25jZQ==" + request := strings.Join([]string{ + "GET /ws HTTP/1.1", + "Host: localhost:9080", + "Connection: Upgrade", + "Upgrade: websocket", + "Sec-WebSocket-Version: 13", + "Sec-WebSocket-Key: " + wsKey, + "", + "", + }, "\r\n") + + if _, err := io.WriteString(conn, request); err != nil { + t.Fatalf("failed to send websocket handshake request: %v", err) + } + + tpr := textproto.NewReader(bufio.NewReader(conn)) + statusLine, err := tpr.ReadLine() + if err != nil { + t.Fatalf("failed reading handshake status line: %v", err) + } + if !strings.Contains(statusLine, "101") || !strings.Contains(strings.ToLower(statusLine), "switching protocols") { + t.Fatalf("unexpected status line: %q", statusLine) + } + + headers, err := tpr.ReadMIMEHeader() + if err != nil { + t.Fatalf("failed reading handshake headers: %v", err) + } + if !strings.EqualFold(headers.Get("Upgrade"), "websocket") { + t.Fatalf("unexpected Upgrade header: %q", headers.Get("Upgrade")) + } + if !strings.Contains(strings.ToLower(headers.Get("Connection")), "upgrade") { + t.Fatalf("unexpected Connection header: %q", headers.Get("Connection")) + } +} + +func computeWebSocketAccept(wsKey string) string { + h := sha1.Sum([]byte(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + return base64.StdEncoding.EncodeToString(h[:]) +}