Shortcuts

Self-supervised learning

Collection of useful functions for self-supervised learning


Identity class

Example:

from pl_bolts.utils import Identity
class pl_bolts.utils.self_supervised.Identity[source]

Bases: torch.nn.

An identity class to replace arbitrary layers in pretrained models.

Example:

from pl_bolts.utils import Identity

model = resnet18()
model.fc = Identity()

SSL-ready resnets

Torchvision resnets with the fc layers removed and with the ability to return all feature maps instead of just the last one.

Example:

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder

resnet = torchvision_ssl_encoder('resnet18', pretrained=False, return_all_feature_maps=True)
x = torch.rand(3, 3, 32, 32)

feat_maps = resnet(x)
pl_bolts.utils.self_supervised.torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False)[source]
Return type

Module


SSL backbone finetuner

class pl_bolts.models.self_supervised.ssl_finetuner.SSLFineTuner(backbone, in_features=2048, num_classes=1000, epochs=100, hidden_dim=None, dropout=0.0, learning_rate=0.1, weight_decay=1e-06, nesterov=False, scheduler_type='cosine', decay_epochs=[60, 80], gamma=0.1, final_lr=0.0)[source]

Bases: pytorch_lightning.

Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP with 1024 units.

Example:

from pl_bolts.utils.self_supervised import SSLFineTuner
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10,
                                                            CPCTrainTransformsCIFAR10

# pretrained model
backbone = CPC_v2.load_from_checkpoint(PATH, strict=False)

# dataset + transforms
dm = CIFAR10DataModule(data_dir='.')
dm.train_transforms = CPCTrainTransformsCIFAR10()
dm.val_transforms = CPCEvalTransformsCIFAR10()

# finetuner
finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)

# train
trainer = pl.Trainer()
trainer.fit(finetuner, dm)

# test
trainer.test(datamodule=dm)
Parameters
  • backbone (Module) – a pretrained model

  • in_features (int) – feature dim of backbone outputs

  • num_classes (int) – classes of the dataset

  • hidden_dim (Optional[int]) – dim of the MLP (1024 default used in self-supervised literature)

Read the Docs v: 0.4.0
Versions
latest
stable
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.