Shortcuts

Self-supervised Learning

This section implements popular contrastive learning tasks used in self-supervised learning.

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


FeatureMapContrastiveTask

This task compares sets of feature maps.

In general the feature map comparison pretext task uses triplets of features. Here are the abstract steps of comparison.

Generate multiple views of the same image

x1_view_1 = data_augmentation(x1)
x1_view_2 = data_augmentation(x1)

Use a different example to generate additional views (usually within the same batch or a pool of candidates)

x2_view_1 = data_augmentation(x2)
x2_view_2 = data_augmentation(x2)

Pick 3 views to compare, these are the anchor, positive and negative features

anchor = x1_view_1
positive = x1_view_2
negative = x2_view_1

Generate feature maps for each view

(a0, a1, a2) = encoder(anchor)
(p0, p1, p2) = encoder(positive)

Make a comparison for a set of feature maps

phi = some_score_function()

# the '01' comparison
score = phi(a0, p1)

# and can be bidirectional
score = phi(p0, a1)

In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the negatives.

class pl_bolts.losses.self_supervised_learning.FeatureMapContrastiveTask(comparisons='00, 11', tclip=10.0, bidirectional=True)[source]

Bases: torch.nn.modules.module.Module

Warning

The feature FeatureMapContrastiveTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed.

# extract feature maps
pos_0, pos_1, pos_2 = encoder(x_pos)
anc_0, anc_1, anc_2 = encoder(x_anchor)

# compare only the 0th feature maps
task = FeatureMapContrastiveTask('00')
loss, regularizer = task((pos_0), (anc_0))

# compare (pos_0 to anc_1) and (pos_0, anc_2)
task = FeatureMapContrastiveTask('01, 02')
losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
loss = losses.sum()

# compare (pos_1 vs a anc_random)
task = FeatureMapContrastiveTask('0r')
loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2))
# with bidirectional the comparisons are done both ways
task = FeatureMapContrastiveTask('01, 02')

# will compare the following:
# 01: (pos_0, anc_1), (anc_0, pos_1)
# 02: (pos_0, anc_2), (anc_0, pos_2)
Parameters
  • comparisons (str) – groupings of feature map indices to compare (zero indexed, ‘r’ means random) ex: ‘00, 1r’

  • tclip (float) – stability clipping value

  • bidirectional (bool) – if true, does the comparison both ways

forward(anchor_maps, positive_maps)[source]

Takes in a set of tuples, each tuple has two feature maps with all matching dimensions.

Example

>>> import torch
>>> from pytorch_lightning import seed_everything
>>> seed_everything(0)
0
>>> a1 = torch.rand(3, 5, 2, 2)
>>> a2 = torch.rand(3, 5, 2, 2)
>>> b1 = torch.rand(3, 5, 2, 2)
>>> b2 = torch.rand(3, 5, 2, 2)
...
>>> task = FeatureMapContrastiveTask('01, 11')
...
>>> losses, regularizer = task((a1, a2), (b1, b2))
>>> losses
tensor([2.2351, 2.1902])
>>> regularizer
tensor(0.0324)
static parse_map_indexes(comparisons)[source]

Example:

>>> FeatureMapContrastiveTask.parse_map_indexes('11')
[(1, 1)]
>>> FeatureMapContrastiveTask.parse_map_indexes('11,59')
[(1, 1), (5, 9)]
>>> FeatureMapContrastiveTask.parse_map_indexes('11,59, 2r')
[(1, 1), (5, 9), (2, -1)]

Context prediction tasks

The following tasks aim to predict a target using a context representation.

CPCContrastiveTask

This is the predictive task from CPC (v2).

task = CPCTask(num_input_channels=32)

# (batch, channels, rows, cols)
# this should be thought of as 49 feature vectors, each with 32 dims
Z = torch.random.rand(3, 32, 7, 7)

loss = task(Z)
class pl_bolts.losses.self_supervised_learning.CPCTask(num_input_channels, target_dim=64, embed_scale=0.1)[source]

Bases: torch.nn.modules.module.Module

Warning

The feature CPCTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Loss used in CPC.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(z)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Read the Docs v: stable
Versions
latest
stable
0.7.0
0.6.0.post1
0.6.0
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.