Menu

PyTorch toolbox to work with spherical surfaces.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class surfify.models.sit.SiT(dim, depth, heads, mlp_dim, n_patches, n_channels, n_vertices, n_classes=1, pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0)[source]

Bases: Module

The SiT model: implements surface vision transformers.

See also

mSiT

Notes

Debuging messages can be displayed by changing the log level using setup_logging(level='debug').

References

Dahan, Simon et al., Surface Vision Transformers: Attention-Based Modelling applied to Cortical Analysis, MIDL, 2022.

Init SiT.

Parameters:

dim : int

the sequence of N flattened patches of size n_channels x n_verticesis is first projected onto a sequence of dimension dim with a trainable linear embedding layer.

depth : int

the number of transformer blocks, each composed of a multi-head self-attention layer (MSA), implementing the self-attention mechanism across the sequence, and a feed-forward network (FFN), which expands then reduces the sequence dimension.

heads : int

number of attention heads.

mlp_dim : int

MLP hidden dim to embedding dim.

n_patches : int

the number of patches.

n_channels : int

the number of channels.

n_vertices : int

the number of vertices.

n_classes : int, default 1

the number of classes to predict: if <=0 return the latent space.

pool : str, default ‘cls’

the polling strategy: ‘cls’ (cls token) or ‘mean’ (mean pooling).

dim_head : int, default 64

the output dimension of the layer.

dropout : float, default 0.

transformer dropout rate.

emb_dropout : float, default 0.

embeding dropout rate.

forward(x)[source]

Forward method.

Parameters:

x : Tensor (n_samples, n_channels, n_patches, n_vertices)

the input data.

Returns:

x : Tensor (n_samples, n_channels, n_patches, n_vertices)

the output data.

Follow us

© 2025, surfify developers