More work on UTF-8 SIMD decode

This commit is contained in:
Kovid Goyal 2024-01-07 20:06:28 +05:30
parent 37c05e3212
commit 8e2d448c5c
No known key found for this signature in database
GPG key ID: 06BC317B515ACE7C

View file

@ -27,7 +27,7 @@ _Pragma("clang diagnostic pop")
#define count_trailing_zeros __builtin_ctz
#if BITS == 128
#define set1_epi8 simde_mm_set1_epi8
#define set1_epi8(x) simde_mm_set1_epi8((char)(x))
#define add_epi8 simde_mm_add_epi8
#define load_unaligned simde_mm_loadu_si128
#define store_aligned simde_mm_store_si128
@ -35,15 +35,18 @@ _Pragma("clang diagnostic pop")
#define cmplt_epi8 simde_mm_cmplt_epi8
#define or_si simde_mm_or_si128
#define and_si simde_mm_and_si128
#define andnot_si simde_mm_andnot_si128
#define movemask_epi8 simde_mm_movemask_epi8
#define extract_lower_quarter_as_chars simde_mm_cvtepu8_epi32
#define shift_left_by_chars simde_mm_slli_epi32
#define shift_left_by_bytes simde_mm_slli_si128
#define shift_right_by_bytes simde_mm_slli_si128
#define blendv_epi8 simde_mm_blendv_epi8
#define shift_left_by_bits16 _mm_slli_epi16
#define shift_right_by_bits32 _mm_srli_epi32
// output[i] = MAX(0, a[i] - b[1i])
#define subtract_saturate_epu8 simde_mm_subs_epu8
#define create_zero_integer _mm_setzero_si128
#else
#define set1_epi8 simde_mm256_set1_epi8
#define set1_epi8(x) simde_mm256_set1_epi8((char)(x))
#define add_epi8 simde_mm256_add_epi8
#define load_unaligned simde_mm256_loadu_si256
#define store_aligned simde_mm256_store_si256
@ -51,12 +54,23 @@ _Pragma("clang diagnostic pop")
#define cmplt_epi8(a, b) simde_mm256_cmpgt_epi8(b, a)
#define or_si simde_mm256_or_si256
#define and_si simde_mm256_and_si256
#define andnot_si simde_mm256_andnot_si256
#define movemask_epi8 simde_mm256_movemask_epi8
#define extract_lower_half_as_chars simde_mm256_cvtepu8_epi32
#define shift_left_by_chars simde_mm256_slli_epi32
#define shift_left_by_bytes simde_mm256_slli_si256
#define blendv_epi8 simde_mm256_blendv_epi8
#define subtract_saturate_epu8 simde_mm256_subs_epu8
#define shift_left_by_bits16 _mm256_slli_epi16
#define shift_right_by_bits32 _mm256_srli_epi32
#define create_zero_integer _mm256_setzero_si256
static inline integer_t
shift_right_by_bytes(const integer_t vec, const unsigned amt) {
if (amt == 0) return vec;
if (amt < 16) return simde_mm256_alignr_epi8(vec, simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0)), 16 - amt);
if (amt == 16) return simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0));
if (amt < 32) return simde_mm256_slli_si256(simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(0, 0, 2, 0)), amt - 16);
return create_zero_integer();
}
#endif
static inline const uint8_t*
@ -78,13 +92,14 @@ FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const u
}
#define print_register_as_bytes(r) { \
alignas(64) uint8_t data[sizeof(r)]; \
store_aligned((integer_t*)data, r); \
for (unsigned i = 0; i < sizeof(integer_t); i++) { \
uint8_t ch = data[i]; \
if (' ' <= ch && ch < 0x7f) printf(" %c ", ch); else printf("%.2x ", ch); \
} \
printf("\n"); \
printf("%s:\n", #r); \
alignas(64) uint8_t data[sizeof(r)]; \
store_aligned((integer_t*)data, r); \
for (unsigned i = 0; i < sizeof(integer_t); i++) { \
uint8_t ch = data[i]; \
if (' ' <= ch && ch < 0x7f) printf(" %c ", ch); else printf("%.2x ", ch); \
} \
printf("\n"); \
}
static inline void
@ -144,32 +159,64 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
return sentinel_found;
}
#if 1
// Classify the bytes
integer_t state = set1_epi8((char)0x80);
integer_t state = set1_epi8(0x80);
integer_t vec_signed = add_epi8(vec, state);
printf("source:\n"); print_register_as_bytes(vec);
print_register_as_bytes(vec);
integer_t bytes_indicating_start_of_two_byte_sequence = cmplt_epi8(set1_epi8(0xc0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state , set1_epi8((char)0xc2), bytes_indicating_start_of_two_byte_sequence);
state = blendv_epi8(state, set1_epi8(0xc2), bytes_indicating_start_of_two_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 byte sequence and 0x80 on the rest
integer_t bytes_indicating_start_of_three_byte_sequence = cmplt_epi8(set1_epi8(0xe0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state , set1_epi8((char)0xe3), bytes_indicating_start_of_three_byte_sequence);
state = blendv_epi8(state, set1_epi8(0xe3), bytes_indicating_start_of_three_byte_sequence);
integer_t bytes_indicating_start_of_four_byte_sequence = cmplt_epi8(set1_epi8(0xf0 - 1 - 0x80), vec_signed);
state = blendv_epi8(state , set1_epi8((char)0xf4), bytes_indicating_start_of_four_byte_sequence);
state = blendv_epi8(state, set1_epi8(0xf4), bytes_indicating_start_of_four_byte_sequence);
// state now has 0xc2 on all bytes that start a 2 byte sequence, 0xe3 on start of 3-byte sequence, 0xf4 on 4-byte start and 0x80 on rest
print_register_as_bytes(state);
integer_t mask = and_si(state, set1_epi8((char)0xf8)); // keep upper 5 bits of state
printf("mask:\n"); print_register_as_bytes(mask);
integer_t mask = and_si(state, set1_epi8(0xf8)); // keep upper 5 bits of state
print_register_as_bytes(mask);
integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state
printf("count:\n"); print_register_as_bytes(count);
print_register_as_bytes(count);
// count contains 0 for ASCII and number of bytes in sequence for other bytes
#define subtract_shift_and_add(target, amt) add_epi8(target, shift_left_by_bytes(subtract_saturate_epu8(target, set1_epi8(amt)), amt))
#define subtract_shift_and_add(target, amt) add_epi8(target, shift_right_by_bytes(subtract_saturate_epu8(target, set1_epi8(amt)), amt))
// shift 02 bytes by 1 and subtract 1
integer_t counts = subtract_shift_and_add(count, 1);
// shift 03 and 04 bytes by 2 and subtract 2
counts = subtract_shift_and_add(counts, 2);
printf("counts:\n"); print_register_as_bytes(counts);
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
print_register_as_bytes(counts);
#undef subtract_shift_and_add
// Processing
// mask all control bits so that we have only useful bits left
print_register_as_bytes(vec);
vec = andnot_si(mask, vec);
print_register_as_bytes(vec);
// Now calculate the four output vectors
// The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2
// In addition, the ASCII bytes are copied unchanged from vec
integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, create_zero_integer()), vec);
print_register_as_bytes(vec_non_ascii);
integer_t vec_right1 = shift_right_by_bytes(vec_non_ascii, 1);
integer_t output1 = blendv_epi8(vec,
or_si(
vec, and_si(shift_left_by_bits16(vec_right1, 6), set1_epi8(0xc0))
),
cmpeq_epi8(counts, set1_epi8(1))
);
print_register_as_bytes(output1);
// The next byte is made up of 4 bits (5, 4, 3, 2) from locations with count == 2 and the first 4 bits from locations with count == 3
integer_t count2_locations = cmpeq_epi8(counts, set1_epi8(2));
integer_t output2 = and_si(vec, count2_locations);
output2 = shift_right_by_bits32(output2, 2); // selects the bits 5, 4, 3, 2
output2 = or_si(output2, and_si(shift_left_by_bits16(vec_right1, 4), set1_epi8(0xf0))); // move 4 bits left and mask lower four bits and OR
output2 = and_si(output2, count2_locations); // keep only the count2 bytes
print_register_as_bytes(output2);
// The last byte is made up of bits 5 and 6 from count == 3 and 3 bits from count == 4
#endif
return sentinel_found;
}
@ -183,15 +230,18 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
#undef cmplt_epi8
#undef or_si
#undef and_si
#undef andnot_si
#undef movemask_epi8
#undef CONCAT
#undef CONCAT_EXPAND
#undef BITS
#undef shift_left_by_chars
#undef shift_left_by_bytes
#undef shift_right_by_bytes
#undef shift_left_by_bits16
#undef shift_right_by_bits32
#undef shift_right_by_bytes128
#undef extract_lower_quarter_as_chars
#undef extract_lower_half_as_chars
#undef blendv_epi8
#undef add_epi8
#undef subtract_saturate_epu8
#undef create_zero_integer