🪭 refactor: safe structured failure logging for langfuse fanout gateway (#14050)

This commit is contained in:
Ravi Kumar L 2026-07-02 16:46:19 +02:00 committed by GitHub
parent b6c0bc7c0d
commit ac7dc490a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 316 additions and 51 deletions

View file

@ -0,0 +1,77 @@
package main
import (
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"strings"
)
type upstreamStatusError struct {
status int
}
func (e upstreamStatusError) Error() string {
return fmt.Sprintf("upstream status %d", e.status)
}
func (g *gateway) writeGatewayError(w http.ResponseWriter, r *http.Request, route route, operation string, status int, message string, err error, attrs ...any) {
g.logGatewayFailure(r, route, operation, status, err, attrs...)
http.Error(w, message, status)
}
func (g *gateway) logGatewayFailure(r *http.Request, route route, operation string, status int, err error, attrs ...any) {
fields := gatewayLogFields(r, route, operation, attrs...)
fields = append(fields, "status", status)
if err != nil {
fields = append(fields, "error", safeErrorMessage(err))
}
if status >= http.StatusInternalServerError {
slog.Error("langfuse fanout gateway request failed", fields...)
return
}
slog.Warn("langfuse fanout gateway request failed", fields...)
}
func (g *gateway) logGatewayWarning(r *http.Request, route route, operation string, message string, attrs ...any) {
fields := gatewayLogFields(r, route, operation, attrs...)
fields = append(fields, "warning", message)
slog.Warn("langfuse fanout gateway warning", fields...)
}
func gatewayLogFields(r *http.Request, route route, operation string, attrs ...any) []any {
fields := []any{
"method", r.Method,
"path", normalizeMetricPath(r.URL.Path),
"operation", operation,
"destination", routeDestinationLabel(route),
}
return append(fields, attrs...)
}
func routeDestinationLabel(route route) string {
if strings.HasPrefix(route.path, mediaUploadProxyPath) {
return "fanout"
}
if route.destination == "" {
return centralName
}
return "tenant_" + route.destination
}
func safeErrorMessage(err error) string {
if err == nil {
return ""
}
var upstreamErr upstreamStatusError
if errors.As(err, &upstreamErr) {
return upstreamErr.Error()
}
var urlErr *url.Error
if errors.As(err, &urlErr) {
return fmt.Sprintf("%s: URL request failed", urlErr.Op)
}
return "error details redacted"
}

View file

@ -265,7 +265,7 @@ func (g *gateway) handle(w http.ResponseWriter, r *http.Request) {
func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route route) {
body, err := readMaybeGzip(r)
if err != nil {
http.Error(w, "failed to read request body", http.StatusBadRequest)
g.writeGatewayError(w, r, route, "trace_export", http.StatusBadRequest, "failed to read request body", err)
return
}
@ -275,7 +275,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou
strings.TrimSpace(r.Header.Get("Authorization")) != "" {
body, err = addTenantRouteAttributes(body, contentType, route.destination)
if err != nil {
http.Error(w, "failed to add OTLP tenant routing attributes", http.StatusBadRequest)
g.writeGatewayError(w, r, route, "trace_route_attributes", http.StatusBadRequest, "failed to add OTLP tenant routing attributes", err)
return
}
}
@ -285,7 +285,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou
contentEncoding = "gzip"
body, err = gzipBytes(body)
if err != nil {
http.Error(w, "failed to encode request body", http.StatusInternalServerError)
g.writeGatewayError(w, r, route, "trace_gzip", http.StatusInternalServerError, "failed to encode request body", err)
return
}
}
@ -293,7 +293,7 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou
resp, err := g.forwardTraceToCollector(r.Context(), r.Header, body, contentType, contentEncoding)
if err != nil {
g.recordTraceExport(route, "error")
http.Error(w, fmt.Sprintf("trace collector export failed: %v", err), http.StatusBadGateway)
g.writeGatewayError(w, r, route, "trace_collector", http.StatusBadGateway, "trace collector export failed", err)
return
}
defer resp.Body.Close()
@ -307,18 +307,18 @@ func (g *gateway) handleTraces(w http.ResponseWriter, r *http.Request, route rou
func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, route route) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 2<<20))
if err != nil {
http.Error(w, "failed to read media create request", http.StatusBadRequest)
g.writeGatewayError(w, r, route, "media_create", http.StatusBadRequest, "failed to read media create request", err)
return
}
if err := g.cfg.uploadStore.Ping(r.Context()); err != nil {
g.recordUploadPlanStoreError("ping")
http.Error(w, fmt.Sprintf("media upload plan store unavailable: %v", err), http.StatusBadGateway)
g.writeGatewayError(w, r, route, "upload_plan_ping", http.StatusBadGateway, "media upload plan store unavailable", err)
return
}
destinations := g.mediaDestinations(route, r.Header.Get("Authorization"))
if len(destinations) == 0 {
http.Error(w, "no media destinations configured", http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "no media destinations configured", nil)
return
}
@ -341,29 +341,26 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout
wg.Wait()
for _, result := range responses {
if result.err != nil {
http.Error(w, fmt.Sprintf("%s media create failed: %v", result.destination.name, result.err), http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, fmt.Sprintf("%s media create failed", result.destination.name), result.err, "upstream_destination", result.destination.name)
return
}
}
mediaID := responses[0].response.MediaID
if mediaID == "" {
http.Error(w, "upstream media create returned empty mediaId", http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "upstream media create returned empty mediaId", nil, "upstream_destination", responses[0].destination.name)
return
}
// Langfuse derives mediaId from the content hash today, so all fanout
// destinations should converge on the same id for the same POST body.
for _, response := range responses[1:] {
if response.response.MediaID != mediaID {
log.Printf(
"upstream media IDs differ: %s=%s %s=%s",
responses[0].destination.name,
mediaID,
response.destination.name,
response.response.MediaID,
)
g.recordMediaDivergence("media_id", response.destination.name)
http.Error(w, "upstream media IDs differ across destinations", http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_create", http.StatusBadGateway, "upstream media IDs differ across destinations", errors.New("upstream media IDs differ across destinations"),
"kind", "media_id",
"upstream_destination", response.destination.name,
"reference_destination", responses[0].destination.name,
)
return
}
}
@ -393,6 +390,10 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout
}
if hadUploadURL {
for _, destination := range missingUploadURLDestinations {
g.logGatewayWarning(r, route, "media_create", "upstream media upload URL presence differs across destinations",
"kind", "upload_url_presence",
"upstream_destination", destination,
)
g.recordMediaDivergence("upload_url_presence", destination)
}
}
@ -401,12 +402,12 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout
if len(uploadPlan.Destinations) > 0 {
uploadID, err := randomID()
if err != nil {
http.Error(w, "failed to create media upload id", http.StatusInternalServerError)
g.writeGatewayError(w, r, route, "upload_plan_create", http.StatusInternalServerError, "failed to create media upload id", err)
return
}
if err := g.storeUpload(r.Context(), uploadID, uploadPlan); err != nil {
g.recordUploadPlanStoreError("put")
http.Error(w, fmt.Sprintf("failed to store media upload plan: %v", err), http.StatusBadGateway)
g.writeGatewayError(w, r, route, "upload_plan_put", http.StatusBadGateway, "failed to store media upload plan", err)
return
}
g.recordUploadPlanCreated(uploadPlan)
@ -420,13 +421,13 @@ func (g *gateway) handleMediaCreate(w http.ResponseWriter, r *http.Request, rout
func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route route) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20))
if err != nil {
http.Error(w, "failed to read media patch request", http.StatusBadRequest)
g.writeGatewayError(w, r, route, "media_patch", http.StatusBadRequest, "failed to read media patch request", err)
return
}
destinations := g.mediaDestinations(route, r.Header.Get("Authorization"))
if len(destinations) == 0 {
http.Error(w, "no media destinations configured", http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_patch", http.StatusBadGateway, "no media destinations configured", nil)
return
}
@ -450,7 +451,7 @@ func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route
wg.Wait()
for _, result := range results {
if result.err != nil {
http.Error(w, fmt.Sprintf("%s media patch failed: %v", result.destination, result.err), http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_patch", http.StatusBadGateway, fmt.Sprintf("%s media patch failed", result.destination), result.err, "upstream_destination", result.destination)
return
}
}
@ -460,16 +461,16 @@ func (g *gateway) handleMediaPatch(w http.ResponseWriter, r *http.Request, route
func (g *gateway) handleMediaGet(w http.ResponseWriter, r *http.Request, route route) {
destinations := g.mediaDestinations(route, r.Header.Get("Authorization"))
if len(destinations) == 0 {
http.Error(w, "no media destinations configured", http.StatusBadGateway)
g.writeGatewayError(w, r, route, "media_get", http.StatusBadGateway, "no media destinations configured", nil)
return
}
target := destinations[0]
if route.destination != "" {
target = destinations[len(destinations)-1]
}
resp := g.getMedia(r.Context(), target, route.path, r.URL.RawQuery)
if resp == nil {
http.Error(w, "media get failed", http.StatusBadGateway)
resp, err := g.getMedia(r.Context(), target, route.path, r.URL.RawQuery)
if err != nil {
g.writeGatewayError(w, r, route, "media_get", http.StatusBadGateway, "media get failed", err, "upstream_destination", target.name)
return
}
defer resp.Body.Close()
@ -478,21 +479,21 @@ func (g *gateway) handleMediaGet(w http.ResponseWriter, r *http.Request, route r
_, _ = io.Copy(w, resp.Body)
}
func (g *gateway) getMedia(ctx context.Context, target destination, path string, rawQuery string) *http.Response {
func (g *gateway) getMedia(ctx context.Context, target destination, path string, rawQuery string) (*http.Response, error) {
upstreamURL := target.baseURL + path
if rawQuery != "" {
upstreamURL += "?" + rawQuery
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL, nil)
if err != nil {
return nil
return nil, err
}
req.Header.Set("Authorization", target.authorization)
resp, err := g.doUpstream(req, "media_get", target.name)
if err != nil {
return nil
return nil, err
}
return resp
return resp, nil
}
func (g *gateway) handleMetrics(w http.ResponseWriter, r *http.Request) {
@ -520,21 +521,23 @@ func (g *gateway) handleMediaUpload(w http.ResponseWriter, r *http.Request) {
plan, ok, err := g.takeUpload(r.Context(), uploadID)
if err != nil {
g.recordUploadPlanStoreError("take")
http.Error(w, "failed to load media upload plan", http.StatusBadGateway)
g.writeGatewayError(w, r, route{path: r.URL.Path}, "upload_plan_take", http.StatusBadGateway, "failed to load media upload plan", err)
return
}
if !ok {
g.recordUploadPlanMiss()
http.Error(w, "unknown or expired upload", http.StatusNotFound)
g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusNotFound, "unknown or expired upload", nil)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxUploadBytes(plan.ContentLength)))
if err != nil {
if err := g.restoreUpload(r.Context(), uploadID, plan); err != nil {
attrs := []any{}
if restoreErr := g.restoreUpload(r.Context(), uploadID, plan); restoreErr != nil {
g.recordUploadPlanStoreError("restore")
attrs = append(attrs, "restore_error", safeErrorMessage(restoreErr))
}
http.Error(w, "failed to read upload body", http.StatusBadRequest)
g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusBadRequest, "failed to read upload body", err, attrs...)
return
}
@ -558,10 +561,12 @@ func (g *gateway) handleMediaUpload(w http.ResponseWriter, r *http.Request) {
status := http.StatusOK
for _, result := range results {
if result.err != nil {
if err := g.restoreUpload(r.Context(), uploadID, plan); err != nil {
attrs := []any{"upstream_destination", result.destination}
if restoreErr := g.restoreUpload(r.Context(), uploadID, plan); restoreErr != nil {
g.recordUploadPlanStoreError("restore")
attrs = append(attrs, "restore_error", safeErrorMessage(restoreErr))
}
http.Error(w, fmt.Sprintf("%s upload failed: %v", result.destination, result.err), http.StatusBadGateway)
g.writeGatewayError(w, r, route{path: r.URL.Path}, "media_upload", http.StatusBadGateway, fmt.Sprintf("%s upload failed", result.destination), result.err, attrs...)
return
}
if result.status > status {
@ -606,8 +611,8 @@ func (g *gateway) forwardTraceToCollector(ctx context.Context, headers http.Head
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer resp.Body.Close()
text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text)))
drainResponseBody(resp.Body)
return nil, upstreamStatusError{status: resp.StatusCode}
}
return resp, nil
}
@ -625,8 +630,8 @@ func (g *gateway) postMediaCreate(ctx context.Context, dest destination, body []
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return mediaUploadResponse{}, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text)))
drainResponseBody(resp.Body)
return mediaUploadResponse{}, upstreamStatusError{status: resp.StatusCode}
}
var result mediaUploadResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
@ -674,8 +679,8 @@ func (g *gateway) putMedia(ctx context.Context, dest uploadDestination, body []b
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return resp.StatusCode, fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text)))
drainResponseBody(resp.Body)
return resp.StatusCode, upstreamStatusError{status: resp.StatusCode}
}
return resp.StatusCode, nil
}
@ -687,8 +692,8 @@ func (g *gateway) doExpect2xx(operation string, destination string, req *http.Re
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
text, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(text)))
drainResponseBody(resp.Body)
return upstreamStatusError{status: resp.StatusCode}
}
return nil
}
@ -697,13 +702,16 @@ func (g *gateway) doUpstream(req *http.Request, operation string, destination st
startedAt := time.Now()
resp, err := g.cfg.client.Do(req)
if err != nil {
duration := time.Since(startedAt)
if g.metrics != nil {
g.metrics.recordUpstream(operation, destination, "error", time.Since(startedAt))
g.metrics.recordUpstream(operation, destination, "error", duration)
}
return nil, err
}
duration := time.Since(startedAt)
upstreamStatusClass := statusClass(resp.StatusCode)
if g.metrics != nil {
g.metrics.recordUpstream(operation, destination, statusClass(resp.StatusCode), time.Since(startedAt))
g.metrics.recordUpstream(operation, destination, upstreamStatusClass, duration)
}
return resp, nil
}
@ -712,11 +720,7 @@ func (g *gateway) recordTraceExport(route route, result string) {
if g.metrics == nil {
return
}
destination := centralName
if route.destination != "" {
destination = "tenant_" + route.destination
}
g.metrics.recordTraceExport(destination, result)
g.metrics.recordTraceExport(routeDestinationLabel(route), result)
}
func (g *gateway) recordMediaDivergence(kind string, destination string) {
@ -1034,6 +1038,10 @@ func copyResponseHeaders(target http.Header, source http.Header) {
}
}
func drainResponseBody(body io.Reader) {
_, _ = io.Copy(io.Discard, io.LimitReader(body, 4096))
}
func normalizeBaseURL(value string) string {
value = strings.TrimSpace(value)
if value == "" {

View file

@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
@ -131,6 +132,46 @@ func TestTraceProxyForwardsExistingRoutingAttributesToCollector(t *testing.T) {
}
}
func TestTraceProxyDoesNotReturnCollectorErrorDetails(t *testing.T) {
var logBuffer bytes.Buffer
previousLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewJSONHandler(&logBuffer, nil)))
t.Cleanup(func() {
slog.SetDefault(previousLogger)
})
collector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "failed https://storage.example.com/object?X-Amz-Signature=secret", http.StatusBadGateway)
}))
defer collector.Close()
gw := newTestGatewayWithCollector(collector.URL)
body := buildTraceRequest(t, nil)
req := httptest.NewRequest(http.MethodPost, otelTracePath, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/x-protobuf")
resp := httptest.NewRecorder()
gw.handle(resp, req)
if resp.Code != http.StatusBadGateway {
t.Fatalf("status = %d, body = %s", resp.Code, resp.Body.String())
}
if strings.Contains(resp.Body.String(), "storage.example.com") || strings.Contains(resp.Body.String(), "secret") {
t.Fatalf("response leaked collector error details: %s", resp.Body.String())
}
if strings.TrimSpace(resp.Body.String()) != "trace collector export failed" {
t.Fatalf("unexpected response body: %s", resp.Body.String())
}
logOutput := logBuffer.String()
if strings.Contains(logOutput, "storage.example.com") || strings.Contains(logOutput, "secret") {
t.Fatalf("log leaked collector error details: %s", logOutput)
}
if !strings.Contains(logOutput, `"operation":"trace_collector"`) {
t.Fatalf("log missing operation context: %s", logOutput)
}
if got := strings.Count(strings.TrimSpace(logOutput), "\n") + 1; got != 1 {
t.Fatalf("expected one gateway failure log, got %d: %s", got, logOutput)
}
}
func TestGzipTraceProxyAddsRoutingAttributesFromPath(t *testing.T) {
t.Parallel()
@ -330,6 +371,89 @@ func TestMediaUploadRejectsInvalidIDBeforeReadingBody(t *testing.T) {
}
}
func TestUploadPlanStoreErrorsUseGenericResponses(t *testing.T) {
t.Parallel()
sensitiveErr := errors.New("redis://internal-redis:6379 leaked-secret")
t.Run("ping", func(t *testing.T) {
t.Parallel()
store := newFakeUploadPlanStore()
store.pingErr = sensitiveErr
gw := newTestGatewayWithStore("http://central.invalid", nil, store)
req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`))
resp := httptest.NewRecorder()
gw.handle(resp, req)
assertGenericErrorResponse(t, resp, http.StatusBadGateway, "media upload plan store unavailable")
})
t.Run("put", func(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
uploadURL := "http://storage.invalid/upload"
writeJSON(w, http.StatusCreated, mediaUploadResponse{
MediaID: "same-media-id",
UploadURL: &uploadURL,
})
}))
defer upstream.Close()
store := newFakeUploadPlanStore()
store.putErr = sensitiveErr
gw := newTestGatewayWithStore(upstream.URL, nil, store)
req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`))
resp := httptest.NewRecorder()
gw.handle(resp, req)
assertGenericErrorResponse(t, resp, http.StatusBadGateway, "failed to store media upload plan")
})
t.Run("take", func(t *testing.T) {
t.Parallel()
store := newFakeUploadPlanStore()
store.takeErr = sensitiveErr
gw := newTestGatewayWithStore("http://central.invalid", nil, store)
req := httptest.NewRequest(http.MethodPut, mediaUploadProxyPath+"abcdef1234", strings.NewReader("hello"))
resp := httptest.NewRecorder()
gw.handle(resp, req)
assertGenericErrorResponse(t, resp, http.StatusBadGateway, "failed to load media upload plan")
})
}
func TestUploadPlanStoreErrorsUseRedactedLogs(t *testing.T) {
var logBuffer bytes.Buffer
previousLogger := slog.Default()
slog.SetDefault(slog.New(slog.NewJSONHandler(&logBuffer, nil)))
t.Cleanup(func() {
slog.SetDefault(previousLogger)
})
store := newFakeUploadPlanStore()
store.pingErr = errors.New("redis://internal-redis:6379 leaked-secret")
gw := newTestGatewayWithStore("http://central.invalid", nil, store)
req := httptest.NewRequest(http.MethodPost, mediaPath, strings.NewReader(`{"contentLength":5}`))
resp := httptest.NewRecorder()
gw.handle(resp, req)
assertGenericErrorResponse(t, resp, http.StatusBadGateway, "media upload plan store unavailable")
logOutput := logBuffer.String()
if strings.Contains(logOutput, "internal-redis") || strings.Contains(logOutput, "leaked-secret") {
t.Fatalf("log leaked upload plan store details: %s", logOutput)
}
if !strings.Contains(logOutput, `"error":"error details redacted"`) {
t.Fatalf("log missing redacted error context: %s", logOutput)
}
}
func TestMediaUploadIsOneTime(t *testing.T) {
t.Parallel()
@ -687,6 +811,36 @@ func TestMetricsEndpointRequiresBearerToken(t *testing.T) {
}
}
func TestSafeErrorMessageRedactsURLErrorURL(t *testing.T) {
t.Parallel()
err := &url.Error{
Op: "Put",
URL: "https://storage.example.com/object?X-Amz-Signature=secret",
Err: errors.New("lookup bucket.storage.example.com: connection refused"),
}
message := safeErrorMessage(err)
if strings.Contains(message, "storage.example.com") || strings.Contains(message, "bucket") || strings.Contains(message, "secret") {
t.Fatalf("safe error leaked URL details: %q", message)
}
if !strings.Contains(message, "Put") || !strings.Contains(message, "URL request failed") {
t.Fatalf("safe error lost useful context: %q", message)
}
}
func TestSafeErrorMessageRedactsGenericError(t *testing.T) {
t.Parallel()
message := safeErrorMessage(errors.New("redis://internal-redis:6379 leaked-secret"))
if strings.Contains(message, "internal-redis") || strings.Contains(message, "leaked-secret") {
t.Fatalf("safe error leaked generic error details: %q", message)
}
if message != "error details redacted" {
t.Fatalf("unexpected generic error message: %q", message)
}
}
func TestTraceProxyRecordsPrometheusMetrics(t *testing.T) {
t.Parallel()
@ -878,6 +1032,19 @@ func scrapeMetrics(t *testing.T, gw *gateway) string {
return resp.Body.String()
}
func assertGenericErrorResponse(t *testing.T, resp *httptest.ResponseRecorder, status int, body string) {
t.Helper()
if resp.Code != status {
t.Fatalf("status = %d, body = %s", resp.Code, resp.Body.String())
}
if strings.TrimSpace(resp.Body.String()) != body {
t.Fatalf("unexpected body: %s", resp.Body.String())
}
if strings.Contains(resp.Body.String(), "internal-redis") || strings.Contains(resp.Body.String(), "leaked-secret") {
t.Fatalf("response leaked upload plan store details: %s", resp.Body.String())
}
}
func newUploadURLPath(t *testing.T, value string) string {
t.Helper()
parsed, err := url.Parse(value)
@ -890,6 +1057,10 @@ func newUploadURLPath(t *testing.T, value string) string {
type fakeUploadPlanStore struct {
mu sync.Mutex
plans map[string]uploadPlan
putErr error
takeErr error
pingErr error
}
func newFakeUploadPlanStore() *fakeUploadPlanStore {
@ -897,6 +1068,9 @@ func newFakeUploadPlanStore() *fakeUploadPlanStore {
}
func (s *fakeUploadPlanStore) Put(_ context.Context, uploadID string, plan uploadPlan) error {
if s.putErr != nil {
return s.putErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.plans[uploadID] = plan
@ -904,6 +1078,9 @@ func (s *fakeUploadPlanStore) Put(_ context.Context, uploadID string, plan uploa
}
func (s *fakeUploadPlanStore) Take(_ context.Context, uploadID string) (uploadPlan, bool, error) {
if s.takeErr != nil {
return uploadPlan{}, false, s.takeErr
}
s.mu.Lock()
defer s.mu.Unlock()
plan, ok := s.plans[uploadID]
@ -912,6 +1089,9 @@ func (s *fakeUploadPlanStore) Take(_ context.Context, uploadID string) (uploadPl
}
func (s *fakeUploadPlanStore) Ping(_ context.Context) error {
if s.pingErr != nil {
return s.pingErr
}
return nil
}