Add xAAEnet Blocks to any User Encoder

Add the xAAEnet blocks needed to explain a user-provided model.

The goal of this module is to start from the user’s own encoder and add the xAAEnet blocks defined in 01_model_aae.ipynb. We do not replace the user’s model idea; we add the components needed to analyze and explain it:

  1. keep the user encoder as the feature extractor;
  2. project encoder features into the latent space z;
  3. add a label head from z;
  4. add the latent discriminator that regularizes z;
  5. build a symmetric U-Net decoder with DynamicUnetSkipDropout from the user encoder and its skip connections.

The encoder must therefore expose spatial feature maps that can be used by the U-Net hooks. For transformer models, the user should provide a wrapper that exposes a spatial feature-map encoder before adding these xAAEnet blocks.


add_xaae_blocks


def add_xaae_blocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
)->EncoderWithAAEBlocks:

Add xAAEnet analysis blocks to a user encoder.


EncoderWithAAEBlocks


def EncoderWithAAEBlocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
):

User encoder extended with xAAEnet analysis blocks.

The user encoder stays the feature extractor. The added blocks project its features into z, predict from z, regularize z, and reconstruct the input with a symmetric U-Net decoder.


set_module_trainable


def set_module_trainable(
    module:Module, trainable:bool
)->Module:

Enable or disable gradient updates for every parameter in a module.

Practical Use Case

The user provides the encoder. add_xaae_blocks adds the xAAEnet analysis blocks and automatically builds the matching U-Net decoder through DynamicUnetSkipDropout.

The forward pass returns labels, while the model keeps z, gan_fake, gan_real, and decoder_output as attributes for training and analysis.

user_encoder = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
)

model = add_xaae_blocks(
    encoder=user_encoder,
    input_size=256,
    input_channels=3,
    encoding_dims=16,
    classes=2,
)

# MS-SSIM (used in `classif_loss_func` / `denoising_ae_loss_func`) needs spatial size > ~160.
batch = torch.randn(2, 3, 256, 256)
labels = model(batch)
labels.shape, model.z.shape, model.decoder_output.shape

Training Loss

Loss names match AAE in 01_model_aae.ipynb:

  • classif_loss_func(output, target, ADV_WEIGHT, RECONS_WEIGHT, CLASS_WEIGHT, **kwargs) combines MS-SSIM + L1 reconstruction on decoder_output vs input_image, the latent GAN term (gan_fake / gan_real), and cross-entropy on the logits.
  • aae_loss_func(output, target) returns only the adversarial latent loss (same alternating generator / discriminator logic as AAE).
  • denoising_ae_loss_func(clean_xb, RECONS_WEIGHT, ADV_WEIGHT, pred, yb) is for denoising-style pretraining: reconstruction vs clean_xb plus the GAN term; pred and yb are ignored but kept for a fastai-compatible signature.
  • pure_classif_loss_func(pred, target, **kwargs) is plain cross-entropy when you do not need reconstruction or the GAN term.

The user encoder path still infers encoder_feature_shape automatically so the U-Net and bottleneck stay consistent with your spatial encoder.

Because reconstruction uses MS-SSIM, training images should be wider and taller than about 160 pixels; the example below uses 256×256 so nbdev tests and classif_loss_func run without assertion errors from pytorch_msssim.

target_labels = torch.tensor([0, 1])

classif_loss = model.classif_loss_func(
    labels, target_labels, ADV_WEIGHT=0.1, RECONS_WEIGHT=0.1, CLASS_WEIGHT=1.0
)

aae_loss = model.aae_loss_func(labels, target_labels)

classif_loss.shape, aae_loss.shape

Encoder Contract

The added blocks stay intentionally close to AAE: they expect an encoder that can be passed to DynamicUnetSkipDropout. In practice, that means a spatial encoder whose intermediate feature maps can be hooked for U-Net skip connections.

If a project uses a transformer, the user should wrap it so the xAAEnet blocks receive spatial feature maps. The notebook keeps that adapter outside this module because it depends on the transformer’s architecture.