Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.utils.weights.Weights

class nidl.utils.weights.Weights(name, data_dir, filepath)[source]

Bases: object

A class to handle (retrieve and apply) model weights or lightning checkpoints.

Parameters:
name: str

the location of the model weights specified in the form hf-hub:path/architecture_name@revision if available in Hugging Face hub or ns-hub:path/architecture_name if available in the NeuroSpin hub or a path if avaiable in your local machine.

data_dir: pathlib.Path or str

path where data should be downloaded.

filepath: str

the path of the file in the repo. If path has ‘.ckpt’ extension, it assumes it is a pytorch_lightning checkpoint.

HF_URL = 'https://huggingface.co'
NS_URL = 'http://nsap.intra.cea.fr/neurospin-hub/'
__init__(name, data_dir, filepath)[source]
classmethod hf_download(data_dir, hf_id, filepath, hf_revision=None, force_download=False)[source]

Download a given file if not already present.

Downloads always resume when possible. If you want to force a new download, use force_download=True.

Parameters:
data_dir: pathlib.Path or str

path where data should be downloaded.

hf_id: str

the id of the repository.

filepath: str

the path of the file in the repo.

hf_revision: str, default=None

the revision of the repository (a tag, or a commit hash).

force_download: bool, default=False

whether the file should be downloaded even if it already exists in the local cache.

Returns:
weight_file: Path

local path to the model weights.

classmethod hub_split(hub_name)[source]

Interpret the input hub name specified in the form hf-hub:path/architecture_name@revision or ns-hub:path/architecture_name.

Parameters:
hub_name: str

name of the repository.

Returns:
hub_id: str

the id of the repository.

hub_revision: str

the revision of the repository.

load_checkpoint(model, **kwargs)[source]

Load the checkpoint.

Parameters:
model: LightningModule

an pytorch_lightning’s module class.

device: torch.device or str

the device on which to load the model and to use for inference. Default to cpu. Only single device is supported for now.

**kwargs: Any extra keyword args needed to init the model. Can also be

used to override saved hyperparameter values, in particular to override trainer parameters such as accelerator or devices.

load_pretrained(model)[source]

Load the model weights.

Parameters:
model: torch.nn.Module

an input model with a load_pretrained method decalred.

classmethod ns_download(data_dir, ns_id, filepath, force_download=False)[source]

Download a given file if not already present.

Downloads always resume when possible. If you want to force a new download, use force_download=True.

Parameters:
data_dir: pathlib.Path or str

path where data should be downloaded.

ns_id: str

the id of the repository.

filepath: str

the path of the file in the repo.

force_download: bool, default=False

whether the file should be downloaded even if it already exists in the local cache.

Returns:
weight_file: Path

local path to the model weights.

Examples using nidl.utils.weights.Weights

Self-Supervised Contrastive Learning with SimCLR on MNIST

Self-Supervised Contrastive Learning with SimCLR on MNIST