使用高斯混合模型拆分多模態(tài)分布
本文介紹如何使用高斯混合模型將一維多模態(tài)分布拆分為多個(gè)分布。
高斯混合模型(Gaussian Mixture Models,簡稱GMM)是一種在統(tǒng)計(jì)和機(jī)器學(xué)習(xí)領(lǐng)域中常用的概率模型,用于對(duì)復(fù)雜數(shù)據(jù)分布進(jìn)行建模和分析。GMM 是一種生成模型,它假設(shè)觀測(cè)數(shù)據(jù)是由多個(gè)高斯分布組合而成的,每個(gè)高斯分布稱為一個(gè)分量,這些分量通過權(quán)重來控制其在數(shù)據(jù)中的貢獻(xiàn)。
生成具有多模態(tài)分布的數(shù)據(jù)
當(dāng)一個(gè)數(shù)據(jù)集顯示出多個(gè)不同的峰值或模態(tài)時(shí),通常會(huì)出現(xiàn)顯示出多個(gè)不同的峰值或模態(tài),每個(gè)模態(tài)代表分布中一個(gè)突出的數(shù)據(jù)點(diǎn)簇或集中。這些模式可以看作是數(shù)據(jù)值更可能出現(xiàn)的高密度區(qū)域。
我們將使用numpy生成的一維數(shù)組。
import numpy as np
dist_1 = np.random.normal(10, 3, 1000)
dist_2 = np.random.normal(30, 5, 4000)
dist_3 = np.random.normal(45, 6, 500)
multimodal_dist = np.concatenate((dist_1, dist_2, dist_3), axis=0)
讓我們把一維的數(shù)據(jù)分布形象化。
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.hist(multimodal_dist, bins=50, alpha=0.5)
plt.show()
使用高斯混合模型拆分多模態(tài)分布
下面我們將通過使用高斯混合模型計(jì)算每個(gè)分布的均值和標(biāo)準(zhǔn)差,將多模態(tài)分布分離回三個(gè)原始分布。高斯混合模型是一種可用于數(shù)據(jù)聚類的概率無監(jiān)督模型。它使用期望最大化算法估計(jì)密度區(qū)域。
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_compnotallow=3)
gmm.fit(multimodal_dist.reshape(-1, 1))
means = gmm.means_
# Conver covariance into Standard Deviation
standard_deviations = gmm.covariances_**0.5
# Useful when plotting the distributions later
weights = gmm.weights_
print(f"Means: {means}, Standard Deviations: {standard_deviations}")
#Means: [29.4, 10.0, 38.9], Standard Deviations: [4.6, 3.1, 7.9]
我們已經(jīng)得到了均值和標(biāo)準(zhǔn)差,可以對(duì)原始分布進(jìn)行建模??梢钥吹诫m然平均值和標(biāo)準(zhǔn)差可能不完全正確,但它們提供了一個(gè)接近的估計(jì)。
把我們的估計(jì)和原始數(shù)據(jù)比較一下。
from scipy.stats import norm
fig, axes = plt.subplots(nrows=3, ncols=1, sharex='col', figsize=(6.4, 7))
for bins, dist in zip([14, 34, 26], [dist_1, dist_2, dist_3]):
axes[0].hist(dist, bins=bins, alpha=0.5)
axes[1].hist(multimodal_dist, bins=50, alpha=0.5)
x = np.linspace(min(multimodal_dist), max(multimodal_dist), 100)
for mean, covariance, weight in zip(means, standard_deviations, weights):
pdf = weight*norm.pdf(x, mean, std)
plt.plot(x.reshape(-1, 1), pdf.reshape(-1, 1), alpha=0.5)
plt.show()
總結(jié)
高斯混合模型是一個(gè)強(qiáng)大的工具,可以用來對(duì)復(fù)雜的數(shù)據(jù)分布進(jìn)行建模和分析,同時(shí)也是許多機(jī)器學(xué)習(xí)算法的基礎(chǔ)之一。它的應(yīng)用范圍涵蓋了多個(gè)領(lǐng)域,能夠解決各種數(shù)據(jù)建模和分析的問題。
這種方法可以作為一種特征工程技術(shù)來估計(jì)輸入變量內(nèi)子分布的置信區(qū)間。