Collapsed Gibbs Sampler for LDA

2019-09-22

이 포스트는 “Griffiths, Steyvers - Finding scientific topics, 2004”의 내용을 발췌독 및 정리한 것이다. 이 paper는 collapsed Gibbs sampler를 이용해 LDA의 inference를 보다 효과적으로 수행하는 방법을 제시한다.

Introduction

앞에서는 p(w|α,β)의 계산은 다루기 힘들기 때문에 variational Bayes 방법을 사용하여 model hyperparameter를 추정하였다. 여기서는 다소 다른 접근법을 소개한다. Collapsed Gibbs sampler를 이용한 LDA는 앞의 방법들처럼 β,θ를 추정해야 할 parameter로 보지 않고, multinomial-Dirichlet conjugacy를 이용하여 integrate out한다. 그 다음, 관측된 document w에 대한 topic variable z의 posterior distribution p(z|w)를 Gibbs sampling을 통해 추정한다. β,θ에 대한 추정치는 이 posterior로부터 얻을 수 있다.

그를 위해 topic이 주어졌을 때 word probability를 나타내는 β를 model hyperparameter로 보지 않고, prior를 부여하자. 그 때의 complete probability model은 다음과 같다. 붉은 글씨로 적힌 β가 원래 LDA에서는 model parameter였으며, 여기서는 Dirichlet prior를 새로 부여한 부분이다.

θdDirichlet(α)d=1,,Dzdn|θdMultinomial(θ)d=1,,D,n=1,,NdβiDirichlet(η)i=1,,kwdn|zdn,βMultinomial(βzdn)d=1,,D,n=1,,Nd

이제 우리의 hyperparameter는 α,η가 주어졌을 때, topic assignment zdn의 conditional을 구하자.

Deriving conditionals

어떤 한 document 내에서, d번째 document의 n번째 word에 대한 topic assignment zdn의 conditional은 다음과 같다.

p(zdn|Zdn,W,α,η)=p(Z,W|α,η)p(Zdn,W|α,η)=p(Z,W|α,η)p(Zdn,Wdn|α,η)p(wdn|Zdn,Wdn,α,η)p(Z,W|α,η)p(Zdn,Wdn|α,η)(p(wdn|Zdn,Wdn,α,η) is a constant with respect to zdn)

분자인 p(Z,W|α,η)을 전개해보면 다음과 같다.

p(Z,W|α,η)=θ(1:D)β1:kp(Z,W,β1:k,θ(1:D)|α,η)dβ1:kdθ(1:D)=β1:kp(W|Z,β1:k)p(β1:k|η)dβ1:kθ(1:D)p(Z|θ(1:D))p(θ(1:D)|α)dθ(1:D)

두 적분식을 각각 정리하면 다음과 같다.

1.β1:kp(W|Z,β1:k)p(β1:k|η)dβ1:k=β1:kp(W|Z,β1:k)ki=1p(βi|η)dβ1:k=β1:kDd=1p(wd|zd,β1:k)ki=1p(βi|η)dβ1:k=β1:k(Dd=1ki=1Vj=1βNdn=1wjdnzidnij)(ki=1Dirichlet(βi;η))dβ1:k=β1:k(ki=1Vj=1βDd=1Ndn=1wjdnzidnij)ki=1(Γ(Vj=1ηj)Vj=1Γ(ηj)Vj=1βηj1ij)dβ1:kLet Ξi,j=Dd=1Ndn=1wjdnzidn: counts of jth word in ith topic across all documents.=[Γ(Vj=1ηj)Vj=1Γ(ηj)]kβ1:kki=1Vj=1β(ηj+Ξi,j)1ijdβ1:k=[Γ(Vj=1ηj)Vj=1Γ(ηj)]kki=1Vj=1Γ(ηj+Ξi,j)Γ(Vj=1ηj+Ξi,)=ki=1Beta(η+Ξi)Beta(η)(where  Ξi=[Ξi,1,,Ξi,V]T) 2.θ(1:D)p(Z|θ(1:D))p(θ(1:D)|α)dθ(1:D)=θ(1:D)Dd=1p(zd|θ(d))p(θ(d)|α)dθ(1:D)=θ(1:D)(Dd=1Ndn=1ki=1θzidndi)(Dd=1Dirichlet(θd;α))dθ(1:D)=θ(1:D)(Dd=1ki=1θNdn=1zidndi)Dd=1(Γ(ki=1αi)ki=1Γ(αi)ki=1θαi1di)dθ(1:D)=[Γ(ki=1αi)ki=1Γ(αi)]Dθ(1:D)(Dd=1ki=1θ(αi+Ndn=1zidn)1di)dθ(1:D)Let Ωd,i=Ndn=1zidn: the number of words of ith topic in dth document.=[Γ(ki=1αi)ki=1Γ(αi)]Dθ(1:D)(Dd=1ki=1θ(αi+Ωd,i)1di)dθ(1:D)=[Γ(ki=1αi)ki=1Γ(αi)]DDd=1ki=1Γ(αi+Ωd,i)Γ(ki=1αi+Ωd,)=Dd=1Beta(α+Ωd)Beta(α)(where  Ωd=[Ωd,1,,Ωd,k]T)

따라서 위 식의 분자 p(Z,W|α,η)는 다음과 같다.

