Dataset Class
Pytorch offers the Dataset class as the primary data storage and access module for both pre-loaded and external. datasets. To set a data class, we simply create a class that inherits from Dataset and define two methods.
- 1. __getitem__: To retrieve the data by index
- 2. __len__: To retrieve the size of our data.
Here is an example of a simple LinearData class.
import torch
from torch.utils.data import Dataset, DataLoader, random_split
class LinearData(Dataset):
def __init__(self, x_input, y_output):
self.x_input = x_input
self.y_output = y_output
def __getitem__(self, index):
if index >= len(self.x_input):
return "Value Does not exist"
return (self.x_input[index], self.y_output[index])
def __len__(self):
return len(self.x_input)
# setting the true parameters
true_alpha, true_beta = 1.5, 2.0
# set seed for reproducibility
torch.manual_seed(420)
# generate random x and epsilon values
x = torch.randn(200, dtype=torch.float)
epsilon = .1 * torch.randn(200, dtype=torch.float)
y = true_alpha + true_beta * x + epsilon
# initializing data
data = LinearData(x_input=x, y_output=y)
# Getting the size of the data and values of index 5
len(data), data[5]