Self-supervised Learning¶
This bolts module houses a collection of all self-supervised learning models.
Self-supervised learning extracts representations of an input by solving a pretext task. In this package, we implement many of the current state-of-the-art self-supervised algorithms.
Self-supervised models are trained with unlabeled datasets
Note
We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!
Use cases¶
Here are some use cases for the self-supervised package.
Extracting image features¶
The models in this module are trained unsupervised and thus can capture better image representations (features).
In this example, we’ll load a resnet 18 which was pretrained on imagenet using CPC as the pretext task.
from pl_bolts.models.self_supervised import SimCLR
# load resnet50 pretrained using SimCLR on imagenet
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr_resnet50 = simclr.encoder
simclr_resnet50.eval()
This means you can now extract image representations that were pretrained via unsupervised learning.
Example:
my_dataset = SomeDataset()
for batch in my_dataset:
x, y = batch
out = simclr_resnet50(x)
Train with unlabeled data¶
These models are perfect for training from scratch when you have a huge set of unlabeled images
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.transforms.self_supervised.simclr_transforms import (
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)
train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
# simclr needs a lot of compute!
model = SimCLR()
trainer = Trainer(tpu_cores=128)
trainer.fit(
model,
DataLoader(train_dataset),
DataLoader(val_dataset),
)
Research¶
Mix and match any part, or subclass to create your own new method
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask
amdim_task = FeatureMapContrastiveTask(comparisons='01, 11, 02', bidirectional=True)
model = CPC_v2(contrastive_task=amdim_task)
Contrastive Learning Models¶
Contrastive self-supervised learning (CSL) is a self-supervised learning approach where we generate representations of instances such that similar instances are near each other and far from dissimilar ones. This is often done by comparing triplets of positive, anchor and negative representations.
In this section, we list Lightning implementations of popular contrastive learning approaches.
AMDIM¶
- class pl_bolts.models.self_supervised.AMDIM(datamodule='cifar10', encoder='amdim_encoder', contrastive_task=FeatureMapContrastiveTask( (nce_loss): AmdimNCELoss() ), image_channels=3, image_height=32, encoder_feature_dim=320, embedding_fx_dim=1280, conv_block_depth=10, use_bn=False, tclip=20.0, learning_rate=0.0002, data_dir='', num_classes=10, batch_size=200, num_workers=16, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
Warning
The feature AMDIM is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
PyTorch Lightning implementation of Augmented Multiscale Deep InfoMax (AMDIM)
Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter.
Model implemented by: William Falcon
This code is adapted to Lightning using the original author repo (the original repo).
Example
>>> from pl_bolts.models.self_supervised import AMDIM ... >>> model = AMDIM(encoder='resnet18')
Train:
trainer = Trainer() trainer.fit(model)
- Parameters
datamodule¶ (
Union
[str
,LightningDataModule
]) – A LightningDatamoduleencoder¶ (
Union
[str
,Module
,LightningModule
]) – an encoder string or modelencoder_feature_dim¶ (
int
) – Called ndf in the paper, this is the representation size for the encoder.embedding_fx_dim¶ (
int
) – Output dim of the embedding function (nrkhs in the paper) (Reproducing Kernel Hilbert Spaces).tclip¶ (
int
) – soft clipping non-linearity to the scores after computing the regularization term and before computing the log-softmax. This is the ‘second trick’ used in the paper
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(img_1, img_2)[source]
Same as
torch.nn.Module.forward()
.
- train_dataloader()[source]
Implement one or more PyTorch DataLoaders for training.
- Returns
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this section.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochs
to a positive integer.For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- training_step(batch, batch_nb)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx¶ (
int
) – Integer displaying index of this batchoptimizer_idx¶ (
int
) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any
) – Passed in iftruncated_bptt_steps
> 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- training_step_end(outputs)[source]
Use this when training with dp because
training_step()
will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss.Note
If you later switch to ddp or some other mode, this will still be called so that you don’t have to change your code
# pseudocode sub_batches = split_batches_for_dp(batch) step_output = [training_step(sub_batch) for sub_batch in sub_batches] training_step_end(step_output)
- Parameters
step_output¶ – What you return in training_step for each batch part.
- Returns
Anything
When using the DP strategy, only a portion of the batch is inside the training_step:
def training_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self(x) # softmax uses only a portion of the batch in the denominator loss = self.softmax(out) loss = nce_loss(loss) return loss
If you wish to do something with all the parts of the batch, then use this method to do it:
def training_step(self, batch, batch_idx): # batch is 1/num_gpus big x, y = batch out = self.encoder(x) return {"pred": out} def training_step_end(self, training_step_outputs): gpu_0_pred = training_step_outputs[0]["pred"] gpu_1_pred = training_step_outputs[1]["pred"] gpu_n_pred = training_step_outputs[n]["pred"] # this softmax now uses the full batch loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred]) return loss
See also
See the Multi GPU Training guide for more details.
- val_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochs
to a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- validation_epoch_end(outputs)[source]
Called at the end of the validation epoch with the outputs of all validation steps.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
outputs¶ – List of outputs you defined in
validation_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.- Returns
None
Note
If you didn’t define a
validation_step()
, this won’t be called.Examples
With a single dataloader:
def validation_epoch_end(self, val_step_outputs): for out in val_step_outputs: ...
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.
def validation_epoch_end(self, outputs): for dataloader_output_result in outputs: dataloader_outs = dataloader_output_result.dataloader_i_outputs self.log("final_metric", final_value)
- validation_step(batch, batch_nb)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch¶ – The output of your
DataLoader
.batch_idx¶ – The index of this batch.
dataloader_idx¶ – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
BYOL¶
- class pl_bolts.models.self_supervised.BYOL(learning_rate=0.2, weight_decay=1.5e-06, warmup_epochs=10, max_epochs=1000, base_encoder='resnet50', encoder_out_dim=2048, projector_hidden_dim=4096, projector_out_dim=256, initial_tau=0.996, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL)_
Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.
- Parameters
learning_rate¶ (float, optional) – optimizer learning rate. Defaults to 0.2.
weight_decay¶ (float, optional) – optimizer weight decay. Defaults to 1.5e-6.
warmup_epochs¶ (int, optional) – number of epochs for scheduler warmup. Defaults to 10.
max_epochs¶ (int, optional) – maximum number of epochs for scheduler. Defaults to 1000.
base_encoder¶ (Union[str, torch.nn.Module], optional) – base encoder architecture. Defaults to “resnet50”.
encoder_out_dim¶ (int, optional) – base encoder output dimension. Defaults to 2048.
projector_hidden_dim¶ (int, optional) – projector MLP hidden dimension. Defaults to 4096.
projector_out_dim¶ (int, optional) – projector MLP output dimension. Defaults to 256.
initial_tau¶ (float, optional) – initial value of target decay rate used. Defaults to 0.996.
- Model implemented by:
Example:
model = BYOL(num_classes=10) dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) trainer = pl.Trainer() trainer.fit(model, datamodule=dm)
CLI command:
# cifar10 python byol_module.py --gpus 1 # imagenet python byol_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32
- calculate_loss(v_online, v_target)[source]
Calculates similarity loss between the online network prediction of target network projection.
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(x)[source]
Returns the encoded representation of a view.
- on_train_batch_end(outputs, batch, batch_idx)[source]
Add callback to perform exponential moving average weight update on target network.
- Return type
CPC (V2)¶
PyTorch Lightning implementation of Data-Efficient Image Recognition with Contrastive Predictive Coding
Paper authors: (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord).
Model implemented by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCTrainTransformsCIFAR10,
CPCEvalTransformsCIFAR10
)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()
# model
model = CPC_v2()
# fit
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
To finetune:
python cpc_finetuner.py
--ckpt_path path/to/checkpoint.ckpt
--dataset cifar10
--gpus 1
CIFAR-10 and STL-10 baselines¶
CPCv2 does not report baselines on CIFAR-10 and STL-10 datasets. Results in table are reported from the YADIM paper.
Dataset |
test acc |
Encoder |
Optimizer |
Batch |
Epochs |
Hardware |
LR |
---|---|---|---|---|---|---|---|
CIFAR-10 |
84.52 |
Adam |
64 |
1000 (upto 24 hours) |
1 V100 (32GB) |
4e-5 |
|
STL-10 |
78.36 |
Adam |
144 |
1000 (upto 72 hours) |
4 V100 (32GB) |
1e-4 |
|
ImageNet |
54.82 |
Adam |
3072 |
1000 (upto 21 days) |
64 V100 (32GB) |
4e-5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/epoch%3D474.ckpt'
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)
cpc_v2.freeze()
Pre-training:

