
快速学会一个机器学习算法:高斯混合模型
在数据科学和机器学习领域,聚类分析是一种重要的无监督学习方法,用于发现数据中的潜在模式和结构。高斯混合模型(Gaussian Mixture Model,简称GMM)作为一种强大的概率模型,在聚类分析中具有广泛的应用。本文将详细介绍GMM聚类的算法原理、数学基础,并通过一个案例分析展示其实际应用。
一、GMM算法简介
高斯混合模型(GMM)是一种基于概率的聚类方法,假设数据集由多个高斯分布(也称为“成分”或“簇”)混合生成。与K-Means等传统聚类算法不同,GMM不仅考虑簇的中心,还考虑簇的形状和大小,通过估计每个数据点属于各个簇的概率,实现更为灵活和准确的聚类效果。
GMM在以下场景中表现出色:
- 复杂数据分布:适用于簇形状不规则、大小不一的数据集。
- 软聚类:允许数据点属于多个簇,适用于模糊边界的聚类任务。
- 概率解释:提供每个数据点的聚类概率,有助于后续的统计分析和决策。
二、GMM算法原理
2.1 概率模型
2.2 期望最大化(EM)算法
GMM的参数估计通常采用期望最大化(Expectation-Maximization,EM)算法。EM算法是一种迭代优化方法,适用于含有隐含变量或不完全数据的概率模型。
EM算法包含两个主要步骤,E步(期望步)和M步(最大化步),反复迭代直到收敛。
- E步(Expectation Step)
在当前参数估计值 下,计算每个数据点属于第 个高斯成分的后验概率(即责任度):
- M步(Maximization Step)
基于E步计算的责任度,重新估计模型参数:
- 收敛条件
EM算法在每次迭代中都会增加似然函数的值,直到似然函数的增幅低于预设的阈值或达到最大迭代次数时停止。
三、案例分析:GMM聚类实战
本文将通过一个简单的案例,使用Python实现GMM聚类,展示其在模拟数据上的应用效果。
3.1 数据生成
首先,我们生成一个包含三个不同簇的二维数据集,每个簇的数据点服从不同的高斯分布。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
import matplotlib.patches as patches
import matplotlib.colors as mcolors
# 设置随机种子
np.random.seed(42)
# 生成三个高斯分布的样本
n_samples = 500
# 第一个簇
mean1 = [2, 0]
cov1 = [[1, 0.2], [0.2, 1]]
X1 = np.random.multivariate_normal(mean1, cov1, n_samples)
# 第二个簇
mean2 = [3, 3]
cov2 = [[1, -0.3], [-0.3, 1]]
X2 = np.random.multivariate_normal(mean2, cov2, n_samples)
# 第三个簇
mean3 = [0, 3]
cov3 = [[1, 0], [0, 1]]
X3 = np.random.multivariate_normal(mean3, cov3, n_samples)
# 合并数据
X = np.vstack((X1, X2, X3))
# 绘制原始数据点图
plt.figure(figsize=(10, 8))
plt.scatter(X1[:, 0], X1[:, 1], s=30, color='red', label='簇 1', alpha=0.6)
plt.scatter(X2[:, 0], X2[:, 1], s=30, color='green', label='簇 2', alpha=0.6)
plt.scatter(X3[:, 0], X3[:, 1], s=30, color='blue', label='簇 3', alpha=0.6)
plt.title('原始数据点分布', fontsize=16)
plt.xlabel('特征1', fontsize=14)
plt.ylabel('特征2', fontsize=14)
plt.legend(title='原始簇类别')
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()
原始数据散点图:
3.2 GMM聚类实现
使用Scikit-learn库中的GaussianMixture
类实现GMM聚类,并预测数据点的簇标签。
# 定义GMM模型,假设有3个簇
gmm = GaussianMixture(n_compnotallow=3, covariance_type='full', random_state=42)
# 拟合GMM模型
gmm.fit(X)
# 预测簇标签
labels = gmm.predict(X)
# 获取GMM的参数
weights = gmm.weights_
means = gmm.means_
covariances = gmm.covariances_
print("GMM混合权重:", weights)
print("GMM均值:\n", means)
print("GMM协方差矩阵:\n", covariances)
输出结果:
GMM混合权重: [0.34443739 0.3287613 0.32680131]
GMM均值:
[[ 2.95728907 3.11741938]
[-0.04155174 2.96571577]
[ 1.9849524 -0.00788892]]
GMM协方差矩阵:
[[[ 0.99816731 -0.25715754]
[-0.25715754 1.02762528]]
[[ 0.91888485 -0.01749 ]
[-0.01749 0.96829226]]
[[ 0.89100266 0.17762317]
[ 0.17762317 0.95128116]]]
3.3 结果可视化
绘制聚类结果和高斯分布的等高线,直观展示GMM的聚类效果。
# 定义颜色
colors = list(mcolors.TABLEAU_COLORS.values())
plt.figure(figsize=(12, 8))
# 绘制数据点
for i in range(gmm.n_components):
plt.scatter(X[labels == i, 0], X[labels == i, 1],
s=30, color=colors[i], label=f'簇 {i+1}', alpha=0.5)
# 绘制高斯分布的等高线
ax = plt.gca()
for i in range(gmm.n_components):
mean = means[i]
cov = covariances[i]
eigenvalues, eigenvectors = np.linalg.eigh(cov)
order = eigenvalues.argsort()[::-1]
eigenvalues, eigenvectors = eigenvalues[order], eigenvectors[:, order]
angle = np.degrees(np.arctan2(*eigenvectors[:,0][::-1]))
width, height = 2 * np.sqrt(eigenvalues)
ellipse = patches.Ellipse(mean, width, height, angle=angle,
edgecolor=colors[i], facecolor='none',
linewidth=3, linestyle='--')
ax.add_patch(ellipse)
plt.title('GMM聚类结果', fnotallow=16)
plt.xlabel('特征1', fnotallow=14)
plt.ylabel('特征2', fnotallow=14)
plt.legend(title='簇类别')
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()
GMM聚类结果图:
结果分析:
通过上述代码,我们生成了一个包含三个簇的二维数据集,并使用GMM进行聚类。结果显示,GMM能够准确地识别出数据中的三个簇,并通过等高线展示了各个高斯成分的分布情况。相比于K-Means,GMM在处理具有不同形状和大小的簇时表现出更高的灵活性和准确性。
四、总结
高斯混合模型(GMM)作为一种基于概率的聚类方法,能够有效地处理复杂数据分布和模糊边界的聚类任务。通过期望最大化(EM)算法,GMM能够迭代地估计模型参数,实现对数据的准确聚类。本文通过理论介绍和实战案例,展示了GMM在机器学习中的应用及其优势。尽管GMM在处理高维数据和选择适当的簇数时可能面临挑战,但其灵活性和概率解释能力使其成为聚类分析中不可或缺的工具。
在实际应用中,结合领域知识选择合适的模型参数和评估指标,可以进一步提升GMM的聚类效果。同时,结合其他机器学习方法,如降维技术和特征工程,可以增强GMM在复杂数据场景下的表现。
本文转载自宝宝数模AI,作者:BBSM
