TOC
Understanding Numpy expand_dims
Shape (n,)
(n is a number) means it has only one dimension.
The number of values is shape brackets represents the number of dimensions.
import numpy as np
arr = np.array([1,2,3,4,5]);
print("arr shape: ",arr.shape)
print("arr shape: ",arr)
arr2 = np.expand_dims(arr, 0);
print("expand_dims axis=0, shape:",arr2.shape)
print("expand_dims axis=0, arr2:",arr2)
arr2 = np.expand_dims(arr, 1);
print("expand_dims axis=1, shape:",arr2.shape)
print("expand_dims axis=1, arr2:",arr2)
output:
arr shape: (5,)
arr shape: [1 2 3 4 5]
expand_dims axis=0, shape: (1, 5)
expand_dims axis=0, arr2: [[1 2 3 4 5]]
expand_dims axis=1, shape: (5, 1)
expand_dims axis=1, arr2: [[1]
[2]
[3]
[4]
[5]]
numpy.expand_dims
expand_dims
looks like inserting 1 into the shape brackets base on the axis value;
import numpy as np
arr = np.array([[[1,2],[3,4]]]);
print("arr shape: ",arr.shape)
print("arr shape: ",arr)
arr2 = np.expand_dims(arr, 0);
print("expand_dims axis=0, shape:",arr2.shape)
print("expand_dims axis=0, arr2:",arr2)
arr2 = np.expand_dims(arr, 1);
print("expand_dims axis=1, shape:",arr2.shape)
print("expand_dims axis=1, arr2:",arr2)
arr2 = np.expand_dims(arr, 2);
print("expand_dims axis=2, shape:",arr2.shape)
print("expand_dims axis=2, arr2:",arr2)
Output:
arr shape: (1, 2, 2)
arr shape: [[[1 2]
[3 4]]]
expand_dims axis=0, shape: (1, 1, 2, 2)
expand_dims axis=0, arr2: [[[[1 2]
[3 4]]]]
expand_dims axis=1, shape: (1, 1, 2, 2)
expand_dims axis=1, arr2: [[[[1 2]
[3 4]]]]
expand_dims axis=2, shape: (1, 2, 1, 2)
expand_dims axis=2, arr2: [[[[1 2]]
[[3 4]]]]
「点个赞」
点个赞
使用微信扫描二维码完成支付
