Self-supervised Learning Contrastive tasks¶
This section implements popular contrastive learning tasks used in self-supervised learning.
This task compares sets of feature maps.
In general the feature map comparison pretext task uses triplets of features. Here are the abstract steps of comparison.
Generate multiple views of the same image
x1_view_1 = data_augmentation(x1) x1_view_2 = data_augmentation(x1)
Use a different example to generate additional views (usually within the same batch or a pool of candidates)
x2_view_1 = data_augmentation(x2) x2_view_2 = data_augmentation(x2)
Pick 3 views to compare, these are the anchor, positive and negative features
anchor = x1_view_1 positive = x1_view_2 negative = x2_view_1
Generate feature maps for each view
(a0, a1, a2) = encoder(anchor) (p0, p1, p2) = encoder(positive)
Make a comparison for a set of feature maps
phi = some_score_function() # the '01' comparison score = phi(a0, p1) # and can be bidirectional score = phi(p0, a1)
In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the negatives.
FeatureMapContrastiveTask(comparisons='00, 11', tclip=10.0, bidirectional=True)
Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed.
# extract feature maps pos_0, pos_1, pos_2 = encoder(x_pos) anc_0, anc_1, anc_2 = encoder(x_anchor) # compare only the 0th feature maps task = FeatureMapContrastiveTask('00') loss, regularizer = task((pos_0), (anc_0)) # compare (pos_0 to anc_1) and (pos_0, anc_2) task = FeatureMapContrastiveTask('01, 02') losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) loss = losses.sum() # compare (pos_1 vs a anc_random) task = FeatureMapContrastiveTask('0r') loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
# with bidirectional the comparisons are done both ways task = FeatureMapContrastiveTask('01, 02') # will compare the following: # 01: (pos_0, anc_1), (anc_0, pos_1) # 02: (pos_0, anc_2), (anc_0, pos_2)
Takes in a set of tuples, each tuple has two feature maps with all matching dimensions.
>>> import torch >>> from pytorch_lightning import seed_everything >>> seed_everything(0) 0 >>> a1 = torch.rand(3, 5, 2, 2) >>> a2 = torch.rand(3, 5, 2, 2) >>> b1 = torch.rand(3, 5, 2, 2) >>> b2 = torch.rand(3, 5, 2, 2) ... >>> task = FeatureMapContrastiveTask('01, 11') ... >>> losses, regularizer = task((a1, a2), (b1, b2)) >>> losses tensor([2.2351, 2.1902]) >>> regularizer tensor(0.0324)
Context prediction tasks¶
The following tasks aim to predict a target using a context representation.
This is the predictive task from CPC (v2).
task = CPCTask(num_input_channels=32) # (batch, channels, rows, cols) # this should be thought of as 49 feature vectors, each with 32 dims Z = torch.random.rand(3, 32, 7, 7) loss = task(Z)