Autoencoders¶
This section houses autoencoders and variational autoencoders.
Basic AE¶
This is the simplest autoencoder. You can use it like so
from pl_bolts.models.autoencoders import AE
model = AE()
trainer = Trainer()
trainer.fit(model)
You can override any part of this AE to build your own variation.
from pl_bolts.models.autoencoders import AE
class MyAEFlavor(AE):
def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
encoder = YourSuperFancyEncoder(...)
return encoder
You can use the pretrained models present in bolts.
CIFAR-10 pretrained model:
from pl_bolts.models.autoencoders import AE
ae = AE(input_height=32)
print(AE.pretrained_weights_available())
ae = ae.from_pretrained('cifar10-resnet18')
ae.freeze()
Training:
Reconstructions:
Both input and generated images are normalized versions as the training was done with such images.
- class pl_bolts.models.autoencoders.AE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, latent_dim=256, lr=0.0001, **kwargs)[source]
Bases:
pytorch_lightning.
Standard AE.
Model is available pretrained on different datasets:
Example:
# not pretrained ae = AE() # pretrained on cifar10 ae = AE(input_height=32).from_pretrained('cifar10-resnet18')
- Parameters
first_conv¶ (
bool
) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 convmaxpool1¶ (
bool
) – use standard maxpool to reduce spatial dim of feat by a factor of 2enc_out_dim¶ (
int
) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)
Variational Autoencoders¶
Basic VAE¶
Use the VAE like so.
from pl_bolts.models.autoencoders import VAE
model = VAE()
trainer = Trainer()
trainer.fit(model)
You can override any part of this VAE to build your own variation.
from pl_bolts.models.autoencoders import VAE
class MyVAEFlavor(VAE):
def get_posterior(self, mu, std):
# do something other than the default
# P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))
return P
You can use the pretrained models present in bolts.
CIFAR-10 pretrained model:
from pl_bolts.models.autoencoders import VAE
vae = VAE(input_height=32)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')
vae.freeze()
Training:
Reconstructions:
Both input and generated images are normalized versions as the training was done with such images.
STL-10 pretrained model:
from pl_bolts.models.autoencoders import VAE
vae = VAE(input_height=96, first_conv=True)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')
vae.freeze()
Training:
- class pl_bolts.models.autoencoders.VAE(input_height, enc_type='resnet18', first_conv=False, maxpool1=False, enc_out_dim=512, kl_coeff=0.1, latent_dim=256, lr=0.0001, **kwargs)[source]
Bases:
pytorch_lightning.
Standard VAE with Gaussian Prior and approx posterior.
Model is available pretrained on different datasets:
Example:
# not pretrained vae = VAE() # pretrained on cifar10 vae = VAE(input_height=32).from_pretrained('cifar10-resnet18') # pretrained on stl10 vae = VAE(input_height=32).from_pretrained('stl10-resnet18')
- Parameters
first_conv¶ (
bool
) – use standard kernel_size 7, stride 2 at start or replace it with kernel_size 3, stride 1 convmaxpool1¶ (
bool
) – use standard maxpool to reduce spatial dim of feat by a factor of 2enc_out_dim¶ (
int
) – set according to the out_channel count of encoder used (512 for resnet18, 2048 for resnet50)