Sklearn Datamodule¶
Utilities to map sklearn or numpy datasets to PyTorch Dataloaders with automatic data splits and GPU/TPU support.
from sklearn.datasets import load_diabetes
from pl_bolts.datamodules import SklearnDataModule
X, y = load_diabetes(return_X_y=True)
loaders = SklearnDataModule(X, y)
train_loader = loaders.train_dataloader(batch_size=32)
val_loader = loaders.val_dataloader(batch_size=32)
test_loader = loaders.test_dataloader(batch_size=32)
Or build your own torch datasets
from sklearn.datasets import load_diabetes
from pl_bolts.datamodules import SklearnDataset
X, y = load_diabetes(return_X_y=True)
dataset = SklearnDataset(X, y)
loader = DataLoader(dataset)
Sklearn Dataset Class¶
Transforms a sklearn or numpy dataset to a PyTorch Dataset.
-
class
pl_bolts.datamodules.sklearn_datamodule.
SklearnDataset
(X, y, X_transform=None, y_transform=None)[source] Bases:
torch.utils.data.
Mapping between numpy (or sklearn) datasets to PyTorch datasets.
Example
>>> from sklearn.datasets import load_diabetes >>> from pl_bolts.datamodules import SklearnDataset ... >>> X, y = load_diabetes(return_X_y=True) >>> dataset = SklearnDataset(X, y) >>> len(dataset) 442
Sklearn DataModule Class¶
Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits.
-
class
pl_bolts.datamodules.sklearn_datamodule.
SklearnDataModule
(X, y, x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, num_workers=0, random_state=1234, shuffle=True, batch_size=16, pin_memory=True, drop_last=False, *args, **kwargs)[source]¶ Bases:
pytorch_lightning.
Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits.
Example
>>> from sklearn.datasets import load_diabetes >>> from pl_bolts.datamodules import SklearnDataModule ... >>> X, y = load_diabetes(return_X_y=True) >>> loaders = SklearnDataModule(X, y, batch_size=32) ... >>> # train set >>> train_loader = loaders.train_dataloader() >>> len(train_loader.dataset) 310 >>> len(train_loader) 10 >>> # validation set >>> val_loader = loaders.val_dataloader() >>> len(val_loader.dataset) 88 >>> len(val_loader) 3 >>> # test set >>> test_loader = loaders.test_dataloader() >>> len(test_loader.dataset) 44 >>> len(test_loader) 2