Implement specialisation for slang shaders

This commit is contained in:
Kovid Goyal 2026-06-29 23:11:53 +05:30
parent c49dcf9fca
commit 5bc8cfaaf5
No known key found for this signature in database
GPG key ID: 06BC317B515ACE7C
3 changed files with 106 additions and 29 deletions

View file

@ -3,6 +3,13 @@
// Distributed under terms of the GPLv3 license.
import blit;
import alpha_blend;
import utils;
extern static const bool is_alpha_mask = false;
extern static const bool texture_is_not_premultiplied = false;
// specialize: alpha_mask: is_alpha_mask=true
// specialize: premult: texture_is_not_premultiplied=true
struct VSOutput
@ -25,9 +32,23 @@ VSOutput vertex_main(
return output;
}
uniform Sampler2D image;
[shader("fragment")]
float4 fragment_main(float2 texcoord : TEXCOORD, uniform Sampler2D image) : SV_Target
{
float4 fragment_main(
float2 texcoord : TEXCOORD,
uniform float3 amask_fg,
uniform float4 amask_bg_premult,
uniform float extra_alpha
) : SV_Target {
float4 color = image.Sample(texcoord);
if (is_alpha_mask) {
color = float4(amask_fg, color.r);
color = vec4_premul(color);
color = alpha_blend_premul(color, amask_bg_premult);
} else color.a *= extra_alpha;
if (texture_is_not_premultiplied) color = vec4_premul(color);
return color;
}

View file

@ -25,12 +25,19 @@ class EntryPoint(NamedTuple):
name: str
class Specialization(NamedTuple):
name: str
variables: dict[str, str]
class SlangFile(NamedTuple):
path: str
text: str
imports: frozenset[str]
entry_points: frozenset[EntryPoint]
module: str
specializable_variables: dict[str, str]
specializations: tuple[Specialization, ...]
@property
def should_compile_to_ir(self) -> bool:
@ -42,9 +49,20 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
entry_points, imports = [], set()
module = ''
found_entry_point = ''
specializable_variables = {}
specializations = []
for line in text.splitlines():
line = line.strip()
if not line or line.startswith('//'):
if not line:
continue
if line.startswith('// specialize: '):
var, sep, spec = line.partition(':')[2].strip().partition(':')
variables = {}
for x in spec.split():
name, sep, val = x.partition('=')
variables[name] = val
specializations.append(Specialization(var, variables))
if line.startswith('//'):
continue
words = line.split()
if found_entry_point:
@ -66,11 +84,16 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
module = words[1].removesuffix(';')
case 'import':
imports.add(words[1].removesuffix(';'))
case 'extern':
if len(words) > 3 and words[1:3] == ['static', 'const']:
specializable_variables[line.partition('=')[0].split()[-1]] = line
case _:
if words[0].startswith('[shader('): # ])
text = words[0].partition('(')[2].partition(')')[0].strip()
found_entry_point = text[1:-1]
return SlangFile(path, text, frozenset(imports), frozenset(entry_points), module)
return SlangFile(
path, text, frozenset(imports), frozenset(entry_points), module, specializable_variables,
tuple(specializations))
@lru_cache(4096)
@ -148,7 +171,7 @@ class Command(NamedTuple):
def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, output_dirpath: str) -> Iterator[Command]:
cmdbase = list(slangc)
cmdbase = list(slangc) + ['-warnings-as-errors', 'all']
for name, sfile in sources.items():
if sfile.should_compile_to_ir:
parts = name.split('.')
@ -164,7 +187,7 @@ def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, o
def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[tuple[str, str, list[str], SlangFile]]:
cmdbase = list(slangc)
cmdbase = list(slangc) + ['-warnings-as-errors', 'all']
for name, sfile in sources.items():
if not sfile.entry_points:
continue
@ -176,30 +199,40 @@ def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]:
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
dest = f'{base_dest}.spv'
cmd += ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name', '-o', dest]
output_mtime = safe_mtime(dest)
module_mtime = os.path.getmtime(slang_module)
needs_build = output_mtime < module_mtime
if needs_build:
built_files.append(dest)
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to SPIR-V ...', cmd)
base_cmd = ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name']
for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
for x in (Specialization('', {}),) + sfile.specializations:
cmd = list(scmd)
dest = f'{base_dest}.{x.name}.spv' if x.name else f'{base_dest}.spv'
if x.name:
cmd.insert(-1, f'{base_dest}.{x.name}.slang-module')
cmd += base_cmd + ['-o', dest]
output_mtime = safe_mtime(dest)
module_mtime = os.path.getmtime(slang_module)
needs_build = output_mtime < module_mtime
if needs_build:
built_files.append(dest)
yield Command(needs_build, f'Linking |{os.path.basename(dest)}| ...', cmd)
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
module_mtime = os.path.getmtime(slang_module)
extra_cmd = ['-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330']
for ep in sfile.entry_points:
c = list(cmd)
c.extend(('-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330'))
dest = f'{base_dest}.{ep.stage.name}.glsl'
c += ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest]
output_mtime = safe_mtime(dest)
needs_build = output_mtime < module_mtime
if needs_build:
built_glsl_files.append(dest)
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c)
for sp in (Specialization('', {}),) + sfile.specializations:
dest = f'{base_dest}.{ep.stage.name}.glsl'
c = list(cmd)
if sp.name:
dest = f'{base_dest}.{sp.name}.{ep.stage.name}.glsl'
c.insert(-1, f'{base_dest}.{sp.name}.slang-module')
c += extra_cmd + ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest]
output_mtime = safe_mtime(dest)
needs_build = output_mtime < module_mtime
if needs_build:
built_glsl_files.append(dest)
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c)
def fixup_opengl_code(glsl_code: str) -> str:
@ -286,6 +319,27 @@ def copy_files_preserving_structure(source_dir: str, dest_dir: str, extension: s
shutil.copy2(file_path, target_path)
def create_specialisations(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[Command]:
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
if sfile.entry_points and sfile.specializations:
for sp in sfile.specializations:
dest = f'{base_dest}.{sp.name}.slang'
lines = []
for key, val in sp.variables.items():
declaration = sfile.specializable_variables[key].rpartition('=')[0]
declaration = declaration.replace('extern ', 'export ', 1)
lines.append(f'{declaration} = {val};')
payload = '\n'.join(lines)
needs_build = True
with suppress(FileNotFoundError), open(dest) as f:
needs_build = f.read() != payload
if needs_build:
with open(dest, 'w') as fw:
fw.write(payload)
yield Command(needs_build, f'Compiling specialisation |{os.path.basename(dest)}|| ...',
list(slangc) + [dest, '-o', dest + '-module'])
def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: ParallelRun) -> None:
src_dir = os.path.abspath('kitty/shaders')
source_tree = get_ordered_sources_in_tree(src_dir)
@ -293,6 +347,8 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
parallel_run(commands_to_compile_dir_to_ir(source_tree, src_dir, build_dir))
# Copy IR to dest_dir
copy_files_preserving_structure(build_dir, dest_dir, '.slang-module')
# Create the specializations
parallel_run(create_specialisations(source_tree, build_dir, dest_dir))
# Now Vulkan shaders
built_spirv_files: list[str] = []
spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files)

