本章将学习以下内容:
- 学习使用OpenCV中的
cv2.kmeans()
函数进行数据聚类
参数详解
输入参数说明
samples
(样本数据):数据类型必须为np.float32
,每个特征应单独作为一列排列。nclusters(K)
(聚类数量):最终需要获得的簇(聚类)数量。criteria
(终止条件):用于设定算法迭代的停止条件,当满足该条件时停止计算,具体是由3个参数组成的元组,格式为 (type
,max_iter
,epsilon
):type
(终止类型)、max_iter
(最大迭代次数)、epsilon
(精度阈值):
3.a 终止条件类型(type of termination criteria)包含以下3种标志位:
cv2.TERM_CRITERIA_EPS
- 当达到指定精度(epsilon)时停止算法迭代
cv2.TERM_CRITERIA_MAX_ITER
- 当达到指定迭代次数(max_iter)时停止算法
cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER
- 满足上述任一条件即停止迭代
- 当达到指定精度(epsilon)时停止算法迭代
3.b 最大迭代次数(
max_iter
):整数型,指定最大迭代次数3.c 精度要求(epsilon):要求的收敛精度阈值
attempts
(尝试次数): 用于指定算法使用不同初始标签执行的次数。算法将返回具有最佳紧密度(compactness)的聚类结果,该紧密度值将作为输出返回。flags
(中心点初始化标志): 用于指定初始聚类中心的选取方式,常用两种标志位:cv2.KMEANS_PP_CENTERS
,基于k-means++算法智能初始化中心点;cv2.KMEANS_RANDOM_CENTERS
:随机初始化中心点。
输出参数说明
compactness
(紧密度): 表示各数据点到其所属聚类中心的平方距离总和labels
(标签数组):与之前文章中'code'含义相同,每个元素被标记为'0','1'...等,对应数据点所属的簇编号。centers
(聚类中心数组): 包含所有聚类中心坐标的数组。
下面将通过三个实际案例演示 K-Means 聚类算法的应用:
案例1:单特征数据聚类分析
典型应用场景:以服装尺码预测系统为例,通过顾客身高数据(单一维度特征)自动划分T恤尺码类别:
接下来将通过以下步骤创建数据集并进行可视化分析:
%matplotlib inline
import numpy as np
import cv2 as cv
from matplotlib import pyplot as plt
x = np.random.randint(25,100,25)
y = np.random.randint(175,255,25)
z = np.hstack((x,y))
z = z.reshape((50,1))
z = np.float32(z)
plt.hist(z,256,[0,256]),plt.show()
/tmp/ipykernel_3845/3580425761.py:11: MatplotlibDeprecationWarning: Passing the range parameter of hist() positionally is deprecated since Matplotlib 3.9; the parameter will become keyword-only in 3.11. plt.hist(z,256,[0,256]),plt.show()
((array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 1., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 2., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 2., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 2., 0., 0., 3., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., 253., 254., 255., 256.]), <BarContainer object of 256 artists>), None)
因此得到了一个包含50个元素、数值范围在0到255之间的数组'z',并将其重塑为列向量格式。
这种数据结构在存在多个特征时更具扩展性。
随后将数据转换为np.float32
类型,最终获得如下处理结果图像:
在应用K-Means算法前,需要明确定义终止条件($criteria$)。 根据需求将配置如下停止准则:
# Define criteria = ( type, max_iter = 10 , epsilon = 1.0 )
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 1.0)
# Set flags (Just to avoid line break in the code)
flags = cv.KMEANS_RANDOM_CENTERS
# Apply KMeans
compactness,labels,centers = cv.kmeans(z,2,None,criteria,10,flags)
这将返回三个关键结果:紧密度(compactness)、标签(labels)和中心点(centers)。 在本例中,得到的中心点值为60和207。 标签数组的大小与测试数据相同,每个数据点会根据其所属的中心点被标记为'0'、'1'、'2'等。 接下来,根据这些标签将数据划分到不同的簇中。
A = z[labels==0]
B = z[labels==1]
现在将数据可视化呈现为簇A数据点使用红色标注,簇B数据点使用蓝色标注,对应的聚类中心点用黄色星形标记。
# Now plot 'A' in red, 'B' in blue, 'centers' in yellow
plt.hist(A,256,[0,256],color = 'r')
plt.hist(B,256,[0,256],color = 'b')
plt.hist(centers,32,[0,256],color = 'y')
plt.show()
/tmp/ipykernel_3845/15340111.py:2: MatplotlibDeprecationWarning: Passing the range parameter of hist() positionally is deprecated since Matplotlib 3.9; the parameter will become keyword-only in 3.11. plt.hist(A,256,[0,256],color = 'r') /tmp/ipykernel_3845/15340111.py:3: MatplotlibDeprecationWarning: Passing the range parameter of hist() positionally is deprecated since Matplotlib 3.9; the parameter will become keyword-only in 3.11. plt.hist(B,256,[0,256],color = 'b') /tmp/ipykernel_3845/15340111.py:4: MatplotlibDeprecationWarning: Passing the range parameter of hist() positionally is deprecated since Matplotlib 3.9; the parameter will become keyword-only in 3.11. plt.hist(centers,32,[0,256],color = 'y')
import numpy as np
import cv2 as cv
from matplotlib import pyplot as plt
X = np.random.randint(25,50,(25,2))
Y = np.random.randint(60,85,(25,2))
Z = np.vstack((X,Y))
# convert to np.float32
Z = np.float32(Z)
# define criteria and apply kmeans()
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 1.0)
ret,label,center=cv.kmeans(Z,2,None,criteria,10,cv.KMEANS_RANDOM_CENTERS)
# Now separate the data, Note the flatten()
A = Z[label.ravel()==0]
B = Z[label.ravel()==1]
# Plot the data
plt.scatter(A[:,0],A[:,1])
plt.scatter(B[:,0],B[:,1],c = 'r')
plt.scatter(center[:,0],center[:,1],s = 80,c = 'y', marker = 's')
plt.xlabel('Height'),plt.ylabel('Weight')
plt.show()
import numpy as np
import cv2 as cv
img = cv.imread('/data/cvdata/home.jpg')
Z = img.reshape((-1,3))
# convert to np.float32
Z = np.float32(Z)
# define criteria, number of clusters(K) and apply kmeans()
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 1.0)
K = 8
ret,label,center=cv.kmeans(Z,K,None,criteria,10,cv.KMEANS_RANDOM_CENTERS)
# Now convert back into uint8, and make original image
center = np.uint8(center)
res = center[label.flatten()]
res2 = res.reshape((img.shape))
plt.imshow(res2)
# cv.imshow('res2',res2)
# cv.waitKey(0)
# cv.destroyAllWindows()
<matplotlib.image.AxesImage at 0x7f40bcfbaea0>
K=8
的结果如下: