0%

机器学习之分类算法:KNN(K-Nearest Neighbors)K近邻算法

这里介绍机器学习中最简单最朴素的一种分类算法——KNN算法

abstract.png

基本原理

KNN(K-Nearest Neighbors)算法是一种监督学习的分类算法。其底层逻辑很简单:一个样本的类别是由与其距离最近的K个邻居的类别决定。即所谓的近朱者赤,近墨者黑。具体地:对于一个待分类的样本而言,首先从训练集中找到与其距离最近的K个邻居,然后根据这K个邻居的类别进行投票。此时,投票数最多的类别即为该样本的分类类别。其中,常用的距离度量有欧式距离、曼哈顿距离、余弦距离等

从上述流程不难看出,该算法虽然是一个监督学习算法需要训练集。但其却并不需要进行训练,因为其只是简单地将训练集存储起来,然后在预测推理阶段才会使用该训练集进行计算。故该算法是一个懒惰学习算法

该算法的K值是一个超参数,其表示有多少个邻居参与投票,决定了模型的预测结果、泛化能力。具体地,如果K值过小,则可能会过拟合、泛化能力差;如果K值过大,则可能会导致欠拟合。具体地,K值一般是奇数,以避免投票时出现平局。可根据业务领域经验先确定一个K值的取值范围。然后通过交叉验证法来评估、确定K值的最终取值

实践

下面通过SKlearn提供的KNN分类器来实现一个分类任务。这里选用为经典的鸢尾花Iris数据集。该数据集包含150个样本,选取了鸢尾花的四个特征(sepal length花萼长度、sepal width花萼宽度、petal length花瓣长度、petal width花瓣宽度)用于预测鸢尾花的品种(setosa/versicolor/virginica)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.metrics import classification_report

# 加载数据集: 鸢尾花
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 特征名称
feature_names = iris.feature_names
# 类别名称
label_names = iris.target_names

# 划分训练集、测试集,比例为7:3。固定random_state种子保证每次划分结果一样,可重复
X_train, X_test, y_train, y_test = train_test_split(X, y ,test_size=0.3, random_state=69)

# 创建KNN分类器实例 K值设为3,距离度量使用欧式距离
knn = KNeighborsClassifier(n_neighbors=3, metric="euclidean")

# 训练模型(事实并未训练,只是存储训练集)
knn.fit(X_train, y_train)

# 预测测试集
y_pred = knn.predict(X_test)

# 计算评估指标
report = classification_report(y_test, y_pred, target_names=label_names)
print(f"------------------------ 评估指标 ------------------------")
print(f"{report}")

# 计算混淆矩阵
confusion_matrix = confusion_matrix(y_true=y_test,y_pred=y_pred)

# 绘制混淆矩阵
# annot=True: 每个单元格显示数值; fmt="d": 单元格数值格式为整数
# xticklabels / yticklabels: 设置X轴/Y轴的刻度标签名称
sns.heatmap(confusion_matrix, annot=True, cmap="Blues", fmt="d", xticklabels=label_names, yticklabels=label_names)
# X轴标题:预测标签
plt.xlabel("predicted Label")
# Y轴标题:真实标签
plt.ylabel("True Label")
plt.title("KNN: Confusion Matrix")
plt.show()

输出结果如下:

1
2
3
4
5
6
7
8
9
10
------------------------ 评估指标 ------------------------
precision recall f1-score support

setosa 1.00 1.00 1.00 16
versicolor 0.92 1.00 0.96 12
virginica 1.00 0.94 0.97 17

accuracy 0.98 45
macro avg 0.97 0.98 0.98 45
weighted avg 0.98 0.98 0.98 45

figure 1.png

参考文献

  • 图解机器学习和深度学习入门 山口达辉、松田洋之著
请我喝杯咖啡捏~
  • 本文作者: Aaron Zhu
  • 本文链接: https://xyzghio.xyz/KNN/
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-ND 许可协议。转载请注明出处!

欢迎关注我的微信公众号:青灯抽丝