Shortcuts

GANs

Collection of Generative Adversarial Networks


Basic GAN

This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.

Implemented by:

  • William Falcon

Example outputs:

Basic GAN generated samples

Loss curves:

Basic GAN disc loss Basic GAN gen loss
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
class pl_bolts.models.gans.GAN(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

Vanilla GAN implementation.

Example:

from pl_bolts.models.gans import GAN

m = GAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python  basic_gan_module.py --gpus 1

# imagenet
python  basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
--batch_size 256 --learning_rate 0.0001
Parameters
  • input_channels (int) – number of channels of an image

  • input_height (int) – image height

  • input_width (int) – image width

  • latent_dim (int) – emb dim for encoder

  • learning_rate (float) – the learning rate

forward(z)[source]

Generates an image given input noise z.

Example:

z = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(z)

DCGAN

DCGAN implementation from the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The implementation is based on the version from PyTorch’s examples.

Implemented by:

Example MNIST outputs:

DCGAN generated MNIST samples

Example LSUN bedroom outputs:

DCGAN generated LSUN bedroom samples

MNIST Loss curves:

DCGAN MNIST disc loss DCGAN MNIST gen loss

LSUN Loss curves:

DCGAN LSUN disc loss DCGAN LSUN gen loss
class pl_bolts.models.gans.DCGAN(beta1=0.5, feature_maps_gen=64, feature_maps_disc=64, image_channels=1, latent_dim=100, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

DCGAN implementation.

Example:

from pl_bolts.models.gans import DCGAN

m = DCGAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python dcgan_module.py --gpus 1

# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
Parameters
  • beta1 (float) – Beta1 value for Adam optimizer

  • feature_maps_gen (int) – Number of feature maps to use for the generator

  • feature_maps_disc (int) – Number of feature maps to use for the discriminator

  • image_channels (int) – Number of channels of the images from the dataset

  • latent_dim (int) – Dimension of the latent space

  • learning_rate (float) – Learning rate

forward(noise)[source]

Generates an image given input noise.

Example:

noise = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(noise)
Return type

Tensor

SRGAN

SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. The implementation is based on the version from deeplearning.ai.

Implemented by:

MNIST results:

SRGAN MNIST with scale factor of 2 (left: low res, middle: generated high res, right: ground truth high res):

SRGAN MNIST with scale factor of 2

SRGAN MNIST with scale factor of 4:

SRGAN MNIST with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

STL10 results:

SRGAN STL10 with scale factor of 2:

SRGAN STL10 with scale factor of 2

SRGAN STL10 with scale factor of 4:

SRGAN STL10 with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

CelebA results:

SRGAN CelebA with scale factor of 2:

SRGAN CelebA with scale factor of 2

SRGAN CelebA with scale factor of 4:

SRGAN CelebA with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000
class pl_bolts.models.gans.SRGAN(image_channels=3, feature_maps_gen=64, feature_maps_disc=64, num_res_blocks=16, scale_factor=4, generator_checkpoint=None, learning_rate=0.0001, scheduler_step=100, **kwargs)[source]

Bases: pytorch_lightning.

SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. It uses a pretrained SRResNet model as the generator if available.

Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:

You can pretrain a SRResNet model with srresnet_module.py.

Example:

from pl_bolts.models.gan import SRGAN

m = SRGAN()
Trainer(gpus=1).fit(m)

Example CLI:

# CelebA dataset, scale_factor 4
python srgan_module.py --dataset=celeba --scale_factor=4 --gpus=1

# MNIST dataset, scale_factor 4
python srgan_module.py --dataset=mnist --scale_factor=4 --gpus=1

# STL10 dataset, scale_factor 4
python srgan_module.py --dataset=stl10 --scale_factor=4 --gpus=1
Parameters
  • image_channels (int) – Number of channels of the images from the dataset

  • feature_maps_gen (int) – Number of feature maps to use for the generator

  • feature_maps_disc (int) – Number of feature maps to use for the discriminator

  • num_res_blocks (int) – Number of res blocks to use in the generator

  • scale_factor (int) – Scale factor for the images (either 2 or 4)

  • generator_checkpoint (Optional[str]) – Generator checkpoint created with SRResNet module

  • learning_rate (float) – Learning rate

  • scheduler_step (int) – Number of epochs after which the learning rate gets decayed

forward(lr_image)[source]

Generates a high resolution image given a low resolution image.

Example:

srgan = SRGAN.load_from_checkpoint(PATH)
hr_image = srgan(lr_image)
Return type

Tensor

class pl_bolts.models.gans.SRResNet(image_channels=3, feature_maps=64, num_res_blocks=16, scale_factor=4, learning_rate=0.0001, **kwargs)[source]

Bases: pytorch_lightning.

SRResNet implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. A pretrained SRResNet model is used as the generator for SRGAN.

Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:

Example:

from pl_bolts.models.gan import SRResNet

m = SRResNet()
Trainer(gpus=1).fit(m)

Example CLI:

# CelebA dataset, scale_factor 4
python srresnet_module.py --dataset=celeba --scale_factor=4 --gpus=1

# MNIST dataset, scale_factor 4
python srresnet_module.py --dataset=mnist --scale_factor=4 --gpus=1

# STL10 dataset, scale_factor 4
python srresnet_module.py --dataset=stl10 --scale_factor=4 --gpus=1
Parameters
  • image_channels (int) – Number of channels of the images from the dataset

  • feature_maps (int) – Number of feature maps to use

  • num_res_blocks (int) – Number of res blocks to use in the generator

  • scale_factor (int) – Scale factor for the images (either 2 or 4)

  • learning_rate (float) – Learning rate

forward(lr_image)[source]

Creates a high resolution image given a low resolution image.

Example:

srresnet = SRResNet.load_from_checkpoint(PATH)
hr_image = srresnet(lr_image)
Return type

Tensor

Read the Docs v: 0.5.0
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.