From d01da39dfbdb1329f084f79fec7aa874b05ca2ff Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sat, 1 Jul 2023 12:39:18 +0530 Subject: [PATCH] Speed up base64 decode even more --- kitty/base64.h | 35 +++++++++++++++++------------------ kitty/data-types.c | 16 ++++++++-------- kitty/fast_data_types.pyi | 4 ++-- kitty/file_transmission.py | 14 +++----------- 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/kitty/base64.h b/kitty/base64.h index 7417fbed4..482416337 100644 --- a/kitty/base64.h +++ b/kitty/base64.h @@ -27,8 +27,8 @@ bool decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz); bool encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding); #ifndef B64_INCLUDED_ONCE -static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz / 4) * 3 + 4; } -static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz / 3) * 4 + 5; } +static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz * 3) / 4 + 4; } +static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz * 4) / 3 + 5; } #endif #ifndef B64_INCLUDED_ONCE @@ -44,17 +44,17 @@ static uint8_t b64_decoding_table[256] = { #endif static void -inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, const size_t dest_sz) { +inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest) { for (size_t i = 0, j = 0; i < src_sz;) { - uint32_t sextet_a = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_b = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_c = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; - uint32_t sextet_d = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_a = b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_b = b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_c = b64_decoding_table[src[i++] & 0xff]; + uint32_t sextet_d = b64_decoding_table[src[i++] & 0xff]; uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6); - if (j < dest_sz) dest[j++] = (triple >> 2 * 8) & 0xFF; - if (j < dest_sz) dest[j++] = (triple >> 1 * 8) & 0xFF; - if (j < dest_sz) dest[j++] = (triple >> 0 * 8) & 0xFF; + dest[j++] = (triple >> 2 * 8) & 0xFF; + dest[j++] = (triple >> 1 * 8) & 0xFF; + dest[j++] = (triple >> 0 * 8) & 0xFF; } } @@ -63,18 +63,17 @@ decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz) { while (src_sz && src[src_sz-1] == '=') src_sz--; // remove trailing padding if (!src_sz) { *dest_sz = 0; return true; } const size_t dest_capacity = *dest_sz; - size_t extra = src_sz % 4; + *dest_sz = src_sz / 4; + size_t extra = src_sz - 4 * *dest_sz; + *dest_sz *= 3; src_sz -= extra; - *dest_sz = (src_sz / 4) * 3; - if (*dest_sz > dest_capacity) return false; - if (src_sz) inner_func(src, src_sz, dest, *dest_sz); - if (extra > 1) { + if (*dest_sz + 4 > dest_capacity) return false; + if (src_sz) inner_func(src, src_sz, dest); + if (extra > 1 && extra < 4) { // < 4 is not needed but it helps compiler unroll the loop INPUT_T buf[4] = {0}; for (size_t i = 0; i < extra; i++) buf[i] = src[src_sz+i]; - dest += *dest_sz; + inner_func(buf, extra, dest + *dest_sz); *dest_sz += extra - 1; - if (*dest_sz > dest_capacity) return false; - inner_func(buf, extra, dest, extra-1); } if (*dest_sz + 1 > dest_capacity) return false; dest[*dest_sz] = 0; // ensure zero-terminated diff --git a/kitty/data-types.c b/kitty/data-types.c index a29b80832..f4b3d74a7 100644 --- a/kitty/data-types.c +++ b/kitty/data-types.c @@ -79,24 +79,24 @@ redirect_std_streams(PyObject UNUSED *self, PyObject *args) { static PyObject* pybase64_encode(PyObject UNUSED *self, PyObject *args) { int add_padding = 0; - const char *src; Py_ssize_t src_len; - if (!PyArg_ParseTuple(args, "y#|p", &src, &src_len, &add_padding)) return NULL; - size_t sz = required_buffer_size_for_base64_encode(src_len); + FREE_BUFFER_AFTER_FUNCTION Py_buffer view = {0}; + if (!PyArg_ParseTuple(args, "s*|p", &view, &add_padding)) return NULL; + size_t sz = required_buffer_size_for_base64_encode(view.len); PyObject *ans = PyBytes_FromStringAndSize(NULL, sz); if (!ans) return NULL; - base64_encode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding); + base64_encode8(view.buf, view.len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding); if (_PyBytes_Resize(&ans, sz) != 0) return NULL; return ans; } static PyObject* pybase64_decode(PyObject UNUSED *self, PyObject *args) { - const char *src; Py_ssize_t src_len; - if (!PyArg_ParseTuple(args, "y#", &src, &src_len)) return NULL; - size_t sz = required_buffer_size_for_base64_decode(src_len); + FREE_BUFFER_AFTER_FUNCTION Py_buffer view = {0}; + if (!PyArg_ParseTuple(args, "s*", &view)) return NULL; + size_t sz = required_buffer_size_for_base64_decode(view.len); PyObject *ans = PyBytes_FromStringAndSize(NULL, sz); if (!ans) return NULL; - base64_decode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz); + base64_decode8(view.buf, view.len, (unsigned char*)PyBytes_AS_STRING(ans), &sz); if (_PyBytes_Resize(&ans, sz) != 0) return NULL; return ans; } diff --git a/kitty/fast_data_types.pyi b/kitty/fast_data_types.pyi index 5f0b37a45..c26c3b67e 100644 --- a/kitty/fast_data_types.pyi +++ b/kitty/fast_data_types.pyi @@ -1534,5 +1534,5 @@ def expand_ansi_c_escapes(test: str) -> str: ... def update_tab_bar_edge_colors(os_window_id: int) -> bool: ... def mask_kitty_signals_process_wide() -> None: ... def is_modifier_key(key: int) -> bool: ... -def base64_encode(src: bytes, add_padding: bool = False) -> bytes: ... -def base64_decode(src: bytes) -> bytes: ... +def base64_encode(src: Union[bytes,str], add_padding: bool = False) -> bytes: ... +def base64_decode(src: Union[bytes,str]) -> bytes: ... diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 0b6b303ca..9539af7e5 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -20,7 +20,7 @@ from typing import IO, Any, Callable, DefaultDict, Deque, Dict, Iterable, Iterat from kittens.transfer.librsync import LoadSignature, PatchFile, delta_for_file, signature_of_file from kittens.transfer.utils import IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path -from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, AES256GCMDecrypt, add_timer, base64_encode, get_boss, get_options +from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, AES256GCMDecrypt, add_timer, base64_decode, base64_encode, get_boss, get_options from kitty.types import run_once from .utils import log_error @@ -248,14 +248,6 @@ def serialized_to_field_map() -> Dict[bytes, 'Field[Any]']: return ans -def b64decode(val: memoryview) -> bytes: - extra = len(val) % 4 - if extra != 0: - padding = b'=' * (4 - extra) - val = memoryview(bytes(val) + padding) - return base64.standard_b64decode(val) - - @dataclass class FileTransmissionCommand: @@ -344,12 +336,12 @@ class FileTransmissionCommand: if issubclass(field.type, Enum): setattr(ans, field.name, field.type[decode_utf8_buffer(val)]) elif field.type is bytes: - setattr(ans, field.name, b64decode(val)) + setattr(ans, field.name, base64_decode(val)) elif field.type is int: setattr(ans, field.name, int(val)) elif field.type is str: if field.metadata.get('base64'): - sval = b64decode(val).decode('utf-8') + sval = base64_decode(val).decode('utf-8') else: sval = safe_string(decode_utf8_buffer(val)) setattr(ans, field.name, safe_string(sval))