Deep learning for NeuroImaging in Python.
Source code for nidl.experiment
##########################################################################
# NSAp - Copyright (C) CEA, 2025
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
import collections
import copy
import importlib
import inspect
import itertools
import os
import warnings
from pprint import pprint
from typing import Optional
import toml
from .utils import Bunch, print_multicolor
SECTIONS = ("project", "global", "import",
"scaler", "transform", "compose", "augmentation", "dataset",
"dataloader", "model", "weights", "loss", "optimizer", "probe",
"scheduler",
"training",
"environments")
[docs]
def fetch_experiment(
expfile: str,
selector: Optional[tuple[str]] = None,
cv: Optional[tuple[str]] = None,
logdir: Optional[str] = None,
verbose: int = 0):
""" Fetch an experiement from an input configuration file.
Allowed keys are:
- project: define here some usefull information about your experiment such
as the 'name', 'author', 'date'...
- global: define here global variables that can be reused in other
sections.
- import: define here import that can be reused in other sections with the
'auto' mode (see desciption below).
- scaler: dl interface
- transform: dl interfaces
- compose: dl interfaces
- augmentation: dl interface
- dataset: dl interface
- dataloader: dl interface
- model: dl interface
- weights: dl interface
- loss: dl interface
- optimizer: dl interface
- scheduler: dl interface
- probe: dl interface
- training: define here training settings.
- environements: define here the interface to load in order to fullfil your
needs and the constraint impose by the 'interface_occurrences'
parameter (see desciption below).
Interface definition:
- the 'interface' key contains a name that specifies what class to import
in absolute terms.
- the 'interface_version' key contains the expected version of the loaded
interface. The '__version__' module parameter is checked if available
and a warning is displayed if is mismatched is detected or the version
connot be checked.
- other key are the interface parameters.
- dynamic parameters can be defined by specifying where this parameter
can be find a previously loaded interface, i.e., 'auto|<name>.<param>'.
Note that the order is important here.
- cross validation can be defined by specifying a list of values, i.e.
'cv|[1, 2, 3]'. This will automatically instanciate multiple interface,
one for each input setting.
The codes works as follows:
- multiple interfaces of the same type can be returned.
- the different choices must be described in an 'environments' section.
- the output name will be prefixed by the environment name.
- use the selector input parameters to filter the available interfaces in
the input configuration file.
How to define multiple building blocks:
- the construction is hierarchic, i.e. child building blocks inherit
the properties of the parent building block.
- a child building block name contains the parent name as a prefix and use
the '.' separator.
The weights section special case:
- model names specified in the form `hf-hub:path/architecture_name@revision`
will be loaded from Hugging Face hub.
- model names specifid with a path will be loaded from the local machine.
Parameters
----------
expfile: str
the experimental design file.
selector: tuple of str, default=None
if multiple interface of the same type are defined, this parameter
allows you to select the appropriate environements.
cv: tuple of str, default=None
if a cross validation scheme is defined, this parameter allows you to
select only some interfaces (i.e., the best set of hyperparapeters).
logdir: str, defautl=None
allows you to save a couple of information about the loaded interface:
for the moment only the source code of each interface.
verbose: int, default=0
enable verbose output.
Returns
-------
data: Bunch
dictionaray-like object containing the experiment building blocks.
"""
if logdir is not None:
assert os.path.isdir(logdir), "Please create the log directory!"
config = toml.load(expfile, _dict=collections.OrderedDict)
for key in config:
assert key in SECTIONS, f"Unexpected section '{key}'!"
settings = {key: config.pop(key) if key in config else None
for key in ["project", "import", "global", "environments"]}
selector = selector or []
for key in selector:
assert key in settings["environments"], (
f"Unexpected environment '{key}'!")
config_env = get_env(settings["global"], settings["import"])
config = filter_config(config, settings["environments"], selector)
if verbose > 0:
print(f"[{print_multicolor('Configuration', display=False)}]")
pprint(config)
interfaces = {}
cv_interfaces = [name.split("_")[0] for name in cv or []]
for key, params in config.items():
name = params.pop("interface") if "interface" in params else None
if name is None:
raise ValueError(f"No interface defined for '{key}'!")
version = (params.pop("interface_version")
if "interface_version" in params else None)
if "interface_occurrences" in params:
params.pop("interface_occurrences")
params, param_sets = update_params(interfaces, params, config_env)
is_cv = (len(params) > 1)
for _idx, _params in enumerate(params):
if verbose > 0:
print(f"\n[{print_multicolor('Loading', display=False)}] "
f"{name}..."
f"\nParameters\n{'-'*10}")
pprint(dict(_params))
_key = f"{key}_{_idx}" if is_cv else key
if (not is_cv or cv is None or key not in cv_interfaces or
_key in cv):
interfaces[_key], code = load_interface(name, _params, version)
if verbose > 0:
print(f"Interface\n{'-'*9}\n{interfaces[_key]}")
if code is not None and logdir is not None:
logfile = os.path.join(logdir, name)
with open(logfile, "w") as of:
of.write(code)
if is_cv:
names = [f"{key}_{_idx}" for _idx in range(len(params))]
_params = dict(zip(names, param_sets))
interfaces.setdefault("grid", Bunch())[key] = load_interface(
"nidl.utils.Bunch", _params, None)[0]
return Bunch(**interfaces)
[docs]
def get_env(
env: dict,
modules: dict) -> dict:
""" Dynamically update an environement.
Parameters
----------
env: dict
a environment to update.
modules: dict
some module to add in the current environment
Returns
-------
updated_env: dict
the updated environemt with the input modules imported.
"""
updated_env = copy.copy(env or {})
if modules is not None:
for key, name in modules.items():
if "." in name:
module_name, object_name = name.rsplit(".", 1)
else:
module_name, object_name = name, None
mod = importlib.import_module(module_name)
if object_name is not None:
updated_env[key] = getattr(mod, object_name)
else:
updated_env[key] = mod
for key, val in updated_env.items():
if isinstance(val, str) and val.startswith("auto|"):
attr = val.split("|")[-1]
try:
updated_env[key] = eval(attr, globals(), updated_env)
except Exception as exc:
print(f"\n[{print_multicolor('Help', display=False)}]..."
f"\nEnvironment\n{'-'*11}")
pprint(updated_env)
raise ValueError(
f"Can't find the '{attr}' dynamic global argument. Please "
"check for a typo in your configuration file.") from exc
return updated_env
[docs]
def filter_config(
config: dict,
env: dict,
selector: tuple[str]) -> dict:
""" Filter configuration based on declared environements and user selector.
Parameters
----------
config: dict
the current configuration.
env: dict
the declared environements.
selector: tuple of str
if multiple interface of the same type are defined, this parameter
allows you to select the appropriate environements.
Returns
-------
filter_conf: dict
the filtered configuration.
"""
selected_env = {}
for env_name in selector:
for key, val in env[env_name].items():
if isinstance(val, list):
selected_env.setdefault(key, []).extend(val)
else:
selected_env.setdefault(key, []).append(val)
filter_config = collections.OrderedDict()
for section, params in config.items():
if selected_env.get(section) == ["none"]:
continue
shared_params, multi_params = {}, []
for name in params:
if isinstance(params[name], collections.OrderedDict):
assert section in selected_env, (
f"Multi-interface '{section}' environments not defined "
"properly!")
if name in selected_env[section]:
multi_params.append((name, params[name]))
else:
shared_params[name] = params[name]
n_envs = (1 if len(multi_params) == 0 else len(multi_params))
multi_envs = (len(multi_params) > 0)
if ("interface_occurrences" in params
and params["interface_occurrences"] != n_envs):
raise ValueError(
f"The maximum occurence of the '{section}' interface is not "
f"respected: {params['interface_occurrences']} vs. {n_envs}. "
"Please update the loaded environments accordingly.")
if multi_envs:
for name, _params in multi_params:
_params.update(shared_params)
if params.get("interface_occurrences") == 1:
filter_config[section] = _params
else:
filter_config[f"{section}_{name}"] = _params
else:
filter_config[section] = shared_params
return filter_config
[docs]
def update_params(
interfaces: dict,
params: dict,
env: dict) -> dict:
""" Replace auto and cv parameters.
Parameters
----------
interfaces: dict
the currently loaded interfaces.
params: dict
the interface parameters.
env: dict
the local environment.
Returns
-------
updated_params: list of dict
the interface parameters with the auto attributes replaced in place. In
case of cross validation a list of parameters is returned.
param_sets: list of dict
the cross validation parameter sets. None means no cross validation.
"""
env.update(globals())
grid_search_params = {}
for key, val in params.items():
if isinstance(val, str) and val.startswith(("auto|", "cv|")):
attr = val.split("|")[-1]
try:
params[key] = eval(attr, interfaces, env)
except Exception as exc:
interfaces.pop("__builtins__")
print(f"\n[{print_multicolor('Help', display=False)}]..."
f"\nEnvironment\n{'-'*11}")
pprint(env)
print(f"\nInterfaces\n{'-'*10}")
pprint(interfaces)
raise ValueError(
f"Can't find the '{attr}' dynamic argument. Please check "
"for a typo in your configuration file.") from exc
interfaces.pop("__builtins__")
if isinstance(val, str) and val.startswith("cv|"):
grid_search_params[key] = params[key]
if len(grid_search_params) > 0:
keys = grid_search_params.keys()
param_sets = [
dict(zip(keys, values))
for values in itertools.product(*grid_search_params.values())]
_params = []
for cv_params in param_sets:
_params.append(copy.deepcopy(params))
_params[-1].update(cv_params)
params = _params
else:
param_sets = None
params = [params]
return params, param_sets
[docs]
def load_interface(
name: str,
params: dict,
version: Optional[str]):
""" Load an interface.
Parameters
----------
name: str
the interface name argument that specifies what class to
import in absolute terms, i.e. 'my_module.my_class'.
params: dict
the interface parameters.
version: str, default None
the exppected modulee version.
Returns
-------
cls: object
a class object.
code: str
the code of the output class object, None in case of issue.
"""
module_name, class_name = name.rsplit(".", 1)
root_module_name = module_name.split(".")[0]
root_mod = importlib.import_module(root_module_name)
if version is not None:
mod_version = getattr(root_mod, "__version__", None)
if mod_version is None:
warnings.warn(
f"The '{module_name}' module has no '__version__' parameter!",
ImportWarning, stacklevel=2)
elif mod_version != version:
warnings.warn(
f"The '{name}' interface has a different version!",
ImportWarning, stacklevel=2)
mod = importlib.import_module(module_name)
cls = getattr(mod, class_name)
assert inspect.isclass(cls), "An interface MUST be defined as a class!"
try:
code = inspect.getsource(cls)
except Exception:
warnings.warn(
f"Impossible to retrieve the '{name}' source code!",
ImportWarning, stacklevel=2)
code = None
return cls(**params), code
Follow us