Shortcuts

Sklearn Datamodule

Utilities to map sklearn or numpy datasets to PyTorch Dataloaders with automatic data splits and GPU/TPU support.

from sklearn.datasets import load_diabetes
from pl_bolts.datamodules import SklearnDataModule

X, y = load_diabetes(return_X_y=True)
loaders = SklearnDataModule(X, y)

train_loader = loaders.train_dataloader(batch_size=32)
val_loader = loaders.val_dataloader(batch_size=32)
test_loader = loaders.test_dataloader(batch_size=32)

Or build your own torch datasets

from sklearn.datasets import load_diabetes
from pl_bolts.datamodules import SklearnDataset

X, y = load_diabetes(return_X_y=True)
dataset = SklearnDataset(X, y)
loader = DataLoader(dataset)

Sklearn Dataset Class

Transforms a sklearn or numpy dataset to a PyTorch Dataset.

class pl_bolts.datamodules.sklearn_datamodule.SklearnDataset(X, y, X_transform=None, y_transform=None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Warning

The feature SklearnDataset is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Mapping between numpy (or sklearn) datasets to PyTorch datasets.

Example

>>> from sklearn.datasets import load_diabetes
>>> from pl_bolts.datamodules import SklearnDataset
...
>>> X, y = load_diabetes(return_X_y=True)
>>> dataset = SklearnDataset(X, y)
>>> len(dataset)
442
Parameters

Sklearn DataModule Class

Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits.

class pl_bolts.datamodules.sklearn_datamodule.SklearnDataModule(X, y, x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, num_workers=0, random_state=1234, shuffle=True, batch_size=16, pin_memory=True, drop_last=False, *args, **kwargs)[source]

Bases: pytorch_lightning.core.datamodule.LightningDataModule

Warning

The feature SklearnDataModule is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html

Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits.

Example

>>> from sklearn.datasets import load_diabetes
>>> from pl_bolts.datamodules import SklearnDataModule
...
>>> X, y = load_diabetes(return_X_y=True)
>>> loaders = SklearnDataModule(X, y, batch_size=32)
...
>>> # train set
>>> train_loader = loaders.train_dataloader()
>>> len(train_loader.dataset)
310
>>> len(train_loader)
10
>>> # validation set
>>> val_loader = loaders.val_dataloader()
>>> len(val_loader.dataset)
88
>>> len(val_loader)
3
>>> # test set
>>> test_loader = loaders.test_dataloader()
>>> len(test_loader.dataset)
44
>>> len(test_loader)
2
prepare_data_per_node[source]

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices[source]

If True, dataloader with zero length within local rank is allowed. Default value is False.

test_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type

DataLoader

Returns

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

train_dataloader()[source]

Implement one or more PyTorch DataLoaders for training.

Return type

DataLoader

Returns

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
val_dataloader()[source]

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type

DataLoader

Returns

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

Read the Docs v: 0.6.0.post1
Versions
latest
stable
0.6.0.post1
0.6.0
0.5.0
0.4.0
0.3.4
0.3.3
0.3.2
0.3.1
0.3.0
0.2.5
0.2.4
0.2.3
0.2.2
0.2.1
0.2.0
0.1.1
docs-build-rtd
0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.