当前位置: 代码网 > it编程>软件设计>算法 > K最近邻算法:简单高效的分类和回归方法

K最近邻算法:简单高效的分类和回归方法

2024年08月05日 算法 我要评论
K最近邻(K-nearest neighbors,简称KNN)算法是一种基于实例的机器学习方法,可以用于分类和回归问题。它的思想非常简单,但在实践中却表现出了出色的效果。本文将介绍KNN算法的原理、应用场景和优缺点,并通过示例代码演示其实现过程以上代码仅仅的简单演示一遍KNN算法,但是真正的KNN算法并没有这么简单,下节我会通过上述代码的基础上进行简单的优化,并进行封装挑战与创造都是很痛苦的,但是很充实。

🍀简介


🍀knn算法原理

  • 计算待分类样本与训练集中每个样本之间的距离(通常使用欧氏距离或曼哈顿距离)
  • 选取距离最近的k个样本作为邻居
  • 根据邻居样本的标签进行投票,将待分类样本归类为得票最多的类别(分类问题)或计算邻居样本标签的平均值(回归问题)
    欧拉距离如下
    请添加图片描述

🍀knn算法应用场景

  • 分类问题:如垃圾邮件过滤、图像识别等
  • 回归问题:如房价预测、股票价格预测等
  • 推荐系统:根据用户和物品的相似度进行推荐
  • 异常检测:检测异常行为或异常事件

例如在邮件分类上就需要如下步骤

  • 数据准备:
    为了使用knn算法进行邮件分类,我们需要准备一个数据集作为训练样本。这个数据集可以由已标记为垃圾邮件和非垃圾邮件的邮件组成。每封邮件都应该被转化为特征向量表示,通常使用词袋模型来表示每个邮件中的单词频率。

  • 特征提取:
    对于每封邮件,我们可以提取出一组特征,例如:

  • 单词频率:统计邮件中每个单词的出现频率,构建一个向量表示邮件的特征。
    主题关键词:根据主题模型提取关键词,构建一个向量表示邮件的主题内容。

  • 数据预处理:
    在应用knn算法之前,需要对数据进行预处理。常见的预处理步骤包括去除停用词、词干提取和编码转换等。

  • 模型训练:
    将预处理后的数据集划分为训练集和测试集。使用knn算法对训练集进行训练,调整k值和距离度量方式来优化模型性能。可以通过交叉验证等技术来选择最优的k值。

  • 模型评估:
    使用训练好的模型对测试集进行预测,并与真实标签进行比较。常用的评估指标包括准确率、精确率、召回率和f1值等,通过这些指标可以评估模型在垃圾邮件过滤方面的性能。

  • 模型使用:
    将训练好的模型应用于新的邮件数据分类。通过计算待分类邮件与训练集样本的距离,并选取最近的k个邻居样本,根据这些邻居样本的标签进行投票,将待分类邮件划分为得票最多的类别,即确定该邮件是否为垃圾邮件。


🍀knn算法优缺点

  • 简单直观,易于实现和理解
  • 适用于多分类问题
  • 对于样本分布不规则的情况,表现良好
  • 需要存储全部训练样本,计算复杂度较高
  • 对于高维数据,效果不佳
  • 对于样本不平衡的数据集,容易被少数类别影响

🍀knn算法代码示例

import numpy as np
from matplotlib import pyplot as plt
raw_data_x = [[5.1935, 2.3312],
              [3.1201, 1.7815],
              [1.3438, 3.3684],
              [2.5323, 3.2762],
              [2.2804, 1.8670],
              [8.4234, 6.6565],
              [8.7451, 7.5340],
              [9.1522, 2.5141],
              [9.7428, 4.4241],
              [8.9398, 1.7916]]
raw_data_y =[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]  #  0是良性,1是恶性
x_train = np.array(raw_data_x)
y_train = np.array(raw_data_y)
plt.scatter(x_train[y_train==0,0],x_train[y_train==0,1],color='r')
plt.scatter(x_train[y_train==1,0],x_train[y_train==1,1],color='b')
plt.show()

运行结果如下
在这里插入图片描述

x = np.array([8.0936, 3.3657])
plt.scatter(x_train[y_train==0,0],x_train[y_train==0,1],color='r')
plt.scatter(x_train[y_train==1,0],x_train[y_train==1,1],color='b')
plt.scatter(x[0],x[1],color='g')
plt.show()

运行结果如下

在这里插入图片描述

from math import sqrt
distance = []  # 保存和其他所有点的距离    
distance = [sqrt(np.sum((x_train-x)**2)) for x_train in x_train]
k = 3
nearest = np.argsort(distance)
nearest[:k]

运行结果如下
在这里插入图片描述

nearest = [i for i in nearest[:k]]

运行结果如下
在这里插入图片描述

top_k = [i for i in y_train[nearest]]

运行结果如下

在这里插入图片描述

from collections import counter
votes = counter(top_k)

运行结果如下
在这里插入图片描述

y_predict = votes.most_common(1)[0][0]

运行结果如下
在这里插入图片描述

🍀总结

请添加图片描述

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com