< 比较,遮盖和布尔逻辑 | 目录 | 数组排序 >
在前面的小节中,我们学习了如何获取和修改数组的元素或部分元素,我们可以通过简单索引(例如arr[0]
),切片(例如arr[:5]
)和布尔遮盖(例如arr[arr > 0]
)来实现。本节来介绍另外一种数组索引的方式,被称为高级索引。高级索引语法上和前面我们学习到的简单索引很像,区别只是它不是传递标量参数作为索引值,而是传递数组参数作为索引值。它能让我们很迅速的获取和修改复杂数组或子数组的元素值。
import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
假如我们需要访问其中三个不同的元素。我们可以这样做:
[x[3], x[7], x[2]]
[71, 86, 14]
还有一种方法,我们以一个数组的方式将这些元素的索引传递给数组,也可以获得相同的结果:
ind = [3, 7, 4]
x[ind]
array([71, 86, 60])
当使用高级索引时,结果数组的形状取决于索引数组的形状而不是被索引数组的形状:
ind = np.array([[3, 7],
[4, 5]]) # 索引数组是一个2x2数组,结果也将会是一个2x2数组
x[ind]
array([[71, 86], [60, 20]])
高级索引也支持多维数组。例如:
X = np.arange(12).reshape((3, 4))
X
array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]])
就像普通索引一样,第一个参数代表行,第二个参数代表列:
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
array([ 2, 5, 11])
结果中的第一个值是x[0, 2]
,第二个值是x[1, 1]
,第三个值是x[2, 3]
。高级索引的多个维度组合方式也遵守广播的规则,请查阅在数组上计算:广播。因此,如果我们在上面的行索引数组中增加一个维度,结果将变成一个二维数组:
X[row[:, np.newaxis], col]
array([[ 2, 1, 3], [ 6, 5, 7], [10, 9, 11]])
这里,每个行索引都会匹配每个列的向量,就像我们在广播的算术运算中看到一样。例如:
row[:, np.newaxis] * col
array([[0, 0, 0], [2, 1, 3], [4, 2, 6]])
记住高级索引结果的形状是索引数组广播后的形状而不是被索引数组形状,这点非常重要。
print(X)
[[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11]]
我们可以将高级索引和简单索引进行组合:
译者注,实际上这就是个广播,将标量广播成一个向量。
X[2, [2, 0, 1]]
array([10, 8, 9])
我们也可以将高级索引和切片进行组合:
X[1:, [2, 0, 1]]
array([[ 6, 4, 5], [10, 8, 9]])
还可以将高级索引和遮盖进行组合:
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask]
array([[ 0, 2], [ 4, 6], [ 8, 10]])
所有这些索引操作可以提供用户非常灵活的方式来获取和修改数组中的数据。
mean = [0, 0]
cov = [[1, 2],
[2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
(100, 2)
使用我们会在第四章详细介绍的Matplotlib工具,我们可以在散点图上绘制这些点:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # 设置图表风格,seaborn
plt.scatter(X[:, 0], X[:, 1]);
下面我们使用高级索引来选择20个随机点。方法是先创建一个索引数组,里面的索引值是没有重复的,然后使用这个索引数组来选择点:
indices = np.random.choice(X.shape[0], 20, replace=False)
indices
array([66, 38, 68, 88, 94, 50, 73, 69, 95, 31, 89, 39, 20, 85, 34, 49, 48, 96, 29, 44])
selection = X[indices] # 使用高级索引
selection.shape
(20, 2)
下面我们来看看那些点被选中,让我们上图的基础上将选中的点圈出来:
plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1],
facecolor='none', s=200);
这种策略经常用来划分数据集,比如用来验证统计模型正确性时需要的训练集和测试集划分(参见超参数及模型验证),还有就是在回答统计问题时进行取样抽象。
x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
print(x)
[ 0 99 99 3 99 5 6 7 99 9]
我们可以使用任何赋值类型操作,例如:
x[i] -= 10
print(x)
[ 0 89 89 3 89 5 6 7 89 9]
请注意下,如果索引数组中有重复的元素的话,这种修改操作可能会导致一个潜在的意料之外的结果。例如:
x = np.zeros(10)
x[[0, 0]] = [4, 6]
print(x)
[6. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
4跑到哪里去了呢?这个操作首先赋值x[0] = 4
,然后赋值x[0] = 6
,因此最后x[0]
的值是6。
上面的例子还算比较清晰,再看下面这个操作:
i = [2, 3, 3, 4, 4, 4]
x[i] += 1
x
array([6., 0., 1., 1., 1., 0., 0., 0., 0., 0.])
我们期望的结果可能是x[3]
的值是2,而x[4]
的值是3,因为这两个元素都多次执行了加法操作。但是为何结果不是呢?这是因为x[i] += 1
是操作x[i] = x[i] + 1
的简写,而x[i] + 1
表达式的值已经计算好了,然后才被赋值给x[i]
。因此,上面的操作不会被扩展为重复的运算,而是一次的赋值操作,造成了这种难以理解的结果。
如果我们真的需要这种重复的操作怎么办?对此,NumPy(版本1.8以上)提供了at()
ufunc方法可以满足这个目的,如下:
x = np.zeros(10)
np.add.at(x, i, 1)
print(x)
[0. 0. 1. 2. 3. 0. 0. 0. 0. 0.]
at()
方法不会预先计算表达式的值,而是每次运算时实时得到,方法在一个数组x
中取得特定索引i
,然后将其取得的值与最后一个参数1
进行相应计算,这里是加法add
。还有一个类似的方法是reduceat()
方法,你可以从NumPy的文档中阅读它的说明。
np.random.seed(42)
x = np.random.randn(100) # 获得一个一维100个标准正态分布值
# 得到一个自定义的数据分组,区间-5至5平均取20个点,每个区间为一个数据分组
bins = np.linspace(-5, 5, 20)
counts = np.zeros_like(bins) # counts是x数值落入区间的计数
# 使用searchsorted,得到x每个元素在bins中落入的区间序号
i = np.searchsorted(bins, x)
# 使用at和add,对x元素在每个区间的元素个数进行计算
np.add.at(counts, i, 1)
counts现在包含着每个数据分组中元素的个数,换句话来说,就是直方图:
译者注:Matplotlib 3.1开始,linestyle关键字参数已经过时,后续版本会抛弃。下面代码依据最新参数更改为drawstyle或ds。
# 用图表展示结果
plt.plot(bins, counts, ds='steps');
当然,如果每次要画直方图的时候,都要经过这么复杂的计算,很不方便。这也就是为什么Matplotlib提供了plt.hist()
方法的原因,可以用一行代码完成上面操作:
plt.hist(x, bins, histtype='step');
这个函数会创建一个和上图基本完全一样的图形。Matplotlib使用np.histogram
函数来计算数据分组,这个函数进行的计算和我们上面的代码非常接近。我们比较一下这两个方法:
print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine: 22.1 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) Custom routine: 16 µs ± 609 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
我们自己写的一行代码比NumPy优化的算法要快出许多,这是因为什么?如果你深入到np.histogram
函数的源代码进行阅读(你可以通过在IPython中输入np.histogram??
来查阅)的时候,你会发现函数除了搜索和计数之外,还做了其他很多工作;这是因为NumPy的函数要更加灵活,而且当数据量变大的时候能够提供更好的性能:
x = np.random.randn(1000000)
print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine: 80 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Custom routine: 121 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
上面的结果说明当涉及到算法的性能时,永远不可能是一个简单的问题。对于大数据集来说一个很高效的算法,并不一定也适用于小数据集,反之亦然(参见大O复杂度)。我们这里使用自己的代码实现这个算法,目的是理解上面的基本函数,后续读者可以使用这些函数构建自己定义的各种功能。在数据科学应用中使用Python编写代码的关键在于,你能掌握NumPy提供的很方便的函数如np.histogram
,你也能知道什么情况下适合使用它们,当需要更加定制的功能时你还能使用底层的函数自己实现相应的算法。
< 比较,遮盖和布尔逻辑 | 目录 | 数组排序 >