# TorchScript Wrapper Patterns for Nuke ## Weight Loading Strategy The most critical step. A previous attempt scored 2/10 because it wrote a DA v2-style wrapper for a DA v3 model. **Rule**: Read the source code. List state_dict keys from a pretrained checkpoint. Match your wrapper's module hierarchy EXACTLY. ```python import torch ckpt = torch.load("model.pt", map_location="cpu") if "model" in ckpt: ckpt = ckpt["model"] for k in sorted(ckpt.keys()): print(f"{k:60s} {ckpt[k].shape}") ``` Then structure your wrapper's `__init__` so that: - `self.encoder.blocks[i].attn.qkv.weight` maps to the checkpoint's `encoder.blocks.i.attn.qkv.weight` - Every submodule name matches exactly **Never** rely on `strict=False` to mask mismatches. Always check: ```python missing, unexpected = model.load_state_dict(state_dict, strict=False) assert len(missing) == 0, f"Missing keys: {missing[:5]}" assert len(unexpected) == 0, f"Unexpected keys: {unexpected[:5]}" ``` ## Standard Wrapper Template ```python import torch import torch.nn as nn import torch.nn.functional as F class NukeWrapper(torch.nn.Module): """TorchScript-compatible wrapper for Nuke Inference node.""" def __init__(self): super().__init__() # ALL attributes must be initialized here self.img_mean: torch.Tensor = torch.zeros(3) self.img_std: torch.Tensor = torch.ones(3) self.encoder = EncoderNuke(...) self.head = HeadNuke(...) def forward(self, x: torch.Tensor) -> torch.Tensor: # x comes from Nuke: [B, C, H, W], float32, sRGB [0,1] # 1. normalize (ImageNet stats) x = (x - self.img_mean[None, :, None, None]) \ / self.img_std[None, :, None, None] # 2. pad to patch-aligned size h, w = x.shape[2], x.shape[3] pad_h = (14 - h % 14) % 14 pad_w = (14 - w % 14) % 14 x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect") # 3. encode features = self.encoder(x) # 4. decode depth = self.head(features, x.shape[2], x.shape[3]) # 5. unpad depth = depth[:, :, :h, :w] return depth ``` ## Replacing xformers with Native PyTorch xformers is NOT available in TorchScript. Replace with `torch.nn.functional.scaled_dot_product_attention`: ```python # BEFORE (xformers - not TorchScript compatible) # from xformers.ops import memory_efficient_attention # attn_out = memory_efficient_attention(q, k, v) # AFTER (native PyTorch, TorchScript safe) class AttentionNuke(nn.Module): def __init__(self, dim: int, num_heads: int): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape( B, N, 3, self.num_heads, self.head_dim ) q = qkv[:, :, 0].transpose(1, 2) k = qkv[:, :, 1].transpose(1, 2) v = qkv[:, :, 2].transpose(1, 2) attn_out = F.scaled_dot_product_attention(q, k, v) attn_out = attn_out.transpose(1, 2).reshape(B, N, C) return self.proj(attn_out) ``` ## Replacing einops with Native torch einops rearrange is not TorchScript compatible. Common replacements: ```python # einops: rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) # native: x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() # einops: rearrange(x, 'b c h w -> b (h w) c') # native: x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) # einops: rearrange(x, 'b n (h d) -> b h n d', h=num_heads) # native: x = x.reshape(B, N, num_heads, head_dim).transpose(1, 2) ``` ## LayerScale Pattern Modern ViTs (DinoV2, DA3) use LayerScale -- a learnable per-channel scaling applied after attention and MLP. Missing this means weights won't load and model dynamics are wrong. ```python class LayerScaleNuke(nn.Module): def __init__(self, dim: int, init_value: float = 1e-5): super().__init__() self.gamma = nn.Parameter( init_value * torch.ones(dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.gamma * x class TransformerBlockNuke(nn.Module): def __init__(self, dim: int, num_heads: int): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = AttentionNuke(dim, num_heads) self.ls1 = LayerScaleNuke(dim) self.norm2 = nn.LayerNorm(dim) self.mlp = MlpNuke(dim) self.ls2 = LayerScaleNuke(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x))) x = x + self.ls2(self.mlp(self.norm2(x))) return x ``` ## Compilation Script Pattern ```python import torch import argparse from wrapper import NukeWrapper def compile_model(variant: str, output_path: str): # 1. build wrapper model = NukeWrapper(variant=variant) # 2. load weights ckpt_path = f"models/{variant}.pth" state_dict = torch.load(ckpt_path, map_location="cpu") missing, unexpected = model.load_state_dict( state_dict, strict=False ) print(f"Missing keys: {len(missing)}") print(f"Unexpected keys: {len(unexpected)}") assert len(missing) == 0, \ f"Weight mismatch! Missing: {missing[:10]}" # 3. script the model (prefer script over trace) model.eval() scripted = torch.jit.script(model) # 4. save scripted.save(output_path) print(f"Saved TorchScript model to {output_path}") # 5. verify round-trip loaded = torch.jit.load(output_path) dummy = torch.randn(1, 3, 518, 518) with torch.no_grad(): out_orig = model(dummy) out_loaded = loaded(dummy) diff = (out_orig - out_loaded).abs().max().item() print(f"Round-trip max diff: {diff}") assert diff < 1e-5, "Round-trip verification failed!" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--variant", required=True) parser.add_argument("--output", required=True) args = parser.parse_args() compile_model(args.variant, args.output) ``` ## Testing Patterns Test TorchScript compatibility BEFORE trying to compile: ```python import torch def test_scriptable(model_class, *args, **kwargs): """Verify a module can be scripted.""" model = model_class(*args, **kwargs) try: scripted = torch.jit.script(model) print(f"OK: {model_class.__name__} is scriptable") return scripted except Exception as e: print(f"FAIL: {model_class.__name__}: {e}") raise def test_output_shape(model, input_shape, expected_shape): """Verify output dimensions.""" dummy = torch.randn(*input_shape) with torch.no_grad(): out = model(dummy) assert out.shape == expected_shape, \ f"Expected {expected_shape}, got {out.shape}" print(f"OK: output shape {out.shape}") ``` ## Docker Build Environment For reproducible compilation with pinned dependencies: ```dockerfile FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime WORKDIR /build COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt COPY . . ENTRYPOINT ["python", "nuke_compile.py"] ``` requirements.txt should pin exact versions: ``` torch==2.2.0 torchvision==0.17.0 huggingface_hub>=0.20.0 safetensors>=0.4.0 ``` ## nn.ModuleList Indexing TorchScript cannot index `nn.ModuleList` with a variable. `zip()` and `enumerate()` over ModuleLists also fail because TorchScript can't statically determine the length. The only reliable approach is manual unrolling with integer literals: ```python # BAD: variable index for i in range(4): x = self.projects[i](x) # FAILS # BAD: zip over ModuleLists for feat, proj in zip(feats, self.projects): # FAILS x = proj(feat) # GOOD: manually unrolled with literal indices x0 = self.projects[0](feats[0]) x1 = self.projects[1](feats[1]) x2 = self.projects[2](feats[2]) x3 = self.projects[3](feats[3]) ``` Verbose but it's the only way. TorchScript needs to resolve each ModuleList access at compile time. ## F.interpolate Gotcha TorchScript can't always infer `torch.Size` slices as tuple types. Always be explicit, and declare size components as typed locals: ```python # BAD: may fail in TorchScript F.interpolate(x, size=y.shape[2:], mode="bilinear") # BAD: inline arithmetic can fail type inference F.interpolate(x, size=(y.shape[2] * 2, y.shape[3] * 2), ...) # GOOD: explicit int locals, then tuple h: int = y.shape[2] * 2 w: int = y.shape[3] * 2 F.interpolate( x, size=(h, w), mode="bilinear", align_corners=True ) ``` Also for Optional size parameters, use `Optional[Tuple[int, int]]` not just `Tuple[int, int] = None`: ```python # BAD: TorchScript chokes on the default def forward(self, x: torch.Tensor, size: Tuple[int, int] = None): # GOOD: proper Optional typing def forward(self, x: torch.Tensor, size: Optional[Tuple[int, int]] = None): ``` ## Type Annotation Requirements TorchScript requires consistent typing. Common pitfalls: ```python # BAD: variable changes type x = None # NoneType x = torch.zeros(3) # Tensor -- type changed! # GOOD: use Optional or initialize correctly x: Optional[torch.Tensor] = None # BAD: bool in module attribute self.use_bias = True # bool not reliably supported # GOOD: use int self.use_bias: int = 1 # 1 for true, 0 for false ``` ## Weight Loading Validation Never trust `strict=False` to silently pass. Always validate: ```python missing, unexpected = model.load_state_dict(state_dict, strict=False) # filter out EXPECTED missing keys (baked constants, skipped modules) expected_missing = ['mean', 'std', 'sky_', 'drop', 'q_norm', 'k_norm'] critical = [k for k in missing if not any(skip in k for skip in expected_missing)] if len(critical) > 0: raise RuntimeError( f"Weight mismatch! {len(critical)} keys missing: " f"{critical[:10]}" ) ``` Add Identity/Dropout placeholders for upstream modules that exist in checkpoints but are functionally no-ops (Dropout(0.0), Identity): ```python # match upstream checkpoint keys even for no-op modules self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.attn_drop = nn.Dropout(0.0) self.proj_drop = nn.Dropout(0.0) ```