Menu

Deep learning for NeuroImaging in Python.

Source code for nidl.experiment

##########################################################################
# NSAp - Copyright (C) CEA, 2025
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

import collections
import copy
import importlib
import inspect
import itertools
import os
import warnings
from pprint import pprint
from typing import Optional

import toml

from .utils import Bunch, print_multicolor

SECTIONS = ("project", "global", "import",
            "scaler", "transform", "compose", "augmentation", "dataset",
            "dataloader", "model", "weights", "loss", "optimizer", "probe",
            "scheduler",
            "training",
            "environments")


[docs] def fetch_experiment( expfile: str, selector: Optional[tuple[str]] = None, cv: Optional[tuple[str]] = None, logdir: Optional[str] = None, verbose: int = 0): """ Fetch an experiement from an input configuration file. Allowed keys are: - project: define here some usefull information about your experiment such as the 'name', 'author', 'date'... - global: define here global variables that can be reused in other sections. - import: define here import that can be reused in other sections with the 'auto' mode (see desciption below). - scaler: dl interface - transform: dl interfaces - compose: dl interfaces - augmentation: dl interface - dataset: dl interface - dataloader: dl interface - model: dl interface - weights: dl interface - loss: dl interface - optimizer: dl interface - scheduler: dl interface - probe: dl interface - training: define here training settings. - environements: define here the interface to load in order to fullfil your needs and the constraint impose by the 'interface_occurrences' parameter (see desciption below). Interface definition: - the 'interface' key contains a name that specifies what class to import in absolute terms. - the 'interface_version' key contains the expected version of the loaded interface. The '__version__' module parameter is checked if available and a warning is displayed if is mismatched is detected or the version connot be checked. - other key are the interface parameters. - dynamic parameters can be defined by specifying where this parameter can be find a previously loaded interface, i.e., 'auto|<name>.<param>'. Note that the order is important here. - cross validation can be defined by specifying a list of values, i.e. 'cv|[1, 2, 3]'. This will automatically instanciate multiple interface, one for each input setting. The codes works as follows: - multiple interfaces of the same type can be returned. - the different choices must be described in an 'environments' section. - the output name will be prefixed by the environment name. - use the selector input parameters to filter the available interfaces in the input configuration file. How to define multiple building blocks: - the construction is hierarchic, i.e. child building blocks inherit the properties of the parent building block. - a child building block name contains the parent name as a prefix and use the '.' separator. The weights section special case: - model names specified in the form `hf-hub:path/architecture_name@revision` will be loaded from Hugging Face hub. - model names specifid with a path will be loaded from the local machine. Parameters ---------- expfile: str the experimental design file. selector: tuple of str, default=None if multiple interface of the same type are defined, this parameter allows you to select the appropriate environements. cv: tuple of str, default=None if a cross validation scheme is defined, this parameter allows you to select only some interfaces (i.e., the best set of hyperparapeters). logdir: str, defautl=None allows you to save a couple of information about the loaded interface: for the moment only the source code of each interface. verbose: int, default=0 enable verbose output. Returns ------- data: Bunch dictionaray-like object containing the experiment building blocks. """ if logdir is not None: assert os.path.isdir(logdir), "Please create the log directory!" config = toml.load(expfile, _dict=collections.OrderedDict) for key in config: assert key in SECTIONS, f"Unexpected section '{key}'!" settings = {key: config.pop(key) if key in config else None for key in ["project", "import", "global", "environments"]} selector = selector or [] for key in selector: assert key in settings["environments"], ( f"Unexpected environment '{key}'!") config_env = get_env(settings["global"], settings["import"]) config = filter_config(config, settings["environments"], selector) if verbose > 0: print(f"[{print_multicolor('Configuration', display=False)}]") pprint(config) interfaces = {} cv_interfaces = [name.split("_")[0] for name in cv or []] for key, params in config.items(): name = params.pop("interface") if "interface" in params else None if name is None: raise ValueError(f"No interface defined for '{key}'!") version = (params.pop("interface_version") if "interface_version" in params else None) if "interface_occurrences" in params: params.pop("interface_occurrences") params, param_sets = update_params(interfaces, params, config_env) is_cv = (len(params) > 1) for _idx, _params in enumerate(params): if verbose > 0: print(f"\n[{print_multicolor('Loading', display=False)}] " f"{name}..." f"\nParameters\n{'-'*10}") pprint(dict(_params)) _key = f"{key}_{_idx}" if is_cv else key if (not is_cv or cv is None or key not in cv_interfaces or _key in cv): interfaces[_key], code = load_interface(name, _params, version) if verbose > 0: print(f"Interface\n{'-'*9}\n{interfaces[_key]}") if code is not None and logdir is not None: logfile = os.path.join(logdir, name) with open(logfile, "w") as of: of.write(code) if is_cv: names = [f"{key}_{_idx}" for _idx in range(len(params))] _params = dict(zip(names, param_sets)) interfaces.setdefault("grid", Bunch())[key] = load_interface( "nidl.utils.Bunch", _params, None)[0] return Bunch(**interfaces)
[docs] def get_env( env: dict, modules: dict) -> dict: """ Dynamically update an environement. Parameters ---------- env: dict a environment to update. modules: dict some module to add in the current environment Returns ------- updated_env: dict the updated environemt with the input modules imported. """ updated_env = copy.copy(env or {}) if modules is not None: for key, name in modules.items(): if "." in name: module_name, object_name = name.rsplit(".", 1) else: module_name, object_name = name, None mod = importlib.import_module(module_name) if object_name is not None: updated_env[key] = getattr(mod, object_name) else: updated_env[key] = mod for key, val in updated_env.items(): if isinstance(val, str) and val.startswith("auto|"): attr = val.split("|")[-1] try: updated_env[key] = eval(attr, globals(), updated_env) except Exception as exc: print(f"\n[{print_multicolor('Help', display=False)}]..." f"\nEnvironment\n{'-'*11}") pprint(updated_env) raise ValueError( f"Can't find the '{attr}' dynamic global argument. Please " "check for a typo in your configuration file.") from exc return updated_env
[docs] def filter_config( config: dict, env: dict, selector: tuple[str]) -> dict: """ Filter configuration based on declared environements and user selector. Parameters ---------- config: dict the current configuration. env: dict the declared environements. selector: tuple of str if multiple interface of the same type are defined, this parameter allows you to select the appropriate environements. Returns ------- filter_conf: dict the filtered configuration. """ selected_env = {} for env_name in selector: for key, val in env[env_name].items(): if isinstance(val, list): selected_env.setdefault(key, []).extend(val) else: selected_env.setdefault(key, []).append(val) filter_config = collections.OrderedDict() for section, params in config.items(): if selected_env.get(section) == ["none"]: continue shared_params, multi_params = {}, [] for name in params: if isinstance(params[name], collections.OrderedDict): assert section in selected_env, ( f"Multi-interface '{section}' environments not defined " "properly!") if name in selected_env[section]: multi_params.append((name, params[name])) else: shared_params[name] = params[name] n_envs = (1 if len(multi_params) == 0 else len(multi_params)) multi_envs = (len(multi_params) > 0) if ("interface_occurrences" in params and params["interface_occurrences"] != n_envs): raise ValueError( f"The maximum occurence of the '{section}' interface is not " f"respected: {params['interface_occurrences']} vs. {n_envs}. " "Please update the loaded environments accordingly.") if multi_envs: for name, _params in multi_params: _params.update(shared_params) if params.get("interface_occurrences") == 1: filter_config[section] = _params else: filter_config[f"{section}_{name}"] = _params else: filter_config[section] = shared_params return filter_config
[docs] def update_params( interfaces: dict, params: dict, env: dict) -> dict: """ Replace auto and cv parameters. Parameters ---------- interfaces: dict the currently loaded interfaces. params: dict the interface parameters. env: dict the local environment. Returns ------- updated_params: list of dict the interface parameters with the auto attributes replaced in place. In case of cross validation a list of parameters is returned. param_sets: list of dict the cross validation parameter sets. None means no cross validation. """ env.update(globals()) grid_search_params = {} for key, val in params.items(): if isinstance(val, str) and val.startswith(("auto|", "cv|")): attr = val.split("|")[-1] try: params[key] = eval(attr, interfaces, env) except Exception as exc: interfaces.pop("__builtins__") print(f"\n[{print_multicolor('Help', display=False)}]..." f"\nEnvironment\n{'-'*11}") pprint(env) print(f"\nInterfaces\n{'-'*10}") pprint(interfaces) raise ValueError( f"Can't find the '{attr}' dynamic argument. Please check " "for a typo in your configuration file.") from exc interfaces.pop("__builtins__") if isinstance(val, str) and val.startswith("cv|"): grid_search_params[key] = params[key] if len(grid_search_params) > 0: keys = grid_search_params.keys() param_sets = [ dict(zip(keys, values)) for values in itertools.product(*grid_search_params.values())] _params = [] for cv_params in param_sets: _params.append(copy.deepcopy(params)) _params[-1].update(cv_params) params = _params else: param_sets = None params = [params] return params, param_sets
[docs] def load_interface( name: str, params: dict, version: Optional[str]): """ Load an interface. Parameters ---------- name: str the interface name argument that specifies what class to import in absolute terms, i.e. 'my_module.my_class'. params: dict the interface parameters. version: str, default None the exppected modulee version. Returns ------- cls: object a class object. code: str the code of the output class object, None in case of issue. """ module_name, class_name = name.rsplit(".", 1) root_module_name = module_name.split(".")[0] root_mod = importlib.import_module(root_module_name) if version is not None: mod_version = getattr(root_mod, "__version__", None) if mod_version is None: warnings.warn( f"The '{module_name}' module has no '__version__' parameter!", ImportWarning, stacklevel=2) elif mod_version != version: warnings.warn( f"The '{name}' interface has a different version!", ImportWarning, stacklevel=2) mod = importlib.import_module(module_name) cls = getattr(mod, class_name) assert inspect.isclass(cls), "An interface MUST be defined as a class!" try: code = inspect.getsource(cls) except Exception: warnings.warn( f"Impossible to retrieve the '{name}' source code!", ImportWarning, stacklevel=2) code = None return cls(**params), code

Follow us

© 2025, nidl developers