user_encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
)
model = add_xaae_blocks(
encoder=user_encoder,
input_size=256,
input_channels=3,
encoding_dims=16,
classes=2,
)
# MS-SSIM (used in `classif_loss_func` / `denoising_ae_loss_func`) needs spatial size > ~160.
batch = torch.randn(2, 3, 256, 256)
labels = model(batch)
labels.shape, model.z.shape, model.decoder_output.shapeAdd xAAEnet Blocks to any User Encoder
The goal of this module is to start from the user’s own encoder and add the xAAEnet blocks defined in 01_model_aae.ipynb. We do not replace the user’s model idea; we add the components needed to analyze and explain it:
- keep the user encoder as the feature extractor;
- project encoder features into the latent space
z; - add a label head from
z; - add the latent discriminator that regularizes
z; - build a symmetric U-Net decoder with
DynamicUnetSkipDropoutfrom the user encoder and its skip connections.
The encoder must therefore expose spatial feature maps that can be used by the U-Net hooks. For transformer models, the user should provide a wrapper that exposes a spatial feature-map encoder before adding these xAAEnet blocks.
add_xaae_blocks
def add_xaae_blocks(
encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
freeze_encoder:bool=False
)->EncoderWithAAEBlocks:
Add xAAEnet analysis blocks to a user encoder.
EncoderWithAAEBlocks
def EncoderWithAAEBlocks(
encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
freeze_encoder:bool=False
):
User encoder extended with xAAEnet analysis blocks.
The user encoder stays the feature extractor. The added blocks project its features into z, predict from z, regularize z, and reconstruct the input with a symmetric U-Net decoder.
set_module_trainable
def set_module_trainable(
module:Module, trainable:bool
)->Module:
Enable or disable gradient updates for every parameter in a module.
Practical Use Case
The user provides the encoder. add_xaae_blocks adds the xAAEnet analysis blocks and automatically builds the matching U-Net decoder through DynamicUnetSkipDropout.
The forward pass returns labels, while the model keeps z, gan_fake, gan_real, and decoder_output as attributes for training and analysis.
Training Loss
Loss names match AAE in 01_model_aae.ipynb:
classif_loss_func(output, target, ADV_WEIGHT, RECONS_WEIGHT, CLASS_WEIGHT, **kwargs)combines MS-SSIM + L1 reconstruction ondecoder_outputvsinput_image, the latent GAN term (gan_fake/gan_real), and cross-entropy on the logits.aae_loss_func(output, target)returns only the adversarial latent loss (same alternating generator / discriminator logic asAAE).denoising_ae_loss_func(clean_xb, RECONS_WEIGHT, ADV_WEIGHT, pred, yb)is for denoising-style pretraining: reconstruction vsclean_xbplus the GAN term;predandybare ignored but kept for a fastai-compatible signature.pure_classif_loss_func(pred, target, **kwargs)is plain cross-entropy when you do not need reconstruction or the GAN term.
The user encoder path still infers encoder_feature_shape automatically so the U-Net and bottleneck stay consistent with your spatial encoder.
Because reconstruction uses MS-SSIM, training images should be wider and taller than about 160 pixels; the example below uses 256×256 so nbdev tests and classif_loss_func run without assertion errors from pytorch_msssim.
target_labels = torch.tensor([0, 1])
classif_loss = model.classif_loss_func(
labels, target_labels, ADV_WEIGHT=0.1, RECONS_WEIGHT=0.1, CLASS_WEIGHT=1.0
)
aae_loss = model.aae_loss_func(labels, target_labels)
classif_loss.shape, aae_loss.shapeEncoder Contract
The added blocks stay intentionally close to AAE: they expect an encoder that can be passed to DynamicUnetSkipDropout. In practice, that means a spatial encoder whose intermediate feature maps can be hooked for U-Net skip connections.
If a project uses a transformer, the user should wrap it so the xAAEnet blocks receive spatial feature maps. The notebook keeps that adapter outside this module because it depends on the transformer’s architecture.