10 KiB
TorchScript Wrapper Patterns for Nuke
Weight Loading Strategy
The most critical step. A previous attempt scored 2/10 because it wrote a DA v2-style wrapper for a DA v3 model.
Rule: Read the source code. List state_dict keys from a pretrained checkpoint. Match your wrapper's module hierarchy EXACTLY.
import torch
ckpt = torch.load("model.pt", map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
for k in sorted(ckpt.keys()):
print(f"{k:60s} {ckpt[k].shape}")
Then structure your wrapper's __init__ so that:
self.encoder.blocks[i].attn.qkv.weightmaps to the checkpoint'sencoder.blocks.i.attn.qkv.weight- Every submodule name matches exactly
Never rely on strict=False to mask mismatches. Always check:
missing, unexpected = model.load_state_dict(state_dict, strict=False)
assert len(missing) == 0, f"Missing keys: {missing[:5]}"
assert len(unexpected) == 0, f"Unexpected keys: {unexpected[:5]}"
Standard Wrapper Template
import torch
import torch.nn as nn
import torch.nn.functional as F
class NukeWrapper(torch.nn.Module):
"""TorchScript-compatible wrapper for Nuke Inference node."""
def __init__(self):
super().__init__()
# ALL attributes must be initialized here
self.img_mean: torch.Tensor = torch.zeros(3)
self.img_std: torch.Tensor = torch.ones(3)
self.encoder = EncoderNuke(...)
self.head = HeadNuke(...)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x comes from Nuke: [B, C, H, W], float32, sRGB [0,1]
# 1. normalize (ImageNet stats)
x = (x - self.img_mean[None, :, None, None]) \
/ self.img_std[None, :, None, None]
# 2. pad to patch-aligned size
h, w = x.shape[2], x.shape[3]
pad_h = (14 - h % 14) % 14
pad_w = (14 - w % 14) % 14
x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
# 3. encode
features = self.encoder(x)
# 4. decode
depth = self.head(features, x.shape[2], x.shape[3])
# 5. unpad
depth = depth[:, :, :h, :w]
return depth
Replacing xformers with Native PyTorch
xformers is NOT available in TorchScript. Replace with
torch.nn.functional.scaled_dot_product_attention:
# BEFORE (xformers - not TorchScript compatible)
# from xformers.ops import memory_efficient_attention
# attn_out = memory_efficient_attention(q, k, v)
# AFTER (native PyTorch, TorchScript safe)
class AttentionNuke(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(
B, N, 3, self.num_heads, self.head_dim
)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v)
attn_out = attn_out.transpose(1, 2).reshape(B, N, C)
return self.proj(attn_out)
Replacing einops with Native torch
einops rearrange is not TorchScript compatible. Common replacements:
# einops: rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
# native:
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
# einops: rearrange(x, 'b c h w -> b (h w) c')
# native:
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
# einops: rearrange(x, 'b n (h d) -> b h n d', h=num_heads)
# native:
x = x.reshape(B, N, num_heads, head_dim).transpose(1, 2)
LayerScale Pattern
Modern ViTs (DinoV2, DA3) use LayerScale -- a learnable per-channel scaling applied after attention and MLP. Missing this means weights won't load and model dynamics are wrong.
class LayerScaleNuke(nn.Module):
def __init__(self, dim: int, init_value: float = 1e-5):
super().__init__()
self.gamma = nn.Parameter(
init_value * torch.ones(dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.gamma * x
class TransformerBlockNuke(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = AttentionNuke(dim, num_heads)
self.ls1 = LayerScaleNuke(dim)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MlpNuke(dim)
self.ls2 = LayerScaleNuke(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
Compilation Script Pattern
import torch
import argparse
from wrapper import NukeWrapper
def compile_model(variant: str, output_path: str):
# 1. build wrapper
model = NukeWrapper(variant=variant)
# 2. load weights
ckpt_path = f"models/{variant}.pth"
state_dict = torch.load(ckpt_path, map_location="cpu")
missing, unexpected = model.load_state_dict(
state_dict, strict=False
)
print(f"Missing keys: {len(missing)}")
print(f"Unexpected keys: {len(unexpected)}")
assert len(missing) == 0, \
f"Weight mismatch! Missing: {missing[:10]}"
# 3. script the model (prefer script over trace)
model.eval()
scripted = torch.jit.script(model)
# 4. save
scripted.save(output_path)
print(f"Saved TorchScript model to {output_path}")
# 5. verify round-trip
loaded = torch.jit.load(output_path)
dummy = torch.randn(1, 3, 518, 518)
with torch.no_grad():
out_orig = model(dummy)
out_loaded = loaded(dummy)
diff = (out_orig - out_loaded).abs().max().item()
print(f"Round-trip max diff: {diff}")
assert diff < 1e-5, "Round-trip verification failed!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--variant", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
compile_model(args.variant, args.output)
Testing Patterns
Test TorchScript compatibility BEFORE trying to compile:
import torch
def test_scriptable(model_class, *args, **kwargs):
"""Verify a module can be scripted."""
model = model_class(*args, **kwargs)
try:
scripted = torch.jit.script(model)
print(f"OK: {model_class.__name__} is scriptable")
return scripted
except Exception as e:
print(f"FAIL: {model_class.__name__}: {e}")
raise
def test_output_shape(model, input_shape, expected_shape):
"""Verify output dimensions."""
dummy = torch.randn(*input_shape)
with torch.no_grad():
out = model(dummy)
assert out.shape == expected_shape, \
f"Expected {expected_shape}, got {out.shape}"
print(f"OK: output shape {out.shape}")
Docker Build Environment
For reproducible compilation with pinned dependencies:
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
WORKDIR /build
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENTRYPOINT ["python", "nuke_compile.py"]
requirements.txt should pin exact versions:
torch==2.2.0
torchvision==0.17.0
huggingface_hub>=0.20.0
safetensors>=0.4.0
nn.ModuleList Indexing
TorchScript cannot index nn.ModuleList with a variable.
zip() and enumerate() over ModuleLists also fail because
TorchScript can't statically determine the length.
The only reliable approach is manual unrolling with integer literals:
# BAD: variable index
for i in range(4):
x = self.projects[i](x) # FAILS
# BAD: zip over ModuleLists
for feat, proj in zip(feats, self.projects): # FAILS
x = proj(feat)
# GOOD: manually unrolled with literal indices
x0 = self.projects[0](feats[0])
x1 = self.projects[1](feats[1])
x2 = self.projects[2](feats[2])
x3 = self.projects[3](feats[3])
Verbose but it's the only way. TorchScript needs to resolve each ModuleList access at compile time.
F.interpolate Gotcha
TorchScript can't always infer torch.Size slices as tuple types.
Always be explicit, and declare size components as typed locals:
# BAD: may fail in TorchScript
F.interpolate(x, size=y.shape[2:], mode="bilinear")
# BAD: inline arithmetic can fail type inference
F.interpolate(x, size=(y.shape[2] * 2, y.shape[3] * 2), ...)
# GOOD: explicit int locals, then tuple
h: int = y.shape[2] * 2
w: int = y.shape[3] * 2
F.interpolate(
x, size=(h, w), mode="bilinear",
align_corners=True
)
Also for Optional size parameters, use Optional[Tuple[int, int]]
not just Tuple[int, int] = None:
# BAD: TorchScript chokes on the default
def forward(self, x: torch.Tensor, size: Tuple[int, int] = None):
# GOOD: proper Optional typing
def forward(self, x: torch.Tensor, size: Optional[Tuple[int, int]] = None):
Type Annotation Requirements
TorchScript requires consistent typing. Common pitfalls:
# BAD: variable changes type
x = None # NoneType
x = torch.zeros(3) # Tensor -- type changed!
# GOOD: use Optional or initialize correctly
x: Optional[torch.Tensor] = None
# BAD: bool in module attribute
self.use_bias = True # bool not reliably supported
# GOOD: use int
self.use_bias: int = 1 # 1 for true, 0 for false
Weight Loading Validation
Never trust strict=False to silently pass. Always validate:
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# filter out EXPECTED missing keys (baked constants, skipped modules)
expected_missing = ['mean', 'std', 'sky_', 'drop', 'q_norm', 'k_norm']
critical = [k for k in missing
if not any(skip in k for skip in expected_missing)]
if len(critical) > 0:
raise RuntimeError(
f"Weight mismatch! {len(critical)} keys missing: "
f"{critical[:10]}"
)
Add Identity/Dropout placeholders for upstream modules that exist in checkpoints but are functionally no-ops (Dropout(0.0), Identity):
# match upstream checkpoint keys even for no-op modules
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.attn_drop = nn.Dropout(0.0)
self.proj_drop = nn.Dropout(0.0)