Tensor Index Slicing

Much like in Numpy arrays, accessing elements in a tensor is achieved through index slicing. The example below demonstrates how to access tensor values at their individual location.

import torch

x = torch.arange(15).reshape((3,5))
x
tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]])

Accessing Individual Elements

To access individual value by position we pass the index location of the row and column indeces respectively. For example, accessing the second row and third column, we pass the indeces 1, 2

x[1, 2]
tensor(7)

Accessing Entire Rows

Similarly, we can access an entire row or a subset of a row using the slicing x[n, : ] where $n$ is the row of interest.

For example, the code below returns the third row of the tensor.

x[2, :]
tensor([10, 11, 12, 13, 14])

Accessing Entire Columns

Conversely, we can access an entire column or a subset of a column using the slicing x[:, n ] where $n$ is the column of interest.

For example, the code below returns the fourth row of the tensor.

x[:, 3]
tensor([ 3, 8, 13])

index_select()

Torch has an index_select() which makes it possible to select elements from a tensors with predefined tensor of indeces.

The code below generates a set of indeces in a tensor. We then select the subset of the tensor x using the indices as arguments to the index_select() function.

indices = torch.tensor([1, 0])
indices
tensor([1, 0])
subset = torch.index_select(x, dim=0, index=indices)
subset
tensor([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]])