Source code for yamle.models.unet

from typing import Union, List, Tuple, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import math

from yamle.models.operations import (
    DoubleConv2d,
    OutputActivation,
    Normalization,
    ReshapeOutput,
)
from yamle.models.model import BaseModel

from yamle.defaults import DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY, SEGMENTATION_KEY


[docs] class DownBlock(nn.Module): """This class is used to create a down block of the UNet model. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. normalization (str): The normalization to use. """ def __init__(self, in_channels: int, out_channels: int, normalization: str) -> None: super(DownBlock, self).__init__() self._conv = DoubleConv2d( in_channels, out_channels, normalization=normalization ) self._pool = nn.MaxPool2d(2)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """The forward function of the down block.""" x = self._conv(x) return self._pool(x), x
[docs] class UpBlock(nn.Module): """This class is used to create an up block of the UNet model. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. normalization (str): The normalization to use. """ def __init__(self, in_channels: int, out_channels: int, normalization: str) -> None: super(UpBlock, self).__init__() self._up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2) self._conv = DoubleConv2d( in_channels, out_channels, normalization=normalization ) def _center_crop(self, x: torch.Tensor, x_ref: torch.Tensor) -> torch.Tensor: """A helper function to center crop the input tensor, given a reference tensor.""" diffY = x_ref.size()[2] - x.size()[2] diffX = x_ref.size()[3] - x.size()[3] x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) return x
[docs] def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: """The forward function of the up block.""" x = self._up(x) x = self._center_crop(x, skip) x = torch.cat([x, skip], dim=1) return self._conv(x)
[docs] class UNetModel(BaseModel): """This class is used to create the UNet model. Args: init_features (int): The number of initial features. normalization (str): The type of normalization to use. Defaults to `batch`. Choices are `batch`, `layer`, `instance` or `None`. """ tasks = [DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY, SEGMENTATION_KEY] def __init__( self, init_features: int = 32, normalization: str = "batch", *args: Any, **kwargs: Any, ) -> None: super(UNetModel, self).__init__(*args, **kwargs) assert normalization in ["batch", "layer", "instance", None] self._normalization = normalization self._features = [ init_features, init_features * 2, init_features * 4, init_features * 8, init_features * 16, ] # Add one extra pointwise convolution to the UNet model such that # There is an explicit 1x1 convolution which can be manipulated by some method self._input = nn.Conv2d(self._inputs_dim[0], self._features[0], 1) self._down1 = DownBlock(self._features[0], self._features[0], normalization) self._down2 = DownBlock(self._features[0], self._features[1], normalization) self._down3 = DownBlock(self._features[1], self._features[2], normalization) self._down4 = DownBlock(self._features[2], self._features[3], normalization) self._center = DoubleConv2d( self._features[3], self._features[4], normalization=normalization ) self._up4 = UpBlock(self._features[4], self._features[3], normalization) self._up3 = UpBlock(self._features[3], self._features[2], normalization) self._up2 = UpBlock(self._features[2], self._features[1], normalization) self._up1 = UpBlock(self._features[1], self._features[0], normalization) self._output = nn.Conv2d(self._features[0], self._outputs_dim, 1) self._output_activation = OutputActivation(self._task, dim=1) self._depth = 4
[docs] def add_method_specific_layers(self, method: str, **kwargs: Any) -> None: """This method is used to add method specific layers to the model. Args: method (str): The method to use. """ super().add_method_specific_layers(method, **kwargs) if method in ["dun", "mimmo"]: self._reshaping_layers = nn.ModuleList() available_heads = [True] * (self._depth-1) if "heads" in kwargs and kwargs["heads"]: self._heads = nn.ModuleList() if "available_heads" in kwargs: available_heads = kwargs["available_heads"] for i, available_head in enumerate(available_heads): if not available_head: continue layers = [] layers.append( nn.Upsample( scale_factor=2**i, mode="bilinear", align_corners=True ) ) layers.append(nn.Conv2d(self._features[i], self._features[0], 1)) layers.append( Normalization( self._normalization, dimension=2, norm_kwargs={"num_features": self._features[0]}, ) ) layers.append(nn.ReLU()) self._reshaping_layers.append(nn.Sequential(*layers)) if "heads" in kwargs and kwargs["heads"]: head = [] head.append( nn.Conv2d(self._features[0], self._output[0].out_features, 1) ) head.append(ReshapeOutput(num_members=kwargs["num_members"])) self._heads.append(nn.Sequential(*head)) elif method in ["early_exit"]: gamma = kwargs["gamma"] hidden_feature_size_output = ( self._output.in_features if method == "early_exit" else self._output[0].in_features ) size_output = ( self._output.out_features if method == "early_exit" else self._output[0].out_features ) self._reshaping_layers = nn.ModuleList() heads = [1] * (self._depth) if "heads" in kwargs and kwargs["heads"] is not None: heads = kwargs["heads"] assert ( len(heads) == self._depth ), f"Number of heads should be {self._depth}, but got {len(heads)}" for i in range(1, self._depth): if not heads[i - 1]: continue layers = [] hidden_feature_size = int( math.sqrt(1 + gamma) ** (self._depth - i) * hidden_feature_size_output ) if gamma > 0: layers.append( nn.Upsample( scale_factor=2 ** (i - 1), mode="bilinear", align_corners=True, ) ) layers.append( nn.Conv2d(self._features[i - 1], hidden_feature_size, 1) ) layers.append( Normalization( self._normalization, dimension=2, norm_kwargs={"num_features": hidden_feature_size}, ) ) layers.append(nn.ReLU()) layers.append(nn.Conv2d(hidden_feature_size, size_output, 1)) else: layers.append( nn.Upsample( scale_factor=2 ** (i - 1), mode="bilinear", align_corners=True, ) ) layers.append(nn.Conv2d(self._features[i - 1], size_output, 1)) self._reshaping_layers.append(nn.Sequential(*layers)) else: raise ValueError(f"Method {method} is not supported.")
[docs] def forward( self, x: torch.Tensor, staged_output: bool = False, input_kwargs: Dict[str, Any] = {}, output_kwargs: Dict[str, Any] = {}, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: """The forward function of the UNet model. Args: x (torch.Tensor): The input tensor. staged_output (bool): Whether to return the output of each layer. Defaults to `False`. input_kwargs (Dict[str, Any]): The input arguments to pass to the input layer. Defaults to `{}`. output_kwargs (Dict[str, Any]): The input arguments to pass to the output layer. Defaults to `{}`. """ layers_outputs = [] x = self._input(x, **input_kwargs) x, x1 = self._down1(x) if staged_output: layers_outputs.append(x1) x, x2 = self._down2(x) if staged_output: layers_outputs.append(x2) x, x3 = self._down3(x) if staged_output: layers_outputs.append(x3) x, x4 = self._down4(x) if staged_output: layers_outputs.append(x4) x = self._center(x) x = self._up4(x, x4) x = self._up3(x, x3) x = self._up2(x, x2) x = self._up1(x, x1) x = self.final_layer(x, **output_kwargs) if staged_output: return x, layers_outputs return x
[docs] def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor: """This function is used to get the final layer output.""" x = self._output(x, **output_kwargs) return self._output_activation(x)
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This function is used to add specific arguments to the parser.""" parser = super(UNetModel, UNetModel).add_specific_args(parent_parser) parser.add_argument( "--model_init_features", type=int, default=32, help="The number of initial features.", ) parser.add_argument( "--model_normalization", type=str, choices=["batch", "instance", "group", "layer", None], default="batch", help="The normalization to use.", ) return parser