View file

@ -5,30 +5,30 @@
module utils;
// Return 0 if x < 1 otherwise 1
__generic<T : __BuiltinFloatingPointType, int N = 1>
public __generic<T : __BuiltinFloatingPointType, int N = 1>
vector<T, N> zero_or_one(vector<T, N> x) {
return step((vector<T, N>)1.0f, x);
}
// condition must be zero or one. When 1 thenval is returned otherwise elseval
__generic<T : __BuiltinFloatingPointType, int N = 1>
public __generic<T : __BuiltinFloatingPointType, int N = 1>
vector<T, N> if_one_then(vector<T, N> condition, vector<T, N> thenval, vector<T, N> elseval) {
return lerp(elseval, thenval, condition);
}
// a < b ? thenval : elseval
__generic<T : __BuiltinFloatingPointType, int N = 1>
public __generic<T : __BuiltinFloatingPointType, int N = 1>
vector<T, N> if_less_than(vector<T, N> a, vector<T, N> b, vector<T, N> thenval, vector<T, N> elseval) {
return lerp(thenval, elseval, step(b, a));
}
// Replaces vec4(rgb * a, a)
float4 vec4_premul(float3 rgb, float a) {
public float4 vec4_premul(float3 rgb, float a) {
return float4(rgb * a, a);
}
// Overloaded variation replacing vec4(rgba.rgb * rgba.a, rgba.a)
float4 vec4_premul(float4 rgba) {
public float4 vec4_premul(float4 rgba) {
return float4(rgba.rgb * rgba.a, rgba.a);
}