From ac7dc490a78bf8531d126e08c08d3d7592b1a691 Mon Sep 17 00:00:00 2001 From: Ravi Kumar L Date: Thu, 2 Jul 2026 16:46:19 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=AD=20refactor:=20safe=20structured=20?= =?UTF-8?q?failure=20logging=20for=20langfuse=20fanout=20gateway=20(#14050?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cmd/langfuse-fanout/logging.go | 77 ++++++++ .../cmd/langfuse-fanout/main.go | 110 ++++++----- .../cmd/langfuse-fanout/main_test.go | 180 ++++++++++++++++++ 3 files changed, 316 insertions(+), 51 deletions(-) create mode 100644 otel/langfuse-fanout/cmd/langfuse-fanout/logging.go diff --git a/otel/langfuse-fanout/cmd/langfuse-fanout/logging.go b/otel/langfuse-fanout/cmd/langfuse-fanout/logging.go new file mode 100644 index 0000000000..0297fadbb6 --- /dev/null +++ b/otel/langfuse-fanout/cmd/langfuse-fanout/logging.go @@ -0,0 +1,77 @@ +package main + +import ( + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" +) + +type upstreamStatusError struct { + status int +} + +func (e upstreamStatusError) Error() string { + return fmt.Sprintf("upstream status %d", e.status) +} + +func (g *gateway) writeGatewayError(w http.ResponseWriter, r *http.Request, route route, operation string, status int, message string, err error, attrs ...any) { + g.logGatewayFailure(r, route, operation, status, err, attrs...) + http.Error(w, message, status) +} + +func (g *gateway) logGatewayFailure(r *http.Request, route route, operation string, status int, err error, attrs ...any) { + fields := gatewayLogFields(r, route, operation, attrs...) + fields = append(fields, "status", status) + if err != nil { + fields = append(fields, "error", safeErrorMessage(err)) + } + if status >= http.StatusInternalServerError { + slog.Error("langfuse fanout gateway request failed", fields...) + return + } + slog.Warn("langfuse fanout gateway request failed", fields...) +} + +func (g *gateway) logGatewayWarning(r *http.Request, route route, operation string, message string, attrs ...any) { + fields := gatewayLogFields(r, route, operation, attrs...) + fields = append(fields, "warning", message) + slog.Warn("langfuse fanout gateway warning", fields...) +} + +func gatewayLogFields(r *http.Request, route route, operation string, attrs ...any) []any { + fields := []any{ + "method", r.Method, + "path", normalizeMetricPath(r.URL.Path), + "operation", operation, + "destination", routeDestinationLabel(route), + } + return append(fields, attrs...) +} + +func routeDestinationLabel(route route) string { + if strings.HasPrefix(route.path, mediaUploadProxyPath) { + return "fanout" + } + if route.destination == "" { + return centralName + } + return "tenant_" + route.destination +} + +func safeErrorMessage(err error) string { + if err == nil { + return "" + } + var upstreamErr upstreamStatusError + if errors.As(err, &upstreamErr) { + return upstreamErr.Error() + } + var urlErr *url.Error + if errors.As(err, &urlErr) { + return fmt.Sprintf("%s: URL request failed", urlErr.Op) + } + return "error details redacted" +} diff --git a/otel/langfuse-fanout/cmd/langfuse-fanout/main.go b/otel/langfuse-fanout/cmd/langfuse-fanout/main.go index 147dd3e92a..fde79c27d6 100644 --- a/otel/langfuse-fanout/cmd/langfuse-fanout/main.go +++ b/otel/langfuse-fanout/cmd/langfuse-fanout/main.go @@ -265,7 +265,7 @@ func (g *gateway) handle(w http.ResponseWriter, r *http.Request) { func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route route) { body, err := readMaybeGzip(r) if err != nil { - http.Error(w, "failed to read request body", http.StatusBadRequest) + g.writeGatewayError(w, r, route, "trace_export", http.StatusBadRequest, "failed to read request body", err) return } @@ -275,7 +275,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou strings.TrimSpace(r.Header.Get("Authorization")) != "" { body, err = addTenantRouteAttributes(body, contentType, route.destination) if err != nil { - http.Error(w, "failed to add OTLP tenant routing attributes", http.StatusBadRequest) + g.writeGatewayError(w, r, route, "trace_route_attributes", http.StatusBadRequest, "failed to add OTLP tenant routing attributes", err) return } } @@ -285,7 +285,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou contentEncoding = "gzip" body, err = gzipBytes(body) if err != nil { - http.Error(w, "failed to encode request body", http.StatusInternalServerError) + g.writeGatewayError(w, r, route, "trace_gzip", http.StatusInternalServerError, "failed to encode request body", err) return } } @@ -293,7 +293,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou resp, err := g.forwardTraceToCollector(r.Context(), r.Header, body, contentType, contentEncoding) if err != nil { g.recordTraceExport(route, "error") - http.Error(w, fmt.Sprintf("trace collector export failed: %v", err), http.StatusBadGateway) + g.writeGatewayError(w, r, route, "trace_collector", http.StatusBadGateway, "trace collector export failed", err) return } defer resp.Body.Close() @@ -307,18 +307,18 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, route route) { body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 2<<20)) if err != nil { - http.Error(w, "failed to read media create request", http.StatusBadRequest) + g.writeGatewayError(w, r, route, "media_create", http.StatusBadRequest, "failed to read media create request", err) return } if err := g.cfg.uploadStore.Ping(r.Context()); err != nil { g.recordUploadPlanStoreError("ping") - http.Error(w, fmt.Sprintf("media upload plan store unavailable: %v", err), http.StatusBadGateway) + g.writeGatewayError(w, r, route, "upload_plan_ping", http.StatusBadGateway, "media upload plan store unavailable", err) return } destinations := g.mediaDestinations(route, r.Header.Get("Authorization")) if len(destinations) == 0 { - http.Error(w, "no media destinations configured", http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "no media destinations configured", nil) return } @@ -341,29 +341,26 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout wg.Wait() for _, result := range responses { if result.err != nil { - http.Error(w, fmt.Sprintf("%s media create failed: %v", result.destination.name, result.err), http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, fmt.Sprintf("%s media create failed", result.destination.name), result.err, "upstream_destination", result.destination.name) return } } mediaID := responses[0].response.MediaID if mediaID == "" { - http.Error(w, "upstream media create returned empty mediaId", http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "upstream media create returned empty mediaId", nil, "upstream_destination", responses[0].destination.name) return } // Langfuse derives mediaId from the content hash today, so all fanout // destinations should converge on the same id for the same POST body. for _, response := range responses[1:] { if response.response.MediaID != mediaID { - log.Printf( - "upstream media IDs differ: %s=%s %s=%s", - responses[0].destination.name, - mediaID, - response.destination.name, - response.response.MediaID, - ) g.recordMediaDivergence("media_id", response.destination.name) - http.Error(w, "upstream media IDs differ across destinations", http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "upstream media IDs differ across destinations", errors.New("upstream media IDs differ across destinations"), + "kind", "media_id", + "upstream_destination", response.destination.name, + "reference_destination", responses[0].destination.name, + ) return } } @@ -393,6 +390,10 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout } if hadUploadURL { for _, destination := range missingUploadURLDestinations { + g.logGatewayWarning(r, route, "media_create", "upstream media upload URL presence differs across destinations", + "kind", "upload_url_presence", + "upstream_destination", destination, + ) g.recordMediaDivergence("upload_url_presence", destination) } } @@ -401,12 +402,12 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout if len(uploadPlan.Destinations) > 0 { uploadID, err := randomID() if err != nil { - http.Error(w, "failed to create media upload id", http.StatusInternalServerError) + g.writeGatewayError(w, r, route, "upload_plan_create", http.StatusInternalServerError, "failed to create media upload id", err) return } if err := g.storeUpload(r.Context(), uploadID, uploadPlan); err != nil { g.recordUploadPlanStoreError("put") - http.Error(w, fmt.Sprintf("failed to store media upload plan: %v", err), http.StatusBadGateway) + g.writeGatewayError(w, r, route, "upload_plan_put", http.StatusBadGateway, "failed to store media upload plan", err) return } g.recordUploadPlanCreated(uploadPlan) @@ -420,13 +421,13 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route route) { body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20)) if err != nil { - http.Error(w, "failed to read media patch request", http.StatusBadRequest) + g.writeGatewayError(w, r, route, "media_patch", http.StatusBadRequest, "failed to read media patch request", err) return } destinations := g.mediaDestinations(route, r.Header.Get("Authorization")) if len(destinations) == 0 { - http.Error(w, "no media destinations configured", http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_patch", http.StatusBadGateway, "no media destinations configured", nil) return } @@ -450,7 +451,7 @@ func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route wg.Wait() for _, result := range results { if result.err != nil { - http.Error(w, fmt.Sprintf("%s media patch failed: %v", result.destination, result.err), http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_patch", http.StatusBadGateway, fmt.Sprintf("%s media patch failed", result.destination), result.err, "upstream_destination", result.destination) return } } @@ -460,16 +461,16 @@ func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route func (g *gateway) handleMediaGet(w http.ResponseWriter, r *http.Request, route route) { destinations := g.mediaDestinations(route, r.Header.Get("Authorization")) if len(destinations) == 0 { - http.Error(w, "no media destinations configured", http.StatusBadGateway) + g.writeGatewayError(w, r, route, "media_get", http.StatusBadGateway, "no media destinations configured", nil) return } target := destinations[0] if route.destination != "" { target = destinations[len(destinations)-1] } - resp := g.getMedia(r.Context(), target, route.path, r.URL.RawQuery) - if resp == nil { - http.Error(w, "media get failed", http.StatusBadGateway) + resp, err := g.getMedia(r.Context(), target, route.path, r.URL.RawQuery) + if err != nil { + g.writeGatewayError(w, r, route, "media_get", http.StatusBadGateway, "media get failed", err, "upstream_destination", target.name) return } defer resp.Body.Close() @@ -478,21 +479,21 @@ func (g *gateway) handleMediaGet(w http.ResponseWriter, r *http.Request, route r _, _ = io.Copy(w, resp.Body) } -func (g *gateway) getMedia(ctx context.Context, target destination, path string, rawQuery string) *http.Response { +func (g *gateway) getMedia(ctx context.Context, target destination, path string, rawQuery string) (*http.Response, error) { upstreamURL := target.baseURL + path if rawQuery != "" { upstreamURL += "?" + rawQuery } req, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL, nil) if err != nil { - return nil + return nil, err } req.Header.Set("Authorization", target.authorization) resp, err := g.doUpstream(req, "media_get", target.name) if err != nil { - return nil + return nil, err } - return resp + return resp, nil } func (g *gateway) handleMetrics(w http.ResponseWriter, r *http.Request) { @@ -520,21 +521,23 @@ func (g *gateway) handleMediaUpload(w http.ResponseWriter, r *http.Request) { plan, ok, err := g.takeUpload(r.Context(), uploadID) if err != nil { g.recordUploadPlanStoreError("take") - http.Error(w, "failed to load media upload plan", http.StatusBadGateway) + g.writeGatewayError(w, r, route{path: r.URL.Path}, "upload_plan_take", http.StatusBadGateway, "failed to load media upload plan", err) return } if !ok { g.recordUploadPlanMiss() - http.Error(w, "unknown or expired upload", http.StatusNotFound) + g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusNotFound, "unknown or expired upload", nil) return } body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxUploadBytes(plan.ContentLength))) if err != nil { - if err := g.restoreUpload(r.Context(), uploadID, plan); err != nil { + attrs := []any{} + if restoreErr := g.restoreUpload(r.Context(), uploadID, plan); restoreErr != nil { g.recordUploadPlanStoreError("restore") + attrs = append(attrs, "restore_error", safeErrorMessage(restoreErr)) } - http.Error(w, "failed to read upload body", http.StatusBadRequest) + g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusBadRequest, "failed to read upload body", err, attrs...) return } @@ -558,10 +561,12 @@ func (g *gateway) handleMediaUpload(w http.ResponseWriter, r *http.Request) { status := http.StatusOK for _, result := range results { if result.err != nil { - if err := g.restoreUpload(r.Context(), uploadID, plan); err != nil { + attrs := []any{"upstream_destination", result.destination} + if restoreErr := g.restoreUpload(r.Context(), uploadID, plan); restoreErr != nil { g.recordUploadPlanStoreError("restore") + attrs = append(attrs, "restore_error", safeErrorMessage(restoreErr)) } - http.Error(w, fmt.Sprintf("%s upload failed: %v", result.destination, result.err), http.StatusBadGateway) + g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusBadGateway, fmt.Sprintf("%s upload failed", result.destination), result.err, attrs...) return } if result.status > status { @@ -606,8 +611,8 @@ func (g *gateway) forwardTraceToCollector(ctx context.Context, headers http.Head } if resp.StatusCode < 200 || resp.StatusCode >= 300 { defer resp.Body.Close() - text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return nil, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text))) + drainResponseBody(resp.Body) + return nil, upstreamStatusError{status: resp.StatusCode} } return resp, nil } @@ -625,8 +630,8 @@ func (g *gateway) postMediaCreate(ctx context.Context, dest destination, body [] } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return mediaUploadResponse{}, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text))) + drainResponseBody(resp.Body) + return mediaUploadResponse{}, upstreamStatusError{status: resp.StatusCode} } var result mediaUploadResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { @@ -674,8 +679,8 @@ func (g *gateway) putMedia(ctx context.Context, dest uploadDestination, body []b } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return resp.StatusCode, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text))) + drainResponseBody(resp.Body) + return resp.StatusCode, upstreamStatusError{status: resp.StatusCode} } return resp.StatusCode, nil } @@ -687,8 +692,8 @@ func (g *gateway) doExpect2xx(operation string, destination string, req *http.Re } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { - text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text))) + drainResponseBody(resp.Body) + return upstreamStatusError{status: resp.StatusCode} } return nil } @@ -697,13 +702,16 @@ func (g *gateway) doUpstream(req *http.Request, operation string, destination st startedAt := time.Now() resp, err := g.cfg.client.Do(req) if err != nil { + duration := time.Since(startedAt) if g.metrics != nil { - g.metrics.recordUpstream(operation, destination, "error", time.Since(startedAt)) + g.metrics.recordUpstream(operation, destination, "error", duration) } return nil, err } + duration := time.Since(startedAt) + upstreamStatusClass := statusClass(resp.StatusCode) if g.metrics != nil { - g.metrics.recordUpstream(operation, destination, statusClass(resp.StatusCode), time.Since(startedAt)) + g.metrics.recordUpstream(operation, destination, upstreamStatusClass, duration) } return resp, nil } @@ -712,11 +720,7 @@ func (g *gateway) recordTraceExport(route route, result string) { if g.metrics == nil { return } - destination := centralName - if route.destination != "" { - destination = "tenant_" + route.destination - } - g.metrics.recordTraceExport(destination, result) + g.metrics.recordTraceExport(routeDestinationLabel(route), result) } func (g *gateway) recordMediaDivergence(kind string, destination string) { @@ -1034,6 +1038,10 @@ func copyResponseHeaders(target http.Header, source http.Header) { } } +func drainResponseBody(body io.Reader) { + _, _ = io.Copy(io.Discard, io.LimitReader(body, 4096)) +} + func normalizeBaseURL(value string) string { value = strings.TrimSpace(value) if value == "" { diff --git a/otel/langfuse-fanout/cmd/langfuse-fanout/main_test.go b/otel/langfuse-fanout/cmd/langfuse-fanout/main_test.go index aaf8b593c2..b64de878e5 100644 --- a/otel/langfuse-fanout/cmd/langfuse-fanout/main_test.go +++ b/otel/langfuse-fanout/cmd/langfuse-fanout/main_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "io" + "log/slog" "net/http" "net/http/httptest" "net/url" @@ -131,6 +132,46 @@ func TestTraceProxyForwardsExistingRoutingAttributesToCollector(t *testing.T) { } } +func TestTraceProxyDoesNotReturnCollectorErrorDetails(t *testing.T) { + var logBuffer bytes.Buffer + previousLogger := slog.Default() + slog.SetDefault(slog.New(slog.NewJSONHandler(&logBuffer, nil))) + t.Cleanup(func() { + slog.SetDefault(previousLogger) + }) + collector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "failed https://storage.example.com/object?X-Amz-Signature=secret", http.StatusBadGateway) + })) + defer collector.Close() + + gw := newTestGatewayWithCollector(collector.URL) + body := buildTraceRequest(t, nil) + req := httptest.NewRequest(http.MethodPost, otelTracePath, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/x-protobuf") + resp := httptest.NewRecorder() + + gw.handle(resp, req) + if resp.Code != http.StatusBadGateway { + t.Fatalf("status = %d, body = %s", resp.Code, resp.Body.String()) + } + if strings.Contains(resp.Body.String(), "storage.example.com") || strings.Contains(resp.Body.String(), "secret") { + t.Fatalf("response leaked collector error details: %s", resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != "trace collector export failed" { + t.Fatalf("unexpected response body: %s", resp.Body.String()) + } + logOutput := logBuffer.String() + if strings.Contains(logOutput, "storage.example.com") || strings.Contains(logOutput, "secret") { + t.Fatalf("log leaked collector error details: %s", logOutput) + } + if !strings.Contains(logOutput, `"operation":"trace_collector"`) { + t.Fatalf("log missing operation context: %s", logOutput) + } + if got := strings.Count(strings.TrimSpace(logOutput), "\n") + 1; got != 1 { + t.Fatalf("expected one gateway failure log, got %d: %s", got, logOutput) + } +} + func TestGzipTraceProxyAddsRoutingAttributesFromPath(t *testing.T) { t.Parallel() @@ -330,6 +371,89 @@ func TestMediaUploadRejectsInvalidIDBeforeReadingBody(t *testing.T) { } } +func TestUploadPlanStoreErrorsUseGenericResponses(t *testing.T) { + t.Parallel() + + sensitiveErr := errors.New("redis://internal-redis:6379 leaked-secret") + + t.Run("ping", func(t *testing.T) { + t.Parallel() + + store := newFakeUploadPlanStore() + store.pingErr = sensitiveErr + gw := newTestGatewayWithStore("http://central.invalid", nil, store) + + req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`)) + resp := httptest.NewRecorder() + gw.handle(resp, req) + + assertGenericErrorResponse(t, resp, http.StatusBadGateway, "media upload plan store unavailable") + }) + + t.Run("put", func(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + uploadURL := "http://storage.invalid/upload" + writeJSON(w, http.StatusCreated, mediaUploadResponse{ + MediaID: "same-media-id", + UploadURL: &uploadURL, + }) + })) + defer upstream.Close() + + store := newFakeUploadPlanStore() + store.putErr = sensitiveErr + gw := newTestGatewayWithStore(upstream.URL, nil, store) + + req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`)) + resp := httptest.NewRecorder() + gw.handle(resp, req) + + assertGenericErrorResponse(t, resp, http.StatusBadGateway, "failed to store media upload plan") + }) + + t.Run("take", func(t *testing.T) { + t.Parallel() + + store := newFakeUploadPlanStore() + store.takeErr = sensitiveErr + gw := newTestGatewayWithStore("http://central.invalid", nil, store) + + req := httptest.NewRequest(http.MethodPut, mediaUploadProxyPath+"abcdef1234", strings.NewReader("hello")) + resp := httptest.NewRecorder() + gw.handle(resp, req) + + assertGenericErrorResponse(t, resp, http.StatusBadGateway, "failed to load media upload plan") + }) +} + +func TestUploadPlanStoreErrorsUseRedactedLogs(t *testing.T) { + var logBuffer bytes.Buffer + previousLogger := slog.Default() + slog.SetDefault(slog.New(slog.NewJSONHandler(&logBuffer, nil))) + t.Cleanup(func() { + slog.SetDefault(previousLogger) + }) + + store := newFakeUploadPlanStore() + store.pingErr = errors.New("redis://internal-redis:6379 leaked-secret") + gw := newTestGatewayWithStore("http://central.invalid", nil, store) + + req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`)) + resp := httptest.NewRecorder() + gw.handle(resp, req) + + assertGenericErrorResponse(t, resp, http.StatusBadGateway, "media upload plan store unavailable") + logOutput := logBuffer.String() + if strings.Contains(logOutput, "internal-redis") || strings.Contains(logOutput, "leaked-secret") { + t.Fatalf("log leaked upload plan store details: %s", logOutput) + } + if !strings.Contains(logOutput, `"error":"error details redacted"`) { + t.Fatalf("log missing redacted error context: %s", logOutput) + } +} + func TestMediaUploadIsOneTime(t *testing.T) { t.Parallel() @@ -687,6 +811,36 @@ func TestMetricsEndpointRequiresBearerToken(t *testing.T) { } } +func TestSafeErrorMessageRedactsURLErrorURL(t *testing.T) { + t.Parallel() + + err := &url.Error{ + Op: "Put", + URL: "https://storage.example.com/object?X-Amz-Signature=secret", + Err: errors.New("lookup bucket.storage.example.com: connection refused"), + } + + message := safeErrorMessage(err) + if strings.Contains(message, "storage.example.com") || strings.Contains(message, "bucket") || strings.Contains(message, "secret") { + t.Fatalf("safe error leaked URL details: %q", message) + } + if !strings.Contains(message, "Put") || !strings.Contains(message, "URL request failed") { + t.Fatalf("safe error lost useful context: %q", message) + } +} + +func TestSafeErrorMessageRedactsGenericError(t *testing.T) { + t.Parallel() + + message := safeErrorMessage(errors.New("redis://internal-redis:6379 leaked-secret")) + if strings.Contains(message, "internal-redis") || strings.Contains(message, "leaked-secret") { + t.Fatalf("safe error leaked generic error details: %q", message) + } + if message != "error details redacted" { + t.Fatalf("unexpected generic error message: %q", message) + } +} + func TestTraceProxyRecordsPrometheusMetrics(t *testing.T) { t.Parallel() @@ -878,6 +1032,19 @@ func scrapeMetrics(t *testing.T, gw *gateway) string { return resp.Body.String() } +func assertGenericErrorResponse(t *testing.T, resp *httptest.ResponseRecorder, status int, body string) { + t.Helper() + if resp.Code != status { + t.Fatalf("status = %d, body = %s", resp.Code, resp.Body.String()) + } + if strings.TrimSpace(resp.Body.String()) != body { + t.Fatalf("unexpected body: %s", resp.Body.String()) + } + if strings.Contains(resp.Body.String(), "internal-redis") || strings.Contains(resp.Body.String(), "leaked-secret") { + t.Fatalf("response leaked upload plan store details: %s", resp.Body.String()) + } +} + func newUploadURLPath(t *testing.T, value string) string { t.Helper() parsed, err := url.Parse(value) @@ -890,6 +1057,10 @@ func newUploadURLPath(t *testing.T, value string) string { type fakeUploadPlanStore struct { mu sync.Mutex plans map[string]uploadPlan + + putErr error + takeErr error + pingErr error } func newFakeUploadPlanStore() *fakeUploadPlanStore { @@ -897,6 +1068,9 @@ func newFakeUploadPlanStore() *fakeUploadPlanStore { } func (s *fakeUploadPlanStore) Put(_ context.Context, uploadID string, plan uploadPlan) error { + if s.putErr != nil { + return s.putErr + } s.mu.Lock() defer s.mu.Unlock() s.plans[uploadID] = plan @@ -904,6 +1078,9 @@ func (s *fakeUploadPlanStore) Put(_ context.Context, uploadID string, plan uploa } func (s *fakeUploadPlanStore) Take(_ context.Context, uploadID string) (uploadPlan, bool, error) { + if s.takeErr != nil { + return uploadPlan{}, false, s.takeErr + } s.mu.Lock() defer s.mu.Unlock() plan, ok := s.plans[uploadID] @@ -912,6 +1089,9 @@ func (s *fakeUploadPlanStore) Take(_ context.Context, uploadID string) (uploadPl } func (s *fakeUploadPlanStore) Ping(_ context.Context) error { + if s.pingErr != nil { + return s.pingErr + } return nil }