Menu

Deep learning for NeuroImaging in Python.

Source code for nidl.experiment

##########################################################################
# NSAp - Copyright (C) CEA, 2025
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

import collections
import copy
import importlib
import inspect
import itertools
import os
import warnings
from pprint import pprint
from typing import Optional

import toml
from toml.decoder import InlineTableDict

from .utils import Bunch, print_multicolor

SECTIONS = ("project", "global", "import",
            "scaler", "transform", "compose", "augmentation", "dataset",
            "dataloader", "model", "weights", "loss", "optimizer", "probe",
            "scheduler",
            "training",
            "environments")


[docs] def fetch_experiment( expfile: str, selector: Optional[tuple[str]] = None, cv: Optional[tuple[str]] = None, logdir: Optional[str] = None, verbose: int = 0): """ Fetch an experiement from an input configuration file. Allowed keys are: - project: define here some usefull information about your experiment such as the 'name', 'author', 'date'... - global: define here global variables that can be reused in other sections. - import: define here import that can be reused in other sections with the 'auto' mode (see desciption below). - scaler: dl interface - transform: dl interfaces - compose: dl interfaces - augmentation: dl interface - dataset: dl interface - dataloader: dl interface - model: dl interface - weights: dl interface - loss: dl interface - optimizer: dl interface - scheduler: dl interface - probe: dl interface - training: define here training settings. - environements: define here the interface to load in order to fullfil your needs and the constraint impose by the 'interface_occurrences' parameter (see desciption below). Interface definition: - the 'interface' key contains a name that specifies what class to import in absolute terms. - the 'interface_version' key contains the expected version of the loaded interface. The '__version__' module parameter is checked if available and a warning is displayed if is mismatched is detected or the version connot be checked. - other key are the interface parameters. - dynamic parameters can be defined by specifying where this parameter can be find a previously loaded interface, i.e., 'auto|<name>.<param>'. Note that the order is important here. - cross validation can be defined by specifying a list of values, i.e. 'cv|[1, 2, 3]'. This will automatically instanciate multiple interface, one for each input setting. The codes works as follows: - multiple interfaces of the same type can be returned. - the different choices must be described in an 'environments' section. - the output name will be prefixed by the environment name. - use the selector input parameters to filter the available interfaces in the input configuration file. How to define multiple building blocks: - the construction is hierarchic, i.e. child building blocks inherit the properties of the parent building block. - a child building block name contains the parent name as a prefix and use the '.' separator. The weights section special case: - model names specified in the form `hf-hub:path/architecture_name@revision` will be loaded from Hugging Face hub. - model names specifid with a path will be loaded from the local machine. Parameters ---------- expfile: str the experimental design file. selector: tuple of str, default=None if multiple interface of the same type are defined, this parameter allows you to select the appropriate environements. cv: tuple of str, default=None if a cross validation scheme is defined, this parameter allows you to select only some interfaces (i.e., the best set of hyperparapeters). logdir: str, defautl=None allows you to save a couple of information about the loaded interface: for the moment only the source code of each interface. verbose: int, default=0 enable verbose output. Returns ------- data: Bunch dictionaray-like object containing the experiment building blocks. """ if logdir is not None: assert os.path.isdir(logdir), "Please create the log directory!" config = toml.load(expfile, _dict=collections.OrderedDict) for key in config: assert key in SECTIONS, f"Unexpected section '{key}'!" settings = {key: config.pop(key) if key in config else None for key in ["project", "import", "global", "environments"]} selector = selector or [] for key in selector: assert key in settings["environments"], ( f"Unexpected environment '{key}'!") config_env = get_env(settings["global"], settings["import"]) config = filter_config(config, settings["environments"], selector) if verbose > 0: print(f"[{print_multicolor('Configuration', display=False)}]") pprint(config) interfaces = {} cv_interfaces = [name.split("_")[0] for name in cv or []] for key, params in config.items(): name = params.pop("interface") if "interface" in params else None if name is None: raise ValueError(f"No interface defined for '{key}'!") version = (params.pop("interface_version") if "interface_version" in params else None) if "interface_occurrences" in params: params.pop("interface_occurrences") params, param_sets = update_params(interfaces, params, config_env) is_cv = (len(params) > 1) for _idx, _params in enumerate(params): if verbose > 0: print(f"\n[{print_multicolor('Loading', display=False)}] " f"{name}..." f"\nParameters\n{'-'*10}") pprint(dict(_params)) _key = f"{key}_{_idx}" if is_cv else key if (not is_cv or cv is None or key not in cv_interfaces or _key in cv): interfaces[_key], code = load_interface(name, _params, version) if verbose > 0: print(f"Interface\n{'-'*9}\n{interfaces[_key]}") if code is not None and logdir is not None: logfile = os.path.join(logdir, name) with open(logfile, "w") as of: of.write(code) if is_cv: names = [f"{key}_{_idx}" for _idx in range(len(params))] _params = dict(zip(names, param_sets)) interfaces.setdefault("grid", Bunch())[key] = load_interface( "nidl.utils.Bunch", _params, None)[0] return Bunch(**interfaces)
[docs] def get_env( env: dict, modules: dict) -> dict: """ Dynamically update an environement. Parameters ---------- env: dict a environment to update. modules: dict some module to add in the current environment Returns ------- updated_env: dict the updated environemt with the input modules imported. """ updated_env = copy.copy(env or {}) if modules is not None: for key, name in modules.items(): if "." in name: module_name, object_name = name.rsplit(".", 1) else: module_name, object_name = name, None mod = importlib.import_module(module_name) if object_name is not None: updated_env[key] = getattr(mod, object_name) else: updated_env[key] = mod for key, val in updated_env.items(): if isinstance(val, str) and val.startswith("auto|"): attr = val.split("|")[-1] try: updated_env[key] = eval(attr, globals(), updated_env) except Exception as exc: print(f"\n[{print_multicolor('Help', display=False)}]..." f"\nEnvironment\n{'-'*11}") pprint(updated_env) raise ValueError( f"Can't find the '{attr}' dynamic global argument. Please " "check for a typo in your configuration file.") from exc return updated_env
[docs] def filter_config( config: dict, env: dict, selector: tuple[str]) -> dict: """ Filter configuration based on declared environements and user selector. Parameters ---------- config: dict the current configuration. env: dict the declared environements. selector: tuple of str if multiple interface of the same type are defined, this parameter allows you to select the appropriate environements. Returns ------- filter_conf: dict the filtered configuration. """ selected_env = {} for env_name in selector: for key, val in env[env_name].items(): if isinstance(val, list): selected_env.setdefault(key, []).extend(val) else: selected_env.setdefault(key, []).append(val) filter_config = collections.OrderedDict() for section, params in config.items(): if selected_env.get(section) == ["none"]: continue shared_params, multi_params = {}, [] for name in params: if (not isinstance(params[name], InlineTableDict) and isinstance(params[name], collections.OrderedDict)): assert section in selected_env, ( f"Multi-interface '{section}' environments not defined " "properly!") if name in selected_env[section]: multi_params.append((name, params[name])) else: shared_params[name] = params[name] n_envs = (1 if len(multi_params) == 0 else len(multi_params)) multi_envs = (len(multi_params) > 0) if ("interface_occurrences" in params and params["interface_occurrences"] != n_envs): raise ValueError( f"The maximum occurence of the '{section}' interface is not " f"respected: {params['interface_occurrences']} vs. {n_envs}. " "Please update the loaded environments accordingly.") if multi_envs: for name, _params in multi_params: _params.update(shared_params) if params.get("interface_occurrences") == 1: filter_config[section] = _params else: filter_config[f"{section}_{name}"] = _params else: filter_config[section] = shared_params return filter_config
[docs] def update_params( interfaces: dict, params: dict, env: dict) -> dict: """ Replace auto and cv parameters. Parameters ---------- interfaces: dict the currently loaded interfaces. params: dict the interface parameters. env: dict the local environment. Returns ------- updated_params: list of dict the interface parameters with the auto attributes replaced in place. In case of cross validation a list of parameters is returned. param_sets: list of dict the cross validation parameter sets. None means no cross validation. """ env.update(globals()) grid_search_params = {} for key, val in params.items(): if isinstance(val, str) and val.startswith(("auto|", "cv|")): attr = val.split("|")[-1] try: params[key] = eval(attr, interfaces, env) except Exception as exc: interfaces.pop("__builtins__") print(f"\n[{print_multicolor('Help', display=False)}]..." f"\nEnvironment\n{'-'*11}") pprint(env) print(f"\nInterfaces\n{'-'*10}") pprint(interfaces) raise ValueError( f"Can't find the '{attr}' dynamic argument. Please check " "for a typo in your configuration file.") from exc interfaces.pop("__builtins__") if isinstance(val, str) and val.startswith("cv|"): grid_search_params[key] = params[key] if len(grid_search_params) > 0: keys = grid_search_params.keys() param_sets = [ dict(zip(keys, values)) for values in itertools.product(*grid_search_params.values())] _params = [] for cv_params in param_sets: _params.append(copy.deepcopy(params)) _params[-1].update(cv_params) params = _params else: param_sets = None params = [params] return params, param_sets
[docs] def load_interface( name: str, params: dict, version: Optional[str]): """ Load an interface. Parameters ---------- name: str the interface name argument that specifies what class to import in absolute terms, i.e. 'my_module.my_class'. params: dict the interface parameters. version: str, default None the exppected modulee version. Returns ------- cls: object a class object. code: str the code of the output class object, None in case of issue. """ module_name, class_name = name.rsplit(".", 1) root_module_name = module_name.split(".")[0] root_mod = importlib.import_module(root_module_name) if version is not None: mod_version = getattr(root_mod, "__version__", None) if mod_version is None: warnings.warn( f"The '{module_name}' module has no '__version__' parameter!", ImportWarning, stacklevel=2) elif mod_version != version: warnings.warn( f"The '{name}' interface has a different version!", ImportWarning, stacklevel=2) mod = importlib.import_module(module_name) cls = getattr(mod, class_name) assert inspect.isclass(cls), "An interface MUST be defined as a class!" try: code = inspect.getsource(cls) except Exception: warnings.warn( f"Impossible to retrieve the '{name}' source code!", ImportWarning, stacklevel=2) code = None return cls(**params), code

Follow us

© 2025, nidl developers