diff --git a/src/hbbs_http/account.rs b/src/hbbs_http/account.rs index 5d21db9c9..81a1e145b 100644 --- a/src/hbbs_http/account.rs +++ b/src/hbbs_http/account.rs @@ -153,7 +153,7 @@ impl OidcSession { } // This URL is used to detect the appropriate TLS implementation for the server. let login_option_url = format!("{}/api/login-options", api_server); - let _ = create_http_client_with_url(&login_option_url, None); + let _ = create_http_client_with_url(&login_option_url, false); write_guard.warmed_api_server = Some(api_server.to_owned()); } diff --git a/src/hbbs_http/downloader.rs b/src/hbbs_http/downloader.rs index e2549f2c1..ac4bcd5eb 100644 --- a/src/hbbs_http/downloader.rs +++ b/src/hbbs_http/downloader.rs @@ -175,7 +175,7 @@ async fn do_download( auto_del_dur: Option, mut rx_cancel: UnboundedReceiver<()>, ) -> ResultType { - let client = create_http_client_async_with_url(&url, Some(false)).await; + let client = create_http_client_async_with_url(&url, true).await; let mut is_all_downloaded = false; tokio::select! { diff --git a/src/hbbs_http/http_client.rs b/src/hbbs_http/http_client.rs index 875c90439..acbf2f784 100644 --- a/src/hbbs_http/http_client.rs +++ b/src/hbbs_http/http_client.rs @@ -120,34 +120,38 @@ pub fn get_url_for_tls<'a>(url: &'a str, proxy_conf: &'a Option) - url } +fn resolve_danger_accept_invalid_cert( + force_strict_tls: bool, + cached_danger_accept_invalid_cert: Option, +) -> Option { + if force_strict_tls { + Some(false) + } else { + cached_danger_accept_invalid_cert + } +} + /// Creates a sync HTTP client for `url`. /// -/// `tls_danger_accept_invalid_cert` has three states: -/// - `None`: use cached TLS backend/cert settings when present; otherwise allow -/// automatic fallback between Rustls/NativeTls and strict/accept-invalid modes. -/// - `Some(false)`: force strict certificate validation, overriding the cache; -/// use for security-critical requests such as updates. -/// - `Some(true)`: force accepting invalid certificates, overriding the cache; -/// use sparingly. -pub fn create_http_client_with_url( - url: &str, - tls_danger_accept_invalid_cert: Option, -) -> SyncClient { +/// Set `force_strict_tls` to `true` for security-critical requests that must +/// reject invalid certificates and ignore the cached certificate policy. +pub fn create_http_client_with_url(url: &str, force_strict_tls: bool) -> SyncClient { let proxy_conf = Config::get_socks(); let tls_url = get_url_for_tls(url, &proxy_conf); let tls_type = get_cached_tls_type(tls_url); let is_tls_type_cached = tls_type.is_some(); let tls_type = tls_type.unwrap_or(TlsType::Rustls); - let danger_accept_invalid_cert = - tls_danger_accept_invalid_cert.or_else(|| get_cached_tls_accept_invalid_cert(tls_url)); - let allow_accept_invalid_fallback = danger_accept_invalid_cert.is_none(); + let danger_accept_invalid_cert = resolve_danger_accept_invalid_cert( + force_strict_tls, + get_cached_tls_accept_invalid_cert(tls_url), + ); create_http_client_with_url_( url, tls_url, tls_type, is_tls_type_cached, danger_accept_invalid_cert, - allow_accept_invalid_fallback, + danger_accept_invalid_cert, ) } @@ -157,16 +161,16 @@ fn create_http_client_with_url_( tls_type: TlsType, is_tls_type_cached: bool, danger_accept_invalid_cert: Option, - allow_accept_invalid_fallback: bool, + original_danger_accept_invalid_cert: Option, ) -> SyncClient { let mut client = create_http_client(tls_type, danger_accept_invalid_cert.unwrap_or(false)); - if is_tls_type_cached && !allow_accept_invalid_fallback { + if is_tls_type_cached && original_danger_accept_invalid_cert.is_some() { return client; } if let Err(e) = client.head(url).send() { if e.is_request() { match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) { - (TlsType::Rustls, _, None) if allow_accept_invalid_fallback => { + (TlsType::Rustls, _, None) => { log::warn!( "Failed to connect to server {} with rustls-tls: {:?}, trying accept invalid cert", tls_url, @@ -178,7 +182,7 @@ fn create_http_client_with_url_( tls_type, is_tls_type_cached, Some(true), - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, ); } (TlsType::Rustls, false, Some(_)) => { @@ -192,15 +196,11 @@ fn create_http_client_with_url_( tls_url, TlsType::NativeTls, is_tls_type_cached, - if allow_accept_invalid_fallback { - None - } else { - danger_accept_invalid_cert - }, - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, + original_danger_accept_invalid_cert, ); } - (TlsType::NativeTls, _, None) if allow_accept_invalid_fallback => { + (TlsType::NativeTls, _, None) => { log::warn!( "Failed to connect to server {} with native-tls: {:?}, trying accept invalid cert", tls_url, @@ -212,7 +212,7 @@ fn create_http_client_with_url_( tls_type, is_tls_type_cached, Some(true), - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, ); } _ => { @@ -249,32 +249,25 @@ fn create_http_client_with_url_( /// Creates an async HTTP client for `url`. /// -/// `tls_danger_accept_invalid_cert` has three states: -/// - `None`: use cached TLS backend/cert settings when present; otherwise allow -/// automatic fallback between Rustls/NativeTls and strict/accept-invalid modes. -/// - `Some(false)`: force strict certificate validation, overriding the cache; -/// use for security-critical requests such as updates. -/// - `Some(true)`: force accepting invalid certificates, overriding the cache; -/// use sparingly. -pub async fn create_http_client_async_with_url( - url: &str, - tls_danger_accept_invalid_cert: Option, -) -> AsyncClient { +/// Set `force_strict_tls` to `true` for security-critical requests that must +/// reject invalid certificates and ignore the cached certificate policy. +pub async fn create_http_client_async_with_url(url: &str, force_strict_tls: bool) -> AsyncClient { let proxy_conf = Config::get_socks(); let tls_url = get_url_for_tls(url, &proxy_conf); let tls_type = get_cached_tls_type(tls_url); let is_tls_type_cached = tls_type.is_some(); let tls_type = tls_type.unwrap_or(TlsType::Rustls); - let danger_accept_invalid_cert = - tls_danger_accept_invalid_cert.or_else(|| get_cached_tls_accept_invalid_cert(tls_url)); - let allow_accept_invalid_fallback = danger_accept_invalid_cert.is_none(); + let danger_accept_invalid_cert = resolve_danger_accept_invalid_cert( + force_strict_tls, + get_cached_tls_accept_invalid_cert(tls_url), + ); create_http_client_async_with_url_( url, tls_url, tls_type, is_tls_type_cached, danger_accept_invalid_cert, - allow_accept_invalid_fallback, + danger_accept_invalid_cert, ) .await } @@ -286,16 +279,16 @@ async fn create_http_client_async_with_url_( tls_type: TlsType, is_tls_type_cached: bool, danger_accept_invalid_cert: Option, - allow_accept_invalid_fallback: bool, + original_danger_accept_invalid_cert: Option, ) -> AsyncClient { let mut client = create_http_client_async(tls_type, danger_accept_invalid_cert.unwrap_or(false)); - if is_tls_type_cached && !allow_accept_invalid_fallback { + if is_tls_type_cached && original_danger_accept_invalid_cert.is_some() { return client; } if let Err(e) = client.head(url).send().await { match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) { - (TlsType::Rustls, _, None) if allow_accept_invalid_fallback => { + (TlsType::Rustls, _, None) => { log::warn!( "Failed to connect to server {} with rustls-tls: {:?}, trying accept invalid cert", tls_url, @@ -307,7 +300,7 @@ async fn create_http_client_async_with_url_( tls_type, is_tls_type_cached, Some(true), - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, ) .await; } @@ -322,16 +315,12 @@ async fn create_http_client_async_with_url_( tls_url, TlsType::NativeTls, is_tls_type_cached, - if allow_accept_invalid_fallback { - None - } else { - danger_accept_invalid_cert - }, - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, + original_danger_accept_invalid_cert, ) .await; } - (TlsType::NativeTls, _, None) if allow_accept_invalid_fallback => { + (TlsType::NativeTls, _, None) => { log::warn!( "Failed to connect to server {} with native-tls: {:?}, trying accept invalid cert", tls_url, @@ -343,7 +332,7 @@ async fn create_http_client_async_with_url_( tls_type, is_tls_type_cached, Some(true), - allow_accept_invalid_fallback, + original_danger_accept_invalid_cert, ) .await; } @@ -370,3 +359,34 @@ async fn create_http_client_async_with_url_( } client } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn force_strict_tls_overrides_cached_cert_policy() { + assert_eq!(resolve_danger_accept_invalid_cert(true, None), Some(false)); + assert_eq!( + resolve_danger_accept_invalid_cert(true, Some(false)), + Some(false) + ); + assert_eq!( + resolve_danger_accept_invalid_cert(true, Some(true)), + Some(false) + ); + } + + #[test] + fn non_strict_tls_uses_cached_cert_policy() { + assert_eq!(resolve_danger_accept_invalid_cert(false, None), None); + assert_eq!( + resolve_danger_accept_invalid_cert(false, Some(false)), + Some(false) + ); + assert_eq!( + resolve_danger_accept_invalid_cert(false, Some(true)), + Some(true) + ); + } +} diff --git a/src/hbbs_http/record_upload.rs b/src/hbbs_http/record_upload.rs index b841368d4..3c6dba94e 100644 --- a/src/hbbs_http/record_upload.rs +++ b/src/hbbs_http/record_upload.rs @@ -32,7 +32,7 @@ pub fn run(rx: Receiver) { ); // This URL is used for TLS connectivity testing and fallback detection. let login_option_url = format!("{}/api/login-options", &api_server); - let client = create_http_client_with_url(&login_option_url, None); + let client = create_http_client_with_url(&login_option_url, false); let mut uploader = RecordUploader { client, api_server, diff --git a/src/updater.rs b/src/updater.rs index e389b15dd..07b2b2d56 100644 --- a/src/updater.rs +++ b/src/updater.rs @@ -173,7 +173,7 @@ fn ensure_verified_update_file( file_path: &Path, expected_sha256: &str, ) -> ResultType<()> { - let client = create_http_client_with_url(download_url, Some(false)); + let client = create_http_client_with_url(download_url, true); let mut is_file_exists = false; if file_path.exists() { // Check if the file size is the same as the server file size @@ -439,7 +439,7 @@ fn fetch_github_asset_sha256(update_url: &str, download_url: &str) -> ResultType } fn fetch_github_release_metadata(api_url: &str) -> ResultType { - let client = create_http_client_with_url(&api_url, Some(false)); + let client = create_http_client_with_url(&api_url, true); let response = client .get(api_url) .header(reqwest::header::USER_AGENT, "rustdesk-updater")