统计学习方法---k近邻法

el/2024/7/17 20:34:09

本文对应《统计学习方法》第3章,用数十行代码实现KNN的kd树构建与搜索算法,并用matplotlib可视化了动画观赏。

k近邻算法

给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。

k近邻模型

模型有3个要素——距离度量方法、k值的选择和分类决策规则。

模型

当3要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。

距离度量

对n维实数向量空间Rn,经常用Lp距离或曼哈顿Minkowski距离。

Lp距离定义如下:

当p=2时,称为欧氏距离:

当p=1时,称为曼哈顿距离:

当p=∞,它是各个坐标距离的最大值,即:

用图表示如下:

k值的选择

k较小,容易被噪声影响,发生过拟合。

k较大,较远的训练实例也会对预测起作用,容易发生错误。

分类决策规则

使用0-1损失函数衡量,那么误分类率是:

Nk是近邻集合,要使左边最小,右边的必须最大,所以多数表决=经验最小化。

k近邻法的实现:kd树

算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。

构造kd树

对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:

找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S',++i,node=current),直到不可再分。

例子

Python代码

短短几行即可搞定:

  1. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  2.  
  3. class node:
  4.     def __init__(self, point):
  5.         self.left = None
  6.         self.right = None
  7.         self.point = point
  8.         pass
  9.     
  10. def median(lst):
  11.     m = len(lst) / 2
  12.     return lst[m], m
  13.  
  14. def build_kdtree(data, d):
  15.     data = sorted(data, key=lambda x: x[d])
  16.     p, m = median(data)
  17.     tree = node(p)
  18.  
  19.     del data[m]
  20.     print data, p
  21.  
  22.     if m > 0: tree.left = build_kdtree(data[:m], not d)
  23.     if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
  24.     return tree
  25.  
  26. kd_tree = build_kdtree(T, 0)
  27. print kd_tree


可视化

可视化的话则要费点功夫保存中间结果,并恰当地展示出来

  1. # -*- coding:utf-8 -*-
  2. # Filename: kdtree.py
  3. # Authorhankcs
  4. # Date: 2015/2/4 15:01
  5. import copy
  6. import itertools
  7. from matplotlib import pyplot as plt
  8. from matplotlib.patches import Rectangle
  9. from matplotlib import animation
  10.  
  11. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  12.  
  13.  
  14. def draw_point(data):
  15.     X, Y = [], []
  16.     for p in data:
  17.         X.append(p[0])
  18.         Y.append(p[1])
  19.     plt.plot(X, Y, 'bo')
  20.  
  21.  
  22. def draw_line(xy_list):
  23.     for xy in xy_list:
  24.         x, y = xy
  25.         plt.plot(x, y, 'g', lw=2)
  26.  
  27.  
  28. def draw_square(square_list):
  29.     currentAxis = plt.gca()
  30.     colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
  31.     for square in square_list:
  32.         currentAxis.add_patch(
  33.             Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
  34.                       color=next(colors)))
  35.  
  36.  
  37. def median(lst):
  38.     m = len(lst) / 2
  39.     return lst[m], m
  40.  
  41.  
  42. history_quare = []
  43.  
  44. def build_kdtree(data, d, square):
  45.     history_quare.append(square)
  46.     data = sorted(data, key=lambda x: x[d])
  47.     p, m = median(data)
  48.  
  49.     del data[m]
  50.     print data, p
  51.  
  52.     if m >= 0:
  53.         sub_square = copy.deepcopy(square)
  54.         if d == 0:
  55.             sub_square[1][0] = p[0]
  56.         else:
  57.             sub_square[1][1] = p[1]
  58.         history_quare.append(sub_square)
  59.         if m > 0: build_kdtree(data[:m], not d, sub_square)
  60.     if len(data) > 1:
  61.         sub_square = copy.deepcopy(square)
  62.         if d == 0:
  63.             sub_square[0][0] = p[0]
  64.         else:
  65.             sub_square[0][1] = p[1]
  66.         build_kdtree(data[m:], not d, sub_square)
  67.  
  68.  
  69. build_kdtree(T, 0, [[0, 0], [10, 10]])
  70. print history_quare
  71.  
  72.  
  73. # draw an animation to show how it works, the data comes from history
  74. # first set up the figure, the axis, and the plot element we want to animate
  75. fig = plt.figure()
  76. ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
  77. line, = ax.plot([], [], 'g', lw=2)
  78. label = ax.text([], [], '')
  79.  
  80. # initialization function: plot the background of each frame
  81. def init():
  82.     plt.axis([0, 10, 0, 10])
  83.     plt.grid(True)
  84.     plt.xlabel('x_1')
  85.     plt.ylabel('x_2')
  86.     plt.title('build kd tree (www.hankcs.com)')
  87.     draw_point(T)
  88.  
  89.  
  90. currentAxis = plt.gca()
  91. colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
  92.  
  93. # animation function.  this is called sequentially
  94. def animate(i):
  95.     square = history_quare[i]
  96.     currentAxis.add_patch(
  97.         Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
  98.                   color=next(colors)))
  99.     return
  100.  
  101. # call the animator.  blit=true means only re-draw the parts that have changed.
  102. anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
  103.                                blit=False)
  104. plt.show()
  105. anim.save('kdtree_build.gif', fps=2, writer='imagemagick')

