diff --git a/src/mtp_handler.erl b/src/mtp_handler.erl index a90898d..11ad75f 100644 --- a/src/mtp_handler.erl +++ b/src/mtp_handler.erl @@ -244,6 +244,7 @@ handle_upstream_data(<>, #state{stage = init, sta secret = Secret, listener = Listener} = S) -> case mtp_obfuscated:from_header(Header, Secret) of {ok, DcId, PacketLayerMod, CryptoCodecSt} -> + maybe_check_replay(Header), mtp_metric:count_inc([?APP, protocol_ok, total], 1, #{labels => [Listener, PacketLayerMod]}), PacketCodec = PacketLayerMod:new(), @@ -263,6 +264,16 @@ handle_upstream_data(Bin, #state{stage = init, stage_state = <<>>} = S) -> handle_upstream_data(Bin, #state{stage = init, stage_state = Buf} = S) -> handle_upstream_data(<> , S#state{stage_state = <<>>}). +maybe_check_replay(Packet) -> + %% Check for session replay attack: attempt to connect with the same 1st 64byte packet + case lists:member(mtp_session_storage, application:get_env(?APP, replay_checks_enabled, [])) of + true -> + (new == mtp_session_storage:check_add(Packet)) orelse + error({protocol_error, replay_session_detected, Packet}); + false -> + ok + end. + up_send(Packet, #state{stage = tunnel, codec = UpCodec, diff --git a/src/mtp_session_storage.erl b/src/mtp_session_storage.erl new file mode 100644 index 0000000..7f9122f --- /dev/null +++ b/src/mtp_session_storage.erl @@ -0,0 +1,229 @@ +%%%------------------------------------------------------------------- +%%% @doc +%%% Storage to store last used sessions to protect from replay-attacks +%%% used in some countries to detect mtproto proxy. +%%% +%%% Data is stored in ?DATA_TAB and there is additional index table ?HISTOGRAM_TAB, where +%%% we store "secondary index" histogram: how many sessions have been added in each 5 minute +%%% interval. It is used to make periodic cleanup procedure more efficient. +%%% @end +%%% Created : 19 May 2019 by Sergey +%%%------------------------------------------------------------------- +-module(mtp_session_storage). + +-behaviour(gen_server). + +%% API +-export([start_link/0, + check_add/1, + status/0]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-include_lib("stdlib/include/ms_transform.hrl"). + +-define(DATA_TAB, ?MODULE). +-define(HISTOGRAM_TAB, mtp_session_storage_histogram). + +%% 5-minute buckets +-define(HISTOGRAM_BUCKET_SIZE, 300). +-define(CHECK_INTERVAL, 60). + +-record(state, {data_tab = ets:tid(), + histogram_tab = ets:tid(), + clean_timer = gen_timeout:tout()}). + +%%%=================================================================== +%%% API +%%%=================================================================== +start_link() -> + gen_server:start_link({local, ?MODULE}, ?MODULE, [], []). + +%% @doc Add secret to the storage. Returns `new' if it was never used and `used' if it was +%% already used before. +-spec check_add(binary()) -> new | used. +check_add(Packet) when byte_size(Packet) == 64 -> + Now = erlang:system_time(second), + check_add_at(Packet, Now). + +check_add_at(Packet, Now) -> + Record = {fingerprint(Packet), Now}, + HistogramBucket = bucket(Now), + ets:update_counter(?HISTOGRAM_TAB, HistogramBucket, 1, {HistogramBucket, 0}), + case ets:insert_new(?DATA_TAB, Record) of + true -> + new; + false -> + %% TODO: should decrement old record's histogram counter, but skip this for simplicity + ets:insert(?DATA_TAB, Record), + used + end. + +-spec status() -> #{tab_size := non_neg_integer(), + tab_memory_kb := non_neg_integer(), + histogram_buckets := non_neg_integer(), + histogram_size := non_neg_integer(), + histogram_oldest := non_neg_integer()}. +status() -> + gen_server:call(?MODULE, status). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([]) -> + {DataTab, HistTab} = new_storage(), + Timer = gen_timeout:new(#{timeout => ?CHECK_INTERVAL}), + {ok, #state{data_tab = DataTab, + histogram_tab = HistTab, + clean_timer = Timer}}. + +handle_call(status, _From, #state{data_tab = DataTid, histogram_tab = HistTid} = State) -> + Now = erlang:system_time(second), + Size = ets:info(DataTid, size), + Memory = tab_memory(DataTid), + MemoryKb = round(Memory / 1024), + HistSize = ets:info(HistTid, size), + {HistOldest, HistTotal} = + ets:foldl(fun({Bucket, Count}, {Oldest, Total}) -> + {erlang:min(Oldest, bucket_to_ts(Bucket)), Total + Count} + end, {Now, 0}, HistTid), + Status = #{tab_size => Size, + tab_memory_kb => MemoryKb, + histogram_buckets => HistSize, + histogram_size => HistTotal, + histogram_oldest_ts => HistOldest, + histogram_oldest_age => Now - HistOldest}, + {reply, Status, State}. + +handle_cast(_Msg, State) -> + {noreply, State}. + +handle_info(timeout, #state{data_tab = DataTab, histogram_tab = HistTab, clean_timer = Timer0} = State) -> + Timer = + case gen_timeout:is_expired(Timer0) of + true -> + Opts = application:get_env(mtproto_proxy, replay_check_session_storage_opts, + #{max_age_minutes => 360}), + Cleans = clean_storage(DataTab, HistTab, Opts), + Remaining = ets:info(DataTab, size), + lager:info("storage cleaned: ~p; remaining: ~p", [Cleans, Remaining]), + gen_timeout:bump(gen_timeout:reset(Timer0)); + false -> + gen_timeout:reset(Timer0) + end, + {noreply, State#state{clean_timer = Timer}}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== + +fingerprint(<<_:8/binary, KeyIV:(32 + 16)/binary, _:8/binary>>) -> + %% It would be better to use whole 64b packet as fingerprint, but will use only + %% 48b Key + IV part to save some space. + KeyIV. + +bucket(Timestamp) -> + Timestamp div ?HISTOGRAM_BUCKET_SIZE. + +bucket_to_ts(BucketTime) -> + BucketTime * ?HISTOGRAM_BUCKET_SIZE. + +bucket_next(BucketTime) -> + BucketTime + 1. + + +new_storage() -> + DataTab = ets:new(?DATA_TAB, [set, public, named_table, {write_concurrency, true}]), + HistTab = ets:new(?HISTOGRAM_TAB, [set, public, named_table, {write_concurrency, true}]), + {DataTab, HistTab}. + + +clean_storage(DataTid, HistogramTid, CleanOpts) -> + lists:filtermap(fun(Check) -> do_clean(DataTid, HistogramTid, CleanOpts, Check) end, + [space, count, max_age]). + +do_clean(DataTid, HistTid, #{max_memory_mb := MaxMem}, space) -> + TabMemBytes = tab_memory(DataTid), + MaxMemBytes = MaxMem * 1024 * 1024, + case TabMemBytes > MaxMemBytes of + true -> + PercentToShrink = (TabMemBytes - MaxMemBytes) / TabMemBytes, + Removed = shrink_percent(DataTid, HistTid, PercentToShrink), + {true, {space, Removed}}; + false -> + false + end; +do_clean(DataTid, HistTid, #{max_items := MaxItems}, count) -> + Count = ets:info(DataTid, size), + case Count > MaxItems of + true -> + PercentToShrink = (Count - MaxItems) / Count, + Removed = shrink_percent(DataTid, HistTid, PercentToShrink), + {true, {count, Removed}}; + false -> + false + end; +do_clean(DataTid, HistTid, #{max_age_minutes := MaxAge}, max_age) -> + %% First scan histogram table, because it's cheaper + CutBucket = bucket(erlang:system_time(second) - (MaxAge * 60)), + HistMs = ets:fun2ms(fun({BucketTs, _}) when BucketTs =< CutBucket -> true end), + case ets:select_count(HistTid, HistMs) of + 0 -> + false; + _ -> + Removed = remove_older(CutBucket, DataTid, HistTid), + {true, {max_age, Removed}} + end. + + +tab_memory(Tid) -> + WordSize = erlang:system_info(wordsize), + Words = ets:info(Tid, memory), + Words * WordSize. + +shrink_percent(DataTid, HistTid, Percent) when Percent < 1, + Percent >= 0 -> + Count = ets:info(DataTid, size), + ToRemove = trunc(Count * Percent), + HistByTime = lists:sort(ets:tab2list(HistTid)), % oldest first + CutBucketTime = find_cut_bucket(HistByTime, ToRemove, 0), + remove_older(CutBucketTime, DataTid, HistTid). + +%% Find the timestamp such that if we remove buckets that are older than this timestamp then we +%% will remove at least `ToRemove' items. +find_cut_bucket([{BucketTime, _}], _, _) -> + BucketTime; +find_cut_bucket([{BucketTime, Count} | Tail], ToRemove, Total) -> + NewTotal = Total + Count, + case NewTotal >= ToRemove of + true -> + BucketTime; + false -> + find_cut_bucket(Tail, ToRemove, NewTotal) + end. + +%% @doc remove records that are in CutBucketTime bucket or older. +%% Returns number of removed data records. +-spec remove_older(integer(), ets:tid(), ets:tid()) -> non_neg_integer(). +remove_older(CutBucketTime, DataTid, HistTid) -> + %% | --- | --- | --- | -- + %% ^ oldest bucket + %% ^ 2nd bucket + %% ^ 3rd bucket + %% ^ current bucket + %% If CutBucketTime is 2nd bucket, following will be removed: + %% | --- | --- + EdgeBucketTime = bucket_next(CutBucketTime), + HistMs = ets:fun2ms(fun({BucketTs, _}) when BucketTs < EdgeBucketTime -> true end), + DataCutTime = bucket_to_ts(EdgeBucketTime), + DataMs = ets:fun2ms(fun({_, Time}) when Time =< DataCutTime -> true end), + ets:select_delete(HistTid, HistMs), + ets:select_delete(DataTid, DataMs). diff --git a/src/mtproto_proxy.app.src b/src/mtproto_proxy.app.src index 22e0321..4af6ab1 100644 --- a/src/mtproto_proxy.app.src +++ b/src/mtproto_proxy.app.src @@ -59,7 +59,29 @@ {allowed_protocols, [mtp_abridged, mtp_intermediate, mtp_secure]}, {init_dc_connections, 2}, - {clients_per_dc_connection, 300} + {clients_per_dc_connection, 300}, + + %% List of enabled replay-attack checks. See + %% https://habr.com/ru/post/452144/ + {replay_checks_enabled, [mtp_session_storage]}, + + %% Options for `mtp_session_storage` replay attack check + %% Those settings are not precise! They are checked not in realtime, but + %% once per minute. + {replay_check_session_storage_opts, + #{%% Start to remove oldest items if there are more than max_items + %% records in the storage + max_items => 4000000, + %% Start to remove oldest items if storage occupies more than + %% `max_memory_mb` megabytes of memory + %% One session uses ~130-150bytes on 64bit linux; + %% 1Gb will be enough to store ~8mln sessions, which is + %% 24 hours of ~90 connections per second + max_memory_mb => 512, + %% Remove items used for the last time more than `max_age_minutes` + %% minutes ago. + %% Less than 10 minutes doesn't make much sense + max_age_minutes => 360}} %% Should be module with function `notify/4' exported. %% See mtp_metric:notify/4 for details diff --git a/src/mtproto_proxy_sup.erl b/src/mtproto_proxy_sup.erl index 52d5091..a0559c5 100644 --- a/src/mtproto_proxy_sup.erl +++ b/src/mtproto_proxy_sup.erl @@ -52,6 +52,8 @@ init([]) -> type => supervisor, start => {mtp_dc_pool_sup, start_link, []}}, #{id => mtp_config, - start => {mtp_config, start_link, []}} + start => {mtp_config, start_link, []}}, + #{id => mtp_session_storage, + start => {mtp_session_storage, start_link, []}} ], {ok, {SupFlags, Childs}}. diff --git a/test/mtp_test_client.erl b/test/mtp_test_client.erl index c3807c8..36b7e8e 100644 --- a/test/mtp_test_client.erl +++ b/test/mtp_test_client.erl @@ -2,6 +2,7 @@ -module(mtp_test_client). -export([connect/5, + connect/6, send/2, recv_packet/2, recv_all/2, @@ -16,13 +17,17 @@ -type tcp_error() :: inet:posix() | closed. % | timeout. connect(Host, Port, Secret, DcId, Protocol) -> + Seed = crypto:strong_rand_bytes(58), + connect(Host, Port, Seed, Secret, DcId, Protocol). + +connect(Host, Port, Seed, Secret, DcId, Protocol) -> Opts = [{packet, raw}, {mode, binary}, {active, false}, {buffer, 1024}, {send_timeout, 5000}], {ok, Sock} = gen_tcp:connect(Host, Port, Opts, 1000), - {Header, _, _, CryptoLayer} = mtp_obfuscated:client_create(Secret, Protocol, DcId), + {Header, _, _, CryptoLayer} = mtp_obfuscated:client_create(Seed, Secret, Protocol, DcId), ok = gen_tcp:send(Sock, Header), PacketLayer = Protocol:new(), Codec = mtp_codec:new(mtp_obfuscated, CryptoLayer, diff --git a/test/single_dc_SUITE.erl b/test/single_dc_SUITE.erl index 4f27e7d..1dffb45 100644 --- a/test/single_dc_SUITE.erl +++ b/test/single_dc_SUITE.erl @@ -12,7 +12,8 @@ packet_too_large_case/1, downstream_size_backpressure_case/1, downstream_qlen_backpressure_case/1, - config_change_case/1 + config_change_case/1, + replay_attack_case/1 ]). -export([set_env/2, @@ -308,6 +309,36 @@ config_change_case(Cfg) when is_list(Cfg) -> ?assertEqual(PortsBefore, mtproto_proxy_app:running_ports()), ok. + +%% @doc test replay attack protection. +%% Attempts to connect with the same 1st 64-byte packet should be rejected. +replay_attack_case({pre, Cfg}) -> + setup_single(?FUNCTION_NAME, 10000 + ?LINE, #{}, Cfg); +replay_attack_case({post, Cfg}) -> + stop_single(Cfg); +replay_attack_case(Cfg) when is_list(Cfg) -> + DcId = ?config(dc_id, Cfg), + Host = ?config(mtp_host, Cfg), + Port = ?config(mtp_port, Cfg), + Secret = ?config(mtp_secret, Cfg), + Seed = crypto:strong_rand_bytes(58), + ErrCount = fun() -> + mtp_test_metric:get_tags( + count, [?APP, protocol_error, total], [replay_session_detected]) + end, + ?assertEqual(not_found, ErrCount()), + Cli1 = mtp_test_client:connect(Host, Port, Seed, Secret, DcId, mtp_secure), + _Cli1_1 = mtp_test_client:send(crypto:strong_rand_bytes(64), Cli1), + ?assertEqual(not_found, ErrCount()), + Cli2 = mtp_test_client:connect(Host, Port, Seed, Secret, DcId, mtp_secure), + ?assertEqual( + ok, mtp_test_metric:wait_for_value( + count, [?APP, protocol_error, total], [replay_session_detected], 1, 5000), + {mtp_session_storage:status(), + sys:get_state(mtp_test_metric)}), + ?assertEqual(1, ErrCount()), + ?assertEqual({error, closed}, mtp_test_client:recv_packet(Cli2, 1000)). + %% TODO: send a lot, not read, and then close - assert connection IDs are cleaned up %% Helpers