GANs¶
Collection of Generative Adversarial Networks
Basic GAN¶
This is a vanilla GAN. This model can work on any dataset size but results are shown for MNIST. Replace the encoder, decoder or any part of the training loop to build a new method, or simply finetune on your data.
Implemented by:
William Falcon
Example outputs:
Loss curves:
from pl_bolts.models.gans import GAN
...
gan = GAN()
trainer = Trainer()
trainer.fit(gan)
-
class
pl_bolts.models.gans.
GAN
(input_channels, input_height, input_width, latent_dim=32, learning_rate=0.0002, **kwargs)[source] Bases:
pytorch_lightning.
Vanilla GAN implementation.
Example:
from pl_bolts.models.gans import GAN m = GAN() Trainer(gpus=2).fit(m)
Example CLI:
# mnist python basic_gan_module.py --gpus 1 # imagenet python basic_gan_module.py --gpus 1 --dataset 'imagenet2012' --data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder --batch_size 256 --learning_rate 0.0001
- Parameters
-
forward
(z)[source] Generates an image given input noise z.
Example:
z = torch.rand(batch_size, latent_dim) gan = GAN.load_from_checkpoint(PATH) img = gan(z)
DCGAN¶
DCGAN implementation from the paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. The implementation is based on the version from PyTorch’s examples.
Implemented by:
Example MNIST outputs:
Example LSUN bedroom outputs:
MNIST Loss curves:
LSUN Loss curves:
-
class
pl_bolts.models.gans.
DCGAN
(beta1=0.5, feature_maps_gen=64, feature_maps_disc=64, image_channels=1, latent_dim=100, learning_rate=0.0002, **kwargs)[source] Bases:
pytorch_lightning.
DCGAN implementation.
Example:
from pl_bolts.models.gans import DCGAN m = DCGAN() Trainer(gpus=2).fit(m)
Example CLI:
# mnist python dcgan_module.py --gpus 1 # cifar10 python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3