.agents/skills/materia-nuke-node/references/torchscript-patterns.md

10 KiB

TorchScript Wrapper Patterns for Nuke

Weight Loading Strategy

The most critical step. A previous attempt scored 2/10 because it wrote a DA v2-style wrapper for a DA v3 model.

Rule: Read the source code. List state_dict keys from a pretrained checkpoint. Match your wrapper's module hierarchy EXACTLY.

import torch

ckpt = torch.load("model.pt", map_location="cpu")
if "model" in ckpt:
    ckpt = ckpt["model"]

for k in sorted(ckpt.keys()):
    print(f"{k:60s}  {ckpt[k].shape}")

Then structure your wrapper's __init__ so that:

  • self.encoder.blocks[i].attn.qkv.weight maps to the checkpoint's encoder.blocks.i.attn.qkv.weight
  • Every submodule name matches exactly

Never rely on strict=False to mask mismatches. Always check:

missing, unexpected = model.load_state_dict(state_dict, strict=False)
assert len(missing) == 0, f"Missing keys: {missing[:5]}"
assert len(unexpected) == 0, f"Unexpected keys: {unexpected[:5]}"

Standard Wrapper Template

import torch
import torch.nn as nn
import torch.nn.functional as F


class NukeWrapper(torch.nn.Module):
    """TorchScript-compatible wrapper for Nuke Inference node."""

    def __init__(self):
        super().__init__()
        # ALL attributes must be initialized here
        self.img_mean: torch.Tensor = torch.zeros(3)
        self.img_std: torch.Tensor = torch.ones(3)

        self.encoder = EncoderNuke(...)
        self.head = HeadNuke(...)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x comes from Nuke: [B, C, H, W], float32, sRGB [0,1]

        # 1. normalize (ImageNet stats)
        x = (x - self.img_mean[None, :, None, None]) \
            / self.img_std[None, :, None, None]

        # 2. pad to patch-aligned size
        h, w = x.shape[2], x.shape[3]
        pad_h = (14 - h % 14) % 14
        pad_w = (14 - w % 14) % 14
        x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")

        # 3. encode
        features = self.encoder(x)

        # 4. decode
        depth = self.head(features, x.shape[2], x.shape[3])

        # 5. unpad
        depth = depth[:, :, :h, :w]

        return depth

Replacing xformers with Native PyTorch

xformers is NOT available in TorchScript. Replace with torch.nn.functional.scaled_dot_product_attention:

# BEFORE (xformers - not TorchScript compatible)
# from xformers.ops import memory_efficient_attention
# attn_out = memory_efficient_attention(q, k, v)

# AFTER (native PyTorch, TorchScript safe)
class AttentionNuke(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(
            B, N, 3, self.num_heads, self.head_dim
        )
        q = qkv[:, :, 0].transpose(1, 2)
        k = qkv[:, :, 1].transpose(1, 2)
        v = qkv[:, :, 2].transpose(1, 2)

        attn_out = F.scaled_dot_product_attention(q, k, v)
        attn_out = attn_out.transpose(1, 2).reshape(B, N, C)
        return self.proj(attn_out)

Replacing einops with Native torch

einops rearrange is not TorchScript compatible. Common replacements:

# einops: rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
# native:
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()

# einops: rearrange(x, 'b c h w -> b (h w) c')
# native:
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)

# einops: rearrange(x, 'b n (h d) -> b h n d', h=num_heads)
# native:
x = x.reshape(B, N, num_heads, head_dim).transpose(1, 2)

LayerScale Pattern

Modern ViTs (DinoV2, DA3) use LayerScale -- a learnable per-channel scaling applied after attention and MLP. Missing this means weights won't load and model dynamics are wrong.