p(Z,W|α,η)=ki=1Beta(η+Ξi)Beta(η)Dd=1Beta(α+Ωd)Beta(α)

또한 같은 방법으로 분모인 p(Zdn,Wdn|α,η)도 구할 수 있다.

p(Zdn,Wdn|α,η)=ki=1Beta(η+Ξdni)Beta(η)Dd=1Beta(α+Ωdnd)Beta(α)

Ξdni,j,Ωdnd,id번째 document의 n번째 word와 topic variable을 제외하고 구한 Ξi,j,Ωd,i이다. 이제 우리는 다음과 같이 conditional을 적을 수 있다.

p(zdn|Zdn,W,α,η)p(Z,W|α,η)p(Zdn,Wdn|α,η)=ki=1Beta(η+Ξi)Beta(η)Dd=1Beta(α+Ωd)Beta(α)ki=1Beta(η+Ξdni)Beta(η)Dd=1Beta(α+Ωdnd)Beta(α)=ki=1Beta(η+Ξi)Beta(η+Ξdni)Dd=1Beta(α+Ωd)Beta(α+Ωdnd)

Conditional을 구하고자 하는 ˜d번째 document의 n번째 word가 ˜j이고, 이 word는 topic ˜i에서 생성되었다고 하자. Ξdni,j,Ωdnd,i은 한 word(observation)을 제외하고 구한 Ξi,j,Ωd,i이므로, 다음이 만족한다.

If i=˜i,j=˜j,Ξ˜dni,j=Ξi,j1.else same.If d=˜d,i=˜i,Ω˜dnd,i=Ωd,i1.else same.

이제 첫 번째 multiplicant를 구하자.

ki=1Beta(η+Ξi)Beta(η+Ξ˜dni)=Beta(η+Ξ˜i)Beta(η+Ξ˜dn˜i)=Γ(Vj=1ηj+Ξ˜dn˜i,)Γ(Vj=1ηj+Ξ˜i,)Vj=1Γ(ηj+Ξ˜i,j)Vj=1Γ(ηj+Ξ˜dn˜i,j)=1Vj=1ηj+Ξ˜dn˜i,η˜j+Ξ˜dn˜i,˜j1=η˜j+Ξ˜dn˜i,˜jVj=1(ηj+Ξ˜dn˜i,j)

두 번째 multiplicant는 다음과 같이 구한다.

Dd=1Beta(α+Ωd)Beta(α+Ω˜dnd)=Beta(α+Ω˜d)Beta(α+Ω˜dn˜d)=Γ(ki=1αi+Ω˜dn˜d,)Γ(ki=1αi+Ω˜d,)ki=1Γ(αi+Ω˜d,i)ki=1Γ(αi+Ω˜dn˜d,i)=1ki=1αi+Ω˜dn˜d,α˜i+Ω˜dn˜d,˜i1=α˜i+Ω˜dn˜d,˜iki=1(αi+Ω˜dn˜d,i)

따라서 도출한 conditional은 다음과 같다.

p(z˜dn=˜i|Z˜dn,W,α,η)p(Z,W|α,η)p(Z˜dn,W˜dn|α,η)=ki=1Beta(η+Ξi)Beta(η+Ξ˜dni)Dd=1Beta(α+Ωd)Beta(α+Ω˜dnd)=η˜j+Ξ˜dn˜i,˜jVj=1(ηj+Ξ˜dn˜i,j)α˜i+Ω˜dn˜d,˜iki=1(αi+Ω˜dn˜d,i)

Algorithm

도출한 conditional을 가지고 Gibbs sampling을 수행하는 것은 어렵지 않다. 눈여겨볼 부분은 conditional 식을 알고리즘에 적용하는데 필요한 것은 Ξdni,j,Ωdnd,i, 즉 i 번째 topic에 word j가 지정된 횟수’와 ‘d 번째 document 내에서 i 번째 topic으로부터 생성된 단어의 수‘뿐이다. Collapsed Gibbs sampling을 이용한 LDA의 수행 과정은 다음과 같다.

1. Initialization

  • Count variable, Ξi,j,Ξi,,Ωd,i,Ωd,0으로 설정.

  • for all documents d=1,,D, do

    • for all words n=1,,Nd, do
      • Sample topic variable zdnMultinomial(1k,,1k).
      • Increment : Ξzdn,j=Ξzdn,j+1.
      • Increment : Ξzdn,=Ξzdn,+1.
      • Increment : Ωd,zdn=Ωd,zdn+1.
      • Increment : Ωd,=Ωd,+1.
    • end for
  • end for

2. Run Gibbs sampling

  • while not converged, do
    • for all documents d=1,,D, do
      • for all words n=1,,Nd, do
        • Current topic assignment of k for word wdn=j,
        • Decrement : Ξk,j=Ξk,j1.
        • Decrement : Ξk,=Ξk,1.
        • Decrement : Ωd,k=Ωd,k1.
        • Decrement : Ωd,=Ωd,1.
        • Sample ˜k=zdnp(zdn|Zdn,W,α,η) with decremented Ξi,j,Ξi,,Ωd,i,Ωd,.
        • Increment : Ξ˜k,j=Ξ˜k,j+1.
        • Increment : Ξ˜k,=Ξ˜k,+1.
        • Increment : Ωd,˜k=Ωd,˜k+1.
        • Increment : Ωd,=Ωd,+1.
      • end for
    • end for
  • end while