Understanding arange, unsqueeze, repeat, stack methods in Pytorch

Posted by yaohong on Friday, July 30, 2021

TOC

Understanding arange, unsqueeze, repeat, stack methods in Pytorch

  • torch.arange(start=0, end, step=1) return 1-D tensor of size (end-start)/step which value begin from start and each value take with common differences step.

  • torch.unsqueeze(input, dim) return a new tensor with a dimension of size one insterted at specified position; A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used.

  • tensor.repeat(size*) return a tensor; the new shape of tensor is that original shape multiplied by arguments correspondingly, if the number of paramter don’t match the original shape, then last dimension of new shape = the last dimension of original shape * last paramter;

  • torch.stack(tuple of Tensors, dim) concatenate a sequences of tensor along a new dimension.

import torch
output_size = 3

y0 = torch.arange(0, output_size)
print("y0.dim: ",y0.dim());
y1 = y0.unsqueeze(1)
y2 = y1.repeat(1, output_size) # like numpy.tile();
print("y0: ",y0,"\n\ty1: ",y1,y1.shape, "\n\ty2: ",y2, y2.shape);
# Output: y0:  tensor([0, 1, 2, 3]) 
#     y0:  tensor([0, 1, 2]) 
#     y1:  tensor([[0],
#         [1],
#         [2]]) torch.Size([3, 1]) 
#     y2:  tensor([[0, 0, 0],
#         [1, 1, 1],
#         [2, 2, 2]]) torch.Size([3, 3])


x0 = torch.arange(0, output_size)
print("x0.dim: ",x0.dim());
x1 = x0.unsqueeze(0)
x2 = x1.repeat(output_size, 1)
print("x0: ",x0,"\n\tx1: ",x1,x1.shape, "\n\tx2: ",x2,x2.shape);
# Output: 
#     x0:  tensor([0, 1, 2]) 
#     x1:  tensor([[0, 1, 2]]) torch.Size([1, 3]) 
#     x2:  tensor([[0, 1, 2],
#         [0, 1, 2],
#         [0, 1, 2]]) torch.Size([3, 3])


grid_xy = torch.stack([x2, y2], dim=2)
print("grid_xy: ", grid_xy, grid_xy.shape);
# output:
# grid_xy:  tensor([[[0, 0],
#          [1, 0],
#          [2, 0]],

#         [[0, 1],
#          [1, 1],
#          [2, 1]],

#         [[0, 2],
#          [1, 2],
#          [2, 2]]]) torch.Size([3, 3, 2])

grid_xy = grid_xy.unsqueeze(0).unsqueeze(3)
print("grid_xy: ", grid_xy.shape);
grid_xy = grid_xy.repeat(2, 1, 1, 3, 1).float();
print("grid_xy \t\trepeat:(2, 1, 1, 3, 1) ");
print("grid_xy: ", grid_xy.shape);
# output:
# grid_xy:  torch.Size([1, 3, 3, 1, 2])
# grid_xy         repeat:(2, 1, 1, 3, 1) 
# grid_xy:  torch.Size([2, 3, 3, 3, 2])

「点个赞」

Yaohong

点个赞

使用微信扫描二维码完成支付