mirror of
https://github.com/kovidgoyal/kitty.git
synced 2026-07-01 21:36:09 +00:00
Implement specialisation for slang shaders
This commit is contained in:
parent
c49dcf9fca
commit
5bc8cfaaf5
3 changed files with 106 additions and 29 deletions
|
|
@ -3,6 +3,13 @@
|
|||
// Distributed under terms of the GPLv3 license.
|
||||
|
||||
import blit;
|
||||
import alpha_blend;
|
||||
import utils;
|
||||
|
||||
extern static const bool is_alpha_mask = false;
|
||||
extern static const bool texture_is_not_premultiplied = false;
|
||||
// specialize: alpha_mask: is_alpha_mask=true
|
||||
// specialize: premult: texture_is_not_premultiplied=true
|
||||
|
||||
|
||||
struct VSOutput
|
||||
|
|
@ -25,9 +32,23 @@ VSOutput vertex_main(
|
|||
return output;
|
||||
}
|
||||
|
||||
uniform Sampler2D image;
|
||||
|
||||
[shader("fragment")]
|
||||
float4 fragment_main(float2 texcoord : TEXCOORD, uniform Sampler2D image) : SV_Target
|
||||
{
|
||||
float4 fragment_main(
|
||||
float2 texcoord : TEXCOORD,
|
||||
uniform float3 amask_fg,
|
||||
uniform float4 amask_bg_premult,
|
||||
uniform float extra_alpha
|
||||
) : SV_Target {
|
||||
float4 color = image.Sample(texcoord);
|
||||
if (is_alpha_mask) {
|
||||
color = float4(amask_fg, color.r);
|
||||
color = vec4_premul(color);
|
||||
color = alpha_blend_premul(color, amask_bg_premult);
|
||||
} else color.a *= extra_alpha;
|
||||
|
||||
if (texture_is_not_premultiplied) color = vec4_premul(color);
|
||||
|
||||
return color;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,12 +25,19 @@ class EntryPoint(NamedTuple):
|
|||
name: str
|
||||
|
||||
|
||||
class Specialization(NamedTuple):
|
||||
name: str
|
||||
variables: dict[str, str]
|
||||
|
||||
|
||||
class SlangFile(NamedTuple):
|
||||
path: str
|
||||
text: str
|
||||
imports: frozenset[str]
|
||||
entry_points: frozenset[EntryPoint]
|
||||
module: str
|
||||
specializable_variables: dict[str, str]
|
||||
specializations: tuple[Specialization, ...]
|
||||
|
||||
@property
|
||||
def should_compile_to_ir(self) -> bool:
|
||||
|
|
@ -42,9 +49,20 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
|
|||
entry_points, imports = [], set()
|
||||
module = ''
|
||||
found_entry_point = ''
|
||||
specializable_variables = {}
|
||||
specializations = []
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'):
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith('// specialize: '):
|
||||
var, sep, spec = line.partition(':')[2].strip().partition(':')
|
||||
variables = {}
|
||||
for x in spec.split():
|
||||
name, sep, val = x.partition('=')
|
||||
variables[name] = val
|
||||
specializations.append(Specialization(var, variables))
|
||||
if line.startswith('//'):
|
||||
continue
|
||||
words = line.split()
|
||||
if found_entry_point:
|
||||
|
|
@ -66,11 +84,16 @@ def parse_slang_text(text: str, path: str = '') -> SlangFile:
|
|||
module = words[1].removesuffix(';')
|
||||
case 'import':
|
||||
imports.add(words[1].removesuffix(';'))
|
||||
case 'extern':
|
||||
if len(words) > 3 and words[1:3] == ['static', 'const']:
|
||||
specializable_variables[line.partition('=')[0].split()[-1]] = line
|
||||
case _:
|
||||
if words[0].startswith('[shader('): # ])
|
||||
text = words[0].partition('(')[2].partition(')')[0].strip()
|
||||
found_entry_point = text[1:-1]
|
||||
return SlangFile(path, text, frozenset(imports), frozenset(entry_points), module)
|
||||
return SlangFile(
|
||||
path, text, frozenset(imports), frozenset(entry_points), module, specializable_variables,
|
||||
tuple(specializations))
|
||||
|
||||
|
||||
@lru_cache(4096)
|
||||
|
|
@ -148,7 +171,7 @@ class Command(NamedTuple):
|
|||
|
||||
|
||||
def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, output_dirpath: str) -> Iterator[Command]:
|
||||
cmdbase = list(slangc)
|
||||
cmdbase = list(slangc) + ['-warnings-as-errors', 'all']
|
||||
for name, sfile in sources.items():
|
||||
if sfile.should_compile_to_ir:
|
||||
parts = name.split('.')
|
||||
|
|
@ -164,7 +187,7 @@ def commands_to_compile_dir_to_ir(sources: dict[str, SlangFile], src_dir: str, o
|
|||
|
||||
|
||||
def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[tuple[str, str, list[str], SlangFile]]:
|
||||
cmdbase = list(slangc)
|
||||
cmdbase = list(slangc) + ['-warnings-as-errors', 'all']
|
||||
for name, sfile in sources.items():
|
||||
if not sfile.entry_points:
|
||||
continue
|
||||
|
|
@ -176,30 +199,40 @@ def iter_entry_point_shaders(sources: dict[str, SlangFile], build_dir: str, dest
|
|||
|
||||
|
||||
def commands_to_compile_to_spirv(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_files: list[str]) -> Iterator[Command]:
|
||||
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
|
||||
dest = f'{base_dest}.spv'
|
||||
cmd += ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name', '-o', dest]
|
||||
output_mtime = safe_mtime(dest)
|
||||
module_mtime = os.path.getmtime(slang_module)
|
||||
needs_build = output_mtime < module_mtime
|
||||
if needs_build:
|
||||
built_files.append(dest)
|
||||
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to SPIR-V ...', cmd)
|
||||
base_cmd = ['-target', 'spirv', '-capability', 'vk_mem_model', '-fvk-use-entrypoint-name']
|
||||
for base_dest, slang_module, scmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
|
||||
for x in (Specialization('', {}),) + sfile.specializations:
|
||||
cmd = list(scmd)
|
||||
dest = f'{base_dest}.{x.name}.spv' if x.name else f'{base_dest}.spv'
|
||||
if x.name:
|
||||
cmd.insert(-1, f'{base_dest}.{x.name}.slang-module')
|
||||
cmd += base_cmd + ['-o', dest]
|
||||
output_mtime = safe_mtime(dest)
|
||||
module_mtime = os.path.getmtime(slang_module)
|
||||
needs_build = output_mtime < module_mtime
|
||||
if needs_build:
|
||||
built_files.append(dest)
|
||||
yield Command(needs_build, f'Linking |{os.path.basename(dest)}| ...', cmd)
|
||||
|
||||
|
||||
|
||||
def commands_to_compile_to_glsl(sources: dict[str, SlangFile], build_dir: str, dest_dir: str, built_glsl_files: list[str]) -> Iterator[Command]:
|
||||
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
|
||||
module_mtime = os.path.getmtime(slang_module)
|
||||
extra_cmd = ['-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330']
|
||||
for ep in sfile.entry_points:
|
||||
c = list(cmd)
|
||||
c.extend(('-line-directive-mode', 'none', '-target', 'glsl', '-profile', 'glsl_330'))
|
||||
dest = f'{base_dest}.{ep.stage.name}.glsl'
|
||||
c += ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest]
|
||||
output_mtime = safe_mtime(dest)
|
||||
needs_build = output_mtime < module_mtime
|
||||
if needs_build:
|
||||
built_glsl_files.append(dest)
|
||||
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c)
|
||||
for sp in (Specialization('', {}),) + sfile.specializations:
|
||||
dest = f'{base_dest}.{ep.stage.name}.glsl'
|
||||
c = list(cmd)
|
||||
if sp.name:
|
||||
dest = f'{base_dest}.{sp.name}.{ep.stage.name}.glsl'
|
||||
c.insert(-1, f'{base_dest}.{sp.name}.slang-module')
|
||||
c += extra_cmd + ['-entry', ep.name, '-stage', ep.stage.name, '-o', dest]
|
||||
output_mtime = safe_mtime(dest)
|
||||
needs_build = output_mtime < module_mtime
|
||||
if needs_build:
|
||||
built_glsl_files.append(dest)
|
||||
yield Command(needs_build, f'Linking |{os.path.basename(slang_module)}| to GLSL {ep.stage} shader ...', c)
|
||||
|
||||
|
||||
def fixup_opengl_code(glsl_code: str) -> str:
|
||||
|
|
@ -286,6 +319,27 @@ def copy_files_preserving_structure(source_dir: str, dest_dir: str, extension: s
|
|||
shutil.copy2(file_path, target_path)
|
||||
|
||||
|
||||
def create_specialisations(sources: dict[str, SlangFile], build_dir: str, dest_dir: str) -> Iterator[Command]:
|
||||
for base_dest, slang_module, cmd, sfile in iter_entry_point_shaders(sources, build_dir, dest_dir):
|
||||
if sfile.entry_points and sfile.specializations:
|
||||
for sp in sfile.specializations:
|
||||
dest = f'{base_dest}.{sp.name}.slang'
|
||||
lines = []
|
||||
for key, val in sp.variables.items():
|
||||
declaration = sfile.specializable_variables[key].rpartition('=')[0]
|
||||
declaration = declaration.replace('extern ', 'export ', 1)
|
||||
lines.append(f'{declaration} = {val};')
|
||||
payload = '\n'.join(lines)
|
||||
needs_build = True
|
||||
with suppress(FileNotFoundError), open(dest) as f:
|
||||
needs_build = f.read() != payload
|
||||
if needs_build:
|
||||
with open(dest, 'w') as fw:
|
||||
fw.write(payload)
|
||||
yield Command(needs_build, f'Compiling specialisation |{os.path.basename(dest)}|| ...',
|
||||
list(slangc) + [dest, '-o', dest + '-module'])
|
||||
|
||||
|
||||
def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: ParallelRun) -> None:
|
||||
src_dir = os.path.abspath('kitty/shaders')
|
||||
source_tree = get_ordered_sources_in_tree(src_dir)
|
||||
|
|
@ -293,6 +347,8 @@ def compile_builtin_shaders(build_dir: str, dest_dir: str, parallel_run: Paralle
|
|||
parallel_run(commands_to_compile_dir_to_ir(source_tree, src_dir, build_dir))
|
||||
# Copy IR to dest_dir
|
||||
copy_files_preserving_structure(build_dir, dest_dir, '.slang-module')
|
||||
# Create the specializations
|
||||
parallel_run(create_specialisations(source_tree, build_dir, dest_dir))
|
||||
# Now Vulkan shaders
|
||||
built_spirv_files: list[str] = []
|
||||
spirv_commands = commands_to_compile_to_spirv(source_tree, build_dir, dest_dir, built_spirv_files)
|
||||
|
|
|
|||
|
|
@ -5,30 +5,30 @@
|
|||
module utils;
|
||||
|
||||
// Return 0 if x < 1 otherwise 1
|
||||
__generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
public __generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
vector<T, N> zero_or_one(vector<T, N> x) {
|
||||
return step((vector<T, N>)1.0f, x);
|
||||
}
|
||||
|
||||
// condition must be zero or one. When 1 thenval is returned otherwise elseval
|
||||
__generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
public __generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
vector<T, N> if_one_then(vector<T, N> condition, vector<T, N> thenval, vector<T, N> elseval) {
|
||||
return lerp(elseval, thenval, condition);
|
||||
}
|
||||
|
||||
// a < b ? thenval : elseval
|
||||
__generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
public __generic<T : __BuiltinFloatingPointType, int N = 1>
|
||||
vector<T, N> if_less_than(vector<T, N> a, vector<T, N> b, vector<T, N> thenval, vector<T, N> elseval) {
|
||||
return lerp(thenval, elseval, step(b, a));
|
||||
}
|
||||
|
||||
// Replaces vec4(rgb * a, a)
|
||||
float4 vec4_premul(float3 rgb, float a) {
|
||||
public float4 vec4_premul(float3 rgb, float a) {
|
||||
return float4(rgb * a, a);
|
||||
}
|
||||
|
||||
// Overloaded variation replacing vec4(rgba.rgb * rgba.a, rgba.a)
|
||||
float4 vec4_premul(float4 rgba) {
|
||||
public float4 vec4_premul(float4 rgba) {
|
||||
return float4(rgba.rgb * rgba.a, rgba.a);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue