2009年5月12日火曜日

変分ベイズで混合正規分布推定(Rで実装)

以前のエントリの先駆けとして簡単な実装を。
初期値と推定結果はこんな感じになる。

初期値と

推定結果

参考文献は上田さんの学会誌

 上田修功,ベイズ学習[IV・完] : 変分ベイズ学習の応用例

基本的にこのとおりに実装すれば動く。素晴らしすぎる。
他に参考文献はこの辺り

 荒木佑季,変分ベイズ学習による混合正規分布推定

ソースは続きから。

Rでのソースはこんな感じ。
2変量正規分布での実装である。
基本的に参考文献通りに計算をしているだけ。
参考文献のnotationでΣが分散共分散行列でない点に注意。

ちなみに、plot.gaussは2次元プロット上に正規分布をplotするための関数。
パッケージにもあった気がするが使い勝手が悪く実装した。


vb<-function(x,K=10,maxItr=100,cirCol="red",...){
source("plot.gauss.r")

#initialize
N <- dim(x)[1]
d <- dim(x)[2]
phi_0 <- N / K
xi_0 <- 1.0
eta_0 <- d + 2
nu_0 <- c(mean(x[,1]),mean(x[,2]))
SS_0 <- matrix(c(var(x[,1]),0,0,var(x[,2])),2,2)
B_0 <- SS_0

x_h<-cbind(rep(0,K),rep(0,K)) #R^K*2
N_h <- rep(N/K,length=K)
phi <- rep(phi_0,length=K)
eta <- rep(eta_0,length=K)
p<-rep(1/K,length=K)
f <- eta_0 + N_h + 1 - d #R^K
mu <- list()
B <- list()
sigma <- list()
C_h <- list()
S <- list() #precision matrix
SS <- list() #variance covariance matrix
for(k in 1:K){
mu[[k]]<-c(rnorm(1,nu_0[1],2),rnorm(1,nu_0[2],2))
B[[k]]<-B_0
sigma[[k]] <- B[[k]]/(f[k] * (N_h[k] + xi_0))
C_h[[k]]<-matrix(0,2,2)
S[[k]] <- (eta_0 + N_h[k]) * solve(B[[k]])
SS[[k]] <- solve(S[[k]])
}
#print(list("N_h"=N_h,"phi"=phi,"mu"=mu,"eta"=eta,"f"=f,"B"=B,"sigma"=sigma,"S"=S,"SS"=SS))

#plot input data
#x11()
plot(x,main="Training data & initial points",...)
plot.gauss(p,mu,SS,add=T,col=cirCol,xlab="",ylab="",xaxt="n",yaxt="n",...)



#########################################################

for(t in 1:maxItr){
cat("Iteration ",t,"\n")
#VB-Estep
A0 <- digamma(phi_0 + N_h) - digamma(K*phi_0 + sum(N_h))
A1 <-0
for(j in 1:d){
A1 <- A1 + digamma((eta_0+N_h+1-j)/2)
}

gamma<-matrix(0,N,K)
for(k in 1:K){
for(n in 1:N){
gamma[n,k] <- A0[k] + 1/2 * A1[k] - 1/2 * log(det(B[[k]])) - 1/2 * sum(diag( (eta_0 + N_h[k]) * solve(B[[k]]) %*% ( f[k]/(f[k]-2) * sigma[[k]] + matrix(x[n,]-mu[[k]]) %*% t(matrix(x[n,]-mu[[k]])) ) ))
}
}
z0<-exp(gamma)
z1<-rep(0,N)
for(n in 1:N){z1[n]<-sum(z0[n,])}
z_h<-z0/z1 #R^N*K

#VB-Mstep
for(k in 1:K){
#N_h and x_h must calc at first
N_h[k] <- sum(z_h[,k]) #R^K
x_h[k,] <- c((sum(z_h[,k]*x[,1]) / N_h[k]) , (sum(z_h[,k]*x[,2]) / N_h[k])) #R^K*2

C_h[[k]]<-matrix(0,2,2)
for(n in 1:N){
C_h[[k]] <- C_h[[k]] + z_h[n,k] * matrix(x[n,] - x_h[k,]) %*% t(matrix(x[n,] - x_h[k,]))
}

mu[[k]] <- (N_h[k] * x_h[k,] + xi_0 * nu_0) / (N_h[k] + xi_0)
B[[k]] <- B_0 + C_h[[k]] + (N_h[k] * xi_0)/(N_h[k] + xi_0) * matrix(x_h[k,] - nu_0) %*% t(matrix(x_h[k,] - nu_0))
}

phi <- phi_0 + N_h
eta <- eta_0 + N_h
f <- eta + 1 - d

for(k in 1:K){
sigma[[k]] <- B[[k]]/(f[k] * (N_h[k] + xi_0))
S[[k]] <- (eta_0 + N_h[k]) * solve(B[[k]])
SS[[k]] <- solve(S[[k]])
}
#print(list("N_h"=N_h,"x_h"=x_h,"phi"=phi,"mu"=mu,"eta"=eta,"f"=f,"C_h"=C_h,"B"=B,"sigma"=sigma,"S"=S,"SS"=SS))
}

#plot estimates
pz<-K*phi_0 + sum(N_h)
for(k in 1:K){
p[k]<-(phi_0+N_h[k])/pz
}
print(list("Mixing parameter"=p))

#quartz()
#x11()
quartz()
plot(x,main="Estimated value (VB)",...)
plot.gauss(p,mu,SS,add=T,col=cirCol,xlab="",ylab="",xaxt="n",yaxt="n",...)

return(list("N_h"=N_h,"x_h"=x_h,"phi"=phi,"mu"=mu,"eta"=eta,"f"=f,"C_h"=C_h,"B"=B,"sigma"=sigma,"S"=S,"SS"=SS,"weights"=p))
}



ちなみに、このエントリを記載中にVB for GMMのパッケージを発見

 vabayelMix

マジかよ、、、
パフォーマンスチェックをした上で、分があればちゃんと公開しよう。

1 件のコメント:

Unknown さんのコメント...
このコメントは投稿者によって削除されました。

コメントを投稿