机器学习02:k-近邻算法综合使用

  • 发布时间:2017年1月5日 17:43
  • 作者:杨仕航

上次写了一篇k-近邻算法的简单分类使用方法,包括什么是k-近邻算法以及基本的实现代码。

这次看《机器学习实战》2.2的实例,讲解修正数据和如何计算错误率以及把数据图形化。


该实例大致意思是从约会网站获取数据,通过下面3个维度的数据:

1)每年获得的飞行常客里程数;

2)玩视频游戏所耗时间百分比;

3)每周消费的冰淇淋公升数。

把约会对象划分成3类:

1)不喜欢的人;

2)魅力一般的人;

3)极具魅力的人。

所需的数据我已经上传到百度网盘,大家下载即可:datingTestSet2.txt

该文件数据如下图:

20170105/20170105152833168.png

前3列对应3个维度的数据,第4列是划分的对应类别,一共1000条数据。当然,这1000条数据,只有部分数据会来拿作为训练集,部分用于测试错误率。


第1步、从文件中读取数据

在上次k-近邻算法的简单分类文章中所创建的knn.py文件同个目录下,创建analysis.py文件并加入如下代码:

#coding:utf-8
import numpy as np
from knn import classify0 #k-近邻算法

knn.py文件是包含k-近邻算法的代码,可以再次拿来使用。

《机器学习实战》书中从文件读取的代码过于复杂。我整理如下:

def get_data_from_file(file_path):
    ''' 从文件获取数据
        file_path:文件地址
    '''
    dataset = []
    labels = []
    
    with open(file_path, 'r') as f:
        line = f.readline()
        #逐行读取文件,该方法可避免文件过大,用readlines读取导致内存不够
        while line != '':
            line = line.strip()
            data = line.split('\t') #用tab拆分
            
            #获取数据
            dataset.append(map(float, data[:3]))
            labels.append(int(data[-1]))
            
            line = f.readline() #读取下一行
            
    #把数据转成numpy的数组并返回
    return np.array(dataset), labels

我们把上面网盘中的文件datingTestSet2.txt下载放到同个目录下。用如下方法获取数据:

if __name__ == '__main__':
    file_path = 'datingTestSet2.txt'
    dataset, labels = get_data_from_file(file_path)


第2步、修正数据

《机器学习实战》书中,先讲画图,再讲修正数据。实际上,倒过来讲更好。

你仔细观察,可以发现3个维度的数据差异比较大。

20170105/20170105152833168.png

其中第1列,也就是飞行常客里程数,该数据要比第2,3列的数据大很多。若直接用这些原始数据计算,会因为另外两个参数数据过小,导致计算出来的欧式距离包含这两个参数的成分变得很小。

那么我们需要修正数据,消除大数值的影响。如下代码修正:

def auto_norm(dataset):
    '''修正第1个参数的数值,消除大数值的影响(让数值范围变成0~1)'''
    #求第一维度的最大最小值
    min_vals = dataset.min(axis = 0)
    max_vals = dataset.max(axis = 0)
    rng_vals = max_vals - min_vals #值域

    #每个数值减去其所在列的最小值,再除以范围
    m = dataset.shape[0]
    norm = dataset - np.tile(min_vals, (m, 1))
    norm = norm/np.tile(rng_vals, (m,1))
    
    #返回修正后的数值,值域,最小值
    return norm, rng_vals, min_vals

numpy数组的特性和tile方法上一篇文章已经讲解,这里就不再赘述。

在刚刚的main代码修改如下:

if __name__ == '__main__':
    #获取数据
    file_path = 'datingTestSet2.txt'
    dataset, labels = get_data_from_file(file_path)
    
    #修正数据
    dataset_norm, rng_vals, min_vals = auto_norm(dataset)


第3步、数据图形化

我们得到这个dataset_norm之后,不直观也不容易分析。需要把数据图形化,方便观察。

绘制图表,需要安装matplotlib库。安装方法自行网上搜索。使用方法网上也大把,可以看看这个:

matplotlib-绘制精美的图表

在《机器学习实战》书中绘制的是2D平面图表。而我们的数据是3维的,需要绘制一个3D图表。

先加入如下引用:

import matplotlib.pyplot as plt # 绘制图表
from mpl_toolkits.mplot3d import Axes3D #3D支持

再添加绘制3D图表方法:

