375 lines
10 KiB
Markdown
375 lines
10 KiB
Markdown
# TorchScript Wrapper Patterns for Nuke
|
|
|
|
## Weight Loading Strategy
|
|
|
|
The most critical step. A previous attempt scored 2/10 because
|
|
it wrote a DA v2-style wrapper for a DA v3 model.
|
|
|
|
**Rule**: Read the source code. List state_dict keys from a
|
|
pretrained checkpoint. Match your wrapper's module hierarchy EXACTLY.
|
|
|
|
```python
|
|
import torch
|
|
|
|
ckpt = torch.load("model.pt", map_location="cpu")
|
|
if "model" in ckpt:
|
|
ckpt = ckpt["model"]
|
|
|
|
for k in sorted(ckpt.keys()):
|
|
print(f"{k:60s} {ckpt[k].shape}")
|
|
```
|
|
|
|
Then structure your wrapper's `__init__` so that:
|
|
- `self.encoder.blocks[i].attn.qkv.weight` maps to the checkpoint's
|
|
`encoder.blocks.i.attn.qkv.weight`
|
|
- Every submodule name matches exactly
|
|
|
|
**Never** rely on `strict=False` to mask mismatches. Always check:
|
|
```python
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
assert len(missing) == 0, f"Missing keys: {missing[:5]}"
|
|
assert len(unexpected) == 0, f"Unexpected keys: {unexpected[:5]}"
|
|
```
|
|
|
|
## Standard Wrapper Template
|
|
|
|
```python
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class NukeWrapper(torch.nn.Module):
|
|
"""TorchScript-compatible wrapper for Nuke Inference node."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# ALL attributes must be initialized here
|
|
self.img_mean: torch.Tensor = torch.zeros(3)
|
|
self.img_std: torch.Tensor = torch.ones(3)
|
|
|
|
self.encoder = EncoderNuke(...)
|
|
self.head = HeadNuke(...)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# x comes from Nuke: [B, C, H, W], float32, sRGB [0,1]
|
|
|
|
# 1. normalize (ImageNet stats)
|
|
x = (x - self.img_mean[None, :, None, None]) \
|
|
/ self.img_std[None, :, None, None]
|
|
|
|
# 2. pad to patch-aligned size
|
|
h, w = x.shape[2], x.shape[3]
|
|
pad_h = (14 - h % 14) % 14
|
|
pad_w = (14 - w % 14) % 14
|
|
x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
|
|
|
|
# 3. encode
|
|
features = self.encoder(x)
|
|
|
|
# 4. decode
|
|
depth = self.head(features, x.shape[2], x.shape[3])
|
|
|
|
# 5. unpad
|
|
depth = depth[:, :, :h, :w]
|
|
|
|
return depth
|
|
```
|
|
|
|
## Replacing xformers with Native PyTorch
|
|
|
|
xformers is NOT available in TorchScript. Replace with
|
|
`torch.nn.functional.scaled_dot_product_attention`:
|
|
|
|
```python
|
|
# BEFORE (xformers - not TorchScript compatible)
|
|
# from xformers.ops import memory_efficient_attention
|
|
# attn_out = memory_efficient_attention(q, k, v)
|
|
|
|
# AFTER (native PyTorch, TorchScript safe)
|
|
class AttentionNuke(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.qkv = nn.Linear(dim, dim * 3)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(
|
|
B, N, 3, self.num_heads, self.head_dim
|
|
)
|
|
q = qkv[:, :, 0].transpose(1, 2)
|
|
k = qkv[:, :, 1].transpose(1, 2)
|
|
v = qkv[:, :, 2].transpose(1, 2)
|
|
|
|
attn_out = F.scaled_dot_product_attention(q, k, v)
|
|
attn_out = attn_out.transpose(1, 2).reshape(B, N, C)
|
|
return self.proj(attn_out)
|
|
```
|
|
|
|
## Replacing einops with Native torch
|
|
|
|
einops rearrange is not TorchScript compatible. Common replacements:
|
|
|
|
```python
|
|
# einops: rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
|
# native:
|
|
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
|
|
# einops: rearrange(x, 'b c h w -> b (h w) c')
|
|
# native:
|
|
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
|
|
# einops: rearrange(x, 'b n (h d) -> b h n d', h=num_heads)
|
|
# native:
|
|
x = x.reshape(B, N, num_heads, head_dim).transpose(1, 2)
|
|
```
|
|
|
|
## LayerScale Pattern
|
|
|
|
Modern ViTs (DinoV2, DA3) use LayerScale -- a learnable per-channel
|
|
scaling applied after attention and MLP. Missing this means weights
|
|
won't load and model dynamics are wrong.
|
|
|
|
```python
|
|
class LayerScaleNuke(nn.Module):
|
|
def __init__(self, dim: int, init_value: float = 1e-5):
|
|
super().__init__()
|
|
self.gamma = nn.Parameter(
|
|
init_value * torch.ones(dim)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.gamma * x
|
|
|
|
|
|
class TransformerBlockNuke(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.attn = AttentionNuke(dim, num_heads)
|
|
self.ls1 = LayerScaleNuke(dim)
|
|
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.mlp = MlpNuke(dim)
|
|
self.ls2 = LayerScaleNuke(dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + self.ls1(self.attn(self.norm1(x)))
|
|
x = x + self.ls2(self.mlp(self.norm2(x)))
|
|
return x
|
|
```
|
|
|
|
## Compilation Script Pattern
|
|
|
|
```python
|
|
import torch
|
|
import argparse
|
|
from wrapper import NukeWrapper
|
|
|
|
def compile_model(variant: str, output_path: str):
|
|
# 1. build wrapper
|
|
model = NukeWrapper(variant=variant)
|
|
|
|
# 2. load weights
|
|
ckpt_path = f"models/{variant}.pth"
|
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
missing, unexpected = model.load_state_dict(
|
|
state_dict, strict=False
|
|
)
|
|
print(f"Missing keys: {len(missing)}")
|
|
print(f"Unexpected keys: {len(unexpected)}")
|
|
assert len(missing) == 0, \
|
|
f"Weight mismatch! Missing: {missing[:10]}"
|
|
|
|
# 3. script the model (prefer script over trace)
|
|
model.eval()
|
|
scripted = torch.jit.script(model)
|
|
|
|
# 4. save
|
|
scripted.save(output_path)
|
|
print(f"Saved TorchScript model to {output_path}")
|
|
|
|
# 5. verify round-trip
|
|
loaded = torch.jit.load(output_path)
|
|
dummy = torch.randn(1, 3, 518, 518)
|
|
with torch.no_grad():
|
|
out_orig = model(dummy)
|
|
out_loaded = loaded(dummy)
|
|
diff = (out_orig - out_loaded).abs().max().item()
|
|
print(f"Round-trip max diff: {diff}")
|
|
assert diff < 1e-5, "Round-trip verification failed!"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--variant", required=True)
|
|
parser.add_argument("--output", required=True)
|
|
args = parser.parse_args()
|
|
compile_model(args.variant, args.output)
|
|
```
|
|
|
|
## Testing Patterns
|
|
|
|
Test TorchScript compatibility BEFORE trying to compile:
|
|
|
|
```python
|
|
import torch
|
|
|
|
def test_scriptable(model_class, *args, **kwargs):
|
|
"""Verify a module can be scripted."""
|
|
model = model_class(*args, **kwargs)
|
|
try:
|
|
scripted = torch.jit.script(model)
|
|
print(f"OK: {model_class.__name__} is scriptable")
|
|
return scripted
|
|
except Exception as e:
|
|
print(f"FAIL: {model_class.__name__}: {e}")
|
|
raise
|
|
|
|
|
|
def test_output_shape(model, input_shape, expected_shape):
|
|
"""Verify output dimensions."""
|
|
dummy = torch.randn(*input_shape)
|
|
with torch.no_grad():
|
|
out = model(dummy)
|
|
assert out.shape == expected_shape, \
|
|
f"Expected {expected_shape}, got {out.shape}"
|
|
print(f"OK: output shape {out.shape}")
|
|
```
|
|
|
|
## Docker Build Environment
|
|
|
|
For reproducible compilation with pinned dependencies:
|
|
|
|
```dockerfile
|
|
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
|
|
|
|
WORKDIR /build
|
|
COPY requirements.txt .
|
|
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
COPY . .
|
|
|
|
ENTRYPOINT ["python", "nuke_compile.py"]
|
|
```
|
|
|
|
requirements.txt should pin exact versions:
|
|
```
|
|
torch==2.2.0
|
|
torchvision==0.17.0
|
|
huggingface_hub>=0.20.0
|
|
safetensors>=0.4.0
|
|
```
|
|
|
|
## nn.ModuleList Indexing
|
|
|
|
TorchScript cannot index `nn.ModuleList` with a variable.
|
|
`zip()` and `enumerate()` over ModuleLists also fail because
|
|
TorchScript can't statically determine the length.
|
|
|
|
The only reliable approach is manual unrolling with integer literals:
|
|
|
|
```python
|
|
# BAD: variable index
|
|
for i in range(4):
|
|
x = self.projects[i](x) # FAILS
|
|
|
|
# BAD: zip over ModuleLists
|
|
for feat, proj in zip(feats, self.projects): # FAILS
|
|
x = proj(feat)
|
|
|
|
# GOOD: manually unrolled with literal indices
|
|
x0 = self.projects[0](feats[0])
|
|
x1 = self.projects[1](feats[1])
|
|
x2 = self.projects[2](feats[2])
|
|
x3 = self.projects[3](feats[3])
|
|
```
|
|
|
|
Verbose but it's the only way. TorchScript needs to resolve
|
|
each ModuleList access at compile time.
|
|
|
|
## F.interpolate Gotcha
|
|
|
|
TorchScript can't always infer `torch.Size` slices as tuple types.
|
|
Always be explicit, and declare size components as typed locals:
|
|
|
|
```python
|
|
# BAD: may fail in TorchScript
|
|
F.interpolate(x, size=y.shape[2:], mode="bilinear")
|
|
|
|
# BAD: inline arithmetic can fail type inference
|
|
F.interpolate(x, size=(y.shape[2] * 2, y.shape[3] * 2), ...)
|
|
|
|
# GOOD: explicit int locals, then tuple
|
|
h: int = y.shape[2] * 2
|
|
w: int = y.shape[3] * 2
|
|
F.interpolate(
|
|
x, size=(h, w), mode="bilinear",
|
|
align_corners=True
|
|
)
|
|
```
|
|
|
|
Also for Optional size parameters, use `Optional[Tuple[int, int]]`
|
|
not just `Tuple[int, int] = None`:
|
|
|
|
```python
|
|
# BAD: TorchScript chokes on the default
|
|
def forward(self, x: torch.Tensor, size: Tuple[int, int] = None):
|
|
|
|
# GOOD: proper Optional typing
|
|
def forward(self, x: torch.Tensor, size: Optional[Tuple[int, int]] = None):
|
|
```
|
|
|
|
## Type Annotation Requirements
|
|
|
|
TorchScript requires consistent typing. Common pitfalls:
|
|
|
|
```python
|
|
# BAD: variable changes type
|
|
x = None # NoneType
|
|
x = torch.zeros(3) # Tensor -- type changed!
|
|
|
|
# GOOD: use Optional or initialize correctly
|
|
x: Optional[torch.Tensor] = None
|
|
|
|
# BAD: bool in module attribute
|
|
self.use_bias = True # bool not reliably supported
|
|
|
|
# GOOD: use int
|
|
self.use_bias: int = 1 # 1 for true, 0 for false
|
|
```
|
|
|
|
## Weight Loading Validation
|
|
|
|
Never trust `strict=False` to silently pass. Always validate:
|
|
|
|
```python
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
|
|
# filter out EXPECTED missing keys (baked constants, skipped modules)
|
|
expected_missing = ['mean', 'std', 'sky_', 'drop', 'q_norm', 'k_norm']
|
|
critical = [k for k in missing
|
|
if not any(skip in k for skip in expected_missing)]
|
|
|
|
if len(critical) > 0:
|
|
raise RuntimeError(
|
|
f"Weight mismatch! {len(critical)} keys missing: "
|
|
f"{critical[:10]}"
|
|
)
|
|
```
|
|
|
|
Add Identity/Dropout placeholders for upstream modules that exist
|
|
in checkpoints but are functionally no-ops (Dropout(0.0), Identity):
|
|
|
|
```python
|
|
# match upstream checkpoint keys even for no-op modules
|
|
self.q_norm = nn.Identity()
|
|
self.k_norm = nn.Identity()
|
|
self.attn_drop = nn.Dropout(0.0)
|
|
self.proj_drop = nn.Dropout(0.0)
|
|
```
|