Fine-tuning:

STL-10 pretrained model:
from pl_bolts.models.self_supervised import CPC_v2
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/epoch%3D624.ckpt'
cpc_v2 = CPC_v2.load_from_checkpoint(weight_path, strict=False)
cpc_v2.freeze()
Pre-training:

Fine-tuning:

CPC (v2) API¶
- class pl_bolts.models.self_supervised.CPC_v2(encoder_name='cpc_encoder', patch_size=8, patch_overlap=4, online_ft=True, task='cpc', num_workers=4, num_classes=10, learning_rate=0.0001, pretrained=None, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
Warning
The feature CPC_v2 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
- Parameters
encoder_name¶ (
str
) – A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoderpatch_overlap¶ (
int
) – How much overlap each patch should haveonline_ft¶ (
bool
) – If True, enables a 1024-unit MLP to fine-tune onlinetask¶ (
str
) – Which self-supervised task to use (‘cpc’, ‘amdim’, etc…)pretrained¶ (
Optional
[str
]) – If true, will use the weights pretrained (using CPC) on Imagenet
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(img)[source]
Same as
torch.nn.Module.forward()
.
- training_step(batch, batch_nb)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx¶ (
int
) – Integer displaying index of this batchoptimizer_idx¶ (
int
) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any
) – Passed in iftruncated_bptt_steps
> 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_step(batch, batch_nb)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch¶ – The output of your
DataLoader
.batch_idx¶ – The index of this batch.
dataloader_idx¶ – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
Moco (v2) API¶
- class pl_bolts.models.self_supervised.Moco_v2(base_encoder='resnet18', emb_dim=128, num_negatives=65536, encoder_momentum=0.999, softmax_temperature=0.07, learning_rate=0.03, momentum=0.9, weight_decay=0.0001, data_dir='./', batch_size=256, use_mlp=False, num_workers=8, *args, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
Warning
The feature Moco_v2 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
PyTorch Lightning implementation of Moco
Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He.
Code adapted from facebookresearch/moco to Lightning by:
Example:
from pl_bolts.models.self_supervised import Moco_v2 model = Moco_v2() trainer = Trainer() trainer.fit(model)
CLI command:
# cifar10 python moco2_module.py --gpus 1 # imagenet python moco2_module.py --gpus 8 --dataset imagenet2012 --data_dir /path/to/imagenet/ --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32
- Parameters
base_encoder¶ (
Union
[str
,Module
]) – torchvision model name or torch.nn.Modulenum_negatives¶ (
int
) – queue size; number of negative keys (default: 65536)encoder_momentum¶ (
float
) – moco momentum of updating key encoder (default: 0.999)softmax_temperature¶ (
float
) – softmax temperature (default: 0.07)datamodule¶ – the DataModule (train, val, test dataloaders)
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(img_q, img_k, queue)[source]
- Input:
im_q: a batch of query images im_k: a batch of key images queue: a queue from which to pick negative samples
- Output:
logits, targets
- init_encoders(base_encoder)[source]
Override to add your own encoders.
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx¶ (
int
) – Integer displaying index of this batchoptimizer_idx¶ (
int
) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any
) – Passed in iftruncated_bptt_steps
> 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_epoch_end(outputs)[source]
Called at the end of the validation epoch with the outputs of all validation steps.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
outputs¶ – List of outputs you defined in
validation_step()
, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.- Returns
None
Note
If you didn’t define a
validation_step()
, this won’t be called.Examples
With a single dataloader:
def validation_epoch_end(self, val_step_outputs): for out in val_step_outputs: ...
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.
def validation_epoch_end(self, outputs): for dataloader_output_result in outputs: dataloader_outs = dataloader_output_result.dataloader_i_outputs self.log("final_metric", final_value)
- validation_step(batch, batch_idx)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch¶ – The output of your
DataLoader
.batch_idx¶ – The index of this batch.
dataloader_idx¶ – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
SimCLR¶
PyTorch Lightning implementation of SimCLR
Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.
Model implemented by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')
# fit
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
CIFAR-10 baseline¶
Implementation |
test acc |
Encoder |
Optimizer |
Batch |
Epochs |
Hardware |
LR |
---|---|---|---|---|---|---|---|
resnet50 |
LARS |
2048 |
800 |
TPUs |
1.0/1.5 |
||
Ours |
88.50 |
LARS |
2048 |
800 (4 hours) |
8 V100 (16GB) |
1.5 |
CIFAR-10 pretrained model:
from pl_bolts.models.self_supervised import SimCLR
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-sgd/simclr-cifar10-sgd.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr.freeze()
Pre-training:

