Menu

Deep learning for NeuroImaging in Python.

Source code for surfify.models.vae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Cortical Spherical Variational Auto-Encoder (GMVAE) models.

[1] Representation Learning of Resting State fMRI with Variational
Autoencoder: https://github.com/libilab/rsfMRI-VAE
"""

# Imports
import torch
import torch.nn as nn
from torch.distributions import Normal
from numpy import sqrt
from ..utils import get_logger, debug_msg
from ..nn import IcoUpConv, IcoPool, IcoSpMaConv, IcoSpMaConvTranspose
from .base import SphericalBase


# Global parameters
logger = get_logger()


[docs] class SphericalVAE(nn.Module): """ Spherical VAE architecture. Use either RePa - Rectangular Patch convolution method or DiNe - Direct Neighbor convolution method. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalGVAE References ---------- Representation Learning of Resting State fMRI with Variational Autoencoder, NeuroImage 2021. Examples -------- >>> import torch >>> from surfify.utils import icosahedron >>> from surfify.models import SphericalVAE >>> verts, tris = icosahedron(order=6) >>> x = torch.zeros((1, 2, len(verts))) >>> model = SphericalVAE( >>> input_channels=2, input_order=6, latent_dim=64, >>> conv_flts=[32, 32, 64, 64], conv_mode="DiNe", dine_size=1, >>> fusion_level=2, standard_ico=False") >>> print(model) >>> out = model(x, x) >>> print(out[0].shape, out[1].shape) """ def __init__(self, input_channels=1, input_order=5, input_dim=192, latent_dim=64, conv_flts=(32, 32, 64, 64), fusion_level=1, activation="LeakyReLU", batch_norm=False, conv_mode="DiNe", cachedir=None, encoder=None, decoder=None, *args, **kwargs): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_order: int, default 5 the input icosahedron order. latent_dim: int, default 64 the size of the stochastic latent state of the SVAE. conv_flts: list of int the size of convolutional filters. conv_mode: str, default 'DiNe' use either 'RePa' - Rectangular Patch convolution method or 'DiNe' - 1 ring Direct Neighbor convolution method. dine_size: int, default 1 the size of the spherical convolution filter, ie. the number of neighbor rings to be considered. repa_size: int, default 5 the size of the rectangular grid in the tangent space. repa_zoom: int, default 5 control the rectangular grid spacing in the tangent space by applying a multiplicative factor of `1 / repa_zoom`. dynamic_repa_zoom: bool, default False dynamically adapt the RePa zoom by applying a multiplicative factor of `log(order + 1) + 1`. fusion_level: int, default 1 at which max pooling level left and right hemisphere data are concatenated. standard_ico: bool, default False optionaly use surfify tesselation. cachedir: str, default None set this folder to use smart caching speedup. """ logger.debug("SphericalVAE init...") super().__init__() assert conv_mode in ["DiNe", "RePa", "SpMa"] use_grid = conv_mode == "SpMa" if use_grid and encoder is None: encoder = HemiFusionEncoder( input_channels, input_dim, latent_dim * 2, conv_flts, fusion_level, activation, batch_norm, *args, **kwargs) elif encoder is None: encoder = SphericalHemiFusionEncoder( input_channels, input_order, latent_dim * 2, conv_flts, fusion_level, activation, batch_norm, conv_mode, *args, **kwargs, cachedir=cachedir) if use_grid and decoder is None: decoder = HemiFusionDecoder( [input_channels, input_dim, input_dim], encoder.flatten_dim, latent_dim, conv_flts, fusion_level, activation, batch_norm, *args, **kwargs) elif decoder is None: decoder = SphericalHemiFusionDecoder( input_channels, input_order, latent_dim, conv_flts, fusion_level, activation, batch_norm, conv_mode, *args, **kwargs, cachedir=cachedir) self.encoder = encoder self.decoder = decoder
[docs] def encode(self, left_x, right_x): """ The encoder. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- q(z | x): Normal (batch_size, <latent_dim>) a Normal distribution. """ x = self.encoder((left_x, right_x)) z_mu, z_logvar = torch.chunk(x, chunks=2, dim=1) return Normal(loc=z_mu, scale=(z_logvar * 0.5).exp())
[docs] def decode(self, z): """ The decoder. Parameters ---------- z: Tensor (samples, <latent_dim>) the stochastic latent state z. Returns ------- left_recon_x: Tensor (samples, <input_channels>, n_vertices) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, n_vertices) reconstructed right cortical texture. """ left_recon_x, right_recon_x = self.decoder(z) return left_recon_x, right_recon_x
[docs] def reparameterize(self, q, sample=True): """ Implement the reparametrization trick. """ if sample: return q.rsample() return q.loc
[docs] def forward(self, left_x, right_x, sample=True): """ The forward method. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ logger.debug("SphericalVAE forward pass") logger.debug(debug_msg("left cortical", left_x)) logger.debug(debug_msg("right cortical", right_x)) q = self.encode(left_x, right_x) logger.debug(debug_msg("posterior loc", q.loc)) logger.debug(debug_msg("posterior scale", q.scale)) z = self.reparameterize(q, sample) logger.debug(debug_msg("z", z)) left_recon_x, right_recon_x = self.decode(z) logger.debug(debug_msg("left recon cortical", left_recon_x)) logger.debug(debug_msg("right recon cortical", right_recon_x)) return left_recon_x, right_recon_x, {"q": q, "z": z}
[docs] class SphericalHemiFusionEncoder(SphericalBase): def __init__(self, input_channels, input_order, latent_dim, conv_flts=(64, 128, 128, 256, 256), fusion_level=1, activation="LeakyReLU", batch_norm=False, conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5, dynamic_repa_zoom=False, standard_ico=False, cachedir=None): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. latent_dim: int, default 64 the size of the latent space it encodes to. conv_flts: list of int the size of convolutional filters. fusion_level: int, default 1 at which max pooling level left and right hemisphere data are concatenated. activation: str, default 'LeakyReLU' activation function's class name in pytorch's nn module to use after each convolution batch_norm: bool, default False optionally uses batch normalization after each convolution """ logger.debug("SphericalHemiFusionEncoder init...") super().__init__( input_order=input_order, n_layers=len(conv_flts), conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size, repa_zoom=repa_zoom, dynamic_repa_zoom=dynamic_repa_zoom, standard_ico=standard_ico, cachedir=cachedir) self.input_channels = input_channels self.conv_flts = list(conv_flts) self.activation = getattr(nn, activation)(inplace=True) self.n_vertices_down = len( self.ico[self.input_order - self.n_layers].vertices) logger.debug(" number of vertices small ico : {}".format( self.n_vertices_down)) self.flatten_dim = conv_flts[-1] * self.n_vertices_down logger.debug(" dimension for linear {}".format(self.flatten_dim)) if fusion_level > self.n_layers or fusion_level <= 0: raise ValueError("Impossible to use input fusion level with " "'{0}' layers.".format(self.n_layers)) self.fusion_level = fusion_level self.latent_dim = latent_dim self.left_conv = nn.Sequential() self.right_conv = nn.Sequential() self.w_conv = nn.Sequential() input_channels = self.input_channels for idx in range(self.n_layers): order = self.input_order - idx output_channels = self.conv_flts[idx] pooling = IcoPool( down_neigh_indices=self.ico[order].neighbor_indices, down_indices=self.ico[order].down_indices, pooling_type="mean") if idx < fusion_level: output_channels = int(output_channels / 2) lconv = self.sconv( input_channels, output_channels, self.ico[order].conv_neighbor_indices) self.left_conv.add_module("l_conv_{0}".format(idx), lconv) if batch_norm: self.left_conv.add_module( "l_bn_{0}".format(idx), nn.BatchNorm1d(output_channels)) self.left_conv.add_module("pooling_{0}".format(idx), pooling) rconv = self.sconv( input_channels, output_channels, self.ico[order].conv_neighbor_indices) self.right_conv.add_module("r_conv_{0}".format(idx), rconv) if batch_norm: self.right_conv.add_module( "r_bn_{0}".format(idx), nn.BatchNorm1d(output_channels)) self.right_conv.add_module("pooling_{0}".format(idx), pooling) input_channels = output_channels else: input_channels = self.conv_flts[idx - 1] conv = self.sconv( input_channels, output_channels, self.ico[order].conv_neighbor_indices) self.w_conv.add_module("conv_{0}".format(idx), conv) if batch_norm: self.w_conv.add_module( "bn_{0}".format(idx), nn.BatchNorm1d(self.conv_flts[idx])) self.w_conv.add_module("pooling_{0}".format(idx), pooling) self.w_dense = nn.Linear(self.flatten_dim, self.latent_dim)
[docs] def forward(self, x): """ The encoding. Parameters ---------- left_x: Tensor (batch_size, <input_channels>, n_vertices) input left cortical textures. right_x: Tensor (batch_size, <input_channels>, n_vertices) input right cortical textures. Returns ------- x: Tensor (batch_size, <latent_dim>) the latent representations. """ left_x, right_x = x logger.debug("SphericalHemiFusionEncoder forward pass") logger.debug(debug_msg(" left cortical", left_x)) logger.debug(debug_msg(" right cortical", right_x)) left_x = self._safe_forward( self.left_conv, left_x, act=self.activation, skip_last_act=True) right_x = self._safe_forward( self.right_conv, right_x, act=self.activation, skip_last_act=True) logger.debug(debug_msg(" left enc", left_x)) logger.debug(debug_msg(" right enc", right_x)) x = torch.cat((left_x, right_x), dim=1) x = self.activation(x) logger.debug(debug_msg(" merged enc", x)) x = self._safe_forward(self.w_conv, x, act=self.activation) logger.debug(debug_msg(" final conv enc", x)) x = x.view(-1, self.flatten_dim) logger.debug(debug_msg(" flattened", x)) x = self.w_dense(x) return x
[docs] class SphericalHemiFusionDecoder(SphericalBase): def __init__(self, input_channels, input_order, latent_dim, conv_flts=(64, 128, 128, 256, 256), fusion_level=1, activation="LeakyReLU", batch_norm=False, conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5, dynamic_repa_zoom=False, standard_ico=False, cachedir=None): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. latent_dim: int, default 64 the size of the latent space it encodes to. conv_flts: list of int the size of convolutional filters. fusion_level: int, default 1 at which max pooling level left and right hemisphere data are concatenated. activation: str, default 'LeakyReLU' activation function's class name in pytorch's nn module to use after each convolution batch_norm: bool, default False optionally uses batch normalization after each convolution """ logger.debug("SphericalHemiFusionDecoder init...") super().__init__( input_order=input_order, n_layers=len(conv_flts), conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size, repa_zoom=repa_zoom, dynamic_repa_zoom=dynamic_repa_zoom, standard_ico=standard_ico, cachedir=cachedir) self.input_channels = input_channels self.latent_dim = latent_dim self.conv_flts = list(conv_flts) self.activation = getattr(nn, activation)(inplace=True) self.n_vertices_down = len( self.ico[self.input_order - self.n_layers].vertices) logger.debug(" number of vertices small ico : {}".format( self.n_vertices_down)) self.flatten_dim = conv_flts[-1] * self.n_vertices_down logger.debug(" dimension for linear {}".format(self.flatten_dim)) if fusion_level > self.n_layers or fusion_level <= 0: raise ValueError("Impossible to use input fusion level with " "'{0}' layers.".format(self.n_layers)) self.fusion_level = fusion_level self.latent_dim = latent_dim self.w_dense = nn.Linear(self.latent_dim, self.flatten_dim) self.w_conv = nn.Sequential() self.left_conv = nn.Sequential() self.right_conv = nn.Sequential() input_channels = self.conv_flts[self.n_layers - 1] for idx in range(self.n_layers - 1, -1, -1): order = self.input_order - idx - 1 input_channels = self.conv_flts[idx] output_channels = self.input_channels * 2 if idx != 0: output_channels = self.conv_flts[idx - 1] if idx < fusion_level: output_channels = int(output_channels / 2) input_channels = int(input_channels / 2) logger.debug("input channels : {}".format(input_channels)) logger.debug("output channels : {}".format(output_channels)) l_pooling = IcoUpConv( in_feats=input_channels, out_feats=output_channels, up_neigh_indices=self.ico[order + 1].neighbor_indices, down_indices=self.ico[order + 1].down_indices) lconv = self.sconv( output_channels, output_channels, self.ico[order + 1].conv_neighbor_indices) self.left_conv.add_module( "l_pooling_{0}".format(idx), l_pooling) self.left_conv.add_module("l_conv_{0}".format(idx), lconv) if batch_norm: self.left_conv.add_module( "l_bn_{0}".format(idx), nn.BatchNorm1d(output_channels)) r_pooling = IcoUpConv( in_feats=input_channels, out_feats=output_channels, up_neigh_indices=self.ico[order + 1].neighbor_indices, down_indices=self.ico[order + 1].down_indices) rconv = self.sconv( output_channels, output_channels, self.ico[order + 1].conv_neighbor_indices) self.right_conv.add_module( "r_pooling_{0}".format(idx), r_pooling) self.right_conv.add_module("r_conv_{0}".format(idx), rconv) if batch_norm: self.right_conv.add_module( "r_bn_{0}".format(idx), nn.BatchNorm1d(output_channels)) else: logger.debug("input channels : {}".format(input_channels)) logger.debug("output channels : {}".format(output_channels)) logger.debug("order : {}".format(order)) pooling = IcoUpConv( in_feats=input_channels, out_feats=output_channels, up_neigh_indices=self.ico[order + 1].neighbor_indices, down_indices=self.ico[order + 1].down_indices) conv = self.sconv( output_channels, output_channels, self.ico[order + 1].conv_neighbor_indices) self.w_conv.add_module("pooling_{0}".format(idx), pooling) self.w_conv.add_module("conv_{0}".format(idx), conv) if batch_norm: self.w_conv.add_module( "bn_{0}".format(idx), nn.BatchNorm1d(self.conv_flts[idx]))
[docs] def forward(self, x): """ The decoding. Parameters ---------- left_x: Tensor (batch_size, <input_channels>, n_vertices) input left cortical textures. right_x: Tensor (batch_size, <input_channels>, n_vertices) input right cortical textures. Returns ------- x: Tensor (batch_size, <latent_dim>) the latent representations. """ logger.debug("SphericalHemiFusionDecoder forward pass") logger.debug(debug_msg("latent", x)) x = self.activation(self.w_dense(x)) x = x.view(-1, self.conv_flts[-1], self.n_vertices_down) logger.debug(debug_msg("input to conv", x)) x = self._safe_forward(self.w_conv, x, act=self.activation) logger.debug(debug_msg("before hemi sep", x)) left_x, right_x = torch.chunk(x, chunks=2, dim=1) logger.debug(debug_msg("after hemi sep right", right_x)) logger.debug(debug_msg("after hemi sep left", left_x)) left_x = self._safe_forward(self.left_conv, left_x, act=self.activation, skip_last_act=True) right_x = self._safe_forward(self.right_conv, right_x, act=self.activation, skip_last_act=True) return left_x, right_x
[docs] def compute_output_dim(input_dim, convnet): """ Compute the output dimension of a convolutional network that takes as input a square input (H = W) Parameters ---------- input_dim: int input height and weight convnet: iterable[nn.Module] iterable containing the various layers. For now, the function can only work with nn.Conv2d and IcoSpMaConv layers Returns ------- output_dim: int output dimension """ output_dim = input_dim for layer in convnet: if type(layer) is nn.Conv2d: output_dim = int( (output_dim + 2 * layer.padding[0] - layer.dilation[0] * (layer.kernel_size[0] - 1) - 1) / layer.stride[0] + 1) elif type(layer) is IcoSpMaConv: output_dim = compute_output_dim( output_dim + 2 * layer.pad, [layer.conv]) return output_dim
[docs] class HemiFusionEncoder(nn.Module): def __init__(self, input_channels, input_dim, latent_dim, conv_flts=(64, 128, 128, 256, 256), fusion_level=1, activation="LeakyReLU", batch_norm=False): """ Init class. Parameters ---------- input_channels: int, default 1 the number of input channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. latent_dim: int, default 64 the size of the latent space it encodes to. conv_flts: list of int the size of convolutional filters. fusion_level: int, default 1 at which max pooling level left and right hemisphere data are concatenated. activation: str, default 'LeakyReLU' activation function's class name in pytorch's nn module to use after each convolution batch_norm: bool, default False optionally uses batch normalization after each convolution """ logger.debug("HemiFusionEncoder init...") super().__init__() self.input_channels = input_channels self.input_dim = input_dim self.latent_dim = latent_dim self.conv_flts = list(conv_flts) self.n_layers = len(self.conv_flts) if fusion_level > self.n_layers or fusion_level <= 0: raise ValueError("Impossible to use input fusion level with " "'{0}' layers.".format(self.n_layers)) self.fusion_level = fusion_level self.left_conv = nn.Sequential() self.right_conv = nn.Sequential() self.w_conv = nn.Sequential() input_channels = self.input_channels for idx in range(self.n_layers): if idx == 0: kernel_size = 8 pad = 3 else: kernel_size = 4 pad = 1 output_channels = self.conv_flts[idx] if idx < fusion_level: output_channels /= 2 lconv = IcoSpMaConv( in_feats=input_channels, out_feats=int(output_channels), kernel_size=kernel_size, stride=2, pad=pad) self.left_conv.add_module("l_conv_{0}".format(idx), lconv) if batch_norm: self.left_conv.add_module( "l_bn_{0}".format(idx), nn.BatchNorm2d(int(output_channels))) self.left_conv.add_module( "l_act_{0}".format(idx), getattr(nn, activation)(inplace=True)) rconv = IcoSpMaConv( in_feats=input_channels, out_feats=int(output_channels), kernel_size=kernel_size, stride=2, pad=pad) self.right_conv.add_module("r_conv_{0}".format(idx), rconv) if batch_norm: self.right_conv.add_module( "r_bn_{0}".format(idx), nn.BatchNorm2d(int(output_channels))) self.right_conv.add_module( "r_act_{0}".format(idx), getattr(nn, activation)(inplace=True)) input_channels = int(output_channels) else: input_channels = self.conv_flts[idx - 1] conv = IcoSpMaConv( in_feats=input_channels, out_feats=self.conv_flts[idx], kernel_size=kernel_size, stride=2, pad=pad) self.w_conv.add_module("conv_{0}".format(idx), conv) if batch_norm: self.w_conv.add_module( "bn_{0}".format(idx), nn.BatchNorm2d(self.conv_flts[idx])) self.w_conv.add_module("act_{0}".format(idx), getattr(nn, activation)(inplace=True)) self.output_dim = compute_output_dim(input_dim, [*self.left_conv, *self.w_conv]) self.flatten_dim = self.output_dim ** 2 * self.conv_flts[-1] self.w_dense = nn.Linear(self.flatten_dim, self.latent_dim)
[docs] def forward(self, x): """ The encoder. Parameters ---------- left_x: Tensor (samples, <input_channels>, azimuth, elevation) input left cortical texture. right_x: Tensor (samples, <input_channels>, azimuth, elevation) input right cortical texture. Returns ------- q(z | x): Normal (batch_size, <latent_dim>) a Normal distribution. """ left_x, right_x = x left_x = self.left_conv(left_x) right_x = self.right_conv(right_x) x = torch.cat((left_x, right_x), dim=1) x = self.w_conv(x) x = x.view(-1, self.flatten_dim) z = self.w_dense(x) return z
[docs] class HemiFusionDecoder(nn.Module): def __init__(self, output_shape, before_latent_dim, latent_dim, conv_flts=(64, 128, 128, 256, 256), fusion_level=1, activation="LeakyReLU", batch_norm=False): """ Init class. Parameters ---------- output_channels: int, default 1 the number of output channels. input_dim: int, the size of the squared input to the convnet, after the dense layer transforming the input from the latent space. latent_dim: int, default 64 the size of the latent space it decodes from. conv_flts: list of int the size of convolutional filters, given in reverse order: the first filter in the list will be the last one in the network. fusion_level: int, default 1 at which max pooling level left and right hemisphere data are concatenated. activation: str, default 'LeakyReLU' activation function's class name in pytorch's nn module to use after each convolution batch_norm: bool, default False optionally uses batch normalization after each convolution """ logger.debug("HemiFusionDecoder init...") super().__init__() self.before_latent_dim = before_latent_dim self.conv_flts = list(conv_flts) self.n_layers = len(conv_flts) self.conv_flts.insert(0, output_shape[0]*2) self.output_shape = output_shape # flatten_dim = input_dim ** 2 * conv_flts[-1] self.w_dense = nn.Linear(latent_dim, before_latent_dim) self.w_conv = nn.Sequential() self.left_conv = nn.Sequential() self.right_conv = nn.Sequential() self.fusion_level = fusion_level for idx in range(self.n_layers, 0, -1): kernel_size = 4 pad = 1 zero_pad = 3 output_shape = None if idx == 1: kernel_size = 8 pad = 1 zero_pad = 5 - (self.output_shape[1] % 8 // 2) output_shape = self.output_shape input_channels = self.conv_flts[idx] output_channels = self.conv_flts[idx - 1] if idx < fusion_level + 1: input_channels = int(input_channels / 2) output_channels = int(output_channels / 2) lconv = IcoSpMaConvTranspose( in_feats=input_channels, out_feats=output_channels, kernel_size=kernel_size, stride=2, pad=pad, zero_pad=zero_pad, output_shape=output_shape) if batch_norm: self.left_conv.add_module( "l_bn_{0}".format(idx), nn.BatchNorm2d(input_channels)) self.left_conv.add_module("l_act_{0}".format(idx), getattr(nn, activation)()) self.left_conv.add_module("l_conv_{0}".format(idx), lconv) rconv = IcoSpMaConvTranspose( in_feats=input_channels, out_feats=output_channels, kernel_size=kernel_size, stride=2, pad=pad, zero_pad=zero_pad, output_shape=output_shape) if batch_norm: self.right_conv.add_module( "r_bn_{0}".format(idx), nn.BatchNorm2d(input_channels)) self.right_conv.add_module("r_act_{0}".format(idx), getattr(nn, activation)()) self.right_conv.add_module("r_conv_{0}".format(idx), rconv) else: conv = IcoSpMaConvTranspose( input_channels, output_channels, kernel_size=kernel_size, stride=2, pad=pad, zero_pad=zero_pad) if batch_norm and idx != self.n_layers: self.w_conv.add_module( "bn_{0}".format(idx), nn.BatchNorm2d(output_channels)) self.w_conv.add_module("act_{0}".format(idx), getattr(nn, activation)(inplace=True)) self.w_conv.add_module("conv_{0}".format(idx), conv)
[docs] def forward(self, z): """ The decoder. Parameters ---------- z: Tensor (samples, <latent_dim>) the stochastic latent state z. Returns ------- left_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed left cortical texture. right_recon_x: Tensor (samples, <input_channels>, azimuth, elevation) reconstructed right cortical texture. """ x = self.w_dense(z) remaining_shape = int(sqrt( x.shape[1] / self.conv_flts[-1])) x = x.view(len(x), self.conv_flts[-1], remaining_shape, remaining_shape) x = self.w_conv(x) left_recon_x, right_recon_x = torch.chunk(x, chunks=2, dim=1) left_recon_x = self.left_conv(left_recon_x) right_recon_x = self.right_conv(right_recon_x) return left_recon_x, right_recon_x

Follow us

© 2025, nidl developers