TOC
Understanding arange, unsqueeze, repeat, stack methods in Pytorch
-
torch.arange(start=0, end, step=1)
return 1-D tensor of size(end-start)/step
which value begin from start and each value take with common differencesstep
. -
torch.unsqueeze(input, dim)
return a new tensor with a dimension of size one insterted at specified position; A dim value within the range[-input.dim() - 1, input.dim() + 1)
can be used. -
tensor.repeat(size*)
return a tensor; the new shape of tensor is that original shape multiplied byarguments
correspondingly, if the number of paramter don’t match the original shape, thenlast dimension of new shape = the last dimension of original shape * last paramter
; -
torch.stack(tuple of Tensors, dim)
concatenate a sequences of tensor along a new dimension.
import torch
output_size = 3
y0 = torch.arange(0, output_size)
print("y0.dim: ",y0.dim());
y1 = y0.unsqueeze(1)
y2 = y1.repeat(1, output_size) # like numpy.tile();
print("y0: ",y0,"\n\ty1: ",y1,y1.shape, "\n\ty2: ",y2, y2.shape);
# Output: y0: tensor([0, 1, 2, 3])
# y0: tensor([0, 1, 2])
# y1: tensor([[0],
# [1],
# [2]]) torch.Size([3, 1])
# y2: tensor([[0, 0, 0],
# [1, 1, 1],
# [2, 2, 2]]) torch.Size([3, 3])
x0 = torch.arange(0, output_size)
print("x0.dim: ",x0.dim());
x1 = x0.unsqueeze(0)
x2 = x1.repeat(output_size, 1)
print("x0: ",x0,"\n\tx1: ",x1,x1.shape, "\n\tx2: ",x2,x2.shape);
# Output:
# x0: tensor([0, 1, 2])
# x1: tensor([[0, 1, 2]]) torch.Size([1, 3])
# x2: tensor([[0, 1, 2],
# [0, 1, 2],
# [0, 1, 2]]) torch.Size([3, 3])
grid_xy = torch.stack([x2, y2], dim=2)
print("grid_xy: ", grid_xy, grid_xy.shape);
# output:
# grid_xy: tensor([[[0, 0],
# [1, 0],
# [2, 0]],
# [[0, 1],
# [1, 1],
# [2, 1]],
# [[0, 2],
# [1, 2],
# [2, 2]]]) torch.Size([3, 3, 2])
grid_xy = grid_xy.unsqueeze(0).unsqueeze(3)
print("grid_xy: ", grid_xy.shape);
grid_xy = grid_xy.repeat(2, 1, 1, 3, 1).float();
print("grid_xy \t\trepeat:(2, 1, 1, 3, 1) ");
print("grid_xy: ", grid_xy.shape);
# output:
# grid_xy: torch.Size([1, 3, 3, 1, 2])
# grid_xy repeat:(2, 1, 1, 3, 1)
# grid_xy: torch.Size([2, 3, 3, 3, 2])
「点个赞」
点个赞
使用微信扫描二维码完成支付
