diff --git a/option/ccm.go b/option/ccm.go index b4be72ea7..dd55a4ba4 100644 --- a/option/ccm.go +++ b/option/ccm.go @@ -102,7 +102,6 @@ type CCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` - PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/option/ocm.go b/option/ocm.go index 0f364821f..e508abae7 100644 --- a/option/ocm.go +++ b/option/ocm.go @@ -102,7 +102,6 @@ type OCMExternalCredentialOptions struct { Token string `json:"token"` Reverse bool `json:"reverse,omitempty"` Detour string `json:"detour,omitempty"` - PlanWeight float64 `json:"plan_weight,omitempty"` UsagesPath string `json:"usages_path,omitempty"` PollInterval badoption.Duration `json:"poll_interval,omitempty"` } diff --git a/service/ccm/credential_external.go b/service/ccm/credential_external.go index 807a06fe8..a5781d6f6 100644 --- a/service/ccm/credential_external.go +++ b/service/ccm/credential_external.go @@ -29,17 +29,16 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - configuredPlanWeight float64 - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -113,22 +112,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.CCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - configuredPlanWeight := options.PlanWeight - if configuredPlanWeight <= 0 { - configuredPlanWeight = 1 - } - cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - configuredPlanWeight: configuredPlanWeight, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -291,7 +284,12 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - return c.configuredPlanWeight + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.remotePlanWeight > 0 { + return c.state.remotePlanWeight + } + return 10 } func (c *externalCredential) weeklyResetTime() time.Time { @@ -422,6 +420,12 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyUtilization = value * 100 } } + if planWeight := headers.Get("X-CCM-Plan-Weight"); planWeight != "" { + value, err := strconv.ParseFloat(planWeight, 64) + if err == nil && value > 0 { + c.state.remotePlanWeight = value + } + } if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() @@ -525,6 +529,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { @@ -540,6 +545,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } diff --git a/service/ccm/credential_state.go b/service/ccm/credential_state.go index 87c9afde2..83db5a197 100644 --- a/service/ccm/credential_state.go +++ b/service/ccm/credential_state.go @@ -72,6 +72,7 @@ type credentialState struct { rateLimitResetAt time.Time accountType string rateLimitTier string + remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int unavailable bool diff --git a/service/ccm/service.go b/service/ccm/service.go index 4fd880f8a..f940d7309 100644 --- a/service/ccm/service.go +++ b/service/ccm/service.go @@ -766,47 +766,48 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]float64{ "five_hour_utilization": avgFiveHour, "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, }) } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64) { - var totalFiveHour, totalWeekly float64 - var count int +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.CCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, cred := range provider.allCredentials() { if !cred.isAvailable() { continue } - // Exclude the user's own external_credential (their contribution to us) if userConfig.ExternalCredential != "" && cred.tagName() == userConfig.ExternalCredential { continue } - // If user doesn't allow external usage, exclude all external credentials if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 - if scaledFiveHour > 100 { - scaledFiveHour = 100 + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 } - scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 - if scaledWeekly > 100 { - scaledWeekly = 100 + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 } - totalFiveHour += scaledFiveHour - totalWeekly += scaledWeekly - count++ + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight } - if count == 0 { - return 100, 100 + if totalWeight == 0 { + return 100, 100, 0 } - return totalFiveHour / float64(count), totalWeekly / float64(count) + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.CCMUser) { @@ -815,11 +816,14 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use return } - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) // Rewrite utilization headers to aggregated average (convert back to 0.0-1.0 range) headers.Set("anthropic-ratelimit-unified-5h-utilization", strconv.FormatFloat(avgFiveHour/100, 'f', 6, 64)) headers.Set("anthropic-ratelimit-unified-7d-utilization", strconv.FormatFloat(avgWeekly/100, 'f', 6, 64)) + if totalWeight > 0 { + headers.Set("X-CCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } } func (s *Service) InterfaceUpdated() { diff --git a/service/ocm/credential_external.go b/service/ocm/credential_external.go index 2c9dce46b..2a09c84ad 100644 --- a/service/ocm/credential_external.go +++ b/service/ocm/credential_external.go @@ -30,18 +30,17 @@ import ( const reverseProxyBaseURL = "http://reverse-proxy" type externalCredential struct { - tag string - baseURL string - token string - credDialer N.Dialer - httpClient *http.Client - state credentialState - stateMutex sync.RWMutex - pollAccess sync.Mutex - pollInterval time.Duration - configuredPlanWeight float64 - usageTracker *AggregatedUsage - logger log.ContextLogger + tag string + baseURL string + token string + credDialer N.Dialer + httpClient *http.Client + state credentialState + stateMutex sync.RWMutex + pollAccess sync.Mutex + pollInterval time.Duration + usageTracker *AggregatedUsage + logger log.ContextLogger onBecameUnusable func() interrupted bool @@ -130,22 +129,16 @@ func newExternalCredential(ctx context.Context, tag string, options option.OCMEx requestContext, cancelRequests := context.WithCancel(context.Background()) reverseContext, reverseCancel := context.WithCancel(context.Background()) - configuredPlanWeight := options.PlanWeight - if configuredPlanWeight <= 0 { - configuredPlanWeight = 1 - } - cred := &externalCredential{ - tag: tag, - token: options.Token, - pollInterval: pollInterval, - configuredPlanWeight: configuredPlanWeight, - logger: logger, - requestContext: requestContext, - cancelRequests: cancelRequests, - reverse: options.Reverse, - reverseContext: reverseContext, - reverseCancel: reverseCancel, + tag: tag, + token: options.Token, + pollInterval: pollInterval, + logger: logger, + requestContext: requestContext, + cancelRequests: cancelRequests, + reverse: options.Reverse, + reverseContext: reverseContext, + reverseCancel: reverseCancel, } if options.URL == "" { @@ -313,7 +306,12 @@ func (c *externalCredential) weeklyCap() float64 { } func (c *externalCredential) planWeight() float64 { - return c.configuredPlanWeight + c.stateMutex.RLock() + defer c.stateMutex.RUnlock() + if c.state.remotePlanWeight > 0 { + return c.state.remotePlanWeight + } + return 10 } func (c *externalCredential) weeklyResetTime() time.Time { @@ -459,6 +457,12 @@ func (c *externalCredential) updateStateFromHeaders(headers http.Header) { c.state.weeklyUtilization = value } } + if planWeight := headers.Get("X-OCM-Plan-Weight"); planWeight != "" { + value, err := strconv.ParseFloat(planWeight, 64) + if err == nil && value > 0 { + c.state.remotePlanWeight = value + } + } if hadData { c.state.consecutivePollFailures = 0 c.state.lastUpdated = time.Now() @@ -562,6 +566,7 @@ func (c *externalCredential) pollUsage(ctx context.Context) { var statusResponse struct { FiveHourUtilization float64 `json:"five_hour_utilization"` WeeklyUtilization float64 `json:"weekly_utilization"` + PlanWeight float64 `json:"plan_weight"` } err = json.NewDecoder(response.Body).Decode(&statusResponse) if err != nil { @@ -577,6 +582,9 @@ func (c *externalCredential) pollUsage(ctx context.Context) { c.state.consecutivePollFailures = 0 c.state.fiveHourUtilization = statusResponse.FiveHourUtilization c.state.weeklyUtilization = statusResponse.WeeklyUtilization + if statusResponse.PlanWeight > 0 { + c.state.remotePlanWeight = statusResponse.PlanWeight + } if c.state.hardRateLimited && time.Now().After(c.state.rateLimitResetAt) { c.state.hardRateLimited = false } diff --git a/service/ocm/credential_state.go b/service/ocm/credential_state.go index 3cb1f48b9..ea7c621bf 100644 --- a/service/ocm/credential_state.go +++ b/service/ocm/credential_state.go @@ -71,6 +71,7 @@ type credentialState struct { hardRateLimited bool rateLimitResetAt time.Time accountType string + remotePlanWeight float64 lastUpdated time.Time consecutivePollFailures int unavailable bool diff --git a/service/ocm/service.go b/service/ocm/service.go index 9dded4740..1f93b4d5c 100644 --- a/service/ocm/service.go +++ b/service/ocm/service.go @@ -832,19 +832,19 @@ func (s *Service) handleStatusEndpoint(w http.ResponseWriter, r *http.Request) { } provider.pollIfStale(r.Context()) - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]float64{ "five_hour_utilization": avgFiveHour, "weekly_utilization": avgWeekly, + "plan_weight": totalWeight, }) } -func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64) { - var totalFiveHour, totalWeekly float64 - var count int +func (s *Service) computeAggregatedUtilization(provider credentialProvider, userConfig *option.OCMUser) (float64, float64, float64) { + var totalWeightedRemaining5h, totalWeightedRemainingWeekly, totalWeight float64 for _, cred := range provider.allCredentials() { if !cred.isAvailable() { continue @@ -855,22 +855,25 @@ func (s *Service) computeAggregatedUtilization(provider credentialProvider, user if !userConfig.AllowExternalUsage && cred.isExternal() { continue } - scaledFiveHour := cred.fiveHourUtilization() / cred.fiveHourCap() * 100 - if scaledFiveHour > 100 { - scaledFiveHour = 100 + weight := cred.planWeight() + remaining5h := cred.fiveHourCap() - cred.fiveHourUtilization() + if remaining5h < 0 { + remaining5h = 0 } - scaledWeekly := cred.weeklyUtilization() / cred.weeklyCap() * 100 - if scaledWeekly > 100 { - scaledWeekly = 100 + remainingWeekly := cred.weeklyCap() - cred.weeklyUtilization() + if remainingWeekly < 0 { + remainingWeekly = 0 } - totalFiveHour += scaledFiveHour - totalWeekly += scaledWeekly - count++ + totalWeightedRemaining5h += remaining5h * weight + totalWeightedRemainingWeekly += remainingWeekly * weight + totalWeight += weight } - if count == 0 { - return 100, 100 + if totalWeight == 0 { + return 100, 100, 0 } - return totalFiveHour / float64(count), totalWeekly / float64(count) + return 100 - totalWeightedRemaining5h/totalWeight, + 100 - totalWeightedRemainingWeekly/totalWeight, + totalWeight } func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, userConfig *option.OCMUser) { @@ -879,7 +882,7 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use return } - avgFiveHour, avgWeekly := s.computeAggregatedUtilization(provider, userConfig) + avgFiveHour, avgWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit")) if activeLimitIdentifier == "" { @@ -888,6 +891,9 @@ func (s *Service) rewriteResponseHeadersForExternalUser(headers http.Header, use headers.Set("x-"+activeLimitIdentifier+"-primary-used-percent", strconv.FormatFloat(avgFiveHour, 'f', 2, 64)) headers.Set("x-"+activeLimitIdentifier+"-secondary-used-percent", strconv.FormatFloat(avgWeekly, 'f', 2, 64)) + if totalWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(totalWeight, 'f', -1, 64)) + } } func (s *Service) InterfaceUpdated() { diff --git a/service/ocm/service_websocket.go b/service/ocm/service_websocket.go index f348f7fa4..2f4911959 100644 --- a/service/ocm/service_websocket.go +++ b/service/ocm/service_websocket.go @@ -368,8 +368,9 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential ResetAt int64 `json:"reset_at"` } `json:"secondary"` } `json:"rate_limits"` - LimitName string `json:"limit_name"` - MeteredLimitName string `json:"metered_limit_name"` + LimitName string `json:"limit_name"` + MeteredLimitName string `json:"metered_limit_name"` + PlanWeight float64 `json:"plan_weight"` } err := json.Unmarshal(data, &rateLimitsEvent) if err != nil { @@ -398,6 +399,9 @@ func (s *Service) handleWebSocketRateLimitsEvent(data []byte, selectedCredential headers.Set("x-"+identifier+"-secondary-reset-at", strconv.FormatInt(w.ResetAt, 10)) } } + if rateLimitsEvent.PlanWeight > 0 { + headers.Set("X-OCM-Plan-Weight", strconv.FormatFloat(rateLimitsEvent.PlanWeight, 'f', -1, 64)) + } selectedCredential.updateStateFromHeaders(headers) } @@ -436,7 +440,11 @@ func (s *Service) rewriteWebSocketRateLimitsForExternalUser(data []byte, provide return nil, err } - averageFiveHour, averageWeekly := s.computeAggregatedUtilization(provider, userConfig) + averageFiveHour, averageWeekly, totalWeight := s.computeAggregatedUtilization(provider, userConfig) + + if totalWeight > 0 { + event["plan_weight"], _ = json.Marshal(totalWeight) + } primaryData, err := rewriteWebSocketRateLimitWindow(rateLimits["primary"], averageFiveHour) if err != nil {