From 8e2d448c5cb2fe175f0f85ec566a3480cbaf0410 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 7 Jan 2024 20:06:28 +0530 Subject: [PATCH] More work on UTF-8 SIMD decode --- kitty/simd-string-impl.h | 104 +++++++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 27 deletions(-) diff --git a/kitty/simd-string-impl.h b/kitty/simd-string-impl.h index a244c4c71..3baf0f0ab 100644 --- a/kitty/simd-string-impl.h +++ b/kitty/simd-string-impl.h @@ -27,7 +27,7 @@ _Pragma("clang diagnostic pop") #define count_trailing_zeros __builtin_ctz #if BITS == 128 -#define set1_epi8 simde_mm_set1_epi8 +#define set1_epi8(x) simde_mm_set1_epi8((char)(x)) #define add_epi8 simde_mm_add_epi8 #define load_unaligned simde_mm_loadu_si128 #define store_aligned simde_mm_store_si128 @@ -35,15 +35,18 @@ _Pragma("clang diagnostic pop") #define cmplt_epi8 simde_mm_cmplt_epi8 #define or_si simde_mm_or_si128 #define and_si simde_mm_and_si128 +#define andnot_si simde_mm_andnot_si128 #define movemask_epi8 simde_mm_movemask_epi8 #define extract_lower_quarter_as_chars simde_mm_cvtepu8_epi32 -#define shift_left_by_chars simde_mm_slli_epi32 -#define shift_left_by_bytes simde_mm_slli_si128 +#define shift_right_by_bytes simde_mm_slli_si128 #define blendv_epi8 simde_mm_blendv_epi8 +#define shift_left_by_bits16 _mm_slli_epi16 +#define shift_right_by_bits32 _mm_srli_epi32 // output[i] = MAX(0, a[i] - b[1i]) #define subtract_saturate_epu8 simde_mm_subs_epu8 +#define create_zero_integer _mm_setzero_si128 #else -#define set1_epi8 simde_mm256_set1_epi8 +#define set1_epi8(x) simde_mm256_set1_epi8((char)(x)) #define add_epi8 simde_mm256_add_epi8 #define load_unaligned simde_mm256_loadu_si256 #define store_aligned simde_mm256_store_si256 @@ -51,12 +54,23 @@ _Pragma("clang diagnostic pop") #define cmplt_epi8(a, b) simde_mm256_cmpgt_epi8(b, a) #define or_si simde_mm256_or_si256 #define and_si simde_mm256_and_si256 +#define andnot_si simde_mm256_andnot_si256 #define movemask_epi8 simde_mm256_movemask_epi8 #define extract_lower_half_as_chars simde_mm256_cvtepu8_epi32 -#define shift_left_by_chars simde_mm256_slli_epi32 -#define shift_left_by_bytes simde_mm256_slli_si256 #define blendv_epi8 simde_mm256_blendv_epi8 #define subtract_saturate_epu8 simde_mm256_subs_epu8 +#define shift_left_by_bits16 _mm256_slli_epi16 +#define shift_right_by_bits32 _mm256_srli_epi32 +#define create_zero_integer _mm256_setzero_si256 + +static inline integer_t +shift_right_by_bytes(const integer_t vec, const unsigned amt) { + if (amt == 0) return vec; + if (amt < 16) return simde_mm256_alignr_epi8(vec, simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0)), 16 - amt); + if (amt == 16) return simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0)); + if (amt < 32) return simde_mm256_slli_si256(simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0)), amt - 16); + return create_zero_integer(); +} #endif static inline const uint8_t* @@ -78,13 +92,14 @@ FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const u } #define print_register_as_bytes(r) { \ - alignas(64) uint8_t data[sizeof(r)]; \ - store_aligned((integer_t*)data, r); \ - for (unsigned i = 0; i < sizeof(integer_t); i++) { \ - uint8_t ch = data[i]; \ - if (' ' <= ch && ch < 0x7f) printf(" %c ", ch); else printf("%.2x ", ch); \ - } \ - printf("\n"); \ + printf("%s:\n", #r); \ + alignas(64) uint8_t data[sizeof(r)]; \ + store_aligned((integer_t*)data, r); \ + for (unsigned i = 0; i < sizeof(integer_t); i++) { \ + uint8_t ch = data[i]; \ + if (' ' <= ch && ch < 0x7f) printf(" %c ", ch); else printf("%.2x ", ch); \ + } \ + printf("\n"); \ } static inline void @@ -144,32 +159,64 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { return sentinel_found; } -#if 1 // Classify the bytes - integer_t state = set1_epi8((char)0x80); + integer_t state = set1_epi8(0x80); integer_t vec_signed = add_epi8(vec, state); - printf("source:\n"); print_register_as_bytes(vec); + print_register_as_bytes(vec); integer_t bytes_indicating_start_of_two_byte_sequence = cmplt_epi8(set1_epi8(0xc0 - 1 - 0x80), vec_signed); - state = blendv_epi8(state , set1_epi8((char)0xc2), bytes_indicating_start_of_two_byte_sequence); + state = blendv_epi8(state, set1_epi8(0xc2), bytes_indicating_start_of_two_byte_sequence); // state now has 0xc2 on all bytes that start a 2 byte sequence and 0x80 on the rest integer_t bytes_indicating_start_of_three_byte_sequence = cmplt_epi8(set1_epi8(0xe0 - 1 - 0x80), vec_signed); - state = blendv_epi8(state , set1_epi8((char)0xe3), bytes_indicating_start_of_three_byte_sequence); + state = blendv_epi8(state, set1_epi8(0xe3), bytes_indicating_start_of_three_byte_sequence); integer_t bytes_indicating_start_of_four_byte_sequence = cmplt_epi8(set1_epi8(0xf0 - 1 - 0x80), vec_signed); - state = blendv_epi8(state , set1_epi8((char)0xf4), bytes_indicating_start_of_four_byte_sequence); + state = blendv_epi8(state, set1_epi8(0xf4), bytes_indicating_start_of_four_byte_sequence); // state now has 0xc2 on all bytes that start a 2 byte sequence, 0xe3 on start of 3-byte sequence, 0xf4 on 4-byte start and 0x80 on rest print_register_as_bytes(state); - integer_t mask = and_si(state, set1_epi8((char)0xf8)); // keep upper 5 bits of state - printf("mask:\n"); print_register_as_bytes(mask); + integer_t mask = and_si(state, set1_epi8(0xf8)); // keep upper 5 bits of state + print_register_as_bytes(mask); integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state - printf("count:\n"); print_register_as_bytes(count); + print_register_as_bytes(count); // count contains 0 for ASCII and number of bytes in sequence for other bytes -#define subtract_shift_and_add(target, amt) add_epi8(target, shift_left_by_bytes(subtract_saturate_epu8(target, set1_epi8(amt)), amt)) +#define subtract_shift_and_add(target, amt) add_epi8(target, shift_right_by_bytes(subtract_saturate_epu8(target, set1_epi8(amt)), amt)) + // shift 02 bytes by 1 and subtract 1 integer_t counts = subtract_shift_and_add(count, 1); + // shift 03 and 04 bytes by 2 and subtract 2 counts = subtract_shift_and_add(counts, 2); - printf("counts:\n"); print_register_as_bytes(counts); + // counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes + print_register_as_bytes(counts); +#undef subtract_shift_and_add + // Processing + // mask all control bits so that we have only useful bits left + print_register_as_bytes(vec); + vec = andnot_si(mask, vec); + print_register_as_bytes(vec); + + // Now calculate the four output vectors + + // The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2 + // In addition, the ASCII bytes are copied unchanged from vec + integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, create_zero_integer()), vec); + print_register_as_bytes(vec_non_ascii); + integer_t vec_right1 = shift_right_by_bytes(vec_non_ascii, 1); + integer_t output1 = blendv_epi8(vec, + or_si( + vec, and_si(shift_left_by_bits16(vec_right1, 6), set1_epi8(0xc0)) + ), + cmpeq_epi8(counts, set1_epi8(1)) + ); + print_register_as_bytes(output1); + + // The next byte is made up of 4 bits (5, 4, 3, 2) from locations with count == 2 and the first 4 bits from locations with count == 3 + integer_t count2_locations = cmpeq_epi8(counts, set1_epi8(2)); + integer_t output2 = and_si(vec, count2_locations); + output2 = shift_right_by_bits32(output2, 2); // selects the bits 5, 4, 3, 2 + output2 = or_si(output2, and_si(shift_left_by_bits16(vec_right1, 4), set1_epi8(0xf0))); // move 4 bits left and mask lower four bits and OR + output2 = and_si(output2, count2_locations); // keep only the count2 bytes + print_register_as_bytes(output2); + + // The last byte is made up of bits 5 and 6 from count == 3 and 3 bits from count == 4 -#endif return sentinel_found; } @@ -183,15 +230,18 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { #undef cmplt_epi8 #undef or_si #undef and_si +#undef andnot_si #undef movemask_epi8 #undef CONCAT #undef CONCAT_EXPAND #undef BITS -#undef shift_left_by_chars -#undef shift_left_by_bytes +#undef shift_right_by_bytes +#undef shift_left_by_bits16 +#undef shift_right_by_bits32 #undef shift_right_by_bytes128 #undef extract_lower_quarter_as_chars #undef extract_lower_half_as_chars #undef blendv_epi8 #undef add_epi8 #undef subtract_saturate_epu8 +#undef create_zero_integer