Understanding Numpy expand_dims
Inserting 1 into the shape brackets base on the axis value
Understanding Numpy expand_dims Shape (n,)(n is a number) means it has only one dimension.
The number of values is shape brackets represents the number of dimensions.
import numpy as np arr = np.array([1,2,3,4,5]); print("arr shape: ",arr.shape) print("arr shape: ",arr) arr2 = np.expand_dims(arr, 0); print("expand_dims axis=0, shape:",arr2.shape) print("expand_dims axis=0, arr2:",arr2) arr2 = np.expand_dims(arr, 1); print("expand_dims axis=1, shape:",arr2.shape) print("expand_dims axis=1, arr2:",arr2) output:
arr shape: (5,) arr shape: [1 2 3 4 5] expand_dims axis=0, shape: (1, 5) expand_dims axis=0, arr2: [[1 2 3 4 5]] expand_dims axis=1, shape: (5, 1) expand_dims axis=1, arr2: [[1] [2] [3] [4] [5]] numpy.