期望最大化技术的直观解释是什么?[关闭]


109

期望最大化(EM)是一种对数据进行分类的概率方法。如果不是分类器,如果我错了,请纠正我。

这种EM技术的直观解释是什么?expectation这里是什么,现在是maximized什么?


12
期望最大化算法是什么?自然生物技术 26,897-899(2008)有一个漂亮的图片,说明如何算法的工作。
chl 2012年

@chl在漂亮图片的b部分中,他们如何获得Z上概率分布的值(即0.45xA,0.55xB等)?
Noob Saibot 2013年

3
你可以看一下这个问题math.stackexchange.com/questions/25111/...
V4R

3
更新了指向 @chl提到的图片的链接
n1k31t4 '18年

Answers:


119

注意:此答案背后的代码可在此处找到。


假设我们从两个不同的组(红色和蓝色)中采样了一些数据:

在此处输入图片说明

在这里,我们可以看到哪个数据点属于红色或蓝色组。这样可以轻松找到表征每个组的参数。例如,红色组的均值约为3,蓝色组的均值约为7(如果需要,我们可以找到确切的均值)。

一般来说,这被称为最大似然估计。给定一些数据,我们将计算最能解释该数据的一个或多个参数的值。

现在想象一下,我们无法看到其值从组的抽样。一切对我们来说都是紫色的:

在此处输入图片说明

在这里,我们有知识,有2个值的组,但我们不知道任何特定值属于哪个组。

我们是否仍可以估计最适合此数据的红色和蓝色组的平均值?

是的,通常我们可以!期望最大化为我们提供了一种实现方法。该算法背后的非常笼统的想法是:

  1. 从每个参数可能的初始估算开始。
  2. 计算每个参数产生数据点的可能性
  3. 根据参数产生的可能性,计算每个数据点的权重,以指示它是红色还是蓝色。将权重与数据(期望值)相结合。
  4. 使用权重调整的数据为参数计算更好的估计值(最大化)。
  5. 重复步骤2到4,直到参数估计收敛为止(该过程停止产生其他估计)。

这些步骤需要进一步说明,因此我将逐步解决上述问题。

示例:估计均值和标准差

在本示例中,我将使用Python,但是如果您不熟悉这种语言,则代码应该很容易理解。

假设我们有两组,红色和蓝色,其值的分布如上图所示。具体来说,每个组均包含一个从正态分布中提取的值,其中包含以下参数:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

这再次是这些红色和蓝色组的图像(以免您不得不向上滚动):

在此处输入图片说明

当我们看到每个点的颜色(即它属于哪个组)时,很容易估计每个组的均值和标准差。我们只是将红色和蓝色值传递给NumPy中的内置函数。例如:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

但是,如果我们有什么不能看的点的颜色?也就是说,每个点都用紫色代替了红色或蓝色。

要尝试恢复红色和蓝色组的均值和标准差参数,我们可以使用期望最大化。

我们的第一步(上面的步骤1)是猜测每组平均值和标准偏差的参数值。我们不必聪明地猜测;我们可以选择任何我们喜欢的数字:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

这些参数估计会产生钟形曲线,如下所示:

在此处输入图片说明

这些是错误的估计。例如,对于有意义的点组,这两种方式(垂直虚线)都远离任何类型的“中间”。我们希望改善这些估计。

下一步(步骤2)是计算每个数据点在当前参数猜测下出现的可能性:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

在这里,我们使用红色和蓝色的均值和标准差的当前猜测,将每个数据点简单地放入正态分布的概率密度函数中。这就告诉我们,例如,我们目前猜测在1.761的数据点是多少更有可能被红色(0.189),比蓝(0.00003)。

对于每个数据点,我们可以将这两个似然值转换为权重(步骤3),以使它们的总和为1,如下所示:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

使用我们当前的估计值和新计算的权重,我们现在可以计算红色和蓝色组的均值和标准差的估计值(步骤4)。

我们使用所有数据点两次计算平均值和标准偏差,但是使用不同的权重:一次用于红色权重,一次用于蓝色权重。

直觉的关键在于,颜色在数据点上的权重越大,数据点对该颜色参数的下一个估计值的影响就越大。这具有沿正确方向“拉动”参数的效果。

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

我们对参数有新的估计。为了再次改善它们,我们可以跳回到步骤2并重复该过程。我们这样做直到估计收敛为止,或者在执行了一些迭代之后(步骤5)。

对于我们的数据,此过程的前五个迭代如下所示(最近的迭代具有更强的外观):

在此处输入图片说明

我们看到均值已经在一些值上收敛,并且曲线的形状(由标准偏差控制)也变得更加稳定。

如果继续进行20次迭代,最终结果如下:

在此处输入图片说明

EM过程已收敛到以下值,结果非常接近实际值(在这里我们可以看到颜色-没有隐藏变量):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

在上面的代码中,您可能已经注意到,新的标准差估算是使用先前迭代的平均值估算来计算的。最终,我们是否首先计算平均值就没有关系,因为我们只是在某个中心点附近找到值的(加权)方差。我们仍将看到参数估计值收敛。


