.agents/skills/materia-nuke-node/references/torchscript-patterns.md

375 lines
10 KiB
Markdown

# TorchScript Wrapper Patterns for Nuke
## Weight Loading Strategy
The most critical step. A previous attempt scored 2/10 because
it wrote a DA v2-style wrapper for a DA v3 model.
**Rule**: Read the source code. List state_dict keys from a
pretrained checkpoint. Match your wrapper's module hierarchy EXACTLY.
```python
import torch
ckpt = torch.load("model.pt", map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
for k in sorted(ckpt.keys()):
print(f"{k:60s} {ckpt[k].shape}")
```
Then structure your wrapper's `__init__` so that:
- `self.encoder.blocks[i].attn.qkv.weight` maps to the checkpoint's
`encoder.blocks.i.attn.qkv.weight`
- Every submodule name matches exactly
**Never** rely on `strict=False` to mask mismatches. Always check:
```python
missing, unexpected = model.load_state_dict(state_dict, strict=False)
assert len(missing) == 0, f"Missing keys: {missing[:5]}"
assert len(unexpected) == 0, f"Unexpected keys: {unexpected[:5]}"
```
## Standard Wrapper Template
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class NukeWrapper(torch.nn.Module):
"""TorchScript-compatible wrapper for Nuke Inference node."""
def __init__(self):
super().__init__()
# ALL attributes must be initialized here
self.img_mean: torch.Tensor = torch.zeros(3)
self.img_std: torch.Tensor = torch.ones(3)
self.encoder = EncoderNuke(...)
self.head = HeadNuke(...)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x comes from Nuke: [B, C, H, W], float32, sRGB [0,1]
# 1. normalize (ImageNet stats)
x = (x - self.img_mean[None, :, None, None]) \
/ self.img_std[None, :, None, None]
# 2. pad to patch-aligned size
h, w = x.shape[2], x.shape[3]
pad_h = (14 - h % 14) % 14
pad_w = (14 - w % 14) % 14
x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
# 3. encode
features = self.encoder(x)
# 4. decode
depth = self.head(features, x.shape[2], x.shape[3])
# 5. unpad
depth = depth[:, :, :h, :w]
return depth
```
## Replacing xformers with Native PyTorch
xformers is NOT available in TorchScript. Replace with
`torch.nn.functional.scaled_dot_product_attention`:
```python
# BEFORE (xformers - not TorchScript compatible)
# from xformers.ops import memory_efficient_attention
# attn_out = memory_efficient_attention(q, k, v)
# AFTER (native PyTorch, TorchScript safe)
class AttentionNuke(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(
B, N, 3, self.num_heads, self.head_dim
)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v)
attn_out = attn_out.transpose(1, 2).reshape(B, N, C)
return self.proj(attn_out)
```
## Replacing einops with Native torch
einops rearrange is not TorchScript compatible. Common replacements:
```python
# einops: rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
# native:
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
# einops: rearrange(x, 'b c h w -> b (h w) c')
# native:
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
# einops: rearrange(x, 'b n (h d) -> b h n d', h=num_heads)
# native:
x = x.reshape(B, N, num_heads, head_dim).transpose(1, 2)
```
## LayerScale Pattern
Modern ViTs (DinoV2, DA3) use LayerScale -- a learnable per-channel
scaling applied after attention and MLP. Missing this means weights
won't load and model dynamics are wrong.
```python
class LayerScaleNuke(nn.Module):
def __init__(self, dim: int, init_value: float = 1e-5):
super().__init__()
self.gamma = nn.Parameter(
init_value * torch.ones(dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.gamma * x
class TransformerBlockNuke(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = AttentionNuke(dim, num_heads)
self.ls1 = LayerScaleNuke(dim)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MlpNuke(dim)
self.ls2 = LayerScaleNuke(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
```
## Compilation Script Pattern
```python
import torch
import argparse
from wrapper import NukeWrapper
def compile_model(variant: str, output_path: str):
# 1. build wrapper
model = NukeWrapper(variant=variant)
# 2. load weights
ckpt_path = f"models/{variant}.pth"
state_dict = torch.load(ckpt_path, map_location="cpu")
missing, unexpected = model.load_state_dict(
state_dict, strict=False
)
print(f"Missing keys: {len(missing)}")
print(f"Unexpected keys: {len(unexpected)}")
assert len(missing) == 0, \
f"Weight mismatch! Missing: {missing[:10]}"
# 3. script the model (prefer script over trace)
model.eval()
scripted = torch.jit.script(model)
# 4. save
scripted.save(output_path)
print(f"Saved TorchScript model to {output_path}")
# 5. verify round-trip
loaded = torch.jit.load(output_path)
dummy = torch.randn(1, 3, 518, 518)
with torch.no_grad():
out_orig = model(dummy)
out_loaded = loaded(dummy)
diff = (out_orig - out_loaded).abs().max().item()
print(f"Round-trip max diff: {diff}")
assert diff < 1e-5, "Round-trip verification failed!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--variant", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
compile_model(args.variant, args.output)
```
## Testing Patterns
Test TorchScript compatibility BEFORE trying to compile:
```python
import torch
def test_scriptable(model_class, *args, **kwargs):
"""Verify a module can be scripted."""
model = model_class(*args, **kwargs)
try:
scripted = torch.jit.script(model)
print(f"OK: {model_class.__name__} is scriptable")
return scripted
except Exception as e:
print(f"FAIL: {model_class.__name__}: {e}")
raise
def test_output_shape(model, input_shape, expected_shape):
"""Verify output dimensions."""
dummy = torch.randn(*input_shape)
with torch.no_grad():
out = model(dummy)
assert out.shape == expected_shape, \
f"Expected {expected_shape}, got {out.shape}"
print(f"OK: output shape {out.shape}")
```
## Docker Build Environment
For reproducible compilation with pinned dependencies:
```dockerfile
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
WORKDIR /build
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENTRYPOINT ["python", "nuke_compile.py"]
```
requirements.txt should pin exact versions:
```
torch==2.2.0
torchvision==0.17.0
huggingface_hub>=0.20.0
safetensors>=0.4.0
```
## nn.ModuleList Indexing
TorchScript cannot index `nn.ModuleList` with a variable.
`zip()` and `enumerate()` over ModuleLists also fail because
TorchScript can't statically determine the length.
The only reliable approach is manual unrolling with integer literals:
```python
# BAD: variable index
for i in range(4):
x = self.projects[i](x) # FAILS
# BAD: zip over ModuleLists
for feat, proj in zip(feats, self.projects): # FAILS
x = proj(feat)
# GOOD: manually unrolled with literal indices
x0 = self.projects[0](feats[0])
x1 = self.projects[1](feats[1])
x2 = self.projects[2](feats[2])
x3 = self.projects[3](feats[3])
```
Verbose but it's the only way. TorchScript needs to resolve
each ModuleList access at compile time.
## F.interpolate Gotcha
TorchScript can't always infer `torch.Size` slices as tuple types.
Always be explicit, and declare size components as typed locals:
```python
# BAD: may fail in TorchScript
F.interpolate(x, size=y.shape[2:], mode="bilinear")
# BAD: inline arithmetic can fail type inference
F.interpolate(x, size=(y.shape[2] * 2, y.shape[3] * 2), ...)
# GOOD: explicit int locals, then tuple
h: int = y.shape[2] * 2
w: int = y.shape[3] * 2
F.interpolate(
x, size=(h, w), mode="bilinear",
align_corners=True
)
```
Also for Optional size parameters, use `Optional[Tuple[int, int]]`
not just `Tuple[int, int] = None`:
```python
# BAD: TorchScript chokes on the default
def forward(self, x: torch.Tensor, size: Tuple[int, int] = None):
# GOOD: proper Optional typing
def forward(self, x: torch.Tensor, size: Optional[Tuple[int, int]] = None):
```
## Type Annotation Requirements
TorchScript requires consistent typing. Common pitfalls:
```python
# BAD: variable changes type
x = None # NoneType
x = torch.zeros(3) # Tensor -- type changed!
# GOOD: use Optional or initialize correctly
x: Optional[torch.Tensor] = None
# BAD: bool in module attribute
self.use_bias = True # bool not reliably supported
# GOOD: use int
self.use_bias: int = 1 # 1 for true, 0 for false
```
## Weight Loading Validation
Never trust `strict=False` to silently pass. Always validate:
```python
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# filter out EXPECTED missing keys (baked constants, skipped modules)
expected_missing = ['mean', 'std', 'sky_', 'drop', 'q_norm', 'k_norm']
critical = [k for k in missing
if not any(skip in k for skip in expected_missing)]
if len(critical) > 0:
raise RuntimeError(
f"Weight mismatch! {len(critical)} keys missing: "
f"{critical[:10]}"
)
```
Add Identity/Dropout placeholders for upstream modules that exist
in checkpoints but are functionally no-ops (Dropout(0.0), Identity):
```python
# match upstream checkpoint keys even for no-op modules
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.attn_drop = nn.Dropout(0.0)
self.proj_drop = nn.Dropout(0.0)
```