统计学习知识---感知机学习算法的拓展(非线性可分数据问题)

el/2024/7/17 20:52:56

感知机算法中的优化方法的几何解释

本部分参考台湾大学林轩田教授机器学习基石课程—PLA部分

PLA算法只有在出现错误分类的时候,才去调整w和b的值,使得错误分类减少。假设我们遇到的数据点(xn,yn)是我们第t次分类错误,那么就有因为是二分类问题,所以只会出现以下两种错误分类的情况: 

  • 第一种:当yn=+1 时,则我们的错误结果为wTxn=wt∗xn=||w||∗||xn||∗cosΘ<0,即cosΘ<0 则Θ太大,为了能过纠正错误,决定减小Θ,就让w(t+1)=wt+x,紫色为改正之后的w(t+1)
  • 同理,对于第二种情况,当yn=-1的时候,则我们的错误结果为wTxn=wt∗xn=||w||∗||xn||∗cosΘ>0,即cosΘ>0 则Θ太小,为了能过纠正错误,决定增大Θ,就让w(t+1)=wt-x,紫色为改正之后的w(t+1)。

综上所述,当分割线遇到点(xn,yn)时,如果分割正确,那么wt就不变,如果分割错误,那么就令 
注意w是分割线wTx=0的法线,也就是说分割线的方向是与w的方向垂直的。。。

思考

PLA 的优点是算法思路比较简单,易于实现。然而这个算法最大的缺点是假设了数据是线性可分的,然而事先并无法知道数据是否线性可分的。假如将PLA 用在线性不可分的数据中时,会导致PLA永远都无法对样本进行完全正确分开从而陷入到死循环中。 
如下图所示,当实例点并不是线性可分的时候,根本找不到一条直线或者一个超平面来完全划分开两类数据,只能利用曲线或者超曲面来划分。

为了避免上面线性不可分的情况,将PLA的条件放宽一点,不再要求所有的样本都能正确的分开,而是要求犯错误的样本尽可能的少,即将问题变为了: 

 

 

也是就是说去寻找一条犯错误最少的线或者超平面。

其实,从实际意义上,是不能的。这是一个著名的NP hard 问题!!!因为线有无穷多个啊!!!无法求得其最优解,因此只能求尽可能接近其最优解的近似解。林教授的课程讲义中提出的一种求解其近似解的算法 Pocket Algorithm(口袋算法,一种贪心算法)。

Pocket Algorithm(口袋算法)

口袋算法基于贪心的思想,他总是让遇到的最好的线(或者超平面)拿在自己手里。简单介绍一下:首先,我们有一条分割线wt,将数据实例不断带入,发现数据点(xn,yn)再上面出现错误分类,那么我们就纠正分割线得到w(t+1),然后我们让wt和w(t+1)遍历所有的数据,看一下哪条线犯的错误少,那么就让w(t+1)代替wt,否则wt不变。

那么如何让算法停下来呢?

由于口袋算法得到的线越来越好(PLA就不一定了,PLA是最终结果最好,其他情况就不一定是什么样子,不一定是越来越好),所以我们就自己规定迭代的次数。

思考?

答案是:PLA更好。先不说PLA可以找到最好的那条线。单从效率上来说,PLA也更好些。最主要的原因是,pocket algorithm 每次比较的时候,都要遍历所有的数据点,且两个算法都要遍历一遍,才会决定那个算法好,而这还是比较一次,如果我们让他迭代500次的,那就麻烦了!!!但是,所有前提是,数据是线性可分的。如果线性不可分,只能用pocket algorithm,因为PLA根本不会停下来(而且PLA的wt也不是每更改一次效果就会比之前的好)!!

与PLA的比较:

