Collapsed Gibbs Sampler for LDA
이 포스트는 “Griffiths, Steyvers - Finding scientific topics, 2004”의 내용을 발췌독 및 정리한 것이다. 이 paper는 collapsed Gibbs sampler를 이용해 LDA의 inference를 보다 효과적으로 수행하는 방법을 제시한다.
Introduction
앞에서는 p(w|α,β)의 계산은 다루기 힘들기 때문에 variational Bayes 방법을 사용하여 model hyperparameter를 추정하였다. 여기서는 다소 다른 접근법을 소개한다. Collapsed Gibbs sampler를 이용한 LDA는 앞의 방법들처럼 β,θ를 추정해야 할 parameter로 보지 않고, multinomial-Dirichlet conjugacy를 이용하여 integrate out한다. 그 다음, 관측된 document w에 대한 topic variable z의 posterior distribution p(z|w)를 Gibbs sampling을 통해 추정한다. β,θ에 대한 추정치는 이 posterior로부터 얻을 수 있다.
그를 위해 topic이 주어졌을 때 word probability를 나타내는 β를 model hyperparameter로 보지 않고, prior를 부여하자. 그 때의 complete probability model은 다음과 같다. 붉은 글씨로 적힌 β가 원래 LDA에서는 model parameter였으며, 여기서는 Dirichlet prior를 새로 부여한 부분이다.
θd∼Dirichlet(α)d=1,⋯,Dzdn|θd∼Multinomial(θ)d=1,⋯,D,n=1,⋯,Ndβi∼Dirichlet(η)i=1,⋯,kwdn|zdn,β∼Multinomial(βzdn)d=1,⋯,D,n=1,⋯,Nd이제 우리의 hyperparameter는 α,η가 주어졌을 때, topic assignment zdn의 conditional을 구하자.
Deriving conditionals
어떤 한 document 내에서, d번째 document의 n번째 word에 대한 topic assignment zdn의 conditional은 다음과 같다.
p(zdn|Z−dn,W,α,η)=p(Z,W|α,η)p(Z−dn,W|α,η)=p(Z,W|α,η)p(Z−dn,W−dn|α,η)p(wdn|Z−dn,W−dn,α,η)∝p(Z,W|α,η)p(Z−dn,W−dn|α,η)(∵p(wdn|Z−dn,W−dn,α,η) is a constant with respect to zdn)분자인 p(Z,W|α,η)을 전개해보면 다음과 같다.
p(Z,W|α,η)=∫θ(1:D)∫β1:kp(Z,W,β1:k,θ(1:D)|α,η)dβ1:kdθ(1:D)=∫β1:kp(W|Z,β1:k)p(β1:k|η)dβ1:k∫θ(1:D)p(Z|θ(1:D))p(θ(1:D)|α)dθ(1:D)두 적분식을 각각 정리하면 다음과 같다.
1.∫β1:kp(W|Z,β1:k)p(β1:k|η)dβ1:k=∫β1:kp(W|Z,β1:k)k∏i=1p(βi|η)dβ1:k=∫β1:kD∏d=1p(wd|zd,β1:k)k∏i=1p(βi|η)dβ1:k=∫β1:k(D∏d=1k∏i=1V∏j=1β∑Ndn=1wjdnzidnij)(k∏i=1Dirichlet(βi;η))dβ1:k=∫β1:k(k∏i=1V∏j=1β∑Dd=1∑Ndn=1wjdnzidnij)k∏i=1(Γ(∑Vj=1ηj)∏Vj=1Γ(ηj)V∏j=1βηj−1ij)dβ1:kLet Ξi,j=D∑d=1Nd∑n=1wjdnzidn: counts of jth word in ith topic across all documents.=[Γ(∑Vj=1ηj)∏Vj=1Γ(ηj)]k∫β1:kk∏i=1V∏j=1β(ηj+Ξi,j)−1ijdβ1:k=[Γ(∑Vj=1ηj)∏Vj=1Γ(ηj)]kk∏i=1∏Vj=1Γ(ηj+Ξi,j)Γ(∑Vj=1ηj+Ξi,∙)=k∏i=1Beta(η+Ξi)Beta(η)(where Ξi=[Ξi,1,⋯,Ξi,V]T) 2.∫θ(1:D)p(Z|θ(1:D))p(θ(1:D)|α)dθ(1:D)=∫θ(1:D)D∏d=1p(zd|θ(d))p(θ(d)|α)dθ(1:D)=∫θ(1:D)(D∏d=1Nd∏n=1k∏i=1θzidndi)(D∏d=1Dirichlet(θd;α))dθ(1:D)=∫θ(1:D)(D∏d=1k∏i=1θ∑Ndn=1zidndi)D∏d=1(Γ(∑ki=1αi)∏ki=1Γ(αi)k∏i=1θαi−1di)dθ(1:D)=[Γ(∑ki=1αi)∏ki=1Γ(αi)]D∫θ(1:D)(D∏d=1k∏i=1θ(αi+∑Ndn=1zid′n)−1di)dθ(1:D)Let Ωd,i=Nd∑n=1zidn: the number of words of ith topic in dth document.=[Γ(∑ki=1αi)∏ki=1Γ(αi)]D∫θ(1:D)(D∏d=1k∏i=1θ(αi+Ωd,i)−1di)dθ(1:D)=[Γ(∑ki=1αi)∏ki=1Γ(αi)]DD∏d=1∏ki=1Γ(αi+Ωd,i)Γ(∑ki=1αi+Ωd,∙)=D∏d=1Beta(α+Ωd)Beta(α)(where Ωd=[Ωd,1,⋯,Ωd,k]T)따라서 위 식의 분자 p(Z,W|α,η)는 다음과 같다.
p(Z,W|α,η)=k∏i=1Beta(η+Ξi)Beta(η)D∏d′=1Beta(α+Ωd′)Beta(α)또한 같은 방법으로 분모인 p(Z−dn,W−dn|α,η)도 구할 수 있다.
p(Z−dn,W−dn|α,η)=k∏i=1Beta(η+Ξ−dni)Beta(η)D∏d′=1Beta(α+Ω−dnd′)Beta(α)Ξ−dni,j,Ω−dnd,i은 d번째 document의 n번째 word와 topic variable을 제외하고 구한 Ξi,j,Ωd,i이다. 이제 우리는 다음과 같이 conditional을 적을 수 있다.
p(zdn|Z−dn,W,α,η)∝p(Z,W|α,η)p(Z−dn,W−dn|α,η)=∏ki=1Beta(η+Ξi)Beta(η)∏Dd′=1Beta(α+Ωd′)Beta(α)∏ki=1Beta(η+Ξ−dni)Beta(η)∏Dd′=1Beta(α+Ω−dnd′)Beta(α)=k∏i=1Beta(η+Ξi)Beta(η+Ξ−dni)D∏d′=1Beta(α+Ωd′)Beta(α+Ω−dnd′)Conditional을 구하고자 하는 ˜d번째 document의 n번째 word가 ˜j이고, 이 word는 topic ˜i에서 생성되었다고 하자. Ξ−dni,j,Ω−dnd,i은 한 word(observation)을 제외하고 구한 Ξi,j,Ωd,i이므로, 다음이 만족한다.
If i=˜i,j=˜j,Ξ−˜dni,j=Ξi,j−1.else same.If d=˜d,i=˜i,Ω−˜dnd,i=Ωd,i−1.else same.이제 첫 번째 multiplicant를 구하자.
k∏i=1Beta(η+Ξi)Beta(η+Ξ−˜dni)=Beta(η+Ξ˜i)Beta(η+Ξ−˜dn˜i)=Γ(∑Vj=1ηj+Ξ−˜dn˜i,∙)Γ(∑Vj=1ηj+Ξ˜i,∙)∏Vj=1Γ(ηj+Ξ˜i,j)∏Vj=1Γ(ηj+Ξ−˜dn˜i,j)=1∑Vj=1ηj+Ξ−˜dn˜i,∙η˜j+Ξ−˜dn˜i,˜j1=η˜j+Ξ−˜dn˜i,˜j∑Vj=1(ηj+Ξ−˜dn˜i,j)두 번째 multiplicant는 다음과 같이 구한다.
D∏d=1Beta(α+Ωd)Beta(α+Ω−˜dnd)=Beta(α+Ω˜d)Beta(α+Ω−˜dn˜d)=Γ(∑ki=1αi+Ω−˜dn˜d,∙)Γ(∑ki=1αi+Ω˜d,∙)∏ki=1Γ(αi+Ω˜d,i)∏ki=1Γ(αi+Ω−˜dn˜d,i)=1∑ki=1αi+Ω−˜dn˜d,∙α˜i+Ω−˜dn˜d,˜i1=α˜i+Ω−˜dn˜d,˜i∑ki=1(αi+Ω−˜dn˜d,i)따라서 도출한 conditional은 다음과 같다.
p(z˜dn=˜i|Z−˜dn,W,α,η)∝p(Z,W|α,η)p(Z−˜dn,W−˜dn|α,η)=k∏i=1Beta(η+Ξi)Beta(η+Ξ−˜dni)D∏d=1Beta(α+Ωd)Beta(α+Ω−˜dnd)=η˜j+Ξ−˜dn˜i,˜j∑Vj=1(ηj+Ξ−˜dn˜i,j)α˜i+Ω−˜dn˜d,˜i∑ki=1(αi+Ω−˜dn˜d,i)Algorithm
도출한 conditional을 가지고 Gibbs sampling을 수행하는 것은 어렵지 않다. 눈여겨볼 부분은 conditional 식을 알고리즘에 적용하는데 필요한 것은 Ξ−dni,j,Ω−dnd,i, 즉 ‘i 번째 topic에 word j가 지정된 횟수’와 ‘d 번째 document 내에서 i 번째 topic으로부터 생성된 단어의 수‘뿐이다. Collapsed Gibbs sampling을 이용한 LDA의 수행 과정은 다음과 같다.
1. Initialization
-
Count variable, Ξi,j,Ξi,∙,Ωd,i,Ωd,∙을 0으로 설정.
-
for all documents d=1,⋯,D, do
- for all words n=1,⋯,Nd, do
- Sample topic variable zdn∼Multinomial(1k,⋯,1k).
- Increment : Ξzdn,j=Ξzdn,j+1.
- Increment : Ξzdn,∙=Ξzdn,∙+1.
- Increment : Ωd,zdn=Ωd,zdn+1.
- Increment : Ωd,∙=Ωd,∙+1.
- end for
- for all words n=1,⋯,Nd, do
-
end for
2. Run Gibbs sampling
- while not converged, do
- for all documents d=1,⋯,D, do
- for all words n=1,⋯,Nd, do
- Current topic assignment of k for word wdn=j,
- Decrement : Ξk,j=Ξk,j−1.
- Decrement : Ξk,∙=Ξk,∙−1.
- Decrement : Ωd,k=Ωd,k−1.
- Decrement : Ωd,∙=Ωd,∙−1.
- Sample ˜k=zdn∼p(zdn|Z−dn,W,α,η) with decremented Ξi,j,Ξi,∙,Ωd,i,Ωd,∙.
- Increment : Ξ˜k,j=Ξ˜k,j+1.
- Increment : Ξ˜k,∙=Ξ˜k,∙+1.
- Increment : Ωd,˜k=Ωd,˜k+1.
- Increment : Ωd,∙=Ωd,∙+1.
- end for
- for all words n=1,⋯,Nd, do
- end for
- for all documents d=1,⋯,D, do
- end while