diff --git a/doc/migration-flow.md b/doc/migration-flow.md new file mode 100644 index 0000000..3302f67 --- /dev/null +++ b/doc/migration-flow.md @@ -0,0 +1,52 @@ +# Transparent client migration on DC connection death + +Telegram periodically closes the TCP connection to the proxy ("DC connection +rotation", typically every 30–70 s). Instead of dropping all clients +multiplexed on that connection, the proxy remaps each idle client to a +surviving (or freshly-started) DC connection transparently. + +**Key actors:** +- `mtp_down_conn (old)` — the dying downstream connection process +- `mtp_dc_pool` — pool managing all downstream connections for one DC +- `mtp_handler` — one process per connected Telegram client +- `mtp_down_conn (new)` — replacement downstream spawned by the pool + +```mermaid +sequenceDiagram + participant TG as Telegram + participant OldDown as mtp_down_conn (old) + participant Pool as mtp_dc_pool + participant Handler as mtp_handler + participant NewDown as mtp_down_conn (new) + + TG->>OldDown: TCP close + + OldDown->>Pool: downstream_closing(self()) [sync] + Pool-->>Pool: remove OldDown from ds_store + monitors + Pool-->>NewDown: spawn & connect (maybe_restart_connection) + Pool-->>OldDown: ok + + OldDown->>Handler: migrate(OldDown) [cast, to all known upstreams] + + Note over OldDown: drain_mailbox(5000) + + alt upstream_new in mailbox + Note over Pool,OldDown: Race: pool processed a {get} call just before
downstream_closing — upstream_new cast already queued + Pool-->>OldDown: upstream_new(Handler2, Opts) [cast, queued] + OldDown->>Handler2: migrate(OldDown) [cast, immediately] + end + + alt Handler was blocked in down_send + Handler-->>OldDown: {send, Data} [call, in mailbox] + OldDown-->>Handler: {error, migrating} + Note over Handler: metric[mid_send] → stop
(client reconnects and resends) + else Handler was idle + Handler->>Pool: migrate(OldDown, self(), Opts) [sync] + Pool-->>Pool: remove Handler from upstreams map + Pool->>NewDown: upstream_new(Handler, Opts) [cast] + Pool-->>Handler: NewDown pid + Note over Handler: down = NewDown
metric[ok] + end + + Note over OldDown: stop {shutdown, downstream_migrated} +``` diff --git a/src/mtp_dc_pool.erl b/src/mtp_dc_pool.erl index 8e95682..b1f192a 100644 --- a/src/mtp_dc_pool.erl +++ b/src/mtp_dc_pool.erl @@ -17,6 +17,8 @@ -export([start_link/1, get/3, return/2, + downstream_closing/2, + migrate/4, add_connection/1, ack_connected/2, status/1, @@ -80,6 +82,23 @@ get(Pool, Upstream, #{addr := _} = Opts) -> return(Pool, Upstream) -> gen_server:cast(Pool, {return, Upstream}). +%% Called by a downstream that received tcp_closed with active upstreams. +%% Removes the downstream from the pool store synchronously so it won't receive +%% new upstreams while its handlers are migrating. +-spec downstream_closing(pid(), downstream()) -> ok. +downstream_closing(Pool, Downstream) -> + gen_server:call(Pool, {downstream_closing, Downstream}). + +%% Atomically return an upstream from a dying downstream and assign it to a new one. +%% Avoids the "attempt to release unknown connection" warning that return+get would cause. +-spec migrate(pid(), downstream(), upstream(), + #{addr := mtp_config:netloc_v4v6(), + ad_tag => binary(), + packet_layer => mtp_down_conn:packet_layer()}) -> + downstream() | {error, empty | not_found}. +migrate(Pool, OldDown, Upstream, Opts) -> + gen_server:call(Pool, {migrate, OldDown, Upstream, Opts}). + add_connection(Pool) -> gen_server:call(Pool, add_connection, 10000). @@ -108,6 +127,11 @@ handle_call({get, Upstream, Opts}, _From, State) -> {Downstream, State1} -> {reply, Downstream, State1} end; +handle_call({downstream_closing, Downstream}, _From, State) -> + {reply, ok, handle_downstream_closing(Downstream, State)}; +handle_call({migrate, OldDown, Upstream, Opts}, _From, State) -> + {Reply, State1} = handle_migrate(OldDown, Upstream, Opts, State), + {reply, Reply, State1}; handle_call(add_connection, _From, State) -> State1 = connect(State), {reply, ok, State1}; @@ -178,6 +202,39 @@ handle_return(Upstream, #state{downstreams = Ds, St#state{downstreams = Ds1, upstreams = Us1}. +%% Remove a dying downstream from the store before its handlers migrate. +%% Called synchronously by mtp_down_conn so removal is complete before +%% {migrate, Self} is sent to upstreams. +handle_downstream_closing(Downstream, #state{downstreams = Ds, + downstream_monitors = DsM, + pending_downstreams = Pending} = St) -> + DsM1 = maps:filter( + fun(MonRef, Pid) when Pid =:= Downstream -> + erlang:demonitor(MonRef, [flush]), + false; + (_, _) -> + true + end, DsM), + Ds1 = ds_remove(Downstream, Ds), + Pending1 = lists:delete(Downstream, Pending), + maybe_restart_connection( + St#state{downstreams = Ds1, + downstream_monitors = DsM1, + pending_downstreams = Pending1}). + +%% Atomically reassign an upstream from its dying downstream to a new one. +handle_migrate(_OldDown, Upstream, Opts, #state{upstreams = Us} = St) -> + case maps:take(Upstream, Us) of + {{_AnyOldDown, OldMonRef}, Us1} -> + erlang:demonitor(OldMonRef, [flush]), + case handle_get(Upstream, Opts, St#state{upstreams = Us1}) of + {empty, St1} -> {{error, empty}, St1}; + {NewDown, St1} -> {NewDown, St1} + end; + error -> + {{error, not_found}, St} + end. + handle_down(MonRef, Pid, Reason, #state{downstreams = Ds, downstream_monitors = DsM, upstreams = Us, @@ -196,6 +253,8 @@ handle_down(MonRef, Pid, Reason, #state{downstreams = Ds, case Reason of {shutdown, downstream_socket_closed} -> ?LOG_INFO("Downstream=~p closed (no active clients)", [Pid]); + {shutdown, downstream_migrated} -> + ?LOG_INFO("Downstream=~p finished migrating clients", [Pid]); _ -> ?LOG_ERROR("Downstream=~p is down. reason=~p", [Pid, Reason]) end, @@ -204,7 +263,7 @@ handle_down(MonRef, Pid, Reason, #state{downstreams = Ds, downstreams = Ds1, downstream_monitors = DsM1}); _ -> - ?LOG_ERROR("Unexpected DOWN. ref=~p, pid=~p, reason=~p", [MonRef, Pid, Reason]), + ?LOG_WARNING("Unexpected DOWN. ref=~p, pid=~p, reason=~p", [MonRef, Pid, Reason]), St end end. diff --git a/src/mtp_down_conn.erl b/src/mtp_down_conn.erl index 77534a7..b3a4f39 100644 --- a/src/mtp_down_conn.erl +++ b/src/mtp_down_conn.erl @@ -101,7 +101,7 @@ shutdown(Conn) -> gen_server:cast(Conn, shutdown). %% To be called by upstream --spec send(handle(), iodata()) -> ok | {error, unknown_upstream}. +-spec send(handle(), iodata()) -> ok | {error, unknown_upstream | migrating}. send(Conn, Data) -> gen_server:call(Conn, {send, Data}, ?SEND_TIMEOUT * 2). @@ -166,10 +166,24 @@ handle_info({tcp, Sock, Data}, #state{sock = Sock, dc_id = DcId} = S) -> {ok, S1} = handle_downstream_data(Data, S), activate_if_no_overflow(S1), {noreply, S1}; -handle_info({tcp_closed, Sock}, #state{sock = Sock, upstreams = Ups} = State) -> +handle_info({tcp_closed, Sock}, #state{sock = Sock, upstreams = Ups, pool = Pool} = State) -> case map_size(Ups) of - 0 -> {stop, {shutdown, downstream_socket_closed}, State}; - _ -> {stop, downstream_socket_closed, State} + 0 -> + {stop, {shutdown, downstream_socket_closed}, State}; + N -> + %% Remove self from pool first so no new upstreams can be assigned. + ok = mtp_dc_pool:downstream_closing(Pool, self()), + ?LOG_INFO("Downstream socket closed with ~p active client(s); migrating", [N]), + %% Notify all known upstreams to migrate immediately. + [mtp_handler:migrate(Upstream, self()) || Upstream <- maps:keys(Ups)], + %% Drain remaining mailbox messages: + %% - {send,...} calls: reply {error, migrating} to unblock callers + %% - upstream_new casts: handlers assigned to us by the pool just + %% before downstream_closing ran; migrate them immediately + NDrained = drain_mailbox(5000), + NDrained > 0 andalso + ?LOG_INFO("Drained ~p pending send call(s) during migration", [NDrained]), + {stop, {shutdown, downstream_migrated}, State} end; handle_info({tcp_error, Sock, Reason}, #state{sock = Sock} = State) -> {stop, {downstream_tcp_error, Reason}, State}; @@ -194,6 +208,9 @@ handle_info(handshake_timeout, #state{stage = Stage, dc_id = DcId} = St) -> end. +terminate({shutdown, downstream_migrated}, _State) -> + %% Normal shutdown during migration; no need to log or notify upstreams. + ok; terminate(Reason, #state{upstreams = Ups}) -> NUps = map_size(Ups), case Reason of @@ -207,8 +224,8 @@ terminate(Reason, #state{upstreams = Ups}) -> lists:foreach( fun(Upstream) -> ok = mtp_handler:send(Upstream, {close_ext, Self}) - end, maps:keys(Ups)), - ok. + end, maps:keys(Ups)). + code_change(_OldVsn, State, _Extra) -> {ok, State}. @@ -471,6 +488,26 @@ non_ack_cleanup_upstream(Upstream, #state{non_ack_count = Cnt, St#state{non_ack_count = Cnt - UpsCnt, non_ack_bytes = Oct - UpsOct}). +%% Drain pending messages from our mailbox during migration. +%% - gen_server:call({send,_}): reply {error, migrating} so callers unblock +%% - gen_server:cast({upstream_new,...}): send {migrate} immediately +%% Timeout controls how long to wait for the next message before giving up. +%% Returns count of drained send calls. +drain_mailbox(Timeout) -> + drain_mailbox(Timeout, 0). + +drain_mailbox(Timeout, NSend) -> + receive + {'$gen_call', From, {send, _Data}} -> + gen_server:reply(From, {error, migrating}), + drain_mailbox(Timeout, NSend + 1); + {'$gen_cast', {upstream_new, Upstream, _Opts}} -> + mtp_handler:migrate(Upstream, self()), + drain_mailbox(Timeout, NSend) + after Timeout -> + NSend + end. + %% %% Connect / handshake diff --git a/src/mtp_handler.erl b/src/mtp_handler.erl index 56df270..f2df969 100644 --- a/src/mtp_handler.erl +++ b/src/mtp_handler.erl @@ -10,7 +10,7 @@ -behaviour(ranch_protocol). %% API --export([start_link/3, start_link/4, send/2]). +-export([start_link/3, start_link/4, send/2, migrate/2]). -export([hex/1, unhex/1]). -export([keys_str/0]). @@ -80,6 +80,10 @@ keys_str() -> send(Upstream, Packet) -> gen_server:cast(Upstream, Packet). +-spec migrate(pid(), OldDown :: mtp_down_conn:handle()) -> ok. +migrate(Upstream, OldDown) -> + gen_server:cast(Upstream, {migrate, OldDown}). + %% Callbacks %% Custom gen_server init @@ -175,6 +179,27 @@ handle_cast({close_ext, Down}, #state{down = Down, sock = USock, transport = UTr ?LOG_DEBUG("asked to close connection by downstream"), ok = UTrans:close(USock), {stop, normal, S#state{down = undefined}}; +handle_cast({migrate, OldDown}, #state{down = OldDown, dc_id = {_DcId, Pool}, + codec = Codec, addr = Addr, + ad_tag = AdTag, listener = Listener} = S) -> + {PacketLayerMod, _} = mtp_codec:info(packet, Codec), + Opts = #{addr => Addr, ad_tag => AdTag, packet_layer => PacketLayerMod}, + case mtp_dc_pool:migrate(Pool, OldDown, self(), Opts) of + {error, Reason} -> + ?LOG_DEBUG("Migration failed (~p), closing client", [Reason]), + true = is_atom(Reason), + mtp_metric:count_inc([?APP, downstream_migration, total], 1, + #{labels => [Listener, Reason]}), + {stop, normal, S#state{down = undefined}}; + NewDown -> + ?LOG_DEBUG("Migrated from ~p to ~p", [OldDown, NewDown]), + mtp_metric:count_inc([?APP, downstream_migration, total], 1, + #{labels => [Listener, ok]}), + {noreply, S#state{down = NewDown}} + end; +handle_cast({migrate, _StaleDown}, S) -> + %% Stale migrate from a previous down_conn — already migrated, ignore. + {noreply, S}; handle_cast({simple_ack, Down, Confirm}, #state{down = Down} = S) -> ?LOG_INFO("Simple ack: ~p, ~p", [Down, Confirm]), {noreply, S}; @@ -613,13 +638,19 @@ up_send_raw(Data, #state{sock = Sock, end end, #{labels => [Listener]}). -down_send(Packet, #state{down = Down} = S) -> +down_send(Packet, #state{down = Down, listener = Listener} = S) -> %% ?LOG_DEBUG(">Down: ~p", [Packet]), case mtp_down_conn:send(Down, Packet) of ok -> {ok, S}; {error, unknown_upstream} -> - handle_unknown_upstream(S) + handle_unknown_upstream(S); + {error, migrating} -> + %% DC connection is closing; this packet was never sent to TG. + %% Stop the handler so the client reconnects and resends. + mtp_metric:count_inc([?APP, downstream_migration, total], 1, + #{labels => [Listener, mid_send]}), + throw({stop, normal, S}) end. handle_unknown_upstream(#state{down = Down, sock = USock, transport = UTrans} = S) -> diff --git a/test/mtp_test_middle_server.erl b/test/mtp_test_middle_server.erl index b7985ff..d47597d 100644 --- a/test/mtp_test_middle_server.erl +++ b/test/mtp_test_middle_server.erl @@ -5,7 +5,8 @@ -export([start/2, stop/1, - get_rpc_handler_state/1]). + get_rpc_handler_state/1, + close_connection/1]). -export([start_link/3, ranch_init/1]). -export([init/1, @@ -61,6 +62,10 @@ stop(Id) -> get_rpc_handler_state(Pid) -> gen_statem:call(Pid, get_rpc_handler_state). +%% Close the server-side TCP socket, simulating Telegram rotating the connection. +close_connection(Pid) -> + gen_statem:call(Pid, close_connection). + %% Callbacks start_link(Ref, Transport, Opts) -> @@ -159,6 +164,9 @@ on_tunnel(info, {tcp, _Sock, TcpData}, #t_state{codec = Codec0} = S) -> {keep_state, activate(S2#t_state{codec = Codec1})}; on_tunnel({call, From}, get_rpc_handler_state, #t_state{rpc_handler_state = HSt}) -> {keep_state_and_data, [{reply, From, HSt}]}; +on_tunnel({call, From}, close_connection, #t_state{sock = Sock, transport = Transport}) -> + Transport:close(Sock), + {stop_and_reply, normal, [{reply, From, ok}]}; on_tunnel(Type, Event, S) -> handle_event(Type, Event, ?FUNCTION_NAME, S). diff --git a/test/mtp_test_reporter_rpc.erl b/test/mtp_test_reporter_rpc.erl new file mode 100644 index 0000000..5970a10 --- /dev/null +++ b/test/mtp_test_reporter_rpc.erl @@ -0,0 +1,17 @@ +%% @doc rpc_handler for mtp_test_middle_server that echoes packets and reports +%% each one to a registered process named `mtp_test_rpc_sink'. +%% The report message is `{rpc_from, self(), ConnId, Data}', where `self()' is +%% the mtp_test_middle_server Ranch connection pid — useful for tests that need +%% to identify which DC connection a client is multiplexed on and close it. +-module(mtp_test_reporter_rpc). +-export([init/1, + handle_rpc/2]). + +init(_) -> + #{}. + +handle_rpc({data, ConnId, Data}, St) -> + mtp_test_rpc_sink ! {rpc_from, self(), ConnId, Data}, + {rpc, {proxy_ans, ConnId, Data}, St}; +handle_rpc({remote_closed, ConnId}, St) -> + {noreply, St#{ConnId => closed}}. diff --git a/test/single_dc_SUITE.erl b/test/single_dc_SUITE.erl index 3eec673..a981949 100644 --- a/test/single_dc_SUITE.erl +++ b/test/single_dc_SUITE.erl @@ -27,7 +27,10 @@ domain_fronting_replay_case/1, per_sni_secrets_on_case/1, per_sni_secrets_wrong_secret_case/1, - malformed_tls_hello_decode_error_case/1 + malformed_tls_hello_decode_error_case/1, + downstream_migration_case/1, + downstream_migration_multi_case/1, + downstream_migration_empty_pool_case/1 ]). -export([set_env/2, @@ -781,6 +784,103 @@ malformed_tls_hello_decode_error_case(Cfg) when is_list(Cfg) -> 1, mtp_test_metric:get_tags( count, [?APP, protocol_error, total], [?FUNCTION_NAME, tls_bad_client_hello])). +%% @doc Client survives a DC connection rotation (1 client, 2 DC connections available). +downstream_migration_case({pre, Cfg}) -> + setup_single(?FUNCTION_NAME, 10000 + ?LINE, + #{init_dc_connections => 2, rpc_handler => mtp_test_reporter_rpc}, Cfg); +downstream_migration_case({post, Cfg}) -> + stop_single(Cfg); +downstream_migration_case(Cfg) when is_list(Cfg) -> + DcId = ?config(dc_id, Cfg), + Host = ?config(mtp_host, Cfg), + Port = ?config(mtp_port, Cfg), + Secret = ?config(mtp_secret, Cfg), + Pool = mtp_dc_pool:dc_to_pool_name(DcId), + register(mtp_test_rpc_sink, self()), + try + Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure), + Cli1 = ping(Cli), + %% Receive the reporter notification to learn which Ranch/middle-server pid + %% this client's DC connection tunnels through. + ServerPid = receive {rpc_from, Pid, _, _} -> Pid end, + ok = mtp_test_middle_server:close_connection(ServerPid), + %% Wait until handler has successfully migrated to the surviving downstream. + ok = mtp_test_metric:wait_for_value( + count, [?APP, downstream_migration, total], + [?FUNCTION_NAME, ok], 1, 5000), + %% Client must still work after migration. + Cli2 = ping(Cli1), + %% Pool tracking must be clean: exactly 1 upstream registered. + ?assertMatch(#{n_upstreams := 1}, mtp_dc_pool:status(Pool)), + ok = mtp_test_client:close(Cli2) + after + unregister(mtp_test_rpc_sink) + end. + +%% @doc All clients survive when one of two DC connections is rotated. +downstream_migration_multi_case({pre, Cfg}) -> + setup_single(?FUNCTION_NAME, 10000 + ?LINE, + #{init_dc_connections => 2, rpc_handler => mtp_test_reporter_rpc}, Cfg); +downstream_migration_multi_case({post, Cfg}) -> + stop_single(Cfg); +downstream_migration_multi_case(Cfg) when is_list(Cfg) -> + DcId = ?config(dc_id, Cfg), + Host = ?config(mtp_host, Cfg), + Port = ?config(mtp_port, Cfg), + Secret = ?config(mtp_secret, Cfg), + Pool = mtp_dc_pool:dc_to_pool_name(DcId), + N = 3, + register(mtp_test_rpc_sink, self()), + try + Clients = [mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure) + || _ <- lists:seq(1, N)], + Clients1 = [ping(C) || C <- Clients], + %% Drain all N {rpc_from,...} messages, grouping by ServerPid to find + %% which DC connection each client landed on. + Groups = lists:foldl(fun(_, Acc) -> + receive {rpc_from, Pid, _, _} -> + maps:update_with(Pid, fun(C) -> C + 1 end, 1, Acc) + end + end, #{}, lists:seq(1, N)), + %% Close the DC connection carrying the most clients. + {ServerPid, NOnServer} = hd(lists:reverse(lists:keysort(2, maps:to_list(Groups)))), + ok = mtp_test_middle_server:close_connection(ServerPid), + %% Wait until exactly NOnServer clients have successfully migrated. + ok = mtp_test_metric:wait_for_value( + count, [?APP, downstream_migration, total], + [?FUNCTION_NAME, ok], NOnServer, 5000), + Clients2 = [ping(C) || C <- Clients1], + ?assertMatch(#{n_upstreams := N}, mtp_dc_pool:status(Pool)), + [ok = mtp_test_client:close(C) || C <- Clients2] + after + unregister(mtp_test_rpc_sink) + end, + ok. + +%% @doc When pool is empty after DC rotation, client closes gracefully. +downstream_migration_empty_pool_case({pre, Cfg}) -> + setup_single(?FUNCTION_NAME, 10000 + ?LINE, + #{init_dc_connections => 0}, Cfg); +downstream_migration_empty_pool_case({post, Cfg}) -> + stop_single(Cfg); +downstream_migration_empty_pool_case(Cfg) when is_list(Cfg) -> + DcId = ?config(dc_id, Cfg), + Host = ?config(mtp_host, Cfg), + Port = ?config(mtp_port, Cfg), + Secret = ?config(mtp_secret, Cfg), + DcCfg = ?config(dc_conf, Cfg), + Pool = mtp_dc_pool:dc_to_pool_name(DcId), + %% Manually add one connection and wait for it to be ready + ok = mtp_dc_pool:add_connection(Pool), + ok = wait_for_pool_status(Pool, fun(S) -> maps:get(n_downstreams, S) >= 1 end, 5000), + Cli = mtp_test_client:connect(Host, Port, Secret, DcId, mtp_secure), + _Cli1 = ping(Cli), + [Conn] = mtp_test_datacenter:middle_connections(DcCfg), + ok = mtp_test_middle_server:close_connection(Conn), + %% Pool stays empty (init_dc_connections=0 so no replacement spawned) + %% Client must close gracefully + ?assertEqual({error, closed}, mtp_test_client:recv_packet(_Cli1, 2000)). + setup_single(Name, MtpPort, DcCfg0, Cfg) -> setup_single(Name, "127.0.0.1", MtpPort, DcCfg0, Cfg). @@ -796,7 +896,11 @@ setup_single(Name, MtpIpStr, MtpPort, DcCfg0, Cfg) -> secret => Secret, tag => <<"dcbe8f1493fa4cd9ab300891c0b5b326">>}], application:load(mtproto_proxy), - Cfg1 = set_env([{ports, Listeners}], Cfg), + AppEnv = case maps:find(init_dc_connections, DcCfg0) of + {ok, N} -> [{init_dc_connections, N}]; + error -> [] + end, + Cfg1 = set_env([{ports, Listeners}] ++ AppEnv, Cfg), {ok, DcCfg} = mtp_test_datacenter:start_dc(PubKey, DcConf, DcCfg0), {ok, _} = application:ensure_all_started(mtproto_proxy), {ok, MtpIp} = inet:parse_address(MtpIpStr), @@ -851,3 +955,23 @@ ping(Cli0) -> {ok, Packet, Cli2} = mtp_test_client:recv_packet(Cli1, 1000), ?assertEqual(Data, Packet), Cli2. + +wait_for_pool_status(Pool, Pred, Timeout) -> + Deadline = erlang:monotonic_time(millisecond) + Timeout, + wait_for_pool_status_loop(Pool, Pred, Deadline). + +wait_for_pool_status_loop(Pool, Pred, Deadline) -> + Status = mtp_dc_pool:status(Pool), + case Pred(Status) of + true -> + ok; + false -> + Remaining = Deadline - erlang:monotonic_time(millisecond), + case Remaining > 0 of + true -> + timer:sleep(50), + wait_for_pool_status_loop(Pool, Pred, Deadline); + false -> + {error, {timeout, Status}} + end + end.