如果我们什至不知道来自何处的正态分布数呢?在这里,您以k = 2分布为例,我们还可以估计k以及k个参数集吗?
stackit

1
@stackit:在这种情况下,我不确定是否存在一种简单直接的通用方法来计算最有可能的k值作为EM过程的一部分。主要问题是,我们需要使用要查找的每个参数的估计值来启动EM,这意味着我们需要在开始之前知道/估计k。但是,这里可以通过EM估计属于一个组的点的比例。也许如果我们高估了k,那么除两组以外的所有组的比例都将降至接近零。我尚未对此进行试验,所以我不知道它在实践中的效果如何。
亚历克斯·赖利

1
@AlexRiley您能多说一些有关计算新的均值和标准差估计值的公式吗?
柠檬

2
@AlexRiley感谢您的解释。为什么要使用旧的均值猜测来计算新的标准偏差估算值?如果首先找到均值的新估计怎么办?
GoodDeeds '18

1
@Lemon GoodDeeds Kaushal-对您对我的问题的最新答复深表歉意。我试图编辑答案以解决您提出的问题。我还可以在此处的笔记本中访问此答案中使用的所有代码 (还包括我所涉及的一些点的更多详细说明)。
Alex Riley

36

EM是一种算法,用于在模型中的某些变量未被观察到时(即,当您具有潜在变量时)最大化似然函数。

您可能会问,如果我们只是试图最大化功能,为什么不仅仅使用现有的机器来最大化功能。好吧,如果您尝试通过求导数并将其设置为零来最大化此值,则会发现在许多情况下,一阶条件没有解。有一个鸡与蛋的问题,要解决模型参数,您需要知道未观察到的数据的分布。但是未观察到的数据的分布是模型参数的函数。

EM尝试通过反复猜测未观察到的数据的分布来解决此问题,然后通过最大化实际似然函数下界的值来估计模型参数,并重复进行直到收敛:

EM算法

从猜测模型参数的值开始

E-步骤:对于每个具有缺失值的数据点,使用模型方程式,根据给定的模型参数当前猜测值和观察到的数据,求解缺失数据的分布(请注意,您正在求解每个缺失的分布值,而不是预期值)。现在我们有了每个缺失值的分布,我们可以针对未观察到的变量计算似然函数的期望值。如果我们对模型参数的猜测是正确的,则该预期可能性将是我们观测数据的实际可能性;如果参数不正确,将只是一个下限。

M步:现在,我们有了一个预期似然函数,其中没有不可观测的变量,可以像在完全观测的情况下那样最大化该函数,以获取模型参数的新估计。

重复直到收敛。


5
我不理解您的E-step。问题的一部分是,在我学习这些东西时,找不到使用相同术语的人。那么,模型方程是什么意思呢?我不知道求解概率分布是什么意思?
user678392 2013年

27

这是理解“期望最大化”算法的简单明了的配方:

1-阅读Do和Batzoglou撰写的EM教程文章

2-您可能脑子里有问号,请看一下此数学堆栈交换页面上的说明。

3-看一下我用Python写的这段代码,它解释了项目1的EM教程纸中的示例:

警告:由于我不是Python开发人员,因此代码可能太乱/不够理想。但这确实起作用。

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

我发现您的程序将A和B的结果都设为0.66,我也使用scala对其进行了实现,还发现结果为0.66,您可以帮忙检查一下吗?
zjffdu

使用电子表格,如果我的初始猜测相等,我只会找到0.66的结果。否则,我可以复制本教程的输出。
soakley

@zjffdu,在返回0.66之前,EM运行了多少次迭代?如果使用相等的值进行初始化,则可能会陷入局部最大值,并且您会发现迭代次数非常少(因为没有任何改进)。
Zhubarb 2013年

您也可以查看这张幻灯片安德鲁·伍(Andrew Ng)和哈佛的课程笔记
Minh Phan,

16

从技术上讲,“ EM”一词的含义不太明确,但是我认为您是指高斯混合建模聚类分析技术,它是通用EM原理的一个实例

实际上,EM聚类分析不是分类器。我知道有人认为聚类是“无监督分类”,但实际上聚类分析是完全不同的。

关键区别以及人们对聚类分析经常会产生的重大误解是:在聚类分析中,没有“正确的解决方案”。这是一种知识发现方法,实际上是要找到新的东西!这使得评估非常棘手。通常使用已知的分类作为参考进行评估,但这并不总是恰当的:您拥有的分类可能会也可能不会反映数据中的内容。

让我举一个例子:您拥有大量的客户数据集,包括性别数据。将数据集与现有类进行比较时,将数据集分为“男性”和“女性”的方法是最佳的。以“预测”的方式认为这很好,因为对于新用户,您现在可以预测其性别。用“知识发现”的方式认为这实际上是不好的,因为您想发现数据中的一些新结构。但是,例如将数据分为老人和孩子的方法,相对于男性/女性类别,其得分可能会更低。但是,这将是一个很好的聚类结果(如果没有给出年龄)。

现在回到EM。从本质上讲,它假设您的数据由多个多元正态分布组成(请注意,这是一个非常强的假设,尤其是当您确定簇数时!)。然后尝试通过交替改进模型和对象对模型的分配来为此找到一个局部最优模型

