Understanding Numpy expand_dims

Inserting 1 into the shape brackets base on the axis value

Posted by yaohong on Monday, November 30, 2020

TOC

Understanding Numpy expand_dims

Shape (n,)(n is a number) means it has only one dimension.

The number of values is shape brackets represents the number of dimensions.

import numpy as np
arr = np.array([1,2,3,4,5]);
print("arr shape: ",arr.shape)
print("arr shape: ",arr)
arr2 = np.expand_dims(arr, 0);
print("expand_dims axis=0, shape:",arr2.shape)
print("expand_dims axis=0, arr2:",arr2)
arr2 = np.expand_dims(arr, 1);
print("expand_dims axis=1, shape:",arr2.shape)
print("expand_dims axis=1, arr2:",arr2)

output:

arr shape:  (5,) 
arr shape:  [1 2 3 4 5]
expand_dims axis=0, shape: (1, 5)
expand_dims axis=0, arr2: [[1 2 3 4 5]]
expand_dims axis=1, shape: (5, 1)
expand_dims axis=1, arr2: [[1]
 [2]
 [3]
 [4]
 [5]]

numpy.expand_dims

expand_dims looks like inserting 1 into the shape brackets base on the axis value;

import numpy as np
arr = np.array([[[1,2],[3,4]]]);
print("arr shape: ",arr.shape)
print("arr shape: ",arr)
arr2 = np.expand_dims(arr, 0);
print("expand_dims axis=0, shape:",arr2.shape)
print("expand_dims axis=0, arr2:",arr2)
arr2 = np.expand_dims(arr, 1);
print("expand_dims axis=1, shape:",arr2.shape)
print("expand_dims axis=1, arr2:",arr2)
arr2 = np.expand_dims(arr, 2);
print("expand_dims axis=2, shape:",arr2.shape)
print("expand_dims axis=2, arr2:",arr2)

Output:

arr shape:  (1, 2, 2)
arr shape:  [[[1 2]
  [3 4]]]
expand_dims axis=0, shape: (1, 1, 2, 2)
expand_dims axis=0, arr2: [[[[1 2]
   [3 4]]]]
expand_dims axis=1, shape: (1, 1, 2, 2)
expand_dims axis=1, arr2: [[[[1 2]
   [3 4]]]]
expand_dims axis=2, shape: (1, 2, 1, 2)
expand_dims axis=2, arr2: [[[[1 2]]
  [[3 4]]]]

「点个赞」

Yaohong

点个赞

使用微信扫描二维码完成支付