Shortcuts

GANs

Collection of Generative Adversarial Networks


Basic GAN

This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.

Implemented by:

  • William Falcon

Example outputs:

Basic GAN generated samples

Loss curves:

Basic GAN disc loss Basic GAN gen loss
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
class pl_bolts.models.gans.GAN(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

Vanilla GAN implementation.

Example:

from pl_bolts.models.gans import GAN

m = GAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python  basic_gan_module.py --gpus 1

# imagenet
python  basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
--data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder
--batch_size 256 --learning_rate 0.0001
Parameters
  • input_channels (int) – number of channels of an image

  • input_height (int) – image height

  • input_width (int) – image width

  • latent_dim (int) – emb dim for encoder

  • learning_rate (float) – the learning rate

forward(z)[source]

Generates an image given input noise z.

Example:

z = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(z)

DCGAN

DCGAN implementation from the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The implementation is based on the version from PyTorch’s examples.

Implemented by:

Example MNIST outputs:

DCGAN generated MNIST samples

Example LSUN bedroom outputs:

DCGAN generated LSUN bedroom samples

MNIST Loss curves:

DCGAN MNIST disc loss DCGAN MNIST gen loss

LSUN Loss curves:

DCGAN LSUN disc loss DCGAN LSUN gen loss
class pl_bolts.models.gans.DCGAN(beta1=0.5, feature_maps_gen=64, feature_maps_disc=64, image_channels=1, latent_dim=100, learning_rate=0.0002, **kwargs)[source]

Bases: pytorch_lightning.

DCGAN implementation.

Example:

from pl_bolts.models.gans import DCGAN

m = DCGAN()
Trainer(gpus=2).fit(m)

Example CLI:

# mnist
python dcgan_module.py --gpus 1

# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
Parameters
  • beta1 (float) – Beta1 value for Adam optimizer

  • feature_maps_gen (int) – Number of feature maps to use for the generator

  • feature_maps_disc (int) – Number of feature maps to use for the discriminator

  • image_channels (int) – Number of channels of the images from the dataset

  • latent_dim (int) – Dimension of the latent space

  • learning_rate (float) – Learning rate

forward(noise)[source]

Generates an image given input noise.

Example:

noise = torch.rand(batch_size, latent_dim)
gan = GAN.load_from_checkpoint(PATH)
img = gan(noise)
Return type

Tensor

Read the Docs v: 0.4.0
Versions
latest
stable
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.