Fine-tuning (Single layer MLP, 1024 hidden units):


To reproduce:
# pretrain
python simclr_module.py
--gpus 8
--dataset cifar10
--batch_size 256
--num_workers 16
--optimizer sgd
--learning_rate 1.5
--exclude_bn_bias
--max_epochs 800
--online_ft
# finetune
python simclr_finetuner.py
--gpus 4
--ckpt_path path/to/simclr/ckpt
--dataset cifar10
--batch_size 64
--num_workers 8
--learning_rate 0.3
--num_epochs 100
Imagenet baseline for SimCLR¶
Implementation |
test acc |
Encoder |
Optimizer |
Batch |
Epochs |
Hardware |
LR |
---|---|---|---|---|---|---|---|
resnet50 |
LARS |
4096 |
800 |
TPUs |
4.8 |
||
Ours |
68.4 |
LARS |
4096 |
800 |
64 V100 (16GB) |
4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SimCLR
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr.freeze()
To reproduce:
# pretrain
python simclr_module.py
--dataset imagenet
--data_path path/to/imagenet
# finetune
python simclr_finetuner.py
--gpus 8
--ckpt_path path/to/simclr/ckpt
--dataset imagenet
--data_dir path/to/imagenet/dataset
--batch_size 256
--num_workers 16
--learning_rate 0.8
--nesterov True
--num_epochs 90
SimCLR API¶
- class pl_bolts.models.self_supervised.SimCLR(gpus, num_samples, batch_size, dataset, num_nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, temperature=0.1, first_conv=True, maxpool1=True, optimizer='adam', exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
Warning
The feature SimCLR is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
- Parameters
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(x)[source]
Same as
torch.nn.Module.forward()
.
- nt_xent_loss(out_1, out_2, temperature, eps=1e-06)[source]
assume out_1 and out_2 are normalized out_1: [batch_size, dim] out_2: [batch_size, dim]
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx¶ (
int
) – Integer displaying index of this batchoptimizer_idx¶ (
int
) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any
) – Passed in iftruncated_bptt_steps
> 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_step(batch, batch_idx)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch¶ – The output of your
DataLoader
.batch_idx¶ – The index of this batch.
dataloader_idx¶ – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
SwAV¶
PyTorch Lightning implementation of SwAV Adapted from the official implementation
Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.
Implementation adapted by:
To Train:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform,
SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization
# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed
dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization()
)
dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization()
)
# model
model = SwAV(
gpus=1,
num_samples=dm.num_unlabeled_samples,
dataset='stl10',
batch_size=batch_size,
num_crops=(2,4)
)
# fit
trainer = pl.Trainer(precision=16, accelerator='auto')
trainer.fit(model, datamodule=dm)
Pre-trained ImageNet¶
We have included an option to directly load ImageNet weights provided by FAIR into bolts.
You can load the pretrained model using:
ImageNet pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=True)
swav.freeze()
STL-10 baseline¶
The original paper does not provide baselines on STL10.
Implementation |
test acc |
Encoder |
Optimizer |
Batch |
Queue used |
Epochs |
Hardware |
LR |
---|---|---|---|---|---|---|---|---|
Ours |
SwAV resnet50 |
LARS |
128 |
No |
100 (~9 hr) |
1 V100 (16GB) |
1e-3 |
STL-10 pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)
swav.freeze()
Pre-training:


