PyTorch toolbox to work with spherical surfaces.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class surfify.models.sit.SiT(dim, depth, heads, mlp_dim, n_patches, n_channels, n_vertices, n_classes=1, pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0)[source]¶
Bases:
Module
The SiT model: implements surface vision transformers.
See also
mSiT
Notes
Debuging messages can be displayed by changing the log level using
setup_logging(level='debug')
.References
Dahan, Simon et al., Surface Vision Transformers: Attention-Based Modelling applied to Cortical Analysis, MIDL, 2022.
Init SiT.
- Parameters:
dim : int
the sequence of N flattened patches of size n_channels x n_verticesis is first projected onto a sequence of dimension dim with a trainable linear embedding layer.
depth : int
the number of transformer blocks, each composed of a multi-head self-attention layer (MSA), implementing the self-attention mechanism across the sequence, and a feed-forward network (FFN), which expands then reduces the sequence dimension.
heads : int
number of attention heads.
mlp_dim : int
MLP hidden dim to embedding dim.
n_patches : int
the number of patches.
n_channels : int
the number of channels.
n_vertices : int
the number of vertices.
n_classes : int, default 1
the number of classes to predict: if <=0 return the latent space.
pool : str, default ‘cls’
the polling strategy: ‘cls’ (cls token) or ‘mean’ (mean pooling).
dim_head : int, default 64
the output dimension of the layer.
dropout : float, default 0.
transformer dropout rate.
emb_dropout : float, default 0.
embeding dropout rate.
Follow us