# Train an Explainable xAAEnet Model


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Training pipeline

`train_xaaenet` always runs the same sequence:

1.  **Adversarial** — latent GAN only, with alternating generator /
    discriminator updates.
2.  **Autoencoder** — denoising reconstruction on corrupted inputs
    (Gaussian noise + random patch masking).
3.  **Classifier** — weighted sum of reconstruction, adversarial, and
    cross-entropy terms.

Each phase saves a checkpoint under `models_dir`, reloads the best
weights then continues to the next phase.

After each phase, when visualization is enabled, `train_xaaenet` follows
the same order:

1.  **Extract** validation latent vectors `z` from the trained model.
2.  **Display** the figure inline in the notebook (when running under
    IPython).
3.  **Save** a PNG under `models_dir`.

- After **adversarial** and **autoencoder**: t-SNE (`save_tsne=True`).
- After **classifier** (default curriculum): supervised **PLS** biplot
  (`save_pls=True`). Callable after any phase via
  `_pls_validation_latent_after_phase`.

## Quick start

Build your model and fastai `DataLoaders` (`ImageBlock`,
`CategoryBlock`), then call:

``` python
from tell_me_why.training import train_xaaenet

learn = train_xaaenet(
    model,
    dls,
    ae_recons_weight=0.4,
    ae_adv_weight=0.6,
    classif_recons_weight=0.599,
    classif_class_weight=0.001,
    classif_adv_weight=0.4,
)
```

Input images should be at least about **160×160** pixels (MS-SSIM in the
reconstruction loss).

## `train_xaaenet`

<table>
<colgroup>
<col style="width: 62%" />
<col style="width: 37%" />
</colgroup>
<thead>
<tr>
<th>Argument</th>
<th>Role</th>
</tr>
</thead>
<tbody>
<tr>
<td><code>epochs_adv</code>, <code>epochs_ae</code>,
<code>epochs_classif</code></td>
<td>Epoch count per phase</td>
</tr>
<tr>
<td><code>ae_recons_weight</code>, <code>ae_adv_weight</code></td>
<td>Loss weights during autoencoder training</td>
</tr>
<tr>
<td><code>classif_recons_weight</code>,
<code>classif_class_weight</code>, <code>classif_adv_weight</code></td>
<td>Loss weights during classifier training</td>
</tr>
<tr>
<td><code>adv_low_threshold</code>, <code>adv_high_threshold</code></td>
<td>Valid-loss band for alternating GAN training</td>
</tr>
<tr>
<td><code>mask_ratio</code>, <code>patch_size</code>,
<code>noise_std</code></td>
<td>Corruption strength for denoising AE training</td>
</tr>
<tr>
<td><code>models_dir</code>, <code>*_fname</code></td>
<td>Checkpoint directory and file names</td>
</tr>
<tr>
<td><code>save_tsne</code></td>
<td>After adversarial / autoencoder: extract validation <code>z</code>,
show t-SNE, save PNG</td>
</tr>
<tr>
<td><code>tsne_max_points</code>, <code>tsne_perplexity</code></td>
<td>Subsample size and t-SNE perplexity for those figures</td>
</tr>
<tr>
<td><code>save_pls</code></td>
<td>After classifier phase (default): extract validation <code>z</code>,
show PLS biplot, save PNG</td>
</tr>
<tr>
<td><code>show_latent_figures</code></td>
<td>Inline notebook display before PNG save (<code>None</code> = auto in
IPython)</td>
</tr>
<tr>
<td><code>pls_target_class</code></td>
<td>Binary target level in <code>dls.vocab</code> for the PLS axis
(default: second class)</td>
</tr>
<tr>
<td><code>pls_max_points</code></td>
<td>Max validation points for the PLS figure (default: 10 000)</td>
</tr>
<tr>
<td><code>extract_latent</code></td>
<td>Save train+valid latent vectors <code>z</code> after the full
run</td>
</tr>
</tbody>
</table>

## Corruption transforms (autoencoder phase)

During phase 2, inputs are corrupted before the forward pass;
reconstruction is computed against the **clean** batch stored by
`CorruptionCallback`.

------------------------------------------------------------------------

### CorruptionCallback

``` python

def CorruptionCallback(
    corruption_tfms:list
):

```

*Apply corruption transforms on `xb` while keeping a clean copy for the
loss.*

------------------------------------------------------------------------

### RandomMasking

``` python

def RandomMasking(
    mask_ratio:float=0.3, patch_size:int=16
):

```

*Zero out random square patches independently per image.*

------------------------------------------------------------------------

### AddGaussianNoise

``` python

def AddGaussianNoise(
    mean:float=0.0, std:float=0.05
):

```

*Add Gaussian noise to a batch of images (values stay in \[0, 1\]).*

## Adversarial callback

`UnfreezeFcCritAdaptative` toggles `model.gen_train` and freezes either
the discriminator (`fc_crit*`) or the rest of the network, based on
validation loss and epoch schedule.

------------------------------------------------------------------------

### UnfreezeFcCritAdaptative

``` python

def UnfreezeFcCritAdaptative(
    switch_every:int=3, low_threshold:float=0.15, high_threshold:float=0.6, window_size:int=3
):

```

*Alternate generator / discriminator training during the adversarial
phase.*

## Metrics, latent extraction, and t-SNE figures

`LossAttrMetric` logs scalar attributes set on the model during the
forward pass (`adv_loss`, `recons_loss`, …). Latent extraction for
figures and `extract_latent=True` runs a validation pass and stacks
`model.z` — you do not need to call a separate callback yourself.

