xAAEnet Model

xAAEnet is a fastai adversarial autoencoder composed of an encoder, a latent space, a discriminator, a classifier, and a decoder.

This module provides the main xAAEnet architecture. It contains the AAE model, its U-Net decoder with dropout on skip connections, and the loss functions used during the different training phases.

Overview

AAE architecture diagram The model follows three main steps:

  1. encode the image with a ResNet34 backbone;
  2. project the features into a latent space z constrained by a discriminator;
  3. reconstruct the image with a U-Net decoder and produce a class prediction from the latent space.

AAE Class

The AAE class brings together the main xAAEnet components:

  • a truncated ResNet34 encoder;
  • a latent bottleneck z with dimension encoding_dims;
  • a linear classification head;
  • a discriminator that constrains the latent space z to a Gaussian distribution;
  • a U-Net decoder that reconstructs the image.

AAE


def AAE(
    input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2, gen_train:bool=True,
    skip_dropout:int=1, # Replaces skip_weight with skip_dropout
):

Adversarial autoencoder used by xAAEnet.

The model encodes an image into a latent vector z, predicts class logits from this latent space, regularizes the latent distribution with an adversarial discriminator, and reconstructs the input image with a U-Net decoder.

U-Net Decoder

The decoder reconstructs the image from the latent vector projected into a spatial representation. It follows the idea behind DynamicUnet: encoder features are retrieved by hooks, then injected into the decoder through skip connections.

The variant used here adds Dropout2d on the skip connections to limit the decoder’s dependence on details passed directly by the encoder. This makes it possible to visualize the features that remain after compression by the encoder.

Attributes Produced During forward

After a forward pass, the model stores several attributes that are later used by losses and visualizations:

  • input_image: original input image;
  • z: latent vector produced by the encoder;
  • gan_fake: discriminator score on the encoded latent vector;
  • gan_real: discriminator score on a simulated Gaussian latent vector;
  • decoder_output: image reconstruction produced by the decoder.

These attributes explain why the loss functions are class methods: they depend on intermediate values computed in forward.

Device Selection

default_device returns the best accelerator available on the machine: CUDA, Apple Silicon MPS, then CPU as a fallback. The model does not force this choice automatically: the user remains free to move the model and batches to the desired device.


default_device


def default_device(
    
):

Return the best available PyTorch device: CUDA, MPS, then CPU.

default_device is useful in training scripts to move the model and batches to the right accelerator.

device = default_device()
model = AAE().to(device)

Compact Model Structure

The raw PyTorch representation of AAE expands every internal ResNet and U-Net block, which makes the documentation difficult to read. At a high level, the model is organized as follows:

  • Encoder: pretrained ResNet34 truncated before the pooling and classification layers.
  • Latent bottleneck: flattening of the encoder feature map, followed by a linear projection to encoding_dims and batch normalization.
  • Classifier head: a linear layer that maps the latent vector z to class logits.
  • Latent discriminator: a small MLP that distinguishes encoded latent vectors from sampled Gaussian latent vectors.
  • Decoder: a dynamic U-Net decoder that reconstructs the input image from the latent vector, with optional Dropout2d on skip connections.

For the default configuration input_size=256, input_channels=3, encoding_dims=128, and classes=2, the main tensor flow is:

image [B, 3, 256, 256]
  -> ResNet34 encoder features [B, 512, 8, 8]
  -> flatten [B, 32768]
  -> latent vector z [B, 128]
  -> classifier logits [B, 2]
  -> decoder projection [B, 512, 8, 8]
  -> U-Net decoder reconstruction [B, 3, 256, 256]

The full module can still be inspected interactively with:

model = AAE(input_size=256, input_channels=3, encoding_dims=128, classes=2)
model