Shortcuts

Object Detection

These are common losses used in object detection.

Note

We rely on the community to keep these updated and working. If something doesn’t work, we’d really appreciate a contribution to fix!


GIoU Loss

pl_bolts.losses.object_detection.giou_loss(preds, target)[source]

Warning

The feature giou_loss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Calculates the generalized intersection over union loss.

It has been proposed in Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression.

Parameters
  • preds (Tensor) – an Nx4 batch of prediction bounding boxes with representation [x_min, y_min, x_max, y_max]

  • target (Tensor) – an Mx4 batch of target bounding boxes with representation [x_min, y_min, x_max, y_max]

Example

>>> import torch
>>> from pl_bolts.losses.object_detection import giou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> giou_loss(preds, target)
tensor([[1.0794]])
Return type

Tensor

Returns

GIoU loss in an NxM tensor containing the pairwise GIoU loss for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes


IoU Loss

pl_bolts.losses.object_detection.iou_loss(preds, target)[source]

Warning

The feature iou_loss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Calculates the intersection over union loss.

Parameters
  • preds (Tensor) – batch of prediction bounding boxes with representation [x_min, y_min, x_max, y_max]

  • target (Tensor) – batch of target bounding boxes with representation [x_min, y_min, x_max, y_max]

Example

>>> import torch
>>> from pl_bolts.losses.object_detection import iou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> iou_loss(preds, target)
tensor([[0.8571]])
Return type

Tensor

Returns

IoU loss


Reinforcement Learning

These are common losses used in RL.


DQN Loss

pl_bolts.losses.rl.dqn_loss(batch, net, target_net, gamma=0.99)[source]

Warning

The feature dqn_loss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Calculates the mse loss using a mini batch from the replay buffer.

Parameters
  • batch (Tuple[Tensor, Tensor]) – current mini batch of replay data

  • net (Module) – main training network

  • target_net (Module) – target network of the main training network

  • gamma (float) – discount factor

Return type

Tensor

Returns

loss

Double DQN Loss

pl_bolts.losses.rl.double_dqn_loss(batch, net, target_net, gamma=0.99)[source]

Warning

The feature double_dqn_loss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the value from the target network. This code is heavily commented in order to explain the process clearly.

Parameters
  • batch (Tuple[Tensor, Tensor]) – current mini batch of replay data

  • net (Module) – main training network

  • target_net (Module) – target network of the main training network

  • gamma (float) – discount factor

Return type

Tensor

Returns

loss

Per DQN Loss

pl_bolts.losses.rl.per_dqn_loss(batch, batch_weights, net, target_net, gamma=0.99)[source]

Warning

The feature per_dqn_loss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Calculates the mse loss with the priority weights of the batch from the PER buffer.

Parameters
  • batch (Tuple[Tensor, Tensor]) – current mini batch of replay data

  • batch_weights (List) – how each of these samples are weighted in terms of priority

  • net (Module) – main training network

  • target_net (Module) – target network of the main training network

  • gamma (float) – discount factor

Return type

Tuple[Tensor, ndarray]

Returns

loss and batch_weights

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.