Shortcuts

Self-supervised learning

These transforms are used in various self-supervised learning approaches.

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


CPC transforms

Transforms used for CPC

CIFAR-10 Train (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsCIFAR10(patch_size=8, overlap=4)[source]

Bases: object

Warning

The feature CPCTrainTransformsCIFAR10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
img_jitter
col_jitter
rnd_gray
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=patch_size // 2)

Example:

# in a regular dataset
CIFAR10(..., transforms=CPCTrainTransformsCIFAR10())

# in a DataModule
module = CIFAR10DataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsCIFAR10())
Parameters
  • patch_size – size of patches when cutting up the image into overlapping patches

  • overlap – how much to overlap patches

__call__(inp)[source]

Call self as a function.

CIFAR-10 Eval (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsCIFAR10(patch_size=8, overlap=4)[source]

Bases: object

Warning

The feature CPCEvalTransformsCIFAR10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=overlap)

Example:

# in a regular dataset
CIFAR10(..., transforms=CPCEvalTransformsCIFAR10())

# in a DataModule
module = CIFAR10DataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsCIFAR10())
Parameters
  • patch_size (int) – size of patches when cutting up the image into overlapping patches

  • overlap (int) – how much to overlap patches

__call__(inp)[source]

Call self as a function.

Imagenet Train (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsImageNet128(patch_size=32, overlap=16)[source]

Bases: object

Warning

The feature CPCTrainTransformsImageNet128 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=patch_size // 2)

Example:

# in a regular dataset
Imagenet(..., transforms=CPCTrainTransformsImageNet128())

# in a DataModule
module = ImagenetDataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsImageNet128())
Parameters
  • patch_size (int) – size of patches when cutting up the image into overlapping patches

  • overlap (int) – how much to overlap patches

__call__(inp)[source]

Call self as a function.

Imagenet Eval (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsImageNet128(patch_size=32, overlap=16)[source]

Bases: object

Warning

The feature CPCEvalTransformsImageNet128 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=patch_size // 2)

Example:

# in a regular dataset
Imagenet(..., transforms=CPCEvalTransformsImageNet128())

# in a DataModule
module = ImagenetDataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsImageNet128())
Parameters
  • patch_size (int) – size of patches when cutting up the image into overlapping patches

  • overlap (int) – how much to overlap patches

__call__(inp)[source]

Call self as a function.

STL-10 Train (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsSTL10(patch_size=16, overlap=8)[source]

Bases: object

Warning

The feature CPCTrainTransformsSTL10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
img_jitter
col_jitter
rnd_gray
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=patch_size // 2)

Example:

# in a regular dataset
STL10(..., transforms=CPCTrainTransformsSTL10())

# in a DataModule
module = STL10DataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsSTL10())
Parameters
  • patch_size (int) – size of patches when cutting up the image into overlapping patches

  • overlap (int) – how much to overlap patches

__call__(inp)[source]

Call self as a function.

STL-10 Eval (c)

class pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsSTL10(patch_size=16, overlap=8)[source]

Bases: object

Warning

The feature CPCEvalTransformsSTL10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms used for CPC:

Transforms:

random_flip
transforms.ToTensor()
normalize
Patchify(patch_size=patch_size, overlap_size=patch_size // 2)

Example:

# in a regular dataset
STL10(..., transforms=CPCEvalTransformsSTL10())

# in a DataModule
module = STL10DataModule(PATH)
train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsSTL10())
Parameters
  • patch_size (int) – size of patches when cutting up the image into overlapping patches

  • overlap (int) – how much to overlap patches

__call__(inp)[source]

Call self as a function.

AMDIM transforms

Transforms used for AMDIM

CIFAR-10 Train (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsCIFAR10[source]

Bases: object

Warning

The feature AMDIMTrainTransformsCIFAR10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

img_jitter,
col_jitter,
rnd_gray,
transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 32, 32)

transform = AMDIMTrainTransformsCIFAR10()
(view1, view2) = transform(x)
__call__(inp)[source]

Call self as a function.

CIFAR-10 Eval (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsCIFAR10[source]

Bases: object

Warning

The feature AMDIMEvalTransformsCIFAR10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 32, 32)

transform = AMDIMEvalTransformsCIFAR10()
(view1, view2) = transform(x)
__call__(inp)[source]

Call self as a function.

Imagenet Train (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsImageNet128(height=128)[source]

Bases: object

Warning

The feature AMDIMTrainTransformsImageNet128 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

img_jitter,
col_jitter,
rnd_gray,
transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 128, 128)

transform = AMDIMTrainTransformsSTL10()
(view1, view2) = transform(x)
__call__(inp)[source]

Call self as a function.

Imagenet Eval (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsImageNet128(height=128)[source]

Bases: object

Warning

The feature AMDIMEvalTransformsImageNet128 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

transforms.Resize(height + 6, interpolation=3),
transforms.CenterCrop(height),
transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 128, 128)

transform = AMDIMEvalTransformsImageNet128()
view1 = transform(x)
__call__(inp)[source]

Call self as a function.

STL-10 Train (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsSTL10(height=64)[source]

Bases: object

Warning

The feature AMDIMTrainTransformsSTL10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

img_jitter,
col_jitter,
rnd_gray,
transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 64, 64)