def chart(dataset, labels):
    #二维散点图
    """fig = plt.figure()
    ax = fig.add_subplot(111)

    ls = 15.*np.array(labels)
    ax.scatter(norm[:,0], norm[:,1], ls, ls)
    plt.show()"""

    #三维散点图
    ls = 15.*np.array(labels)
    #创建一个三维的绘图工程
    ax = plt.subplot(111, projection='3d') 
    ax.scatter(dataset[:,0],dataset[:,1],dataset[:,2], ls,ls,ls)

    #设置坐标轴
    ax.set_xlabel('sports')
    ax.set_ylabel('play')
    ax.set_zlabel('eat')

    plt.show()

绘制二维图表的代码也放在里面。再修改main部分的代码:

if __name__ == '__main__':
    #获取数据
    file_path = 'datingTestSet2.txt'
    dataset, labels = get_data_from_file(file_path)
    
    #修正数据
    dataset_norm, rng_vals, min_vals = auto_norm(dataset)
    
    #绘图
    chart(dataset_norm, labels)

用python运行该文件,可看到下图:

20170105/20170105172504670.png

可鼠标拖动和保存、调整等。不同颜色的点,代表不同类型的人。这样可以直观看出数据分布。


第4步、计算错误率

为了更好使用该模型,需要给出一个合理的k值。

随机训练集中的10%数据,用k-近邻算法得到对应的分类。再和实际的分类对比,看看是否正确。统计计算对应的错误率。

#测试数据(主要看k值)
def test(dataset, labels, k):
    #测试方法,从dataset随机取出10%的数据。鉴于数据本身就具有随机性,直接去前10%即可
    ratio = 0.1
    m = dataset.shape[0]

    test_num = int(m*ratio)
    test_rng = dataset[test_num:m]
    test_label = labels[test_num:m]

    err_count = 0
    for i in range(test_num):
        #通过k-近邻算法得到标签
        result_label = classify0(dataset[i], test_rng, test_label, k)[0] 

        #对比标签,错误累计
        if result_label != labels[i]:
            err_count += 1

    print("all:%s, test:%s. k=%s, err:%s, error ratio: %.2f" % (m, test_num, k, err_count, err_count/float(test_num)))

我们可以给定一些k值,计算。修改main部分的代码:

if __name__ == '__main__':
    #获取数据
    file_path = 'datingTestSet2.txt'
    dataset, labels = get_data_from_file(file_path)
    
    #修正数据
    dataset_norm, rng_vals, min_vals = auto_norm(dataset)
    
    #绘图
    #chart(dataset_norm, labels)
    
    #计算错误率,获取最佳的k值
    for k in range(3,11):
        test(norm, labels, k)

运行结果如下:

20170105/20170105173426394.png

从图中可以看出,当k=4时,错误率最小。那我们可以确定k为4。


第5步、使用算法

确定k之后,就可以拿这个算法来使用。如下代码,我给一些数据,判断该约会对象应该归入哪一类:

if __name__ == '__main__':
    #获取数据
    file_path = 'datingTestSet2.txt'
    dataset, labels = get_data_from_file(file_path)
    
    #修正数据
    dataset_norm, rng_vals, min_vals = auto_norm(dataset)
    
    #绘图
    #chart(dataset_norm, labels)
    
    #计算错误率,获取最佳的k值
    #for k in range(3,11):
    #    test(norm, labels, k)
        
    #测试发现k=4时,错误率最低
    k = 4
    
    #使用算法
    ff_miles = 10000
    play_games = 10    
    ice_cream = 0.5

    item = np.array([ff_miles, play_games, ice_cream])
    norm_item = (item - min_vals)/rng_vals #消除特征
    result = classify0(norm_item, norm, labels, k)[0]

    classes = [u'不喜欢的人', u'魅力一般的人', u'极具魅力的人']
    print(u"飞行常客公里数:%s\n玩视频游戏时间占比:%s%%\n每周消费冰淇淋公斤数:%s\n为%s" % (ff_miles, play_games, ice_cream, classes[result-1]))

可以得到如下结果:

20170105/20170105174313686.png

点击查看相关目录

上一篇:我的网站搭建(第39天) 博文随机推荐

下一篇:机器学习01:k-近邻算法简单分类

相关专题: 机器学习实战   

评论列表

智慧如你,不想发表一下意见吗?

新的评论

清空

猜你喜欢

  • 猜测中,请稍等...