为了在分类上下文中获得最佳结果,请选择大于类数的聚类数,甚至将聚类仅应用于单个类(以查明该类中是否存在某些结构!)。

假设您想训练一个分类器来区分“汽车”,“自行车”和“卡车”。假设数据完全由3个正态分布组成几乎没有用。但是,您可能会假设有不止一种类型的汽车(以及卡车和自行车)。因此,您不必为这三个类别训练分类器,而是将汽车,卡车和自行车分别分组为10个类别(或者也许是10辆汽车,3辆卡车和3个自行车,无论如何),然后训练一个分类器来区分这30个类别,然后将类结果合并回原始类。您可能还会发现,存在一个很难分类的集群,例如Trikes。他们有点像汽车,有些像自行车。还是送货卡车,它比卡车更像是超大型汽车。


EM的规格不足吗?
sam boosalis

它有多个版本。从技术上讲,您也可以将Lloyd风格的k均值称为“ EM”。您需要指定使用哪种模型
已退出-Anony-Mousse 2013年

2

其他答案很好,我将尝试提供另一种观点并解决问题的直觉部分。

EM(期望最大化)算法是使用对偶性的一类迭代算法的变体

摘录(强调我的):

在数学中,对偶性通常是(但不总是)通过对合运算以一对一的方式将概念,定理或数学结构转换为其他概念,定理或结构: A是B,则B的对偶是A。此类对合有时具有固定点,因此A的对偶是A本身

通常,对象 A的对 B 在某种程度上与A有关,并且保持了某些对称性或兼容性。例如AB = const

采用对偶性(在先前的意义上)的迭代算法示例如下:

  1. 最大公除数的欧几里得算法及其变体
  2. Gram–Schmidt向量基础算法和变体
  3. 算术均值-几何均值不等式及其变体
  4. 期望最大化算法及其变体有关信息几何视图,另请参见此处
  5. (..其他类似的算法..)

以类似的方式,EM算法也可以看作是两个双重最大化步骤

.. [EM]被视为使参数和未观察变量的分布的联合函数最大化。E步相对于未观察变量的分布使此函数最大化。关于参数的M步

在使用对偶的迭代算法中,存在一个平衡(或固定)收敛点的显式(或隐式)假设(对于EM,这使用Jensen不等式证明)

因此,此类算法的概述为:

  1. 类似于E的步骤:针对给定y找到最佳解x保持恒定,。
  2. 类似于M的步骤(对偶):相对于x(如上一步中计算的)保持恒定,找到最佳解y
  3. “终止/收敛准则”步骤:使用更新的xy重复步骤1、2,直到收敛(或达到指定的迭代次数)为止。

注意,当这样的算法收敛到(全局)最优时,它已经找到了在两种意义上最佳的配置(即,在x域/参数和y域/参数中)。然而,该算法只能找到局部最优而不是全局最优。

我会说这是算法概述的直观描述

对于统计参数和应用,其他答案也给出了很好的解释(请参阅此答案中的参考文献)


2

接受的答案引用了Chuong EM论文,该论文在解释EM方面做得不错。还有一个youtube视频,对本文进行了更详细的说明。

回顾一下,这是场景:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

对于第一个试验的问题,直觉上我们认为B产生了它,因为正面的比例非常匹配B的偏见...但是该值只是一个猜测,所以我们不能确定。

考虑到这一点,我喜欢这样考虑EM解决方案:

  • 每次翻盖试验都会“投票”它最喜欢的硬币
    • 这取决于每种硬币适合其分布的程度
    • 或者,从硬币的角度来看,相对于另一枚硬币(基于对数似然性),人们对该试验的期望很高。
  • 根据每个试验对每个硬币的喜欢程度,它可以更新对该硬币参数(偏差)的猜测。
    • 审判越喜欢硬币,它就越能更新硬币的偏见以反映其自身!
    • 从本质上讲,通过在所有试验中组合这些加权更新来更新硬币的偏差,该过程称为(最大化),是指尝试通过一组试验获得每种硬币偏差的最佳猜测。

这可能是过于简单化(甚至在某些层面上根本上是错误的),但是我希望这在直观的层面上有所帮助!


1

EM用于最大化具有潜在变量Z的模型Q的可能性。

这是一个迭代优化。

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-步骤:给定当前的Z估计值,计算预期的对数似然函数

m步:找到使该Q最大化的theta

GMM示例:

e-step:在当前gmm参数估计的情况下,估计每个数据点的标签分配

m-步骤:根据新标签分配最大化新theta

K-means也是一种EM算法,关于K-means的动画有很多解释。


1

使用在Zubarb的答案中引用的Do和Batzoglou的同一篇文章,我在Java中针对该问题实现了EM。对他的回答的评论表明,算法陷入局部最优状态,如果参数thetaA和thetaB相同,则在我的实现中也会发生。

以下是我的代码的标准输出,显示了参数的收敛性。

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

下面是我的EM的Java实现,以解决(Do和Batzoglou,2008)中的问题。实现的核心部分是运行EM直到参数收敛的循环。

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

以下是整个代码。

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.