
Self-supervised Learning

This bolts module houses a collection of all self-supervised learning models.

Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.

Self-supervised models are trained with unlabeled datasets


We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!

Use cases

Here are some use cases for the self-supervised package.

Extracting image features

The models in this module are trained unsupervised and thus can capture better image representations (features).

In this example, we’ll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.

from pl_bolts.models.self_supervised import SimCLR

# load resnet50 pretrained using SimCLR on imagenet
weight_path = ''
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

simclr_resnet50 = simclr.encoder

This means you can now extract image representations that were pretrained via unsupervised learning.


my_dataset = SomeDataset()
for batch in my_dataset:
    x, y = batch
    out = simclr_resnet50(x)

Train with unlabeled data

These models are perfect for training from scratch when you have a huge set of unlabeled images

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.transforms.self_supervised.simclr_transforms import (

train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())

# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)


Mix and match any part, or subclass to create your own new method

from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask

amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPC_v2(contrastive_task=amdim_task)

Contrastive Learning Models

Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.

In this section, we list Lightning implementations of popular contrastive learning approaches.


class pl_bolts.models.self_supervised.AMDIM(datamodule='cifar10', encoder='amdim_encoder', contrastive_task=FeatureMapContrastiveTask(   (nce_loss): AmdimNCELoss() ), image_channels=3, image_height=32, encoder_feature_dim=320, embedding_fx_dim=1280, conv_block_depth=10, use_bn=False, tclip=20.0, learning_rate=0.0002, data_dir='', num_classes=10, batch_size=200, num_workers=16, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule


The feature AMDIM is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details:

PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM)

Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.

Model implemented by: William Falcon

This code is adapted to Lightning using the original author repo (the original repo).


>>> from pl_bolts.models.self_supervised import AMDIM
>>> model = AMDIM(encoder='resnet18')


trainer = Trainer()
  • datamodule (Union[str, LightningDataModule]) – A LightningDatamodule

  • encoder (Union[str, Module, LightningModule]) – an encoder string or model

  • image_channels (int) – 3

  • image_height (int) – pixels

  • encoder_feature_dim (int) – Called ndf in the paper, this is the representation size for the encoder.

  • embedding_fx_dim (int) – Output dim of the embedding function (nrkhs in the paper) (Reproducing Kernel Hilbert Spaces).

  • conv_block_depth (int) – Depth of each encoder block,

  • use_bn (bool) – If true will use batchnorm.

  • tclip (int) – soft clipping non-linearity to the scores after computing the regularization term and before computing the log-softmax. This is the ‘second trick’ used in the paper

  • learning_rate (int) – The learning rate

  • data_dir (str) – Where to store data

  • num_classes (int) – How many classes in the dataset

  • batch_size (int) – The batch size


Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.

forward(img_1, img_2)[source]

Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output


Implement one or more PyTorch DataLoaders for training.


A collection of specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.


do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()


Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.


# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
    loader =
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader =
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    cifar_loader =
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader =
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    cifar_loader =
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
training_step(batch, batch_nb)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.


Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.


def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}


The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.


When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.


Use this when training with dp because training_step() will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.


If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code

# pseudocode
sub_batches = split_batches_for_dp(batch)
step_output = [training_step(sub_batch) for sub_batch in sub_batches]

step_output – What you return in training_step for each batch part.



When using the DP strategy, only a portion of the batch is inside the training_step:

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self(x)

    # softmax uses only a portion of the batch in the denominator
    loss = self.softmax(out)
    loss = nce_loss(loss)
    return loss

If you wish to do something with all the parts of the batch, then use this method to do it:

def training_step(self, batch, batch_idx):
    # batch is 1/num_gpus big
    x, y = batch

    out = self.encoder(x)
    return {"pred": out}

def training_step_end(self, training_step_outputs):
    gpu_0_pred = training_step_outputs[0]["pred"]
    gpu_1_pred = training_step_outputs[1]["pred"]
    gpu_n_pred = training_step_outputs[n]["pred"]

    # this softmax now uses the full batch
    loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
    return loss

See also

See the Multi GPU Training guide for more details.


Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()


Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.


A or a sequence of them specifying validation samples.


def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader =

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]


If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.


In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.


Called at the end of the validation epoch with the outputs of all validation steps.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)

outputs – List of outputs you defined in validation_step(), or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.




If you didn’t define a validation_step(), this won’t be called.


With a single dataloader:

def validation_epoch_end(self, val_step_outputs):
    for out in val_step_outputs:

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