transform = AMDIMTrainTransformsSTL10()
(view1, view2) = transform(x)
__call__(inp)[source]

Call self as a function.

STL-10 Eval (a)

class pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsSTL10(height=64)[source]

Bases: object

Warning

The feature AMDIMEvalTransformsSTL10 is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Transforms applied to AMDIM.

Transforms:

transforms.Resize(height + 6, interpolation=3),
transforms.CenterCrop(height),
transforms.ToTensor(),
normalize

Example:

x = torch.rand(5, 3, 64, 64)

transform = AMDIMTrainTransformsSTL10()
view1 = transform(x)
__call__(inp)[source]

Call self as a function.

MOCO V2 transforms

Transforms used for MOCO V2

CIFAR-10 Train (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2TrainCIFAR10Transforms(height=32)[source]

Bases: object

Warning

The feature Moco2TrainCIFAR10Transforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

CIFAR-10 Eval (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2EvalCIFAR10Transforms(height=32)[source]

Bases: object

Warning

The feature Moco2EvalCIFAR10Transforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

Imagenet Train (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2TrainSTL10Transforms(height=64)[source]

Bases: object

Warning

The feature Moco2TrainSTL10Transforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

Imagenet Eval (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2EvalSTL10Transforms(height=64)[source]

Bases: object

Warning

The feature Moco2EvalSTL10Transforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

STL-10 Train (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2TrainImagenetTransforms(height=128)[source]

Bases: object

Warning

The feature Moco2TrainImagenetTransforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

STL-10 Eval (m2)

class pl_bolts.models.self_supervised.moco.transforms.Moco2EvalImagenetTransforms(height=128)[source]

Bases: object

Warning

The feature Moco2EvalImagenetTransforms is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Moco 2 augmentation:

https://arxiv.org/pdf/2003.04297.pdf

__call__(inp)[source]

Call self as a function.

SimCLR transforms

Transforms used for SimCLR

Train (sc)

class pl_bolts.models.self_supervised.simclr.transforms.SimCLRTrainDataTransform(input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None)[source]

Bases: object

Transforms for SimCLR during training step of the pre-training stage.

Transform:

RandomResizedCrop(size=self.input_height)
RandomHorizontalFlip()
RandomApply([color_jitter], p=0.8)
RandomGrayscale(p=0.2)
RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5)
transforms.ToTensor()

Example:

from pl_bolts.models.self_supervised.simclr.transforms import SimCLRTrainDataTransform

transform = SimCLRTrainDataTransform(input_height=32)
x = sample()
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
__call__(sample)[source]

Call self as a function.

Eval (sc)

class pl_bolts.models.self_supervised.simclr.transforms.SimCLREvalDataTransform(input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None)[source]

Bases: pl_bolts.models.self_supervised.simclr.transforms.SimCLRTrainDataTransform

Transforms for SimCLR during the validation step of the pre-training stage.

Transform:

Resize(input_height + 10, interpolation=3)
transforms.CenterCrop(input_height),
transforms.ToTensor()

Example:

from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform

transform = SimCLREvalDataTransform(input_height=32)
x = sample()
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used

Identity class

Example:

from pl_bolts.utils import Identity
class pl_bolts.utils.self_supervised.Identity[source]

Bases: torch.nn.

Warning

The feature Identity is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

An identity class to replace arbitrary layers in pretrained models.

Example:

from pl_bolts.utils import Identity

model = resnet18()
model.fc = Identity()

SSL-ready resnets

Torchvision resnets with the fc layers removed and with the ability to return all feature maps instead of just the last one.

Example:

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder

resnet = torchvision_ssl_encoder('resnet18', pretrained=False, return_all_feature_maps=True)
x = torch.rand(3, 3, 32, 32)

feat_maps = resnet(x)
pl_bolts.utils.self_supervised.torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False)[source]

Warning

The feature torchvision_ssl_encoder is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Return type

Module


SSL backbone finetuner

class pl_bolts.models.self_supervised.ssl_finetuner.SSLFineTuner(backbone, in_features=2048, num_classes=1000, epochs=100, hidden_dim=None, dropout=0.0, learning_rate=0.1, weight_decay=1e-06, nesterov=False, scheduler_type='cosine', decay_epochs=(60, 80), gamma=0.1, final_lr=0.0)[source]

Bases: pytorch_lightning.

Warning

The feature SSLFineTuner is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units.

Example:

from pl_bolts.utils.self_supervised import SSLFineTuner
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                            CPCTrainTransformsCIFAR10

# pretrained model
backbone = CPC_v2.load_from_checkpoint(PATH, strict=False)

# dataset + transforms
dm = CIFAR10DataModule(data_dir='.')
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()

# finetuner
finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

# train
trainer = pl.Trainer()
trainer.fit(finetuner, dm)

# test
trainer.test(datamodule=dm)
Parameters
  • backbone (Module) – a pretrained model

  • in_features (int) – feature dim of backbone outputs

  • num_classes (int) – classes of the dataset

  • hidden_dim (Optional[int]) – dim of the MLP (1024 default used in self-supervised literature)

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.