Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.utils.weights.Weights

class nidl.utils.weights.Weights(name, data_dir, filepath)[source]

Bases: object

A class to handle (retrieve and apply) model weights or lightning checkpoints.

Parameters:
name: str

the location of the model weights specified in the form hf-hub:path/architecture_name@revision if available in Hugging Face hub or ns-hub:path/architecture_name if available in the NeuroSpin hub or a path if available in your local machine.

data_dir: pathlib.Path or str

path where data should be downloaded.

filepath: str

the path of the file in the repo. If path has ‘.ckpt’ extension, it assumes it is a pytorch_lightning checkpoint.

HF_URL = 'https://huggingface.co'
NS_URL = 'http://nsap.intra.cea.fr/neurospin-hub/'
__init__(name, data_dir, filepath)[source]
classmethod hf_download(data_dir, hf_id, filepath, hf_revision=None, force_download=False)[source]

Download a given file if not already present.

Downloads always resume when possible. If you want to force a new download, use force_download=True.

Parameters:
data_dir: pathlib.Path or str

path where data should be downloaded.

hf_id: str

the id of the repository.

filepath: str

the path of the file in the repo.

hf_revision: str, default=None

the revision of the repository (a tag, or a commit hash).

force_download: bool, default=False

whether the file should be downloaded even if it already exists in the local cache.

Returns:
weight_file: Path

local path to the model weights.

classmethod hub_split(hub_name)[source]

Interpret the input hub name specified in the form hf-hub:path/architecture_name@revision or ns-hub:path/architecture_name.

Parameters:
hub_name: str

name of the repository.

Returns:
hub_id: str

the id of the repository.

hub_revision: str

the revision of the repository.

load_checkpoint(model, *, map_location='cpu', **kwargs)[source]

Load the checkpoint.

Parameters:
model: LightningModule

an pytorch_lightning’s module class.

map_locationUnion[str, dict, torch.device]

Device mapping used when loading the checkpoint.

**kwargs: Any

Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values, in particular to override trainer parameters such as accelerator or devices.

Returns:
pl.LightningModule or None

The instantiated LightningModule loaded from the checkpoint. Returns None if the provided file is not recognized as a Lightning checkpoint.

load_pretrained(model, weights_only=True)[source]

Load the model weights.

Parameters:
model: torch.nn.Module

an input model with a load_pretrained method decalred.

weights_only: bool, default=False

Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals.

classmethod ns_download(data_dir, ns_id, filepath, force_download=False)[source]

Download a given file if not already present.

Downloads always resume when possible. If you want to force a new download, use force_download=True.

Parameters:
data_dir: pathlib.Path or str

path where data should be downloaded.

ns_id: str

the id of the repository.

filepath: str

the path of the file in the repo.

force_download: bool, default=False

whether the file should be downloaded even if it already exists in the local cache.

Returns:
weight_file: Path

local path to the model weights.