GANs¶
Collection of Generative Adversarial Networks
Basic GAN¶
This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.
Implemented by:
William Falcon
Example outputs:
Loss curves:
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
- class pl_bolts.models.gans.GAN(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source]
Bases:
pytorch_lightning.
Vanilla GAN implementation.
Example:
from pl_bolts.models.gans import GAN m = GAN() Trainer(gpus=2).fit(m)
Example CLI:
# mnist python basic_gan_module.py --gpus 1 # imagenet python basic_gan_module.py --gpus 1 --dataset 'imagenet2012' --data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder --batch_size 256 --learning_rate 0.0001
- Parameters
- forward(z)[source]
Generates an image given input noise z.
Example:
z = torch.rand(batch_size, latent_dim) gan = GAN.load_from_checkpoint(PATH) img = gan(z)
DCGAN¶
DCGAN implementation from the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The implementation is based on the version from PyTorch’s examples.
Implemented by:
Example MNIST outputs:
Example LSUN bedroom outputs:
MNIST Loss curves:
LSUN Loss curves:
- class pl_bolts.models.gans.DCGAN(beta1=0.5, feature_maps_gen=64, feature_maps_disc=64, image_channels=1, latent_dim=100, learning_rate=0.0002, **kwargs)[source]
Bases:
pytorch_lightning.
DCGAN implementation.
Example:
from pl_bolts.models.gans import DCGAN m = DCGAN() Trainer(gpus=2).fit(m)
Example CLI:
# mnist python dcgan_module.py --gpus 1 # cifar10 python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
SRGAN¶
SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. The implementation is based on the version from deeplearning.ai.
Implemented by:
MNIST results:
SRGAN MNIST with scale factor of 2 (left: low res, middle: generated high res, right: ground truth high res):
SRGAN MNIST with scale factor of 4:
- SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000- SRGAN training command used::
>>> python srgan_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --batch_size=16 \ --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000
STL10 results:
SRGAN STL10 with scale factor of 2:
SRGAN STL10 with scale factor of 4:
- SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000- SRGAN training command used::
>>> python srgan_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --batch_size=16 \ --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000
CelebA results:
SRGAN CelebA with scale factor of 2:
SRGAN CelebA with scale factor of 4:
- SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \ --batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000- SRGAN training command used::
>>> python srgan_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --batch_size=16 \ --num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000
- class pl_bolts.models.gans.SRGAN(image_channels=3, feature_maps_gen=64, feature_maps_disc=64, num_res_blocks=16, scale_factor=4, generator_checkpoint=None, learning_rate=0.0001, scheduler_step=100, **kwargs)[source]
Bases:
pytorch_lightning.
SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. It uses a pretrained SRResNet model as the generator if available.
Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:
You can pretrain a SRResNet model with
srresnet_module.py
.Example:
from pl_bolts.models.gan import SRGAN m = SRGAN() Trainer(gpus=1).fit(m)
Example CLI:
# CelebA dataset, scale_factor 4 python srgan_module.py --dataset=celeba --scale_factor=4 --gpus=1 # MNIST dataset, scale_factor 4 python srgan_module.py --dataset=mnist --scale_factor=4 --gpus=1 # STL10 dataset, scale_factor 4 python srgan_module.py --dataset=stl10 --scale_factor=4 --gpus=1
- Parameters
image_channels¶ (
int
) – Number of channels of the images from the datasetfeature_maps_gen¶ (
int
) – Number of feature maps to use for the generatorfeature_maps_disc¶ (
int
) – Number of feature maps to use for the discriminatornum_res_blocks¶ (
int
) – Number of res blocks to use in the generatorscale_factor¶ (
int
) – Scale factor for the images (either 2 or 4)generator_checkpoint¶ (
Optional
[str
]) – Generator checkpoint created with SRResNet modulescheduler_step¶ (
int
) – Number of epochs after which the learning rate gets decayed
- class pl_bolts.models.gans.SRResNet(image_channels=3, feature_maps=64, num_res_blocks=16, scale_factor=4, learning_rate=0.0001, **kwargs)[source]
Bases:
pytorch_lightning.
SRResNet implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. A pretrained SRResNet model is used as the generator for SRGAN.
Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:
Example:
from pl_bolts.models.gan import SRResNet m = SRResNet() Trainer(gpus=1).fit(m)
Example CLI:
# CelebA dataset, scale_factor 4 python srresnet_module.py --dataset=celeba --scale_factor=4 --gpus=1 # MNIST dataset, scale_factor 4 python srresnet_module.py --dataset=mnist --scale_factor=4 --gpus=1 # STL10 dataset, scale_factor 4 python srresnet_module.py --dataset=stl10 --scale_factor=4 --gpus=1
- Parameters