搜索kd树

上面的代码其实并没有搜索kd树,现在来实现搜索。

搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):

  1. # -*- coding:utf-8 -*-
  2. # Filename: search_kdtree.py
  3. # Authorhankcs
  4. # Date: 2015/2/4 15:01
  5.  
  6. T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
  7.  
  8.  
  9. class node:
  10.     def __init__(self, point):
  11.         self.left = None
  12.         self.right = None
  13.         self.point = point
  14.         self.parent = None
  15.         pass
  16.  
  17.     def set_left(self, left):
  18.         if left == None: pass
  19.         left.parent = self
  20.         self.left = left
  21.  
  22.     def set_right(self, right):
  23.         if right == None: pass
  24.         right.parent = self
  25.         self.right = right
  26.  
  27.  
  28. def median(lst):
  29.     m = len(lst) / 2
  30.     return lst[m], m
  31.  
  32.  
  33. def build_kdtree(data, d):
  34.     data = sorted(data, key=lambda x: x[d])
  35.     p, m = median(data)
  36.     tree = node(p)
  37.  
  38.     del data[m]
  39.  
  40.     if m > 0: tree.set_left(build_kdtree(data[:m], not d))
  41.     if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
  42.     return tree
  43.  
  44.  
  45. def distance(a, b):
  46.     print a, b
  47.     return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
  48.  
  49.  
  50. def search_kdtree(tree, d, target):
  51.     if target[d] < tree.point[d]:
  52.         if tree.left != None:
  53.             return search_kdtree(tree.left, not d, target)
  54.     else:
  55.         if tree.right != None:
  56.             return search_kdtree(tree.right, not d, target)
  57.  
  58.     def update_best(t, best):
  59.         if t == None: return
  60.         t = t.point
  61.         d = distance(t, target)
  62.         if d < best[1]:
  63.             best[1] = d
  64.             best[0] = t
  65.  
  66.     best = [tree.point, 100000.0]
  67.     while (tree.parent != None):
  68.         update_best(tree.parent.left, best)
  69.         update_best(tree.parent.right, best)
  70.         tree = tree.parent
  71.     return best[0]
  72.  
  73.  
  74. kd_tree = build_kdtree(T, 0)
  75. print search_kdtree(kd_tree, 0, [9, 4])

去掉注释和空白,大概数十行,Python真不愧是可运行的伪码。

输出:

  1. [8, 1] [9, 4]
  2. [5, 4] [9, 4]
  3. [9, 6] [9, 4]
  4. [9, 6]

