From 7029a35df9b4ba2cc35fcfc6efa68d91c4fc10e8 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 6 Jun 2025 13:28:58 +0530 Subject: [PATCH] Port subseq tests to fzf --- tools/fzf/algo_test.go | 100 +++++++++++++++++++++++++++++++++++++++++ tools/fzf/types.go | 30 ++++++------- 2 files changed, 115 insertions(+), 15 deletions(-) diff --git a/tools/fzf/algo_test.go b/tools/fzf/algo_test.go index a5a5a2f7a..330eb2769 100644 --- a/tools/fzf/algo_test.go +++ b/tools/fzf/algo_test.go @@ -1,9 +1,16 @@ package fzf import ( + "cmp" "fmt" "sort" + "strconv" + "strings" "testing" + + gcmp "github.com/google/go-cmp/cmp" + + "github.com/kovidgoyal/kitty/tools/utils" ) var _ = fmt.Print @@ -87,4 +94,97 @@ func TestFZFAlgo(t *testing.T) { assertMatch(t, fn, "Foo Bar Baz", "fbb", -1, -1, 0) assertMatch(t, fn, "fooBarbaz", "fooBarbazz", -1, -1, 0) } + + var positions [][]int + sort_by_score := false + fn.Case_sensitive = false + + simple := func(items, query string, expected ...string) { + ilist := utils.Splitlines(items) + matches, err := fn.Score(ilist, query) + if err != nil { + t.Fatal(err) + } + if sort_by_score { + slist := make([]int, len(matches)) + for i := range len(slist) { + slist[i] = i + } + utils.StableSort(slist, func(a, b int) int { + return cmp.Compare(matches[b].Score, matches[a].Score) + }) + nlist, nmatches := make([]string, len(ilist)), make([]Result, len(matches)) + for i, j := range slist { + nlist[i] = ilist[j] + nmatches[i] = matches[j] + } + ilist = nlist + matches = nmatches + } + actual := make([]string, 0, len(matches)) + actual_positions := make([][]int, 0, len(matches)) + for i, m := range matches { + if m.Score > 0 { + sort.Ints(m.Positions) + actual = append(actual, ilist[i]) + actual_positions = append(actual_positions, m.Positions) + } + } + if expected == nil { + expected = []string{} + } + + if diff := gcmp.Diff(expected, actual); diff != "" { + t.Fatalf("Failed for items: %#v\nQuery: %#v\nMatches: %#v\n%s", ilist, query, matches, diff) + } + if positions != nil { + if diff := gcmp.Diff(positions, actual_positions); diff != "" { + t.Fatalf("Failed positions for items: %v\n%s", ilist, diff) + } + positions = nil + } + } + simple("test\nxyz", "te", "test") + simple("abc\nxyz", "ba") + simple("abc\n123", "abc", "abc") + simple("test\nxyz", "Te", "test") + simple("test\nxyz", "XY", "xyz") + simple("test\nXYZ", "xy", "XYZ") + simple("test\nXYZ", "mn") + + positions = [][]int{{0, 2}, {0, 1}} + simple("abc\nac", "ac", "abc", "ac") + positions = [][]int{{0}} + simple("abc\nv", "a", "abc") + positions = [][]int{{1, 3}} + simple("汉a字b\nxyz", "ab", "汉a字b") + + sort_by_score = true + // Match at start + simple("archer\nelementary", "e", "elementary", "archer") + // Match at level factor + simple("xxxy\nxx/y", "y", "xx/y", "xxxy") + // CamelCase + simple("xxxy\nxxxY", "y", "xxxY", "xxxy") + // Distance + simple("abbc\nabc", "ac", "abc", "abbc") + // Extreme chars + simple("xxa\naxx", "a", "axx", "xxa") + // Highest score + positions = [][]int{{3}} + simple("xa/a", "a", "xa/a") + + sort_by_score = false + items := make([]string, 256) + + for i := range items { + items[i] = strconv.Itoa(i) + } + expected := make([]string, 0, len(items)) + for _, x := range items { + if strings.ContainsRune(x, rune('2')) { + expected = append(expected, x) + } + } + simple(strings.Join(items, "\n"), "2", expected...) } diff --git a/tools/fzf/types.go b/tools/fzf/types.go index c9bae90f0..3e52d3c8c 100644 --- a/tools/fzf/types.go +++ b/tools/fzf/types.go @@ -20,25 +20,25 @@ type Chars struct { runes []rune } -const ( - overflow64 uint64 = 0x8080808080808080 - overflow32 uint32 = 0x80808080 -) - func check_ascii(bytes []byte) (ascii_until int) { + slen := len(bytes) + // Process 8 bytes at a time i := 0 - for ; i <= len(bytes)-8; i += 8 { - if (overflow64 & *(*uint64)(unsafe.Pointer(&bytes[i]))) > 0 { - return i + for ; i+8 <= slen; i += 8 { + v := *(*uint64)(unsafe.Pointer(&bytes[i])) + // If any byte has its high bit set, v & 0x8080808080808080 != 0 + if v&0x8080808080808080 != 0 { + // At least one non-ASCII byte in this chunk, find which + for j := range 8 { + if bytes[i+j]&utf8.RuneSelf != 0 { + return i + j + } + } } } - for ; i <= len(bytes)-4; i += 4 { - if (overflow32 & *(*uint32)(unsafe.Pointer(&bytes[i]))) > 0 { - return i - } - } - for ; i < len(bytes); i++ { - if bytes[i] >= utf8.RuneSelf { + // Handle remaining bytes + for ; i < slen; i++ { + if bytes[i] > 127 { return i } }