Dataset Class

Pytorch offers the Dataset class as the primary data storage and access module for both pre-loaded and external. datasets. To set a data class, we simply create a class that inherits from Dataset and define two methods.

  1. 1. __getitem__: To retrieve the data by index
  2. 2. __len__: To retrieve the size of our data.

Here is an example of a simple LinearData class.

import torch
from torch.utils.data import Dataset, DataLoader, random_split

class LinearData(Dataset):
    
    def __init__(self, x_input, y_output):
        self.x_input = x_input
        self.y_output = y_output
        
    
    def __getitem__(self, index):
        if index >= len(self.x_input):
            return "Value Does not exist"
        return (self.x_input[index], self.y_output[index])
    
    def __len__(self):
        return len(self.x_input)
# setting the true parameters
true_alpha, true_beta = 1.5, 2.0

# set seed for reproducibility 
torch.manual_seed(420)

# generate random x and epsilon values
x = torch.randn(200, dtype=torch.float)
epsilon = .1 * torch.randn(200, dtype=torch.float)
y = true_alpha + true_beta * x + epsilon


# initializing data
data = LinearData(x_input=x, y_output=y)

# Getting the size of the data and values of index 5
len(data), data[5]
(200, (tensor(0.3826), tensor(2.3138)))