数据挖掘实战之K-Means算法python实现

数据挖掘实战之K-Means算法python实现

概念:

在“无监督”学习当中,训练样本的标记信息是未知的,目标是通过对无标记训练样本的学习来揭示数据的内在性质及规律,为进一步的数据分析提供基础。此类学习任务中研究最多、应用最广的是“聚类”,而我们此次所实现的K-Means算法则是其中代表性之一。

tips K-Means算法可以看作高斯混合聚类在混合成分方差相等、且每个样本仅指派给一个混合成分时的特例,聚类簇数k通常需要用户提供,有一些启发式用于自动确定k,但常用的仍是基于不同k值多次运行后选取最佳结果。

import numpy as np
import matplotlib.pyplot as plt

# 定义方法用来计算两点距离
def distance(e1, e2):
    return np.sqrt((e1[0]-e2[0])**2+(e1[1]-e2[1])**2)

# 定义方法用来标记类别中心
def means(arr):
    return np.array([np.mean([e[0] for e in arr]), np.mean([e[1] for e in arr])])

# arr中距离a最远的元素,用于初始化聚类中心
def farthest(k_arr, arr):
    f = [0, 0]
    max_d = 0
    for e in arr:
        d = 0
        for i in range(k_arr.__len__()):
            d = d + np.sqrt(distance(k_arr[i], e))
        if d > max_d:
            max_d = d
            f = e
    return f

# arr中距离a最近的元素,用于聚类
def closest(a, arr):
    c = arr[1]
    min_d = distance(a, arr[1])
    arr = arr[1:]
    for e in arr:
        d = distance(a, e)
        if d < min_d:
            min_d = d
            c = e
    return c


if __name__ == "__main__":
    # 定义一个变量用来存储读取进来的数据
    arr = []
    # 读入数据
    # 此处路径设置为自身数据所保存的路径
    with open("C:/iris.txt") as fp:
        for line in fp.readlines():
            # 这里可以读取每一行
            list = line.split(",")
            # 获取x
            first = float(list[0])
            # 获取y
            second = float(list[1])
            arr.append([first,second])
    # print(type(arr))
    # print(arr)

    ## 初始化聚类中心和聚类容器
    # 本次实验数据集当中类别有三个,则定义变量为3
    m = 3
    r = np.random.randint(arr.__len__() - 1)
    k_arr = np.array([arr[r]])
    cla_arr = [[]]
    for i in range(m-1):
        k = farthest(k_arr, arr)
        k_arr = np.concatenate([k_arr, np.array([k])])
        cla_arr.append([])

    ## 迭代聚类
    n = 20
    cla_temp = cla_arr
    for i in range(n):    # 迭代n次
        for e in arr:    # 把集合里每一个元素聚到最近的类
            ki = 0        # 假定距离第一个中心最近
            min_d = distance(e, k_arr[ki])
            for j in range(1, k_arr.__len__()):
                if distance(e, k_arr[j]) < min_d:    # 找到更近的聚类中心
                    min_d = distance(e, k_arr[j])
                    ki = j
            cla_temp[ki].append(e)
        # 迭代更新聚类中心
        for k in range(k_arr.__len__()):
            if n - 1 == i:
                break
            k_arr[k] = means(cla_temp[k])
            cla_temp[k] = []

    ## 对所得结果进行可视化展示
    col = [HotPink, Aqua, Chartreuse]
    col2 = [blue,green,black]
    for i in range(m):
        # 数据点
        plt.scatter([e[0] for e in cla_temp[i]], [e[1] for e in cla_temp[i]], color=col[i])
        # 各类别中心
        plt.scatter(k_arr[i][0], k_arr[i][1], linewidth=5, color=col2[i])
    plt.show()

上图即为最终分类结果。

注意事项:

为避免运行时间过长,通常设置一个最大运行轮数或者最小调整幅度阈值,若达到最大轮数或调整幅度小于阈值,则停止运行。

下一章节我们将采用高斯混合聚类(Mixture-of-Gaussian)方法,对iris数据集进行分类。

经验分享 程序员 微信小程序 职场和发展