user_encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
)
model = add_xaae_blocks(
encoder=user_encoder,
input_size=256,
input_channels=3,
encoding_dims=16,
classes=2,
)
# MS-SSIM (used in `classif_loss_func` / `denoising_ae_loss_func`) needs spatial size > ~160.
batch = torch.randn(2, 3, 256, 256)
labels = model(batch)
labels.shape, model.z.shape, model.decoder_output.shapeAdd xAAEnet Blocks to any User Encoder
The goal of this module is to start from any user’s own encoder and add the xAAEnet blocks defined previously. We do not replace the user’s model idea; we add the components needed to analyze and explain it:
- keep the user encoder as the feature extractor;
- project encoder features into the latent space
z; - add a label head from
z; - add the latent discriminator that regularizes
z; - build a symmetric U-Net decoder with
DynamicUnetSkipDropoutfrom the user encoder and its skip connections.
The encoder must therefore expose spatial feature maps that can be used by the U-Net hooks. For transformer models, the user should provide a wrapper that exposes a spatial feature-map encoder before adding these xAAEnet blocks.
add_xaae_blocks
def add_xaae_blocks(
encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
freeze_encoder:bool=False
)->EncoderWithAAEBlocks:
Add xAAEnet analysis blocks to a user encoder.
EncoderWithAAEBlocks
def EncoderWithAAEBlocks(
encoder:Module, input_size:int=256, input_channels:int=3, encoding_dims:int=128, classes:int=2,
linear:torch.nn.modules.module.Module | None=None, gen_train:bool=True, skip_dropout:float=1.0,
freeze_encoder:bool=False
):
User encoder extended with xAAEnet analysis blocks.
The user encoder stays the feature extractor. The added blocks project its features into z, predict from z, regularize z, and reconstruct the input with a symmetric U-Net decoder.
set_module_trainable
def set_module_trainable(
module:Module, trainable:bool
)->Module:
Enable or disable gradient updates for every parameter in a module.
Practical Use Case
The user provides the encoder. add_xaae_blocks adds the xAAEnet analysis blocks and automatically builds the matching U-Net decoder through DynamicUnetSkipDropout.
The forward pass returns labels.
Training Loss
Defined Losses:
denoising_ae_loss_func(clean_xb, RECONS_WEIGHT, ADV_WEIGHT, pred, yb)is for denoising-style pretraining: reconstruction vsclean_xbplus the GAN term;predandybare ignored but kept for a fastai-compatible signature.aae_loss_func(output, target)returns only the adversarial latent loss (same alternating generator / discriminator logic asAAE).classif_loss_func(output, target, ADV_WEIGHT, RECONS_WEIGHT, CLASS_WEIGHT, **kwargs)combines MS-SSIM + L1 reconstruction ondecoder_outputvsinput_image, the latent GAN term (gan_fake/gan_real), and cross-entropy on the logits.
The user encoder path still infers encoder_feature_shape automatically so the U-Net and bottleneck stay consistent with your spatial encoder.
Because reconstruction uses MS-SSIM, training images should be wider and taller than about 160 pixels; the example below uses 256×256 so nbdev tests and classif_loss_func run without assertion errors from pytorch_msssim.
target_labels = torch.tensor([0, 1])
classif_loss = model.classif_loss_func(
labels, target_labels, ADV_WEIGHT=0.1, RECONS_WEIGHT=0.1, CLASS_WEIGHT=1.0
)
aae_loss = model.aae_loss_func(labels, target_labels)
classif_loss.shape, aae_loss.shape