雷锋网(公众号:雷锋网)按:本文作者贾志刚,原文载于作者个人博客,雷锋网已获授权。
一. 概述
高斯混合模型(GMM)在图像分割、对象识别、视频分析等方面均有应用,对于任意给定的数据样本集合,根据其分布概率, 可以计算每个样本数据向量的概率分布,从而根据概率分布对其进行分类,但是这些概率分布是混合在一起的,要从中分离出单个样本的概率分布就实现了样本数据聚类,而概率分布描述我们可以使用高斯函数实现,这个就是高斯混合模型-GMM。
这种方法也称为D-EM即基于距离的期望最大化。
二. 算法步骤
1. 初始化变量定义-指定的聚类数目K与数据维度D
2. 初始化均值、协方差、先验概率分布
3. 迭代E-M步骤
– E步计算期望
– M步更新均值、协方差、先验概率分布
-检测是否达到停止条件(最大迭代次数与最小误差满足),达到则退出迭代,否则继续E-M步骤
4. 打印最终分类结果
三. 代码实现
package com.gloomyfish.image.gmm;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
*
* @author gloomy fish
*
*/
public class GMMProcessor {
public final static double MIN_VAR = 1E-10;
public static double[] samples = new double[]{10, 9, 4, 23, 13, 16, 5, 90, 100, 80, 55, 67, 8, 93, 47, 86, 3};
private int dimNum;
private int mixNum;
private double[] weights;
private double[][] m_means;
private double[][] m_vars;
private double[] m_minVars;
/***
*
* @param m_dimNum – 每个样本数据的维度, 对于图像每个像素点来说是RGB三个向量
* @param m_mixNum – 需要分割为几个部分,即高斯混合模型中高斯模型的个数
*/
public GMMProcessor(int m_dimNum, int m_mixNum) {
dimNum = m_dimNum;
mixNum = m_mixNum;
weights = new double[mixNum];
m_means = new double[mixNum][dimNum];
m_vars = new double[mixNum][dimNum];
m_minVars = new double[dimNum];
}
/***
* data – 需要处理的数据
* @param data
*/
public void process(double[] data) {
int m_maxIterNum = 100;
double err = 0.001;
boolean loop = true;
double iterNum = 0;
double lastL = 0;
double currL = 0;
int unchanged = 0;
initParameters(data);
int size = data.length;
double[] x = new double[dimNum];
double[][] next_means = new double[mixNum][dimNum];
double[] next_weights = new double[mixNum];
double[][] next_vars = new double[mixNum][dimNum];
List<DataNode> cList = new ArrayList<DataNode>();
while(loop) {
Arrays.fill(next_weights, 0);
cList.clear();
for(int i=0; i<mixNum; i++) {
Arrays.fill(next_means[i], 0);
Arrays.fill(next_vars[i], 0);
}
lastL = currL;
currL = 0;
for (int k = 0; k < size; k++)
{
for(int j=0;j<dimNum;j++)
x[j]=data[k*dimNum+j];
double p = getProbability(x); // 总的概率密度分布
DataNode dn = new DataNode(x);
dn.index = k;
cList.add(dn);
double maxp = 0;
for (int j = 0; j < mixNum; j++)
{
double pj = getProbability(x, j) * weights[j] / p; // 每个分类的概率密度分布百分比
if(maxp < pj) {
maxp = pj;
dn.cindex = j;
}
next_weights[j] += pj; // 得到后验概率
for (int d = 0; d < dimNum; d++)
{
next_means[j][d] += pj * x[d];
next_vars[j][d] += pj* x[d] * x[d];
}
}
currL += (p > 1E-20) ? Math.log10(p) : -20;
}
currL /= size;
// Re-estimation: generate new weight, means and variances.
for (int j = 0; j < mixNum; j++)
{
weights[j] = next_weights[j] / size;
if (weights[j] > 0)
{
for (int d = 0; d < dimNum; d++)
{
m_means[j][d] = next_means[j][d] / next_weights[j];
m_vars[j][d] = next_vars[j][d] / next_weights[j] – m_means[j][d] * m_means[j][d];
if (m_vars[j][d] < m_minVars[d])
{
m_vars[j][d] = m_minVars[d];
}
}
}
}
// Terminal conditions
iterNum++;
if (Math.abs(currL – lastL) < err * Math.abs(lastL))
{
unchanged++;
}
if (iterNum >= m_maxIterNum || unchanged >= 3)
{
loop = false;
}
}
// print result
System.out.println("=================最终结果=================");
for(int i=0; i<mixNum; i++) {
for(int k=0; k<dimNum; k++) {
System.out.println("[" + i + "]: ");
System.out.println("means : " + m_means[i][k]);
System.out.println("var : " + m_vars[i][k]);
System.out.println();
}
}
// 获取分类
for(int i=0; i<size; i++) {
System.out.println("data[" + i + "]=" + data[i] + " cindex : " + cList.get(i).cindex);
}
}
/**
*
* @param data
*/
private void initParameters(double[] data) {
// 随机方法初始化均值
int size = data.length;
for (int i = 0; i < mixNum; i++)
{
for (int d = 0; d < dimNum; d++)
{
m_means[i][d] = data[(int)(Math.random()*size)];
}
}
// 根据均值获取分类
int[] types = new int[size];
for (int k = 0; k < size; k++)
{
double max = 0;
for (int i = 0; i < mixNum; i++)
{
double v = 0;
for(int j=0;j<dimNum;j++) {
v += Math.abs(data[k*dimNum+j] – m_means[i][j]);
}
if(v > max) {
max = v;
types[k] = i;
}
}
}
double[] counts = new double[mixNum];
for(int i=0; i<types.length; i++) {
counts[types[i]]++;
}
// 计算先验概率权重
for (int i = 0; i < mixNum; i++)
{
weights[i] = counts[i] / size;
}
// 计算每个分类的方差
int label = -1;
int[] Label = new int[size];
double[] overMeans = new double[dimNum];
double[] x = new double[dimNum];
for (int i = 0; i < size; i++)
{
for(int j=0;j<dimNum;j++)
x[j]=data[i*dimNum+j];
label=Label[i];
// Count each Gaussian
counts[label]++;
for (int d = 0; d < dimNum; d++)
{
m_vars[label][d] += (x[d] – m_means[types[i]][d]) * (x[d] – m_means[types[i]][d]);
}
// Count the overall mean and variance.
for (int d = 0; d < dimNum; d++)
{
overMeans[d] += x[d];
m_minVars[d] += x[d] * x[d];
}
}
// Compute the overall variance (* 0.01) as the minimum variance.
for (int d = 0; d < dimNum; d++)
{
overMeans[d] /= size;
m_minVars[d] = Math.max(MIN_VAR, 0.01 * (m_minVars[d] / size – overMeans[d] * overMeans[d]));
}
// Initialize each Gaussian.
for (int i = 0; i < mixNum; i++)
{
if (weights[i] > 0)
{
for (int d = 0; d < dimNum; d++)
{
m_vars[i][d] = m_vars[i][d] / counts[i];
// A minimum variance for each dimension is required.
if (m_vars[i][d] < m_minVars[d])
{
m_vars[i][d] = m_minVars[d];
}
}
}
}
System.out.println("=================初始化=================");
for(int i=0; i<mixNum; i++) {
for(int k=0; k<dimNum; k++) {
System.out.println("[" + i + "]: ");
System.out.println("means : " + m_means[i][k]);
System.out.println("var : " + m_vars[i][k]);
System.out.println();
}
}
}
/***
*
* @param sample – 采样数据点
* @return 该点总概率密度分布可能性
*/
public double getProbability(double[] sample)
{
double p = 0;
for (int i = 0; i < mixNum; i++)
{
p += weights[i] * getProbability(sample, i);
}
return p;
}
/**
* Gaussian Model -> PDF
* @param x – 表示采样数据点向量
* @param j – 表示对对应的第J个分类的概率密度分布
* @return – 返回概率密度分布可能性值
*/
public double getProbability(double[] x, int j)
{
double p = 1;
for (int d = 0; d < dimNum; d++)
{
p *= 1 / Math.sqrt(2 * 3.14159 * m_vars[j][d]);
p *= Math.exp(-0.5 * (x[d] – m_means[j][d]) * (x[d] – m_means[j][d]) / m_vars[j][d]);
}
return p;
}
public static void main(String[] args) {
GMMProcessor filter = new GMMProcessor(1, 2);
filter.process(samples);
}
}
结构类DataNode
package com.gloomyfish.image.gmm;
public class DataNode {
public int cindex; // cluster
public int index;
public double[] value;
public DataNode(double[] v) {
this.value = v;
cindex = -1;
index = -1;
}
}
四. 结果
这里初始中心均值的方法我是通过随机数来实现,GMM算法运行结果跟初始化有很大关系,常见初始化中心点的方法是通过K-Means来计算出中心点。大家可以尝试修改代码基于K-Means初始化参数,我之所以选择随机参数初始,主要是为了省事!
雷锋网相关阅读:
25 行 Python 代码实现人脸检测——OpenCV 技术教程
手把手教你如何用 OpenCV + Python 实现人脸识别
深度学习之神经网络特训班
20年清华大学神经网络授课导师邓志东教授,带你系统学习人工智能之神经网络理论及应用!
课程链接:http://www.mooc.ai/course/65
加入AI慕课学院人工智能学习交流QQ群:624413030,与AI同行一起交流成长
雷锋网版权文章,未经授权禁止转载。详情见。
原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/85680.html