def validation_epoch_end(self, outputs):
    for dataloader_output_result in outputs:
        dataloader_outs = dataloader_output_result.dataloader_i_outputs

    self.log("final_metric", final_value)
validation_step(batch, batch_nb)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)


  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):


# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.


If you don’t need to validate you don’t need to implement this method.


When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.


class pl_bolts.models.self_supervised.BYOL(learning_rate=0.2, weight_decay=1.5e-06, warmup_epochs=10, max_epochs=1000, base_encoder='resnet50', encoder_out_dim=2048, projector_hidden_dim=4096, projector_out_dim=256, initial_tau=0.996, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule

PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL)_

Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

  • learning_rate (float, optional) – optimizer learning rate. Defaults to 0.2.

  • weight_decay (float, optional) – optimizer weight decay. Defaults to 1.5e-6.

  • warmup_epochs (int, optional) – number of epochs for scheduler warmup. Defaults to 10.

  • max_epochs (int, optional) – maximum number of epochs for scheduler. Defaults to 1000.

  • base_encoder (Union[str, torch.nn.Module], optional) – base encoder architecture. Defaults to “resnet50”.

  • encoder_out_dim (int, optional) – base encoder output dimension. Defaults to 2048.

  • projector_hidden_dim (int, optional) – projector MLP hidden dimension. Defaults to 4096.

  • projector_out_dim (int, optional) – projector MLP output dimension. Defaults to 256.

  • initial_tau (float, optional) – initial value of target decay rate used. Defaults to 0.996.

Model implemented by:


model = BYOL(num_classes=10)

dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = pl.Trainer(), datamodule=dm)

CLI command:

# cifar10
python --gpus 1

# imagenet
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
calculate_loss(v_online, v_target)[source]

Calculates similarity loss between the online network prediction of target network projection.

  • v_online (Tensor) – Online network view

  • v_target (Tensor) – Target network view

Return type



Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.


Returns the encoded representation of a view.


x (Tensor) – sample to be encoded

Return type


on_train_batch_end(outputs, batch, batch_idx)[source]

Add callback to perform exponential moving average weight update on target network.

Return type


training_step(batch, batch_idx)[source]

Complete training loop.

Return type


validation_step(batch, batch_idx)[source]

Complete validation loop.

Return type


CPC (V2)

PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding

Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).

Model implemented by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.self_supervised.cpc_transforms import (

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()

# model
model = CPC_v2()

# fit
trainer = pl.Trainer(), datamodule=dm)

To finetune:

    --ckpt_path path/to/checkpoint.ckpt
    --dataset cifar10
    --gpus 1

CIFAR-10 and STL-10 baselines

CPCv2 does not report baselines on CIFAR-10 and STL-10 datasets. Results in table are reported from the YADIM paper.

CPCv2 implementation results


test acc












1000 (upto 24 hours)

1 V100 (32GB)







1000 (upto 72 hours)

4 V100 (32GB)







1000 (upto 21 days)

64 V100 (32GB)


CIFAR-10 pretrained model:

from pl_bolts.models.self_supervised import CPC_v2

weight_path = ''
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)



pretraining validation loss


online finetuning accuracy

STL-10 pretrained model:

from pl_bolts.models.self_supervised import CPC_v2

weight_path = ''
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)



pretraining validation loss


online finetuning accuracy

CPC (v2) API

class pl_bolts.models.self_supervised.CPC_v2(encoder_name='cpc_encoder', patch_size=8, patch_overlap=4, online_ft=True, task='cpc', num_workers=4, num_classes=10, learning_rate=0.0001, pretrained=None, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule


The feature CPC_v2 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details:

  • encoder_name (str) – A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder

  • patch_size (int) – How big to make the image patches

  • patch_overlap (int) – How much overlap each patch should have

  • online_ft (bool) – If True, enables a 1024-unit MLP to fine-tune online

  • task (str) – Which self-supervised task to use (‘cpc’, ‘amdim’, etc…)

  • num_workers (int) – number of dataloader workers

  • num_classes (int) – number of classes

  • learning_rate (float) – learning rate

  • pretrained (Optional[str]) – If true, will use the weights pretrained (using CPC) on Imagenet


Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.


Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

training_step(batch, batch_nb)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.


Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.


def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}


The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.


When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_nb)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)


  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):


# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.


If you don’t need to validate you don’t need to implement this method.


When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

Moco (v2) API

class pl_bolts.models.self_supervised.Moco_v2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03, momentum=0.9, weight_decay=0.0001, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule


The feature Moco_v2 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details:

PyTorch Lightning implementation of Moco

Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.