Fine-tuning (Single layer MLP, 1024 hidden units):


To reproduce:
# pretrain
python swav_module.py
--online_ft
--gpus 1
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
--queue_length 0
--jitter_strength 1.
--num_prototypes 512
# finetune
python swav_finetuner.py
--gpus 8
--ckpt_path path/to/simclr/ckpt
--dataset imagenet
--data_dir path/to/imagenet/dataset
--batch_size 256
--num_workers 16
--learning_rate 0.8
--nesterov True
--num_epochs 90
Imagenet baseline for SwAV¶
Implementation |
test acc |
Encoder |
Optimizer |
Batch |
Epochs |
Hardware |
LR |
---|---|---|---|---|---|---|---|
Original |
75.3 |
resnet50 |
LARS |
4096 |
800 |
64 V100s |
4.8 |
Ours |
74 |
LARS |
4096 |
800 |
64 V100 (16GB) |
4.8 |
Imagenet pretrained model:
from pl_bolts.models.self_supervised import SwAV
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/bolts_swav_imagenet/swav_imagenet.ckpt'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)
swav.freeze()
SwAV API¶
- class pl_bolts.models.self_supervised.SwAV(gpus, num_samples, batch_size, dataset, num_nodes=1, arch='resnet50', hidden_mlp=2048, feat_dim=128, warmup_epochs=10, max_epochs=100, num_prototypes=3000, freeze_prototypes_epochs=1, temperature=0.1, sinkhorn_iterations=3, queue_length=0, queue_path='queue', epoch_queue_starts=15, crops_for_assign=(0, 1), num_crops=(2, 6), first_conv=True, maxpool1=True, optimizer='adam', exclude_bn_bias=False, start_lr=0.0, learning_rate=0.001, final_lr=0.0, weight_decay=1e-06, epsilon=0.05, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
- Parameters
gpus¶ (
int
) – number of gpus per node used in training, passed to SwAV module to manage the queue and select distributed sinkhornnum_samples¶ (
int
) – number of image samples used for traininghidden_mlp¶ (
int
) – hidden layer of non-linear projection head, set to 0 to use a linear projection headwarmup_epochs¶ (
int
) – apply linear warmup for this many epochsfreeze_prototypes_epochs¶ (
int
) – epoch till which gradients of prototype layer are frozensinkhorn_iterations¶ (
int
) – iterations for sinkhorn normalizationqueue_length¶ (
int
) – set queue when batch size is small, must be divisible by total batch-size (i.e. total_gpus * batch_size), set to 0 to remove the queueepoch_queue_starts¶ (
int
) – start uing the queue after this epochcrops_for_assign¶ (
tuple
) – list of crop ids for computing assignmentnum_crops¶ (
tuple
) – number of global and local crops, ex: [2, 6]first_conv¶ (
bool
) – keep first conv same as the original resnet architecture, if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)maxpool1¶ (
bool
) – keep first maxpool layer same as the original resnet architecture, if set to false, first maxpool is turned off (cifar10, maybe stl10)exclude_bn_bias¶ (
bool
) – exclude batchnorm and bias layers from weight decay in optimizersfinal_lr¶ (
float
) – float = final learning rate for cosine weight decay
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.Tuple of dictionaries as described above, with an optional
"frequency"
key.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated" # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
The
frequency
value specified in a dict along with theoptimizer
key is an int corresponding to the number of sequential batches optimized with the specific optimizer. It should be given to none or to all of the optimizers. There is a difference between passing multiple optimizers in a list, and passing multiple optimizers in dictionaries with a frequency of 1:In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
This is different from the
frequency
value specified in thelr_scheduler_config
mentioned above.def configure_optimizers(self): optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) return [ {"optimizer": optimizer_one, "frequency": 5}, {"optimizer": optimizer_two, "frequency": 10}, ]
In this example, the first optimizer will be used for the first 5 steps, the second optimizer for the next 10 steps and that cycle will continue. If an LR scheduler is specified for an optimizer using the
lr_scheduler
key in the above dict, the scheduler will only be updated when its optimizer is being used.Examples:
# most cases. no learning rate scheduler def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) dis_sch = CosineAnnealing(dis_opt, T_max=10) return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers # each optimizer has its own scheduler def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) gen_sch = { 'scheduler': ExponentialLR(gen_opt, 0.99), 'interval': 'step' # called after each training step } dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, {'optimizer': gen_opt, 'frequency': 1} )
Note
Some things to know:
Lightning calls
.backward()
and.step()
on each optimizer as needed.If learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizers.If you use multiple optimizers,
training_step()
will have an additionaloptimizer_idx
parameter.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer at each training step.
If you need to control how often those optimizers step or override the default
.step()
schedule, override theoptimizer_step()
hook.
- forward(x)[source]
Same as
torch.nn.Module.forward()
.
- on_after_backward()[source]
Called after
loss.backward()
and before optimizers are stepped.Note
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_step
if you need the unscaled gradients.
- on_train_epoch_end()[source]
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule OR
Cache data across steps on the attribute(s) of the LightningModule and access them in this hook
- Return type
- on_train_epoch_start()[source]
Called in the training loop at the very beginning of the epoch.
- setup(stage)[source]
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters
stage¶ – either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch¶ (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx¶ (
int
) – Integer displaying index of this batchoptimizer_idx¶ (
int
) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any
) – Passed in iftruncated_bptt_steps
> 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_step(batch, batch_idx)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
- Parameters
batch¶ – The output of your
DataLoader
.batch_idx¶ – The index of this batch.
dataloader_idx¶ – The index of the dataloader that produced this batch. (only if multiple val dataloaders used)
- Returns
Any object or value
None
- Validation will skip to the next batch
# pseudocode of order val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) val_outs.append(out) val_outs = validation_epoch_end(val_outs)
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
SimSiam¶
- class pl_bolts.models.self_supervised.SimSiam(learning_rate=0.05, weight_decay=0.0001, momentum=0.9, warmup_epochs=10, max_epochs=100, base_encoder='resnet50', encoder_out_dim=2048, projector_hidden_dim=2048, projector_out_dim=2048, predictor_hidden_dim=512, exclude_bn_bias=False, **kwargs)[source]
Bases:
pytorch_lightning.core.module.LightningModule
PyTorch Lightning implementation of Exploring Simple Siamese Representation Learning (SimSiam)_
Paper authors: Xinlei Chen, Kaiming He.
- Parameters
learning_rate¶ (float, optional) – optimizer leaning rate. Defaults to 0.05.
weight_decay¶ (float, optional) – optimizer weight decay. Defaults to 1e-4.
momentum¶ (float, optional) – optimizer momentum. Defaults to 0.9.
warmup_epochs¶ (int, optional) – number of epochs for scheduler warmup. Defaults to 10.
max_epochs¶ (int, optional) – maximum number of epochs for scheduler. Defaults to 100.
base_encoder¶ (Union[str, nn.Module], optional) – base encoder architecture. Defaults to “resnet50”.
encoder_out_dim¶ (int, optional) – base encoder output dimension. Defaults to 2048.
projector_hidden_dim¶ (int, optional) – projector MLP hidden dimension. Defaults to 2048.
projector_out_dim¶ (int, optional) – project MLP output dimension. Defaults to 2048.
predictor_hidden_dim¶ (int, optional) – predictor MLP hidden dimension. Defaults to 512.
exclude_bn_bias¶ (bool, optional) – option to exclude batchnorm and bias terms from weight decay. Defaults to False.
- Model implemented by:
Example:
model = SimSiam() dm = CIFAR10DataModule(num_workers=0) dm.train_transforms = SimCLRTrainDataTransform(32) dm.val_transforms = SimCLREvalDataTransform(32) trainer = Trainer() trainer.fit(model, datamodule=dm)
CLI command:
# cifar10 python simsiam_module.py --gpus 1 # imagenet python simsiam_module.py --gpus 8 --dataset imagenet2012 --meta_dir /path/to/folder/with/meta.bin/ --batch_size 32
- calculate_loss(v_online, v_target)[source]
Calculates similarity loss between the online network prediction of target network projection.
- configure_optimizers()[source]
Configure optimizer and learning rate scheduler.
- static exclude_from_weight_decay(named_params, weight_decay, skip_list=('bias', 'bn'))[source]
Exclude parameters from weight decay.