Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.volume.backbones.VisionTransformer3D

class nidl.volume.backbones.VisionTransformer3D(img_size, patch_size, in_chans=1, num_classes=0, global_pool='cls_token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_norm=False, scale_attn_norm=False, scale_mlp_norm=False, proj_bias=True, class_token=True, reg_tokens=0, no_embed_class=False, pre_norm=False, final_norm=True, fc_norm=None, dynamic_img_size=False, pos_embed='learned', drop_rate=0.0, pos_drop_rate=0.0, proj_drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, embed_layer=<class 'nidl.volume.backbones.vit3d.PatchEmbed3D'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, act_layer=<class 'torch.nn.modules.activation.GELU'>, block_fn=<class 'timm.models.vision_transformer.Block'>, mlp_layer=<class 'timm.layers.mlp.Mlp'>, attn_layer=<class 'timm.layers.attention.Attention'>)[source]

Bases: Module

3D Vision Transformer with a timm-like interface.

Parameters:
img_sizeint or sequence of int

Input size in (H, W, D).

patch_sizeint or sequence of int

Patch size in (H, W, D).

in_chansint, default=1

Number of input channels.

num_classesint, default=0

Number of output classes. If non-positive, head is identity.

global_pool{“cls_token”, “avg”, “max”, “avgmax”, “”},

default=”cls_token” Pooling mode. If “”, no pooling is applied and the full token sequence is returned.

embed_dimint, default=768

Embedding dimension.

depthint, default=12

Number of transformer blocks.

num_headsint, default=12

Number of attention heads.

mlp_ratiofloat, default=4.0

MLP expansion ratio.

qkv_biasbool, default=True

Whether to use qkv bias.

class_tokenbool, default=True

Whether to prepend a CLS token.

reg_tokensint, default=0

Number of register tokens.

no_embed_classbool, default=False

If True, position embeddings are defined only on patch tokens.

pre_normbool, default=False

Whether to apply normalization before the transformer blocks.

fc_normbool or None, default=None

Whether to apply normalization after pooling. If None, defaults to global_pool == "avg".

dynamic_img_sizebool, default=False

Present for interface compatibility. Not used here.

pos_embed{“learned”, “sincos”, “none”}, default=”learned”

Absolute positional embedding mode.

drop_ratefloat, default=0.0

Head dropout.

pos_drop_ratefloat, default=0.0

Positional dropout.

proj_drop_ratefloat, default=0.0

Projection dropout in each transformer block.

attn_drop_ratefloat, default=0.0

Attention dropout in each transformer block.

drop_path_ratefloat, default=0.0

Maximum stochastic depth rate.

norm_layercallable, default=nn.LayerNorm

Normalization layer constructor.

act_layercallable, default=nn.GELU

Activation layer constructor.

Notes

This implementation reuses timm’s transformer Block as-is. Only the 3D-specific pieces are implemented locally: patch embedding, 3D positional embeddings, and checkpoint inflation.

__init__(img_size, patch_size, in_chans=1, num_classes=0, global_pool='cls_token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_norm=False, scale_attn_norm=False, scale_mlp_norm=False, proj_bias=True, class_token=True, reg_tokens=0, no_embed_class=False, pre_norm=False, final_norm=True, fc_norm=None, dynamic_img_size=False, pos_embed='learned', drop_rate=0.0, pos_drop_rate=0.0, proj_drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, embed_layer=<class 'nidl.volume.backbones.vit3d.PatchEmbed3D'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, act_layer=<class 'torch.nn.modules.activation.GELU'>, block_fn=<class 'timm.models.vision_transformer.Block'>, mlp_layer=<class 'timm.layers.mlp.Mlp'>, attn_layer=<class 'timm.layers.attention.Attention'>)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Forward pass.

Parameters:
xtorch.Tensor

Input tensor of shape (B, C, H, W, D).

Returns:
torch.Tensor

Logits or features depending on the classifier head.

forward_features(x)[source]

Compute token features.

Parameters:
xtorch.Tensor

Input tensor of shape (B, C, H, W, D).

Returns:
torch.Tensor

Token features of shape (B, N_total, C).

forward_head(x, pre_logits=False)[source]

Apply pooling and classification head.

Parameters:
xtorch.Tensor

Token features of shape (B, N_total, C).

pre_logitsbool, default=False

Whether to return pooled features before the classifier.

Returns:
torch.Tensor

Model outputs.

forward_intermediates(x, indices=(-1,), norm=False, output_fmt='NLC', intermediates_only=False)[source]

Forward and return intermediate block outputs.

Parameters:
xtorch.Tensor

Input tensor of shape (B, C, H, W, D).

indicessequence of int or int, default=(-1,)

Block indices to return. Negative indices are supported.

normbool, default=False

Whether to apply final normalization to returned intermediates.

output_fmt{“NLC”}, default=”NLC”

Output format for intermediates.

intermediates_onlybool, default=False

If True, return only the intermediate tensors.

Returns:
tuple or list

If intermediates_only=False, returns (final, intermediates) Otherwise returns intermediates.

Raises:
ValueError

If an unsupported output format is requested.

get_classifier()[source]

Return the classifier head.

Returns:
nn.Module

Classifier head.

init_weights()[source]

Initialize model weights.

no_weight_decay()[source]

Return parameter names that should typically avoid weight decay.

Returns:
set of str

Parameter names exempt from weight decay.

pool(x, pool_type=None)[source]

Pool token features.

Parameters:
xtorch.Tensor

Token features of shape (B, N_total, C).

pool_typestr, optional

Pooling mode. If None, self.global_pool is used.

Returns:
torch.Tensor

Pooled features of shape (B, C) or the full sequence if pool_type == "".

reset_classifier(num_classes, global_pool=None)[source]

Reset the classification head.

Parameters:
num_classesint

New number of classes.

global_poolstr, optional

New global pooling mode.

Examples using nidl.volume.backbones.VisionTransformer3D

Self-Supervised Learning with I-JEPA on MedMNIST3D

Self-Supervised Learning with I-JEPA on MedMNIST3D