Code adapted from facebookresearch/moco to Lightning by:


from pl_bolts.models.self_supervised import Moco_v2
model = Moco_v2()
trainer = Trainer()

CLI command:

# cifar10
python --gpus 1

# imagenet
    --gpus 8
    --dataset imagenet2012
    --data_dir /path/to/imagenet/
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
  • base_encoder (Union[str, Module]) – torchvision model name or torch.nn.Module

  • emb_dim (int) – feature dimension (default: 128)

  • num_negatives (int) – queue size; number of negative keys (default: 65536)

  • encoder_momentum (float) – moco momentum of updating key encoder (default: 0.999)

  • softmax_temperature (float) – softmax temperature (default: 0.07)

  • learning_rate (float) – the learning rate

  • momentum (float) – optimizer momentum

  • weight_decay (float) – optimizer weight decay

  • datamodule – the DataModule (train, val, test dataloaders)

  • data_dir (str) – the directory to store data

  • batch_size (int) – batch size

  • use_mlp (bool) – add an mlp to the encoders

  • num_workers (int) – workers for the loaders


Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.

forward(img_q, img_k, queue)[source]

im_q: a batch of query images im_k: a batch of key images queue: a queue from which to pick negative samples


logits, targets


Override to add your own encoders.

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.


Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.


def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}


The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.


When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.


Called at the end of the validation epoch with the outputs of all validation steps.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)

outputs – List of outputs you defined in validation_step(), or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.




If you didn’t define a validation_step(), this won’t be called.


With a single dataloader:

def validation_epoch_end(self, val_step_outputs):
    for out in val_step_outputs:

With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.

def validation_epoch_end(self, outputs):
    for dataloader_output_result in outputs:
        dataloader_outs = dataloader_output_result.dataloader_i_outputs

    self.log("final_metric", final_value)
validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)


  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):


# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.


If you don’t need to validate you don’t need to implement this method.


When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.


PyTorch Lightning implementation of SimCLR

Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

Model implemented by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')

# fit
trainer = pl.Trainer(), datamodule=dm)

CIFAR-10 baseline

Cifar-10 implementation results


test acc




















800 (4 hours)

8 V100 (16GB)


CIFAR-10 pretrained model:

from pl_bolts.models.self_supervised import SimCLR

weight_path = ''
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)



pretraining validation loss

Fine-tuning (Single layer MLP, 1024 hidden units):

finetuning validation accuracy
finetuning test accuracy

To reproduce:

# pretrain
    --gpus 8
    --dataset cifar10
    --batch_size 256
    --num_workers 16
    --optimizer sgd
    --learning_rate 1.5
    --max_epochs 800

# finetune
    --gpus 4
    --ckpt_path path/to/simclr/ckpt
    --dataset cifar10
    --batch_size 64
    --num_workers 8
    --learning_rate 0.3
    --num_epochs 100

Imagenet baseline for SimCLR

Cifar-10 implementation results


test acc





















64 V100 (16GB)


Imagenet pretrained model:

from pl_bolts.models.self_supervised import SimCLR

weight_path = ''
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)


To reproduce:

# pretrain
    --dataset imagenet
    --data_path path/to/imagenet

# finetune
    --gpus 8
    --ckpt_path path/to/simclr/ckpt
    --dataset imagenet
    --data_dir path/to/imagenet/dataset
    --batch_size 256
    --num_workers 16
    --learning_rate 0.8
    --nesterov True
    --num_epochs 90


class pl_bolts.models.self_supervised.SimCLR(gpus, num_samples, batch_size, dataset, num_nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, temperature=0.1, first_conv=True, maxpool1=True, optimizer='adam', exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule


The feature SimCLR is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details:

  • batch_size (int) – the batch size

  • num_samples (int) – num samples in the dataset

  • warmup_epochs (int) – epochs to warmup the lr for

  • lr – the optimizer learning rate

  • opt_weight_decay – the optimizer weight decay

  • loss_temperature – the loss temperature


Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.


Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output

nt_xent_loss(out_1, out_2, temperature, eps=1e-06)[source]

assume out_1 and out_2 are normalized out_1: [batch_size, dim] out_2: [batch_size, dim]

training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.


Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.


def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}


The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.


When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)


  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):


# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.


If you don’t need to validate you don’t need to implement this method.


When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.


PyTorch Lightning implementation of SwAV Adapted from the official implementation

Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.

Implementation adapted by:

