Shortcuts

GANs

Collection of Generative Adversarial Networks

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


Basic GAN

This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.

Implemented by:

  • William Falcon

Example outputs:

Basic GAN generated samples

Loss curves:

Basic GAN disc loss Basic GAN gen loss
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
class pl_bolts.models.gans.GAN(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

Vanilla GAN implementation.

Example:

from pl_bolts.models.gans import GAN

m = GAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python basic_gan_module.py --gpus 1

# imagenet
python  basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
--batch_size 256 --learning_rate 0.0001
Parameters
  • input_channels (int) – number of channels of an image

  • input_height (int) – image height

  • input_width (int) – image width

  • latent_dim (int) – emb dim for encoder

  • learning_rate (float) – the learning rate

forward(z)[source]

Generates an image given input noise z.

Example:

z = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(z)

DCGAN

DCGAN implementation from the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The implementation is based on the version from PyTorch’s examples.

Implemented by:

Example MNIST outputs:

DCGAN generated MNIST samples

Example LSUN bedroom outputs:

DCGAN generated LSUN bedroom samples

MNIST Loss curves:

DCGAN MNIST disc loss DCGAN MNIST gen loss

LSUN Loss curves:

DCGAN LSUN disc loss DCGAN LSUN gen loss
class pl_bolts.models.gans.DCGAN(beta1=0.5, feature_maps_gen=64, feature_maps_disc=64, image_channels=1, latent_dim=100, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

Warning

The feature DCGAN is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

DCGAN implementation.

Example:

from pl_bolts.models.gans import DCGAN

m = DCGAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python dcgan_module.py --gpus 1

# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
Parameters
  • beta1 (float) – Beta1 value for Adam optimizer

  • feature_maps_gen (int) – Number of feature maps to use for the generator

  • feature_maps_disc (int) – Number of feature maps to use for the discriminator

  • image_channels (int) – Number of channels of the images from the dataset

  • latent_dim (int) – Dimension of the latent space

  • learning_rate (float) – Learning rate

forward(noise)[source]

Generates an image given input noise.

Example:

noise = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(noise)
Return type

Tensor

SRGAN

SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. The implementation is based on the version from deeplearning.ai.

Implemented by:

MNIST results:

SRGAN MNIST with scale factor of 2 (left: low res, middle: generated high res, right: ground truth high res):

SRGAN MNIST with scale factor of 2

SRGAN MNIST with scale factor of 4:

SRGAN MNIST with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

STL10 results:

SRGAN STL10 with scale factor of 2:

SRGAN STL10 with scale factor of 2

SRGAN STL10 with scale factor of 4:

SRGAN STL10 with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

CelebA results:

SRGAN CelebA with scale factor of 2:

SRGAN CelebA with scale factor of 2

SRGAN CelebA with scale factor of 4:

SRGAN CelebA with scale factor of 4
SRResNet pretraining command used::
>>>  python srresnet_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000
SRGAN training command used::
>>>  python srgan_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000
class pl_bolts.models.gans.SRGAN(image_channels=3, feature_maps_gen=64, feature_maps_disc=64, num_res_blocks=16, scale_factor=4, generator_checkpoint=None, learning_rate=0.0001, scheduler_step=100, **kwargs)[source]

Bases: pytorch_lightning.

Warning

The feature SRGAN is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

SRGAN implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. It uses a pretrained SRResNet model as the generator if available.

Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:

You can pretrain a SRResNet model with srresnet_module.py.

Example:

from pl_bolts.models.gan import SRGAN

m = SRGAN()
Trainer(gpus=1).fit(m)

Example CLI:

# CelebA dataset, scale_factor 4
python srgan_module.py --dataset=celeba --scale_factor=4 --gpus=1

# MNIST dataset, scale_factor 4
python srgan_module.py --dataset=mnist --scale_factor=4 --gpus=1

# STL10 dataset, scale_factor 4
python srgan_module.py --dataset=stl10 --scale_factor=4 --gpus=1
Parameters
  • image_channels (int) – Number of channels of the images from the dataset

  • feature_maps_gen (int) – Number of feature maps to use for the generator

  • feature_maps_disc (int) – Number of feature maps to use for the discriminator

  • num_res_blocks (int) – Number of res blocks to use in the generator

  • scale_factor (int) – Scale factor for the images (either 2 or 4)

  • generator_checkpoint (Optional[str]) – Generator checkpoint created with SRResNet module

  • learning_rate (float) – Learning rate

  • scheduler_step (int) – Number of epochs after which the learning rate gets decayed

forward(lr_image)[source]

Generates a high resolution image given a low resolution image.

Example:

srgan = SRGAN.load_from_checkpoint(PATH)
hr_image = srgan(lr_image)
Return type

Tensor

class pl_bolts.models.gans.SRResNet(image_channels=3, feature_maps=64, num_res_blocks=16, scale_factor=4, learning_rate=0.0001, **kwargs)[source]

Bases: pytorch_lightning.

Warning

The feature SRResNet is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

SRResNet implementation from the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. A pretrained SRResNet model is used as the generator for SRGAN.

Code adapted from https-deeplearning-ai/GANs-Public to Lightning by:

Example:

from pl_bolts.models.gan import SRResNet

m = SRResNet()
Trainer(gpus=1).fit(m)

Example CLI:

# CelebA dataset, scale_factor 4
python srresnet_module.py --dataset=celeba --scale_factor=4 --gpus=1

# MNIST dataset, scale_factor 4
python srresnet_module.py --dataset=mnist --scale_factor=4 --gpus=1

# STL10 dataset, scale_factor 4
python srresnet_module.py --dataset=stl10 --scale_factor=4 --gpus=1
Parameters
  • image_channels (int) – Number of channels of the images from the dataset

  • feature_maps (int) – Number of feature maps to use

  • num_res_blocks (int) – Number of res blocks to use in the generator

  • scale_factor (int) – Scale factor for the images (either 2 or 4)

  • learning_rate (float) – Learning rate

forward(lr_image)[source]

Creates a high resolution image given a low resolution image.

Example:

srresnet = SRResNet.load_from_checkpoint(PATH)
hr_image = srresnet(lr_image)
Return type

Tensor

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.