1、Pocket Algorithm事先设定迭代次数,而不是等着算法自己收敛到最优。
2、随机遍历数据集,而不是循环遍历。
3、遇到错误点校正时,只有当新得到的w对于所有的数据优于旧w的时候(也就是整体错误更少的时候)才更新,而PLA算法中,只要出现错误分类就更新。由此也可知,pocket Algorithm算法是保证每次得到的线或这面是越来越好的,而PLA算法不一定。而且,由于Pocket要比较错误率,需要计算所有的数据点,因此效率要地域PLA。所以在线性可分的数据集上,使用PLA算法,而不选择使用Pocket算法。但是,只要迭代次数足够多,Pocket和PLA的效果是一样的,都能够把数据完全正确分开,只是速度慢。

代码实现

下面,我们用Python来实现pocket Algorithm算法。

# -*- encoding:utf-8 -*-
'''
@author=Ada
Python实现的pocket algoritm
'''from numpy import vectorize
import numpy as np
import matplotlib.pyplot as pltclass Pocket:def __init__(self,random_state=None):self.numberOfIter=7000#最大迭代次数self.minWeights=Noneself.intercept=None#bias 截距self.errorCountArr=np.zeros(7000)#统计每次迭代的出错数self.errorCount=[]#统计每次优化或者每个更新权值时候的出错次数self.minErrors=7000#出错数量的初始值,一开始一般设置一个比较大的值self.random_state=random_statedef predict(self,z):if z<0:return -1else:return 1def checkPredictedValue(self,z,actualZ):if(z==actualZ):return Trueelse:return Falsedef fit(self,X,Y):row,col=X.shapeweights=np.array([1.0,1.0,1.0,1.0])vpredict = vectorize(self.predict)vcheckPredictedValue=vectorize(self.checkPredictedValue)learning_rate=1.0bias_val=np.ones((row,1))data=np.concatenate((bias_val,X),axis=1)np.random.seed(self.random_state)count=0iter=0while self.numberOfIter>0:weightedSum=np.dot(data,weights)predictedValues=vpredict(weightedSum)predictions=vcheckPredictedValue(predictedValues,Y)misclassifiedPoints=np.where(predictions==False)#分类错误的数据misclassifiedPoints=misclassifiedPoints[0]numOfErrors=len(misclassifiedPoints)#分类错误的数据量self.errorCountArr[iter]=numOfErrorsif numOfErrors<self.minErrors:self.minErrors=numOfErrorsself.errorCount.append(self.minErrors)count+=1iter+=1misclassifiedIndex=np.random.choice(misclassifiedPoints)#这一步与PLA不同,# 在此是随机从错误数据点里面选择一个点,进行更新权值weights+=(Y[misclassifiedIndex]*learning_rate*data[misclassifiedIndex])self.numberOfIter-=1self.weights=weights[1:]self.intercept=weights[0]def main():data=np.loadtxt('classification.txt',dtype='float',delimiter=',',usecols=(0,1,2,4))X=data[:,0:3]Y=data[:,3]p=Pocket(random_state=2308863)p.fit(X,Y)print "Weights:"print p.weightsprint "Intercept Value:"print p.interceptprint "Minimum Number Of Errors:"print p.minErrorsax1=plt.subplot(121)ax1.plot(np.arange(0,7000),p.errorCountArr)ax2=plt.subplot(122)ax2.plot(np.arange(0,len(p.errorCount)),p.errorCount)plt.show()if __name__ == "__main__":main()

 

实验结果:Weights:[-0.43701921 -0.10683611 0.34784736] 
Intercept Value: -1.0 
Minimum Number Of Errors:935

这里写图片描述

上图的第一幅表示的是每次迭代时的出错数,第二幅图表示的每次更新权重时的出错数。通过上图可以观察到,在7000次迭代过程中,每次迭代出错数是不固定的,而每次更新时出错数是递减的。而且,7000词迭代过程中只有十次左右的更新操作。

参考资料:

1、机器学习基石—PLA 
2、台湾大学林轩田教授机器学习基石课程理解及python实现—-PLA 
3、分类系列之感知器学习算法PLA 和 口袋算法Pocket Algorithm 
4、听课笔记(第二讲): Perceptron-感知机 (台湾国立大学机器学习基石)

 

 


