本节之前,我们主要关注NumPy中那些获取和操作数组数据的工具。本小节我们会介绍对NumPy数组进行排序的算法。这些算法在基础计算机科学领域是很热门的课题:如果你学习过相关的课程的话,你可能梦(或者根据你的经理,可能是噩梦)到过有关插入排序、选择排序、归并排序、快速排序、冒泡排序和其他很多很多名词。这些都是为了完成一件工作的:对数组进行排序。
例如,一个简单的选择排序会重复寻找列表中最小的值,然后和当前值进行交换,直到列表排序完成。我们可以在Python中用简单的几行代码完成这个算法:
import numpy as np
def selection_sort(x):
for i in range(len(x)):
swap = i + np.argmin(x[i:]) # 寻找子数组中的最小值的索引序号
(x[i], x[swap]) = (x[swap], x[i]) # 交换当前值和最小值
return x
x = np.array([2, 1, 4, 3, 5])
selection_sort(x)
array([1, 2, 3, 4, 5])
任何一个5年的计算机科学专业都会教你,选择排序很简单,但是对于大的数组来说运行效率就不够了。对于数组具有$N$个值,它需要$N$次循环,每次循环中需要$\sim N$次比较和寻找来交换元素。大O表示法经常用来对算法性能进行定量分析(参见大O复杂度),选择排序平均需要$\mathcal{O}[N^2]$:如果列表中的元素个数加倍,执行时间增长大约是原来的4倍。
甚至选择排序也远比下面这个bogo排序算法有效地多,这是作者最喜爱的排序算法:
def bogosort(x):
while np.any(x[:-1] > x[1:]):
np.random.shuffle(x)
return x
x = np.array([2, 1, 4, 3, 5])
bogosort(x)
array([1, 2, 3, 4, 5])
这个有趣而粗苯的算法完全依赖于概率:它重复的对数组进行随机的乱序直到结果刚好是正确排序为止。这个算法平均需要$\mathcal{O}[N \times N!]$,即N乘以N的阶乘,明显的,在真实情况下,它不应该被用于排序计算。
幸运的是,Python內建有了排序算法,比我们刚才提到那些简单的算法都要高效。我们从Python內建的排序开始介绍,然后再去讨论NumPy中为了数组优化的排序函数。
x = np.array([2, 1, 4, 3, 5])
np.sort(x)
array([1, 2, 3, 4, 5])
如果你期望直接改变数组的数据进行排序,你可以对数组对象使用它的sort
方法:
x.sort()
print(x)
[1 2 3 4 5]
相关的函数是argsort
,它将返回排好序后元素原始的序号序列:
x = np.array([2, 1, 4, 3, 5])
i = np.argsort(x)
print(i)
[1 0 3 2 4]
结果的第一个元素是数组中最小元素的序号,第二个元素是数组中第二小元素的序号,以此类推。这些序号可以通过高级索引的方式使用,从而获得一个排好序的数组:
译者注:更好的问题应该是,假如我们希望获得数组中第二、三小的元素,我们可以这样做:
x[i[1:3]]
x[i]
array([1, 2, 3, 4, 5])
NumPy的排序算法可以沿着多维数组的某些轴axis
进行,如行或者列。例如:
rand = np.random.RandomState(42)
X = rand.randint(0, 10, (4, 6))
print(X)
[[6 3 7 4 6 9] [2 6 7 4 3 7] [7 2 5 4 1 7] [5 1 4 0 9 5]]
# 沿着每列对数据进行排序
np.sort(X, axis=0)
array([[2, 1, 4, 0, 1, 5], [5, 2, 5, 4, 3, 7], [6, 3, 7, 4, 6, 7], [7, 6, 7, 4, 9, 9]])
# 沿着每行对数据进行排序
np.sort(X, axis=1)
array([[3, 4, 6, 6, 7, 9], [2, 3, 4, 6, 7, 7], [1, 2, 4, 5, 7, 7], [0, 1, 4, 5, 5, 9]])
必须注意的是,这样的排序会独立的对每一行或者每一列进行排序。因此结果中原来行或列之间的联系都会丢失。
x = np.array([7, 2, 3, 1, 6, 5, 4])
np.partition(x, 3)
array([2, 1, 3, 4, 6, 5, 7])
你可以看到结果中最小的三个值在左边,其余4个值位于数组的右边,每个分区内部,元素的顺序是任意的。
和排序一样,我们可以按照任意维度对一个多维数组进行分区:
np.partition(X, 2, axis=1)
array([[3, 4, 6, 7, 6, 9], [2, 3, 4, 7, 6, 7], [1, 2, 4, 5, 7, 7], [0, 1, 4, 5, 9, 5]])
结果中每行的前两个元素就是该行最小的两个值,该行其余的值会出现在后面。
最后,就像np.argsort
函数可以返回排好序的元素序号一样,np.argpartition
可以计算分区后元素的序号。后面的例子中我们会看到它的使用。
X = rand.rand(10, 2)
我们先来观察一下这些点的分布情况,散点图很适合这种情形:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # 图表风格,seaborn
plt.scatter(X[:, 0], X[:, 1], s=100);
现在让我们来计算每两个点之间的距离。距离平方的定义是两点坐标差的平方和。应用广播(在数组上计算:广播)和聚合(聚合:Min, Max, 以及其他)函数,我们可以使用一行代码就能计算出所有点之间的距离平方:
dist_sq = np.sum((X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2, axis=-1)
上面的这行代码包含很多的内容值得探讨,如果对于不是特别熟悉广播机制的读者来说,看起来可能会让人难以理解。当你读到这样的代码的时候,将它们打散成一步步的操作会有帮助:
# 计算每两个点之间的坐标距离
differences = X[:, np.newaxis, :] - X[np.newaxis, :, :]
differences.shape
(10, 10, 2)
# 计算距离的平方
sq_differences = differences ** 2
sq_differences.shape
(10, 10, 2)
# 按照最后一个维度求和
dist_sq = sq_differences.sum(-1)
dist_sq.shape
(10, 10)
你可以检查这个矩阵的对角线元素,对角线元素的值是点与其自身的距离平方,应该全部为0:
dist_sq.diagonal()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
确认正确。现在我们已经有了一个距离平方的矩阵,然后就可以使用np.argsort
函数来按照每行来排序。最左边的列就会给出每个点的最近邻:
nearest = np.argsort(dist_sq, axis=1)
print(nearest)
[[0 3 9 7 1 4 2 5 6 8] [1 4 7 9 3 6 8 5 0 2] [2 1 4 6 3 0 8 9 7 5] [3 9 7 0 1 4 5 8 6 2] [4 1 8 5 6 7 9 3 0 2] [5 8 6 4 1 7 9 3 2 0] [6 8 5 4 1 7 9 3 2 0] [7 9 3 1 4 0 5 8 6 2] [8 5 6 4 1 7 9 3 2 0] [9 7 3 0 1 4 5 8 6 2]]
结果中的第一列是0到9的数字:这是因为距离每个点最近的是自己,正如我们预料的一样。
上面我们进行了完整的排序,事实上我们并不需要这么做。如果我们只是对最近的$K$个邻居感兴趣的话,我们可以使用分区来完成,只需要在距离平方矩阵中对每行进行$K+1$分区,只需要调用np.argpartition
函数即可:
K = 2
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)
为了展示最近邻的网络结构,我们在图中为每个点和它最近的两个点之间连上线:
plt.scatter(X[:, 0], X[:, 1], s=100)
# 为每个点和它最近的两个点之间连上线
K = 2
for i in range(X.shape[0]):
for j in nearest_partition[i, :K+1]:
# 从X[i]连线到X[j]
# 使用一些zip的魔术方法画线
plt.plot(*zip(X[j], X[i]), color='black')
图上的每个点都和与它最近的两个点相连。初看起来,你可能注意到有些点的连线可能超过2条,这很奇怪:实际原因是如果A是B的最近两个近邻之一,并不代表着B也必须是A的最近两个近邻之一。
虽然使用广播和逐行排序的方式完成任务可能没有使用循环来的直观,但是在Python中这是一种非常有效的方式。你可能忍不住使用循环的方式对每个点去计算它相应的最近邻,但是这种方式几乎肯定会比我们前面使用的向量化方案要慢很多。向量化的解法还有一个优点,那就是它不关心数据的尺寸:我们可以使用同样的代码和方法计算100个点或1,000,000个点以及任意维度数的数据的最近邻。
最后,需要说明的是,当对一个非常大的数据集进行最近邻搜索时,还有一种基于树或相似的算法能够将时间复杂度从$\mathcal{O}[N^2]$优化到$\mathcal{O}[N\log N]$或更好。其中一个例子是KD-Tree。
额外内容:大 O 复杂度
大O复杂度是一种衡量随着输入数据的增加,需要执行的操作的数量的量级情况的指标。要正确使用它,需要深入了解计算机科学的理论知识,要和其他相关的概念如小O复杂度,大$\theta$复杂度,大$\Omega$复杂度区分开来,更加不容易。虽然精确地描述出这些复杂度是属于算法的范畴,除了学院派计算机科学理论的测验和评分以外,你在其他应用领域很难看到这些严格的定义和划分。在数据科学领域中,我们不会使用这样死板的大O复杂度概念,虽然这和算法领域的概念在精确程度上有一定差距。带着对理论学者和学院派的歉意,本书将一直使用对大O复杂度的这种非精确概念解释。
大O复杂度,简单来说,会告诉你当你的数据增大时,你的算法运行需要的时间。例如你有一个$\mathcal{O}[N]$(英文读作"Order $N$")的算法,对于N=1000的数据量,它需要运行1秒,那么对于N=5000的数据量,算法需要执行的时间就为5秒。如果你的算法复杂度为$\mathcal{O}[N^2]$(英文读作"Order N squared"),对于N=1000的数据量需要运行1秒,那么你可以预期当数据量增长为N=5000时,运行时间为25秒。
对于我们的目标来说,N通常代表着数据集的大小(数据点的数量,维度数等)。当我们需要分析的数据样本量达到百万级或十亿级时,$\mathcal{O}[N]$和$\mathcal{O}[N^2]$之间的差距将会是巨大的。
请记住大O复杂度本身并不能告诉你实际上运算消耗的时间,它仅仅能够告诉你当N变化时,运行时间会怎样随之发生变化。通常来说,$\mathcal{O}[N]$复杂度的算法被认为肯定要比$\mathcal{O}[N^2]$复杂度的算法要好。但对于小的数据集来说,好的大O复杂度算法并不一定能带来更快的执行效率。例如,某个特定情况下,$\mathcal{O}[N^2]$复杂度的算法可能需要0.01秒的运行时间而$\mathcal{O}[N]$复杂度的算法可能需要1秒。但是如果将N增大1000倍,那么$\mathcal{O}[N]$复杂度的算法将会胜出。
我们这里使用的这种非严格定义的大O复杂度对于算法的性能也是有指示意义的,在本书的后续部分当我们讨论到算法范畴时都会应用到它。