Add xAAEnet Blocks to any User Encoder

Add the xAAEnet blocks needed to explain a user-provided model.

The goal of this module is to start from any user’s own encoder and add the xAAEnet blocks defined previously. We do not replace the user’s model idea; we add the components needed to analyze and explain it:

  1. keep the user encoder as the feature extractor;
  2. project encoder features into the latent space z;
  3. add a label head from z;
  4. add the latent discriminator that regularizes z;
  5. build a symmetric U-Net decoder with DynamicUnetSkipDropout from the user encoder and its skip connections.

The encoder must therefore expose spatial feature maps that can be used by the U-Net hooks. For transformer models, the user should provide a wrapper that exposes a spatial feature-map encoder before adding these xAAEnet blocks.


add_xaae_blocks


def add_xaae_blocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
)->EncoderWithAAEBlocks:

Add xAAEnet analysis blocks to a user encoder.


EncoderWithAAEBlocks


def EncoderWithAAEBlocks(
    encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
    linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
    freeze_encoder:bool=False
):

User encoder extended with xAAEnet analysis blocks.

The user encoder stays the feature extractor. The added blocks project its features into z, predict from z, regularize z, and reconstruct the input with a symmetric U-Net decoder.


set_module_trainable


def set_module_trainable(
    module:Module, trainable:bool
)->Module:

Enable or disable gradient updates for every parameter in a module.

Practical Use Case

The user provides the encoder. add_xaae_blocks adds the xAAEnet analysis blocks and automatically builds the matching U-Net decoder through DynamicUnetSkipDropout.

The forward pass returns labels.

user_encoder = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
)

model = add_xaae_blocks(
    encoder=user_encoder,
    input_size=256,
    input_channels=3,
    encoding_dims=16,
    classes=2,
)

# MS-SSIM (used in `classif_loss_func` / `denoising_ae_loss_func`) needs spatial size > ~160.
batch = torch.randn(2, 3, 256, 256)
labels = model(batch)
labels.shape, model.z.shape, model.decoder_output.shape

Training Loss

Defined Losses:

  • denoising_ae_loss_func(clean_xb, RECONS_WEIGHT, ADV_WEIGHT, pred, yb) is for denoising-style pretraining: reconstruction vs clean_xb plus the GAN term; pred and yb are ignored but kept for a fastai-compatible signature.
  • aae_loss_func(output, target) returns only the adversarial latent loss (same alternating generator / discriminator logic as AAE).
  • classif_loss_func(output, target, ADV_WEIGHT, RECONS_WEIGHT, CLASS_WEIGHT, **kwargs) combines MS-SSIM + L1 reconstruction on decoder_output vs input_image, the latent GAN term (gan_fake / gan_real), and cross-entropy on the logits.

The user encoder path still infers encoder_feature_shape automatically so the U-Net and bottleneck stay consistent with your spatial encoder.

Because reconstruction uses MS-SSIM, training images should be wider and taller than about 160 pixels; the example below uses 256×256 so nbdev tests and classif_loss_func run without assertion errors from pytorch_msssim.

target_labels = torch.tensor([0, 1])

classif_loss = model.classif_loss_func(
    labels, target_labels, ADV_WEIGHT=0.1, RECONS_WEIGHT=0.1, CLASS_WEIGHT=1.0
)

aae_loss = model.aae_loss_func(labels, target_labels)

classif_loss.shape, aae_loss.shape