-
机器学习回顾篇(9):K-means聚类算法
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import copy
In [2]:
a = np.random.normal(20,5,300)
b = np.random.normal(15,5,300)
cluster1 = np.array([[x, y] for x, y in zip(a,b)])
In [3]:
a = np.random.normal(20,5,300)
b = np.random.normal(45,5,300)
cluster2 = np.array([[x, y] for x, y in zip(a,b)])
In [4]:
a = np.random.normal(55,5,300)
b = np.random.normal(30,5,300)
cluster3 = np.array([[x, y] for x, y in zip(a,b)])
In [5]:
dataset = np.append(np.append(cluster1,cluster2, axis=0),cluster3, axis=0)
In [6]:
for i in dataset:
plt.scatter(i[0], i[1],c='black',s=6)
plt.show()
In [7]:
def calc_dist(simple1, simple2):
"""计算两数据对象间的欧氏距离"""
return np.linalg.norm(simple1-simple2)
In [8]:
def init_centers(k, dataset):
"""随机获取k个初始化聚类中心"""
shuffle_array = np.arange(dataset.shape[0])
np.random.shuffle(shuffle_array)
center_index = shuffle_array[:k] # 获取k个随机索引
center_dict = {}
for i in range(k):
center = dataset[center_index[i]] # 聚类中心
center_dict[i] = center
return center_dict
In [9]:
def k_means(k,dataset):
"""实现K-means算法"""
ds = copy.deepcopy(dataset) # 复制一份数据
epoch = 0 # 迭代次数
center_dict = init_centers(k, ds) # 第一次迭代时,随机初始化k个聚类中心
ds = np.insert(ds, 2, values=-1, axis=1) # 插入一列作为类标签,默认为0
total_last = np.inf # 上一次迭代距离总和
while epoch<=20: # 迭代次数少于20次时继续迭代,也可以直接设为True,当目标函数收敛时自动结束迭代
cluster_dist = {i:0 for i in range(k)} # 记录每一个类簇距离总和
for simple in ds:
min_dist = np.inf # simple 到最近的聚类中心的距离
min_label = -1 # 最近的聚类中心类标签
for label in center_dict.keys():
dist = calc_dist(simple[:2], center_dict[label])
if dist < min_dist:
min_dist = dist
min_label = label
simple[2] = min_label # 将当前样本点划分到最近的聚类中心所在聚类中
cluster_dist[int(min_label)] = cluster_dist[int(min_label)] + min_dist # 更新类簇内部距离总和
loss_now = sum(cluster_dist.values()) # 所有类簇内部距离总和
print("epoch:{}, tatal distance: {}".format(epoch,loss_now))
for i in ds:
if i[2] == 0:
plt.scatter(i[0], i[1],c='red',s=6)
elif i[2] == 1:
plt.scatter(i[0], i[1],c='green',s=6)
else:
plt.scatter(i[0], i[1],c='blue',s=6)
for center in center_dict.values():
plt.scatter(center[0], center[1],c='black')
plt.show()
if total_last == loss_now: # 如果两次迭代距离总和都不变,证明已收敛
break
total_last = loss_now
for label in center_dict.keys(): # 更新聚类中心
simple_list = ds[ds[:,2]==label] # 挑选出类标签为k的所有样本
x = np.mean(simple_list[:, 0])
y = np.mean(simple_list[:, 1])
center_dict[label] = [x, y]
epoch += 1
return ds, center_dict
In [12]:
ds,cluster_label = k_means(3,dataset)
最新更新
博克-定制图例
博克-注释和图例
Bokeh–添加小部件
向博克图添加标签
将交互式滑块添加到博克图
在 Bokeh 中添加按钮
谷歌、微软、Meta?谁才是 Python 最大的金
Objective-C语法之代码块(block)的使用
URL Encode
go语言写http踩得坑
2个场景实例讲解GaussDB(DWS)基表统计信息估
常用的 SQL Server 关键字及其含义
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
SQL Server -- 解决存储过程传入参数作为s
武装你的WEBAPI-OData入门
武装你的WEBAPI-OData便捷查询
武装你的WEBAPI-OData分页查询
武装你的WEBAPI-OData资源更新Delta
5. 武装你的WEBAPI-OData使用Endpoint 05-09
武装你的WEBAPI-OData之API版本管理
武装你的WEBAPI-OData常见问题
武装你的WEBAPI-OData聚合查询
OData WebAPI实践-OData与EDM
OData WebAPI实践-Non-EDM模式