# Add xAAEnet Blocks to any User Encoder


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

The goal of this module is to start from the user’s own encoder and add
the xAAEnet blocks defined in `01_model_aae.ipynb`. We do not replace
the user’s model idea; we add the components needed to analyze and
explain it:

1.  keep the user encoder as the feature extractor;
2.  project encoder features into the latent space `z`;
3.  add a label head from `z`;
4.  add the latent discriminator that regularizes `z`;
5.  build a symmetric U-Net decoder with `DynamicUnetSkipDropout` from
    the user encoder and its skip connections.

The encoder must therefore expose spatial feature maps that can be used
by the U-Net hooks. For transformer models, the user should provide a
wrapper that exposes a spatial feature-map encoder before adding these
xAAEnet blocks.

------------------------------------------------------------------------

### add_xaae_blocks

``` python

def add_xaae_blocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
)->EncoderWithAAEBlocks:

```

*Add xAAEnet analysis blocks to a user encoder.*

------------------------------------------------------------------------

### EncoderWithAAEBlocks

``` python

def EncoderWithAAEBlocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
):

```

*User encoder extended with xAAEnet analysis blocks.*

The user encoder stays the feature extractor. The added blocks project
its features into `z`, predict from `z`, regularize `z`, and reconstruct
the input with a symmetric U-Net decoder.

------------------------------------------------------------------------

### set_module_trainable

``` python

def set_module_trainable(
    module:Module, trainable:bool
)->Module:

```

*Enable or disable gradient updates for every parameter in a module.*

## Practical Use Case

The user provides the encoder. `add_xaae_blocks` adds the xAAEnet
analysis blocks and automatically builds the matching U-Net decoder
through `DynamicUnetSkipDropout`.

The forward pass returns `labels`, while the model keeps `z`,
`gan_fake`, `gan_real`, and `decoder_output` as attributes for training
and analysis.

``` python
user_encoder = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
)

model = add_xaae_blocks(
    encoder=user_encoder,
    input_size=256,
    input_channels=3,
    encoding_dims=16,
    classes=2,
)

# MS-SSIM (used in `classif_loss_func` / `denoising_ae_loss_func`) needs spatial size > ~160.
batch = torch.randn(2, 3, 256, 256)
labels = model(batch)
labels.shape, model.z.shape, model.decoder_output.shape
```

## Training Loss

Loss names match `AAE` in `01_model_aae.ipynb`:

- `classif_loss_func(output, target, ADV_WEIGHT, RECONS_WEIGHT, CLASS_WEIGHT, **kwargs)`
  combines MS-SSIM + L1 reconstruction on `decoder_output` vs
  `input_image`, the latent GAN term (`gan_fake` / `gan_real`), and
  cross-entropy on the logits.
- `aae_loss_func(output, target)` returns **only** the adversarial
  latent loss (same alternating generator / discriminator logic as
  `AAE`).
- `denoising_ae_loss_func(clean_xb, RECONS_WEIGHT, ADV_WEIGHT, pred, yb)`
  is for denoising-style pretraining: reconstruction vs `clean_xb` plus
  the GAN term; `pred` and `yb` are ignored but kept for a
  fastai-compatible signature.
- `pure_classif_loss_func(pred, target, **kwargs)` is plain
  cross-entropy when you do not need reconstruction or the GAN term.

The user encoder path still infers `encoder_feature_shape` automatically
so the U-Net and bottleneck stay consistent with **your** spatial
encoder.

Because reconstruction uses MS-SSIM, training images should be **wider
and taller than about 160 pixels**; the example below uses `256×256` so
`nbdev` tests and `classif_loss_func` run without assertion errors from
`pytorch_msssim`.

``` python
target_labels = torch.tensor([0, 1])

classif_loss = model.classif_loss_func(
    labels, target_labels, ADV_WEIGHT=0.1, RECONS_WEIGHT=0.1, CLASS_WEIGHT=1.0
)

aae_loss = model.aae_loss_func(labels, target_labels)

classif_loss.shape, aae_loss.shape
```

## Encoder Contract

The added blocks stay intentionally close to `AAE`: they expect an
encoder that can be passed to `DynamicUnetSkipDropout`. In practice,
that means a spatial encoder whose intermediate feature maps can be
hooked for U-Net skip connections.

If a project uses a transformer, the user should wrap it so the xAAEnet
blocks receive spatial feature maps. The notebook keeps that adapter
outside this module because it depends on the transformer’s
architecture.