To Train:

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(

dm.val_transforms = SwAVEvalDataTransform(

# model
model = SwAV(

# fit
trainer = pl.Trainer(precision=16, accelerator='auto'), datamodule=dm)

Pre-trained ImageNet

We have included an option to directly load ImageNet weights provided by FAIR into bolts.

You can load the pretrained model using:

ImageNet pretrained model:

from pl_bolts.models.self_supervised import SwAV

weight_path = ''
swav = SwAV.load_from_checkpoint(weight_path, strict=True)


STL-10 baseline

The original paper does not provide baselines on STL10.

STL-10 implementation results


test acc




Queue used






SwAV resnet50




100 (~9 hr)

1 V100 (16GB)


STL-10 pretrained model:

from pl_bolts.models.self_supervised import SwAV

weight_path = ''
swav = SwAV.load_from_checkpoint(weight_path, strict=False)



pretraining validation loss
online finetuning validation acc

Fine-tuning (Single layer MLP, 1024 hidden units):

finetuning validation accuracy
finetuning validation loss

To reproduce:

# pretrain
    --gpus 1
    --batch_size 128
    --learning_rate 1e-3
    --queue_length 0
    --jitter_strength 1.
    --num_prototypes 512

# finetune
--gpus 8
--ckpt_path path/to/simclr/ckpt
--dataset imagenet
--data_dir path/to/imagenet/dataset
--batch_size 256
--num_workers 16
--learning_rate 0.8
--nesterov True
--num_epochs 90

Imagenet baseline for SwAV

Cifar-10 implementation results


test acc













64 V100s








64 V100 (16GB)


Imagenet pretrained model:

from pl_bolts.models.self_supervised import SwAV

weight_path = ''
swav = SwAV.load_from_checkpoint(weight_path, strict=False)



class pl_bolts.models.self_supervised.SwAV(gpus, num_samples, batch_size, dataset, num_nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, num_prototypes=3000, freeze_prototypes_epochs=1, temperature=0.1, sinkhorn_iterations=3, queue_length=0, queue_path='queue', epoch_queue_starts=15, crops_for_assign=(0, 1), num_crops=(2, 6), first_conv=True, maxpool1=True, optimizer='adam', exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, epsilon=0.05, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule

  • gpus (int) – number of gpus per node used in training, passed to SwAV module to manage the queue and select distributed sinkhorn

  • num_nodes (int) – number of nodes to train on

  • num_samples (int) – number of image samples used for training

  • batch_size (int) – batch size per GPU in ddp

  • dataset (str) – dataset being used for train/val

  • arch (str) – encoder architecture used for pre-training

  • hidden_mlp (int) – hidden layer of non-linear projection head, set to 0 to use a linear projection head

  • feat_dim (int) – output dim of the projection head

  • warmup_epochs (int) – apply linear warmup for this many epochs

  • max_epochs (int) – epoch count for pre-training

  • num_prototypes (int) – count of prototype vectors

  • freeze_prototypes_epochs (int) – epoch till which gradients of prototype layer are frozen

  • temperature (float) – loss temperature

  • sinkhorn_iterations (int) – iterations for sinkhorn normalization

  • queue_length (int) – set queue when batch size is small, must be divisible by total batch-size (i.e. total_gpus * batch_size), set to 0 to remove the queue

  • queue_path (str) – folder within the logs directory

  • epoch_queue_starts (int) – start uing the queue after this epoch

  • crops_for_assign (tuple) – list of crop ids for computing assignment

  • num_crops (tuple) – number of global and local crops, ex: [2, 6]

  • first_conv (bool) – keep first conv same as the original resnet architecture, if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)

  • maxpool1 (bool) – keep first maxpool layer same as the original resnet architecture, if set to false, first maxpool is turned off (cifar10, maybe stl10)

  • optimizer (str) – optimizer to use

  • exclude_bn_bias (bool) – exclude batchnorm and bias layers from weight decay in optimizers

  • start_lr (float) – starting lr for linear warmup

  • learning_rate (float) – learning rate

  • final_lr (float) – float = final learning rate for cosine weight decay

  • weight_decay (float) – weight decay for optimizer

  • epsilon (float) – epsilon val for swav assignments


Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.


Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • Tuple of dictionaries as described above, with an optional "frequency" key.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated"
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.


The frequency value specified in a dict along with the optimizer key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:

  • In the former case, all optimizers will operate on the given batch in each optimization step.

  • In the latter, only one optimizer will operate on the given batch at every step.

This is different from the frequency value specified in the lr_scheduler_config mentioned above.

def configure_optimizers(self):
    optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
    optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
    return [
        {"optimizer": optimizer_one, "frequency": 5},
        {"optimizer": optimizer_two, "frequency": 10},

In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the lr_scheduler key in the above dict, the scheduler will only be updated when its optimizer is being used.


# most cases. no learning rate scheduler
def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    dis_sch = CosineAnnealing(dis_opt, T_max=10)
    return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    gen_sch = {
        'scheduler': ExponentialLR(gen_opt, 0.99),
        'interval': 'step'  # called after each training step
    dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
    return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
def configure_optimizers(self):
    gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
    dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
    n_critic = 5
    return (
        {'optimizer': dis_opt, 'frequency': n_critic},
        {'optimizer': gen_opt, 'frequency': 1}


Some things to know:

  • Lightning calls .backward() and .step() on each optimizer as needed.

  • If learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizers.

  • If you use multiple optimizers, training_step() will have an additional optimizer_idx parameter.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.

  • If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step() hook.


Same as torch.nn.Module.forward().

  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.


Your model’s output


Called after loss.backward() and before optimizers are stepped.


If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.


Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule OR

  2. Cache data across steps on the attribute(s) of the LightningModule and access them in this hook

Return type



Called in the training loop at the very beginning of the epoch.


Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.


stage – either 'fit', 'validate', 'test', or 'predict'


class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

  • batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.

  • batch_idx (int) – Integer displaying index of this batch

  • optimizer_idx (int) – When using multiple optimizers, this argument will also be present.

  • hiddens (Any) – Passed in if truncated_bptt_steps > 0.


Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.


def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
    if optimizer_idx == 1:
        # do training_step with decoder

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}


The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.


When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
  • batch – The output of your DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)


  • Any object or value

  • None - Validation will skip to the next batch

# pseudocode of order
val_outs = []
for val_batch in val_data:
    out = validation_step(val_batch)
    if defined("validation_step_end"):
        out = validation_step_end(out)
val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader:
def validation_step(self, batch, batch_idx):

# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0):


# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.


If you don’t need to validate you don’t need to implement this method.


When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.


class pl_bolts.models.self_supervised.SimSiam(learning_rate=0.05, weight_decay=0.0001, momentum=0.9, warmup_epochs=10, max_epochs=100, base_encoder='resnet50', encoder_out_dim=2048, projector_hidden_dim=2048, projector_out_dim=2048, predictor_hidden_dim=512, exclude_bn_bias=False, **kwargs)[source]

Bases: pytorch_lightning.core.module.LightningModule

PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam)_

Paper authors: Xinlei Chen, Kaiming He.

  • learning_rate (float, optional) – optimizer leaning rate. Defaults to 0.05.

  • weight_decay (float, optional) – optimizer weight decay. Defaults to 1e-4.

  • momentum (float, optional) – optimizer momentum. Defaults to 0.9.

  • warmup_epochs (int, optional) – number of epochs for scheduler warmup. Defaults to 10.

  • max_epochs (int, optional) – maximum number of epochs for scheduler. Defaults to 100.

  • base_encoder (Union[str, nn.Module], optional) – base encoder architecture. Defaults to “resnet50”.

  • encoder_out_dim (int, optional) – base encoder output dimension. Defaults to 2048.

  • projector_hidden_dim (int, optional) – projector MLP hidden dimension. Defaults to 2048.

  • projector_out_dim (int, optional) – project MLP output dimension. Defaults to 2048.

  • predictor_hidden_dim (int, optional) – predictor MLP hidden dimension. Defaults to 512.

  • exclude_bn_bias (bool, optional) – option to exclude batchnorm and bias terms from weight decay. Defaults to False.

Model implemented by:


model = SimSiam()

dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = Trainer(), datamodule=dm)

CLI command:

# cifar10
python --gpus 1

# imagenet
    --gpus 8
    --dataset imagenet2012
    --meta_dir /path/to/folder/with/meta.bin/
    --batch_size 32
calculate_loss(v_online, v_target)[source]

Calculates similarity loss between the online network prediction of target network projection.

  • v_online (Tensor) – Online network view

  • v_target (Tensor) – Target network view

Return type



Configure optimizer and learning rate scheduler.

static exclude_from_weight_decay(named_params, weight_decay, skip_list=('bias', 'bn'))[source]

Exclude parameters from weight decay.

Return type



Returns encoded representation of a view.

Return type


training_step(batch, batch_idx)[source]

Complete training loop.

Return type


validation_step(batch, batch_idx)[source]

Complete validation loop.

Return type


Read the Docs v: 0.7.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.