class LayerScaleNuke(nn.Module):
    def __init__(self, dim: int, init_value: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(
            init_value * torch.ones(dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gamma * x


class TransformerBlockNuke(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = AttentionNuke(dim, num_heads)
        self.ls1 = LayerScaleNuke(dim)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MlpNuke(dim)
        self.ls2 = LayerScaleNuke(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.ls1(self.attn(self.norm1(x)))
        x = x + self.ls2(self.mlp(self.norm2(x)))
        return x

Compilation Script Pattern

import torch
import argparse
from wrapper import NukeWrapper

def compile_model(variant: str, output_path: str):
    # 1. build wrapper
    model = NukeWrapper(variant=variant)

    # 2. load weights
    ckpt_path = f"models/{variant}.pth"
    state_dict = torch.load(ckpt_path, map_location="cpu")
    missing, unexpected = model.load_state_dict(
        state_dict, strict=False
    )
    print(f"Missing keys: {len(missing)}")
    print(f"Unexpected keys: {len(unexpected)}")
    assert len(missing) == 0, \
        f"Weight mismatch! Missing: {missing[:10]}"

    # 3. script the model (prefer script over trace)
    model.eval()
    scripted = torch.jit.script(model)

    # 4. save
    scripted.save(output_path)
    print(f"Saved TorchScript model to {output_path}")

    # 5. verify round-trip
    loaded = torch.jit.load(output_path)
    dummy = torch.randn(1, 3, 518, 518)
    with torch.no_grad():
        out_orig = model(dummy)
        out_loaded = loaded(dummy)
    diff = (out_orig - out_loaded).abs().max().item()
    print(f"Round-trip max diff: {diff}")
    assert diff < 1e-5, "Round-trip verification failed!"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--variant", required=True)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()
    compile_model(args.variant, args.output)

Testing Patterns

Test TorchScript compatibility BEFORE trying to compile:

import torch

def test_scriptable(model_class, *args, **kwargs):
    """Verify a module can be scripted."""
    model = model_class(*args, **kwargs)
    try:
        scripted = torch.jit.script(model)
        print(f"OK: {model_class.__name__} is scriptable")
        return scripted
    except Exception as e:
        print(f"FAIL: {model_class.__name__}: {e}")
        raise


def test_output_shape(model, input_shape, expected_shape):
    """Verify output dimensions."""
    dummy = torch.randn(*input_shape)
    with torch.no_grad():
        out = model(dummy)
    assert out.shape == expected_shape, \
        f"Expected {expected_shape}, got {out.shape}"
    print(f"OK: output shape {out.shape}")

Docker Build Environment

For reproducible compilation with pinned dependencies:

FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime

WORKDIR /build
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

ENTRYPOINT ["python", "nuke_compile.py"]

requirements.txt should pin exact versions:

torch==2.2.0
torchvision==0.17.0
huggingface_hub>=0.20.0
safetensors>=0.4.0

nn.ModuleList Indexing

TorchScript cannot index nn.ModuleList with a variable. zip() and enumerate() over ModuleLists also fail because TorchScript can't statically determine the length.

The only reliable approach is manual unrolling with integer literals:

# BAD: variable index
for i in range(4):
    x = self.projects[i](x)  # FAILS

# BAD: zip over ModuleLists
for feat, proj in zip(feats, self.projects):  # FAILS
    x = proj(feat)

# GOOD: manually unrolled with literal indices
x0 = self.projects[0](feats[0])
x1 = self.projects[1](feats[1])
x2 = self.projects[2](feats[2])
x3 = self.projects[3](feats[3])

Verbose but it's the only way. TorchScript needs to resolve each ModuleList access at compile time.

F.interpolate Gotcha

TorchScript can't always infer torch.Size slices as tuple types. Always be explicit, and declare size components as typed locals:

# BAD: may fail in TorchScript
F.interpolate(x, size=y.shape[2:], mode="bilinear")

# BAD: inline arithmetic can fail type inference
F.interpolate(x, size=(y.shape[2] * 2, y.shape[3] * 2), ...)

# GOOD: explicit int locals, then tuple
h: int = y.shape[2] * 2
w: int = y.shape[3] * 2
F.interpolate(
    x, size=(h, w), mode="bilinear",
    align_corners=True
)

Also for Optional size parameters, use Optional[Tuple[int, int]] not just Tuple[int, int] = None:

# BAD: TorchScript chokes on the default
def forward(self, x: torch.Tensor, size: Tuple[int, int] = None):

# GOOD: proper Optional typing
def forward(self, x: torch.Tensor, size: Optional[Tuple[int, int]] = None):

Type Annotation Requirements

TorchScript requires consistent typing. Common pitfalls:

# BAD: variable changes type
x = None          # NoneType
x = torch.zeros(3)  # Tensor -- type changed!

# GOOD: use Optional or initialize correctly
x: Optional[torch.Tensor] = None

# BAD: bool in module attribute
self.use_bias = True  # bool not reliably supported

# GOOD: use int
self.use_bias: int = 1  # 1 for true, 0 for false

Weight Loading Validation

Never trust strict=False to silently pass. Always validate:

missing, unexpected = model.load_state_dict(state_dict, strict=False)

# filter out EXPECTED missing keys (baked constants, skipped modules)
expected_missing = ['mean', 'std', 'sky_', 'drop', 'q_norm', 'k_norm']
critical = [k for k in missing
            if not any(skip in k for skip in expected_missing)]

if len(critical) > 0:
    raise RuntimeError(
        f"Weight mismatch! {len(critical)} keys missing: "
        f"{critical[:10]}"
    )

Add Identity/Dropout placeholders for upstream modules that exist in checkpoints but are functionally no-ops (Dropout(0.0), Identity):

# match upstream checkpoint keys even for no-op modules
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.attn_drop = nn.Dropout(0.0)
self.proj_drop = nn.Dropout(0.0)