可见对于点[9, 4],在n=6的数据集中,kdtree算法一共只进行了3次计算。




http://www.ngui.cc/el/5557017.html

相关文章

统计学习方法---

4、朴素贝叶斯法 http://www.hankcs.com/ml/naive-bayesian-method.html http://blog.csdn.net/u010626937/article/details/73810753 5、决策树 http://www.hankcs.com/ml/decision-tree.html 6、逻辑斯谛回归与最大熵模型 http://www.hankcs.com/ml/the-logistic-regressi…

pandas数据合并与重塑---concat方法

谈到pandas数据的行更新、表合并等操作&#xff0c;一般用到的方法有concat、join、merge。但这三种方法对于很多新手来说&#xff0c;都不太好分清使用的场合与用途。今天就pandas官网中关于数据合并和重述的章节做个使用方法的总结。 1、concat pd.concat(objs, axis0, joino…

pandas数据合并与重塑---join、merge方法

在上一篇文章中&#xff0c;我整理了pandas在数据合并和重塑中常用到的concat方法的使用说明。在这里&#xff0c;将接着介绍pandas中也常常用到的join 和merge方法 merge pandas的merge方法提供了一种类似于SQL的内存链接操作&#xff0c;官网文档提到它的性能会比其他开源语言…

XGBoost简介---相关概念、原理

XGBoost是2014年2月诞生的专注于梯度提升算法的机器学习函数库&#xff0c;此函数库因其优良的学习效果以及高效的训练速度而获得广泛的关注。仅在2015年&#xff0c;在Kaggle竞赛中获胜的29个算法中&#xff0c;有17个使用了XGBoost库&#xff0c;而作为对比&#xff0c;近年大…

Sklearn工具包---train_test_split随机划分训练集和测试集

一般形式&#xff1a; train_test_split是交叉验证中常用的函数&#xff0c;功能是从样本中随机的按比例选取train data和test data&#xff0c;形式为&#xff1a; X_train,X_test, y_train, y_test cross_validation.train_test_split(train_data,train_target,test_size0.4…

sklearn工具包---分类效果评估(acc、recall、F1、ROC、回归、距离)

一、acc、recall、F1、混淆矩阵、分类综合报告 1、准确率 第一种方式&#xff1a;accuracy_score # 准确率 import numpy as np from sklearn.metrics import accuracy_score y_pred [0, 2, 1, 3,9,9,8,5,8] y_true [0, 1, 2, 3,2,6,3,5,9] #共9个数据&#xff0c;3个相同…

python中可变和不可变对象(复值,拷贝,函数值传递)

python中有可变对象和不可变对象&#xff0c;可变对象&#xff1a; list, dict.不可变对象有: int, string, float, tuple.python不可变对象int, string, float, tuple先来看一个例子 def int_test(): i 77 j 77 print(id(77)) #140396579590…

推荐算法概述:基于内容的推荐算法、协同过滤推荐算法和基于知识的推荐算法

“无意中发现了一个巨牛的人工智能教程&#xff0c;忍不住分享一下给大家。教程不仅是零基础&#xff0c;通俗易懂&#xff0c;而且非常风趣幽默&#xff0c;像看小说一样&#xff01;觉得太牛了&#xff0c;所以分享给大家。点这里可以跳转到教程。” 所谓推荐算法就是利用用…

推荐算法--基于物品的协同过滤算法

“无意中发现了一个巨牛的人工智能教程&#xff0c;忍不住分享一下给大家。教程不仅是零基础&#xff0c;通俗易懂&#xff0c;而且非常风趣幽默&#xff0c;像看小说一样&#xff01;觉得太牛了&#xff0c;所以分享给大家。点这里可以跳转到教程。” ItemCF&#xff1a;ItemC…

Neo4j简介及Py2Neo的用法(python操作neo4j)

博客原文&#xff1a; http://cuiqingcai.com/4778.html