# xAAEnet Model


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

This module provides the main xAAEnet architecture. It contains the
`AAE` model, its U-Net decoder with dropout on skip connections, and the
loss functions used during the different training phases.

## Overview

![AAE architecture diagram](images/schema_bloc_AAE.png) The model
follows three main steps:

1.  encode the image with a ResNet34 backbone;
2.  project the features into a latent space `z` constrained by a
    discriminator;
3.  reconstruct the image with a U-Net decoder and produce a class
    prediction from the latent space.

## `AAE` Class

The `AAE` class brings together the main xAAEnet components:

- a truncated ResNet34 encoder;
- a latent bottleneck `z` with dimension `encoding_dims`;
- a linear classification head;
- a discriminator that constrains the latent space `z` to a Gaussian
  distribution;
- a U-Net decoder that reconstructs the image.

------------------------------------------------------------------------

### AAE

``` python

def AAE(
    input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2, gen_train:bool=True,
    skip_dropout:int=1, # Replaces skip_weight with skip_dropout
):

```

*Adversarial autoencoder used by xAAEnet.*

The model encodes an image into a latent vector `z`, predicts class
logits from this latent space, regularizes the latent distribution with
an adversarial discriminator, and reconstructs the input image with a
U-Net decoder.

## U-Net Decoder

The decoder reconstructs the image from the latent vector projected into
a spatial representation. It follows the idea behind `DynamicUnet`:
encoder features are retrieved by hooks, then injected into the decoder
through skip connections.

The variant used here adds `Dropout2d` on the skip connections to limit
the decoder’s dependence on details passed directly by the encoder. This
makes it possible to visualize the features that remain after
compression by the encoder.

## Attributes Produced During `forward`

After a forward pass, the model stores several attributes that are later
used by losses and visualizations:

- `input_image`: original input image;
- `z`: latent vector produced by the encoder;
- `gan_fake`: discriminator score on the encoded latent vector;
- `gan_real`: discriminator score on a simulated Gaussian latent vector;
- `decoder_output`: image reconstruction produced by the decoder.

These attributes explain why the loss functions are class methods: they
depend on intermediate values computed in `forward`.

## Device Selection

`default_device` returns the best accelerator available on the machine:
CUDA, Apple Silicon MPS, then CPU as a fallback. The model does not
force this choice automatically: the user remains free to move the model
and batches to the desired device.

------------------------------------------------------------------------

### default_device

``` python

def default_device(
    
):

```

*Return the best available PyTorch device: CUDA, MPS, then CPU.*

`default_device` is useful in training scripts to move the model and
batches to the right accelerator.

``` python
device = default_device()
model = AAE().to(device)
```

## Compact Model Structure

The raw PyTorch representation of `AAE` expands every internal ResNet
and U-Net block, which makes the documentation difficult to read. At a
high level, the model is organized as follows:

- **Encoder**: pretrained ResNet34 truncated before the pooling and
  classification layers.
- **Latent bottleneck**: flattening of the encoder feature map, followed
  by a linear projection to `encoding_dims` and batch normalization.
- **Classifier head**: a linear layer that maps the latent vector `z` to
  class logits.
- **Latent discriminator**: a small MLP that distinguishes encoded
  latent vectors from sampled Gaussian latent vectors.
- **Decoder**: a dynamic U-Net decoder that reconstructs the input image
  from the latent vector, with optional `Dropout2d` on skip connections.

For the default configuration `input_size=256`, `input_channels=3`,
`encoding_dims=128`, and `classes=2`, the main tensor flow is:

``` text
image [B, 3, 256, 256]
  -> ResNet34 encoder features [B, 512, 8, 8]
  -> flatten [B, 32768]
  -> latent vector z [B, 128]
  -> classifier logits [B, 2]
  -> decoder projection [B, 512, 8, 8]
  -> U-Net decoder reconstruction [B, 3, 256, 256]
```

The full module can still be inspected interactively with:

``` python
model = AAE(input_size=256, input_channels=3, encoding_dims=128, classes=2)
model
```
