This section implements popular contrastive learning tasks used in self-supervised learning.
We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!
This task compares sets of feature maps.
In general the feature map comparison pretext task uses triplets of features. Here are the abstract steps of comparison.
Generate multiple views of the same image
x1_view_1 = data_augmentation(x1) x1_view_2 = data_augmentation(x1)
Use a different example to generate additional views (usually within the same batch or a pool of candidates)
x2_view_1 = data_augmentation(x2) x2_view_2 = data_augmentation(x2)
Pick 3 views to compare, these are the anchor, positive and negative features
anchor = x1_view_1 positive = x1_view_2 negative = x2_view_1
Generate feature maps for each view
(a0, a1, a2) = encoder(anchor) (p0, p1, p2) = encoder(positive)
Make a comparison for a set of feature maps
phi = some_score_function() # the '01' comparison score = phi(a0, p1) # and can be bidirectional score = phi(p0, a1)
In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the negatives.
- class pl_bolts.losses.self_supervised_learning.FeatureMapContrastiveTask(comparisons='00, 11', tclip=10.0, bidirectional=True)
The feature FeatureMapContrastiveTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed.
# extract feature maps pos_0, pos_1, pos_2 = encoder(x_pos) anc_0, anc_1, anc_2 = encoder(x_anchor) # compare only the 0th feature maps task = FeatureMapContrastiveTask('00') loss, regularizer = task((pos_0), (anc_0)) # compare (pos_0 to anc_1) and (pos_0, anc_2) task = FeatureMapContrastiveTask('01, 02') losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) loss = losses.sum() # compare (pos_1 vs a anc_random) task = FeatureMapContrastiveTask('0r') loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
# with bidirectional the comparisons are done both ways task = FeatureMapContrastiveTask('01, 02') # will compare the following: # 01: (pos_0, anc_1), (anc_0, pos_1) # 02: (pos_0, anc_2), (anc_0, pos_2)
- forward(anchor_maps, positive_maps)
Takes in a set of tuples, each tuple has two feature maps with all matching dimensions.
>>> import torch >>> from pytorch_lightning import seed_everything >>> seed_everything(0) 0 >>> a1 = torch.rand(3, 5, 2, 2) >>> a2 = torch.rand(3, 5, 2, 2) >>> b1 = torch.rand(3, 5, 2, 2) >>> b2 = torch.rand(3, 5, 2, 2) ... >>> task = FeatureMapContrastiveTask('01, 11') ... >>> losses, regularizer = task((a1, a2), (b1, b2)) >>> losses tensor([2.2351, 2.1902]) >>> regularizer tensor(0.0324)
Context prediction tasks¶
The following tasks aim to predict a target using a context representation.
This is the predictive task from CPC (v2).
task = CPCTask(num_input_channels=32) # (batch, channels, rows, cols) # this should be thought of as 49 feature vectors, each with 32 dims Z = torch.random.rand(3, 32, 7, 7) loss = task(Z)
- class pl_bolts.losses.self_supervised_learning.CPCTask(num_input_channels, target_dim=64, embed_scale=0.1)
The feature CPCTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
Loss used in CPC.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.