diff --git a/kitty/shaders/graphics.slang b/kitty/shaders/graphics.slang index 5d1832cc9..28703532f 100644 --- a/kitty/shaders/graphics.slang +++ b/kitty/shaders/graphics.slang @@ -3,6 +3,13 @@ // Distributed under terms of the GPLv3 license. import blit; +import alpha_blend; +import utils; + +extern static const bool is_alpha_mask = false; +extern static const bool texture_is_not_premultiplied = false; +// specialize: alpha_mask: is_alpha_mask=true +// specialize: premult: texture_is_not_premultiplied=true struct VSOutput @@ -25,9 +32,23 @@ VSOutput vertex_main( return output; } +uniform Sampler2D image; + [shader("fragment")] -float4 fragment_main(float2 texcoord : TEXCOORD, uniform Sampler2D image) : SV_Target -{ +float4 fragment_main( + float2 texcoord : TEXCOORD, + uniform float3 amask_fg, + uniform float4 amask_bg_premult, + uniform float extra_alpha +) : SV_Target { float4 color = image.Sample(texcoord); + if (is_alpha_mask) { + color = float4(amask_fg, color.r); + color = vec4_premul(color); + color = alpha_blend_premul(color, amask_bg_premult); + } else color.a *= extra_alpha; + + if (texture_is_not_premultiplied) color = vec4_premul(color); + return color; } diff --git a/kitty/shaders/slang.py b/kitty/shaders/slang.py index d5b3d551f..8e64eeebd 100644 --- a/kitty/shaders/slang.py +++ b/kitty/shaders/slang.py @@ -25,12 +25,19 @@ class EntryPoint(NamedTuple): name: str +class Specialization(NamedTuple): + name: str + variables: dict[str, str] + + class SlangFile(NamedTuple): path: str text: str imports: frozenset[str] entry_points: frozenset[EntryPoint] module: str + specializable_variables: dict[str, str] + specializations: tuple[Specialization, ...] @property def should_compile_to_ir(self) -> bool: @@ -42,9 +49,20 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile: entry_points, imports = [], set() module = '' found_entry_point = '' + specializable_variables = {} + specializations = [] for line in text.splitlines(): line = line.strip() - if not line or line.startswith('//'): + if not line: + continue + if line.startswith('// specialize: '): + var, sep, spec = line.partition(':')[2].strip().partition(':') + variables = {} + for x in spec.split(): + name, sep, val = x.partition('=') + variables[name] = val + specializations.append(Specialization(var, variables)) + if line.startswith('//'): continue words = line.split() if found_entry_point: @@ -66,11 +84,16 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile: module = words[1].removesuffix(';') case 'import': imports.add(words[1].removesuffix(';')) + case 'extern': + if len(words) > 3 and words[1:3] == ['static', 'const']: + specializable_variables[line.partition('=')[0].split()[-1]] = line case _: if words[0].startswith('[shader('): # ]) text = words[0].partition('(')[2].partition(')')[0].strip() found_entry_point = text[1:-1] - return SlangFile(path, text, frozenset(imports), frozenset(entry_points), module) + return SlangFile( + path, text, frozenset(imports), frozenset(entry_points), module, specializable_variables, + tuple(specializations)) @lru_cache(4096) @@ -148,7 +171,7 @@ class Command(NamedTuple): def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, output_dirpath: str) -> Iterator[Command]: - cmdbase = list(slangc) + cmdbase = list(slangc) + ['-warnings-as-errors', 'all'] for name, sfile in sources.items(): if sfile.should_compile_to_ir: parts = name.split('.') @@ -164,7 +187,7 @@ def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, o def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[tuple[str, str, list[str], SlangFile]]: - cmdbase = list(slangc) + cmdbase = list(slangc) + ['-warnings-as-errors', 'all'] for name, sfile in sources.items(): if not sfile.entry_points: continue @@ -176,30 +199,40 @@ def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]: - for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): - dest = f'{base_dest}.spv' - cmd += ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name', '-o', dest] - output_mtime = safe_mtime(dest) - module_mtime = os.path.getmtime(slang_module) - needs_build = output_mtime < module_mtime - if needs_build: - built_files.append(dest) - yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to SPIR-V ...', cmd) + base_cmd = ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name'] + for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): + for x in (Specialization('', {}),) + sfile.specializations: + cmd = list(scmd) + dest = f'{base_dest}.{x.name}.spv' if x.name else f'{base_dest}.spv' + if x.name: + cmd.insert(-1, f'{base_dest}.{x.name}.slang-module') + cmd += base_cmd + ['-o', dest] + output_mtime = safe_mtime(dest) + module_mtime = os.path.getmtime(slang_module) + needs_build = output_mtime < module_mtime + if needs_build: + built_files.append(dest) + yield Command(needs_build, f'Linking |{os.path.basename(dest)}| ...', cmd) + def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]: for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): module_mtime = os.path.getmtime(slang_module) + extra_cmd = ['-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330'] for ep in sfile.entry_points: - c = list(cmd) - c.extend(('-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330')) - dest = f'{base_dest}.{ep.stage.name}.glsl' - c += ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest] - output_mtime = safe_mtime(dest) - needs_build = output_mtime < module_mtime - if needs_build: - built_glsl_files.append(dest) - yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c) + for sp in (Specialization('', {}),) + sfile.specializations: + dest = f'{base_dest}.{ep.stage.name}.glsl' + c = list(cmd) + if sp.name: + dest = f'{base_dest}.{sp.name}.{ep.stage.name}.glsl' + c.insert(-1, f'{base_dest}.{sp.name}.slang-module') + c += extra_cmd + ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest] + output_mtime = safe_mtime(dest) + needs_build = output_mtime < module_mtime + if needs_build: + built_glsl_files.append(dest) + yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c) def fixup_opengl_code(glsl_code: str) -> str: @@ -286,6 +319,27 @@ def copy_files_preserving_structure(source_dir: str, dest_dir: str, extension: s shutil.copy2(file_path, target_path) +def create_specialisations(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[Command]: + for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir): + if sfile.entry_points and sfile.specializations: + for sp in sfile.specializations: + dest = f'{base_dest}.{sp.name}.slang' + lines = [] + for key, val in sp.variables.items(): + declaration = sfile.specializable_variables[key].rpartition('=')[0] + declaration = declaration.replace('extern ', 'export ', 1) + lines.append(f'{declaration} = {val};') + payload = '\n'.join(lines) + needs_build = True + with suppress(FileNotFoundError), open(dest) as f: + needs_build = f.read() != payload + if needs_build: + with open(dest, 'w') as fw: + fw.write(payload) + yield Command(needs_build, f'Compiling specialisation |{os.path.basename(dest)}|| ...', + list(slangc) + [dest, '-o', dest + '-module']) + + def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: ParallelRun) -> None: src_dir = os.path.abspath('kitty/shaders') source_tree = get_ordered_sources_in_tree(src_dir) @@ -293,6 +347,8 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle parallel_run(commands_to_compile_dir_to_ir(source_tree, src_dir, build_dir)) # Copy IR to dest_dir copy_files_preserving_structure(build_dir, dest_dir, '.slang-module') + # Create the specializations + parallel_run(create_specialisations(source_tree, build_dir, dest_dir)) # Now Vulkan shaders built_spirv_files: list[str] = [] spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files) diff --git a/kitty/shaders/utils.slang b/kitty/shaders/utils.slang index 9b5673aca..0d2b03e9f 100644 --- a/kitty/shaders/utils.slang +++ b/kitty/shaders/utils.slang @@ -5,30 +5,30 @@ module utils; // Return 0 if x < 1 otherwise 1 -__generic +public __generic vector zero_or_one(vector x) { return step((vector)1.0f, x); } // condition must be zero or one. When 1 thenval is returned otherwise elseval -__generic +public __generic vector if_one_then(vector condition, vector thenval, vector elseval) { return lerp(elseval, thenval, condition); } // a < b ? thenval : elseval -__generic +public __generic vector if_less_than(vector a, vector b, vector thenval, vector elseval) { return lerp(thenval, elseval, step(b, a)); } // Replaces vec4(rgb * a, a) -float4 vec4_premul(float3 rgb, float a) { +public float4 vec4_premul(float3 rgb, float a) { return float4(rgb * a, a); } // Overloaded variation replacing vec4(rgba.rgb * rgba.a, rgba.a) -float4 vec4_premul(float4 rgba) { +public float4 vec4_premul(float4 rgba) { return float4(rgba.rgb * rgba.a, rgba.a); }