Speed up base64 decode even more

This commit is contained in:
Kovid Goyal 2023-07-01 12:39:18 +05:30
parent aa86b98eee
commit d01da39dfb
No known key found for this signature in database
GPG key ID: 06BC317B515ACE7C
4 changed files with 30 additions and 39 deletions

View file

@ -27,8 +27,8 @@
bool decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz);
bool encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding);
#ifndef B64_INCLUDED_ONCE
static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz / 4) * 3 + 4; }
static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz / 3) * 4 + 5; }
static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz * 3) / 4 + 4; }
static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz * 4) / 3 + 5; }
#endif
#ifndef B64_INCLUDED_ONCE
@ -44,17 +44,17 @@ static uint8_t b64_decoding_table[256] = {
#endif
static void
inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, const size_t dest_sz) {
inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest) {
for (size_t i = 0, j = 0; i < src_sz;) {
uint32_t sextet_a = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_b = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_c = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_d = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_a = b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_b = b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_c = b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_d = b64_decoding_table[src[i++] & 0xff];
uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6);
if (j < dest_sz) dest[j++] = (triple >> 2 * 8) & 0xFF;
if (j < dest_sz) dest[j++] = (triple >> 1 * 8) & 0xFF;
if (j < dest_sz) dest[j++] = (triple >> 0 * 8) & 0xFF;
dest[j++] = (triple >> 2 * 8) & 0xFF;
dest[j++] = (triple >> 1 * 8) & 0xFF;
dest[j++] = (triple >> 0 * 8) & 0xFF;
}
}
@ -63,18 +63,17 @@ decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz) {
while (src_sz && src[src_sz-1] == '=') src_sz--; // remove trailing padding
if (!src_sz) { *dest_sz = 0; return true; }
const size_t dest_capacity = *dest_sz;
size_t extra = src_sz % 4;
*dest_sz = src_sz / 4;
size_t extra = src_sz - 4 * *dest_sz;
*dest_sz *= 3;
src_sz -= extra;
*dest_sz = (src_sz / 4) * 3;
if (*dest_sz > dest_capacity) return false;
if (src_sz) inner_func(src, src_sz, dest, *dest_sz);
if (extra > 1) {
if (*dest_sz + 4 > dest_capacity) return false;
if (src_sz) inner_func(src, src_sz, dest);
if (extra > 1 && extra < 4) { // < 4 is not needed but it helps compiler unroll the loop
INPUT_T buf[4] = {0};
for (size_t i = 0; i < extra; i++) buf[i] = src[src_sz+i];
dest += *dest_sz;
inner_func(buf, extra, dest + *dest_sz);
*dest_sz += extra - 1;
if (*dest_sz > dest_capacity) return false;
inner_func(buf, extra, dest, extra-1);
}
if (*dest_sz + 1 > dest_capacity) return false;
dest[*dest_sz] = 0; // ensure zero-terminated

View file

@ -79,24 +79,24 @@ redirect_std_streams(PyObject UNUSED *self, PyObject *args) {
static PyObject*
pybase64_encode(PyObject UNUSED *self, PyObject *args) {
int add_padding = 0;
const char *src; Py_ssize_t src_len;
if (!PyArg_ParseTuple(args, "y#|p", &src, &src_len, &add_padding)) return NULL;
size_t sz = required_buffer_size_for_base64_encode(src_len);
FREE_BUFFER_AFTER_FUNCTION Py_buffer view = {0};
if (!PyArg_ParseTuple(args, "s*|p", &view, &add_padding)) return NULL;
size_t sz = required_buffer_size_for_base64_encode(view.len);
PyObject *ans = PyBytes_FromStringAndSize(NULL, sz);
if (!ans) return NULL;
base64_encode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding);
base64_encode8(view.buf, view.len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding);
if (_PyBytes_Resize(&ans, sz) != 0) return NULL;
return ans;
}
static PyObject*
pybase64_decode(PyObject UNUSED *self, PyObject *args) {
const char *src; Py_ssize_t src_len;
if (!PyArg_ParseTuple(args, "y#", &src, &src_len)) return NULL;
size_t sz = required_buffer_size_for_base64_decode(src_len);
FREE_BUFFER_AFTER_FUNCTION Py_buffer view = {0};
if (!PyArg_ParseTuple(args, "s*", &view)) return NULL;
size_t sz = required_buffer_size_for_base64_decode(view.len);
PyObject *ans = PyBytes_FromStringAndSize(NULL, sz);
if (!ans) return NULL;
base64_decode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz);
base64_decode8(view.buf, view.len, (unsigned char*)PyBytes_AS_STRING(ans), &sz);
if (_PyBytes_Resize(&ans, sz) != 0) return NULL;
return ans;
}

View file

@ -1534,5 +1534,5 @@ def expand_ansi_c_escapes(test: str) -> str: ...
def update_tab_bar_edge_colors(os_window_id: int) -> bool: ...
def mask_kitty_signals_process_wide() -> None: ...
def is_modifier_key(key: int) -> bool: ...
def base64_encode(src: bytes, add_padding: bool = False) -> bytes: ...
def base64_decode(src: bytes) -> bytes: ...
def base64_encode(src: Union[bytes,str], add_padding: bool = False) -> bytes: ...
def base64_decode(src: Union[bytes,str]) -> bytes: ...

View file

@ -20,7 +20,7 @@ from typing import IO, Any, Callable, DefaultDict, Deque, Dict, Iterable, Iterat
from kittens.transfer.librsync import LoadSignature, PatchFile, delta_for_file, signature_of_file
from kittens.transfer.utils import IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path
from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, AES256GCMDecrypt, add_timer, base64_encode, get_boss, get_options
from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, AES256GCMDecrypt, add_timer, base64_decode, base64_encode, get_boss, get_options
from kitty.types import run_once
from .utils import log_error
@ -248,14 +248,6 @@ def serialized_to_field_map() -> Dict[bytes, 'Field[Any]']:
return ans
def b64decode(val: memoryview) -> bytes:
extra = len(val) % 4
if extra != 0:
padding = b'=' * (4 - extra)
val = memoryview(bytes(val) + padding)
return base64.standard_b64decode(val)
@dataclass
class FileTransmissionCommand:
@ -344,12 +336,12 @@ class FileTransmissionCommand:
if issubclass(field.type, Enum):
setattr(ans, field.name, field.type[decode_utf8_buffer(val)])
elif field.type is bytes:
setattr(ans, field.name, b64decode(val))
setattr(ans, field.name, base64_decode(val))
elif field.type is int:
setattr(ans, field.name, int(val))
elif field.type is str:
if field.metadata.get('base64'):
sval = b64decode(val).decode('utf-8')
sval = base64_decode(val).decode('utf-8')
else:
sval = safe_string(decode_utf8_buffer(val))
setattr(ans, field.name, safe_string(sval))