Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.utils.weights.Weights¶
- class nidl.utils.weights.Weights(name, data_dir, filepath)[source]¶
Bases:
objectA class to handle (retrieve and apply) model weights or lightning checkpoints.
- Parameters:
- name: str
the location of the model weights specified in the form hf-hub:path/architecture_name@revision if available in Hugging Face hub or ns-hub:path/architecture_name if available in the NeuroSpin hub or a path if avaiable in your local machine.
- data_dir: pathlib.Path or str
path where data should be downloaded.
- filepath: str
the path of the file in the repo. If path has ‘.ckpt’ extension, it assumes it is a pytorch_lightning checkpoint.
- HF_URL = 'https://huggingface.co'¶
- NS_URL = 'http://nsap.intra.cea.fr/neurospin-hub/'¶
- classmethod hf_download(data_dir, hf_id, filepath, hf_revision=None, force_download=False)[source]¶
Download a given file if not already present.
Downloads always resume when possible. If you want to force a new download, use force_download=True.
- Parameters:
- data_dir: pathlib.Path or str
path where data should be downloaded.
- hf_id: str
the id of the repository.
- filepath: str
the path of the file in the repo.
- hf_revision: str, default=None
the revision of the repository (a tag, or a commit hash).
- force_download: bool, default=False
whether the file should be downloaded even if it already exists in the local cache.
- Returns:
- weight_file: Path
local path to the model weights.
- classmethod hub_split(hub_name)[source]¶
Interpret the input hub name specified in the form hf-hub:path/architecture_name@revision or ns-hub:path/architecture_name.
- Parameters:
- hub_name: str
name of the repository.
- Returns:
- hub_id: str
the id of the repository.
- hub_revision: str
the revision of the repository.
- load_checkpoint(model, **kwargs)[source]¶
Load the checkpoint.
- Parameters:
- model: LightningModule
an pytorch_lightning’s module class.
- device: torch.device or str
the device on which to load the model and to use for inference. Default to cpu. Only single device is supported for now.
- **kwargs: Any extra keyword args needed to init the model. Can also be
used to override saved hyperparameter values, in particular to override trainer parameters such as accelerator or devices.
- load_pretrained(model)[source]¶
Load the model weights.
- Parameters:
- model: torch.nn.Module
an input model with a load_pretrained method decalred.
- classmethod ns_download(data_dir, ns_id, filepath, force_download=False)[source]¶
Download a given file if not already present.
Downloads always resume when possible. If you want to force a new download, use force_download=True.
- Parameters:
- data_dir: pathlib.Path or str
path where data should be downloaded.
- ns_id: str
the id of the repository.
- filepath: str
the path of the file in the repo.
- force_download: bool, default=False
whether the file should be downloaded even if it already exists in the local cache.
- Returns:
- weight_file: Path
local path to the model weights.
Examples using nidl.utils.weights.Weights¶
Self-Supervised Contrastive Learning with SimCLR on MNIST