http://www.ngui.cc/el/5557020.html

相关文章

统计学习方法---感知机算法拓展(神经网络)

神经元 神经元是神经网络的基本单元&#xff0c;接受多个神经元传递过来的输入信号&#xff0c;然后通过激活函数计算输出信号。 从图里可以看到每个输入信号都有一个权重w&#xff0c;这个权重是动态改变的。我们平时所说的训练神经网络主要是训练&#xff08;修正&#xff09…

统计学习方法---KNN(K近邻)

前言 k邻近算法&#xff08;k-nearest&#xff09;是一种判别模型&#xff0c;解决分类问题和回归问题&#xff0c;以分类问题为主&#xff0c;在此我们也主要介绍分类问题中的k近邻算法。 k近邻算法的输入为实例的特征向量&#xff0c;对应予特征空间中的点&#xff1b;输出…

统计学习方法---k近邻法

本文对应《统计学习方法》第3章&#xff0c;用数十行代码实现KNN的kd树构建与搜索算法&#xff0c;并用matplotlib可视化了动画观赏。 k近邻算法 给定一个训练数据集&#xff0c;对新的输入实例&#xff0c;在训练数据集中找到跟它最近的k个实例&#xff0c;根据这k个实例的类…

统计学习方法---

4、朴素贝叶斯法 http://www.hankcs.com/ml/naive-bayesian-method.html http://blog.csdn.net/u010626937/article/details/73810753 5、决策树 http://www.hankcs.com/ml/decision-tree.html 6、逻辑斯谛回归与最大熵模型 http://www.hankcs.com/ml/the-logistic-regressi…

pandas数据合并与重塑---concat方法

谈到pandas数据的行更新、表合并等操作&#xff0c;一般用到的方法有concat、join、merge。但这三种方法对于很多新手来说&#xff0c;都不太好分清使用的场合与用途。今天就pandas官网中关于数据合并和重述的章节做个使用方法的总结。 1、concat pd.concat(objs, axis0, joino…

pandas数据合并与重塑---join、merge方法

在上一篇文章中&#xff0c;我整理了pandas在数据合并和重塑中常用到的concat方法的使用说明。在这里&#xff0c;将接着介绍pandas中也常常用到的join 和merge方法 merge pandas的merge方法提供了一种类似于SQL的内存链接操作&#xff0c;官网文档提到它的性能会比其他开源语言…

XGBoost简介---相关概念、原理

XGBoost是2014年2月诞生的专注于梯度提升算法的机器学习函数库&#xff0c;此函数库因其优良的学习效果以及高效的训练速度而获得广泛的关注。仅在2015年&#xff0c;在Kaggle竞赛中获胜的29个算法中&#xff0c;有17个使用了XGBoost库&#xff0c;而作为对比&#xff0c;近年大…

Sklearn工具包---train_test_split随机划分训练集和测试集

一般形式&#xff1a; train_test_split是交叉验证中常用的函数&#xff0c;功能是从样本中随机的按比例选取train data和test data&#xff0c;形式为&#xff1a; X_train,X_test, y_train, y_test cross_validation.train_test_split(train_data,train_target,test_size0.4…

sklearn工具包---分类效果评估(acc、recall、F1、ROC、回归、距离)

一、acc、recall、F1、混淆矩阵、分类综合报告 1、准确率 第一种方式&#xff1a;accuracy_score # 准确率 import numpy as np from sklearn.metrics import accuracy_score y_pred [0, 2, 1, 3,9,9,8,5,8] y_true [0, 1, 2, 3,2,6,3,5,9] #共9个数据&#xff0c;3个相同…

python中可变和不可变对象(复值,拷贝,函数值传递)

python中有可变对象和不可变对象&#xff0c;可变对象&#xff1a; list, dict.不可变对象有: int, string, float, tuple.python不可变对象int, string, float, tuple先来看一个例子 def int_test(): i 77 j 77 print(id(77)) #140396579590…