355 lines
No EOL
11 KiB
Python
355 lines
No EOL
11 KiB
Python
import asyncio
|
|
import logging
|
|
import struct
|
|
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
from typing import Dict, List, Optional
|
|
|
|
from .utils import *
|
|
from .stats import stats
|
|
from .balancer import balancer
|
|
from .config import proxy_config
|
|
from .raw_websocket import RawWebSocket
|
|
|
|
|
|
log = logging.getLogger('tg-mtproto-proxy')
|
|
_st_I_le = struct.Struct('<I')
|
|
|
|
ZERO_64 = b'\x00' * 64
|
|
DC_DEFAULT_IPS: Dict[int, str] = {
|
|
1: '149.154.175.50',
|
|
2: '149.154.167.51',
|
|
3: '149.154.175.100',
|
|
4: '149.154.167.91',
|
|
5: '149.154.171.5',
|
|
203: '91.105.192.100'
|
|
}
|
|
|
|
|
|
class CryptoCtx:
|
|
__slots__ = ('clt_dec', 'clt_enc', 'tg_enc', 'tg_dec')
|
|
|
|
def __init__(self, clt_dec, clt_enc, tg_enc, tg_dec):
|
|
self.clt_dec = clt_dec # decrypt from client
|
|
self.clt_enc = clt_enc # encrypt to client
|
|
self.tg_enc = tg_enc # encrypt to telegram
|
|
self.tg_dec = tg_dec # decrypt from telegram
|
|
|
|
|
|
class MsgSplitter:
|
|
"""
|
|
Splits TCP stream data into individual MTProto transport packets
|
|
so each can be sent as a separate WS frame.
|
|
"""
|
|
__slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
|
|
|
|
def __init__(self, relay_init: bytes, proto_int: int):
|
|
cipher = Cipher(algorithms.AES(relay_init[8:40]),
|
|
modes.CTR(relay_init[40:56]))
|
|
self._dec = cipher.encryptor()
|
|
self._dec.update(ZERO_64)
|
|
self._proto = proto_int
|
|
self._cipher_buf = bytearray()
|
|
self._plain_buf = bytearray()
|
|
self._disabled = False
|
|
|
|
def split(self, chunk: bytes) -> List[bytes]:
|
|
if not chunk:
|
|
return []
|
|
if self._disabled:
|
|
return [chunk]
|
|
|
|
self._cipher_buf.extend(chunk)
|
|
self._plain_buf.extend(self._dec.update(chunk))
|
|
|
|
parts = []
|
|
while self._cipher_buf:
|
|
packet_len = self._next_packet_len()
|
|
if packet_len is None:
|
|
break
|
|
if packet_len <= 0:
|
|
parts.append(bytes(self._cipher_buf))
|
|
self._cipher_buf.clear()
|
|
self._plain_buf.clear()
|
|
self._disabled = True
|
|
break
|
|
parts.append(bytes(self._cipher_buf[:packet_len]))
|
|
del self._cipher_buf[:packet_len]
|
|
del self._plain_buf[:packet_len]
|
|
return parts
|
|
|
|
def flush(self) -> List[bytes]:
|
|
if not self._cipher_buf:
|
|
return []
|
|
tail = bytes(self._cipher_buf)
|
|
self._cipher_buf.clear()
|
|
self._plain_buf.clear()
|
|
return [tail]
|
|
|
|
def _next_packet_len(self) -> Optional[int]:
|
|
if not self._plain_buf:
|
|
return None
|
|
if self._proto == PROTO_ABRIDGED_INT:
|
|
return self._next_abridged_len()
|
|
if self._proto in (PROTO_INTERMEDIATE_INT,
|
|
PROTO_PADDED_INTERMEDIATE_INT):
|
|
return self._next_intermediate_len()
|
|
return 0
|
|
|
|
def _next_abridged_len(self) -> Optional[int]:
|
|
first = self._plain_buf[0]
|
|
if first in (0x7F, 0xFF):
|
|
if len(self._plain_buf) < 4:
|
|
return None
|
|
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
|
|
header_len = 4
|
|
else:
|
|
payload_len = (first & 0x7F) * 4
|
|
header_len = 1
|
|
if payload_len <= 0:
|
|
return 0
|
|
packet_len = header_len + payload_len
|
|
if len(self._plain_buf) < packet_len:
|
|
return None
|
|
return packet_len
|
|
|
|
def _next_intermediate_len(self) -> Optional[int]:
|
|
if len(self._plain_buf) < 4:
|
|
return None
|
|
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
|
|
if payload_len <= 0:
|
|
return 0
|
|
packet_len = 4 + payload_len
|
|
if len(self._plain_buf) < packet_len:
|
|
return None
|
|
return packet_len
|
|
|
|
|
|
|
|
async def do_fallback(reader, writer, relay_init, label,
|
|
dc: int, is_media: bool, media_tag: str,
|
|
ctx: CryptoCtx, splitter=None):
|
|
fallback_dst = DC_DEFAULT_IPS.get(dc)
|
|
use_cf = proxy_config.fallback_cfproxy
|
|
cf_first = proxy_config.fallback_cfproxy_priority
|
|
|
|
methods: List[str] = ['tcp']
|
|
|
|
if use_cf:
|
|
methods.insert(0 if cf_first else 1, 'cf')
|
|
|
|
for method in methods:
|
|
if method == 'cf':
|
|
ok = await _cfproxy_fallback(
|
|
reader, writer, relay_init, label, ctx,
|
|
dc=dc, is_media=is_media,
|
|
splitter=splitter)
|
|
if ok:
|
|
return True
|
|
elif method == 'tcp' and fallback_dst:
|
|
log.info("[%s] DC%d%s -> TCP fallback to %s:443",
|
|
label, dc, media_tag, fallback_dst)
|
|
ok = await _tcp_fallback(
|
|
reader, writer, fallback_dst, 443,
|
|
relay_init, label, ctx)
|
|
if ok:
|
|
return True
|
|
return False
|
|
|
|
|
|
async def _cfproxy_fallback(reader, writer, relay_init, label,
|
|
ctx: CryptoCtx,
|
|
dc: int, is_media: bool,
|
|
splitter=None):
|
|
media_tag = ' media' if is_media else ''
|
|
ws = None
|
|
chosen_domain = None
|
|
|
|
log.info("[%s] DC%d%s -> trying CF proxy",
|
|
label, dc, media_tag)
|
|
|
|
for base_domain in balancer.get_domains_for_dc(dc):
|
|
domain = f'kws{dc}.{base_domain}'
|
|
try:
|
|
ws = await RawWebSocket.connect(domain, domain, timeout=10.0)
|
|
chosen_domain = base_domain
|
|
break
|
|
except Exception as exc:
|
|
log.warning("[%s] DC%d%s CF proxy failed: %s",
|
|
label, dc, media_tag, repr(exc))
|
|
|
|
if ws is None:
|
|
return False
|
|
|
|
if chosen_domain and balancer.update_domain_for_dc(dc, chosen_domain):
|
|
log.info("[%s] Switched active CF domain", label)
|
|
|
|
stats.connections_cfproxy += 1
|
|
await ws.send(relay_init)
|
|
await bridge_ws_reencrypt(reader, writer, ws, label, ctx,
|
|
dc=dc, is_media=is_media,
|
|
splitter=splitter)
|
|
return True
|
|
|
|
|
|
async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: CryptoCtx):
|
|
try:
|
|
rr, rw = await asyncio.wait_for(
|
|
asyncio.open_connection(dst, port), timeout=10)
|
|
except Exception as exc:
|
|
log.warning("[%s] TCP fallback to %s:%d failed: %s",
|
|
label, dst, port, repr(exc))
|
|
return False
|
|
|
|
stats.connections_tcp_fallback += 1
|
|
rw.write(relay_init)
|
|
await rw.drain()
|
|
await _bridge_tcp_reencrypt(reader, writer, rr, rw, label, ctx)
|
|
return True
|
|
|
|
|
|
async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label,
|
|
ctx: CryptoCtx,
|
|
dc=None, is_media=False,
|
|
splitter: Optional[MsgSplitter] = None):
|
|
"""
|
|
Bidirectional TCP(client) <-> WS(telegram) with re-encryption.
|
|
client ciphertext → decrypt(clt_key) → encrypt(tg_key) → WS
|
|
WS data → decrypt(tg_key) → encrypt(clt_key) → client TCP
|
|
"""
|
|
dc_tag = f"DC{dc}{'m' if is_media else ''}" if dc else "DC?"
|
|
|
|
up_bytes = 0
|
|
down_bytes = 0
|
|
up_packets = 0
|
|
down_packets = 0
|
|
start_time = asyncio.get_running_loop().time()
|
|
|
|
async def tcp_to_ws():
|
|
nonlocal up_bytes, up_packets
|
|
try:
|
|
while True:
|
|
chunk = await reader.read(65536)
|
|
if not chunk:
|
|
if splitter:
|
|
tail = splitter.flush()
|
|
if tail:
|
|
await ws.send(tail[0])
|
|
break
|
|
n = len(chunk)
|
|
stats.bytes_up += n
|
|
up_bytes += n
|
|
up_packets += 1
|
|
plain = ctx.clt_dec.update(chunk)
|
|
chunk = ctx.tg_enc.update(plain)
|
|
if splitter:
|
|
parts = splitter.split(chunk)
|
|
if not parts:
|
|
continue
|
|
if len(parts) > 1:
|
|
await ws.send_batch(parts)
|
|
else:
|
|
await ws.send(parts[0])
|
|
else:
|
|
await ws.send(chunk)
|
|
except (asyncio.CancelledError, ConnectionError, OSError):
|
|
return
|
|
except Exception as e:
|
|
log.debug("[%s] tcp->ws ended: %s", label, e)
|
|
|
|
async def ws_to_tcp():
|
|
nonlocal down_bytes, down_packets
|
|
try:
|
|
while True:
|
|
data = await ws.recv()
|
|
if data is None:
|
|
break
|
|
n = len(data)
|
|
stats.bytes_down += n
|
|
down_bytes += n
|
|
down_packets += 1
|
|
plain = ctx.tg_dec.update(data)
|
|
data = ctx.clt_enc.update(plain)
|
|
writer.write(data)
|
|
await writer.drain()
|
|
except (asyncio.CancelledError, ConnectionError, OSError):
|
|
return
|
|
except Exception as e:
|
|
log.debug("[%s] ws->tcp ended: %s", label, e)
|
|
|
|
tasks = [asyncio.create_task(tcp_to_ws()),
|
|
asyncio.create_task(ws_to_tcp())]
|
|
try:
|
|
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
finally:
|
|
for t in tasks:
|
|
t.cancel()
|
|
for t in tasks:
|
|
try:
|
|
await t
|
|
except BaseException:
|
|
pass
|
|
elapsed = asyncio.get_running_loop().time() - start_time
|
|
log.info("[%s] %s WS session closed: "
|
|
"^%s (%d pkts) v%s (%d pkts) in %.1fs",
|
|
label, dc_tag,
|
|
human_bytes(up_bytes), up_packets,
|
|
human_bytes(down_bytes), down_packets,
|
|
elapsed)
|
|
try:
|
|
await ws.close()
|
|
except BaseException:
|
|
pass
|
|
try:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
except BaseException:
|
|
pass
|
|
|
|
|
|
async def _bridge_tcp_reencrypt(reader, writer, remote_reader, remote_writer,
|
|
label, ctx: CryptoCtx):
|
|
"""Bidirectional TCP <-> TCP with re-encryption."""
|
|
|
|
async def forward(src, dst_w, is_up):
|
|
try:
|
|
while True:
|
|
data = await src.read(65536)
|
|
if not data:
|
|
break
|
|
n = len(data)
|
|
if is_up:
|
|
stats.bytes_up += n
|
|
plain = ctx.clt_dec.update(data)
|
|
data = ctx.tg_enc.update(plain)
|
|
else:
|
|
stats.bytes_down += n
|
|
plain = ctx.tg_dec.update(data)
|
|
data = ctx.clt_enc.update(plain)
|
|
dst_w.write(data)
|
|
await dst_w.drain()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
log.debug("[%s] forward ended: %s", label, e)
|
|
|
|
tasks = [
|
|
asyncio.create_task(forward(reader, remote_writer, True)),
|
|
asyncio.create_task(forward(remote_reader, writer, False)),
|
|
]
|
|
try:
|
|
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
finally:
|
|
for t in tasks:
|
|
t.cancel()
|
|
for t in tasks:
|
|
try:
|
|
await t
|
|
except BaseException:
|
|
pass
|
|
for w in (writer, remote_writer):
|
|
try:
|
|
w.close()
|
|
await w.wait_closed()
|
|
except BaseException:
|
|
pass |