import numpy as np
arr = np.array([[1, 10, 3], [6, 5, 11], [7, 8, 9], [12, 2, 4]])
print('axis=0', np.argmax(arr, axis=0))
print('axis=1', np.argmax(arr, axis=1))# arr[0] = [1, 10, 3]
# arr[1] = [6, 5, 11]
# arr[2] = [7, 8, 9]
# arr[3] = [12, 2, 4]
axis=0 [3 0 1]
axis=1 [1 2 2 0]