At the end of each training phase, `train_xaaenet` **extracts**
validation `z`, then calls:

- `save_latent_tsne_figure` — t-SNE (default: up to 5000 points,
  perplexity 30), **display** inline, then PNG.
- `save_latent_pls_figure` — supervised **PLS** (2 components) on `z`
  and binary targets, **display** inline, then PNG (one arrow for the
  target direction).

## Example output figures

With default options, figures appear in the notebook then are written as
PNG under `models_dir`. Below are representative examples (validation
set, binary classification).

### t-SNE (`save_tsne=True`)

Saved after the **adversarial** and **autoencoder** phases as
`tsne_<checkpoint_name>.png`:

<figure>
<img src="images/tsne_example.png"
alt="t-SNE of the validation latent space" />
<figcaption aria-hidden="true">t-SNE of the validation latent
space</figcaption>
</figure>

Direct projection of `z` (e.g. 128D → 2D). Points are not colored by
class; the plot monitors how the latent cloud evolves between phases.

### PLS (`save_pls=True`)

Saved once after the **classifier** phase as `pls_<classif_fname>.png`:

<figure>
<img src="images/pls_example.png"
alt="PLS biplot — supervised latent space" />
<figcaption aria-hidden="true">PLS biplot — supervised latent
space</figcaption>
</figure>

Supervised PLS on `z`: color = score on PLS component 1, **one white
arrow** = direction of the binary target in latent space.

The figure above illustrates the overall PLS style (axes, colorbar,
latent cloud).

------------------------------------------------------------------------

### save_latent_pls_figure

``` python

def save_latent_pls_figure(
    z:torch.Tensor, targets:torch.Tensor, save_path:str | Path, class_names:Sequence[str] | None=None,
    target_class:str | None=None, phase:str='', encoding_dims:int | None=None, max_points:int=10000,
    random_state:int=42, show:bool | None=None
)->Path:

```

*Supervised PLS biplot of `z` vs binary targets: display inline in a
notebook, then save a PNG.*

------------------------------------------------------------------------

### save_latent_tsne_figure

``` python

def save_latent_tsne_figure(
    z:torch.Tensor, save_path:str | Path, phase:str, encoding_dims:int | None=None, max_points:int=5000,
    perplexity:float=30, random_state:int=42, show:bool | None=None
)->Path:

```

*Project `z` with t-SNE, display inline in a notebook, then save a PNG.*

------------------------------------------------------------------------

### LossAttrMetric

``` python

def LossAttrMetric(
    attr:str
):

```

*Average a float attribute stored on the model (e.g. `model.adv_loss`).*

------------------------------------------------------------------------

### GetLatentSpace

``` python

def GetLatentSpace(
    after_create:NoneType=None, before_fit:NoneType=None, before_epoch:NoneType=None, before_train:NoneType=None,
    before_batch:NoneType=None, after_pred:NoneType=None, after_loss:NoneType=None, before_backward:NoneType=None,
    after_cancel_backward:NoneType=None, after_backward:NoneType=None, before_step:NoneType=None,
    after_cancel_step:NoneType=None, after_step:NoneType=None, after_cancel_batch:NoneType=None,
    after_batch:NoneType=None, after_cancel_train:NoneType=None, after_train:NoneType=None,
    before_validate:NoneType=None, after_cancel_validate:NoneType=None, after_validate:NoneType=None,
    after_cancel_epoch:NoneType=None, after_epoch:NoneType=None, after_cancel_fit:NoneType=None,
    after_fit:NoneType=None
):

```

*Collect latent vectors `z` during validation into `learn.z_valid`.*

------------------------------------------------------------------------

### train_xaaenet

``` python

def train_xaaenet(
    model:ExplainableModel, # `AAE` or `EncoderWithAAEBlocks`.
    dls:DataLoaders, # fastai `DataLoaders` with `(ImageBlock, CategoryBlock)`.
    epochs_adv:int=35, epochs_ae:int=50, epochs_classif:int=30, ae_recons_weight:float=0.4, ae_adv_weight:float=0.6,
    classif_recons_weight:float=0.599, classif_class_weight:float=0.001, classif_adv_weight:float=0.4,
    lr_max:float=0.0001, lr_max_factor:float=3.0, mask_ratio:float=0.2, patch_size:int=16, noise_std:float=0.05,
    adv_low_threshold:float=0.65, adv_high_threshold:float=0.8, grad_accum:int=4, patience:int=10,
    models_dir:str | Path='models', adv_fname:str='xaaenet_adv', ae_fname:str='xaaenet_ae',
    classif_fname:str='xaaenet_classif',
    save_tsne:bool=True, # If True, save t-SNE figures after the adversarial and autoencoder phases.
    tsne_max_points:int=5000, tsne_perplexity:float=30,
    save_pls:bool=True, # If True, after the classifier phase extract validation ``z``, show a supervised PLS biplot, save PNG.
    pls_target_class:str | None=None, # Name of the binary target level in `dls.vocab` for the PLS axis (default: second class).
    pls_max_points:int=10000, # Cap on validation points for the PLS figure (default: 10_000).
    extract_latent:bool=True, # If True, save stacked train+valid latent vectors to `latent_path` (or a default under `models_dir`).
    latent_path:str | Path | None=None,
    show_latent_figures:bool | None=None, # If not False, display t-SNE / PLS figures inline when running inside a notebook (before saving PNG).
)->Learner:

```

*Train an explainable xAAEnet model: adversarial, then autoencoder, then
classifier.*
