두 정규분포 N(μ, σ²)와 N(0, 1) 사이의 KL 발산 유도

 

두 정규분포 N(μ, σ²)와 N(0, 1) 사이의 KL 발산 유도

...
# 대충 인코더 구현부 ...

self.mu = tf.layers.dense(h, self.z_size, name="enc_fc_mu")
self.logvar = tf.layers.dense(h, self.z_size, name="enc_fc_log_var")
self.sigma = tf.exp(self.logvar / 2.0)
self.epsilon = tf.random_normal([self.batch_size, self.z_size])
self.z = self.mu + self.sigma * self.epsilon

# 대충 디코더 구현부 ...

if self.is_training:
    self.global_step = tf.Variable(0, name='global_step', trainable=False)
    self.r_loss = tf.reduce_sum(tf.square(self.x - self.y), reduction_indices = [1,2,3])
    self.r_loss = tf.reduce_mean(self.r_loss)
    self.kl_loss = - 0.5 * tf.reduce_sum((1 + self.logvar - tf.square(self.mu) - tf.exp(self.logvar)), reduction_indices = 1)
...

변분 자동 인코더(VAE)의 샘플 코드를 살펴보던 중, 위 코드처럼 self.kl_loss를 계산하는 특정 수식이 눈에 들어왔다. KL 발산 손실이 왜 이런 형태의 수식으로 사용되는지 궁금증이 생겼다. 아래 코드 라인은 이해하기 어려웠다.

self.kl_loss = - 0.5 * tf.reduce_sum((1 + self.logvar - tf.square(self.mu) - tf.exp(self.logvar)), reduction_indices = 1)

필자는 이를 직접 유도해보면서 KL 발산에 대해 이해하고자 했다. 나와 같은 궁금증을 가진 독자에게도 도움이 되기를 희망한다.

KL 발산 정의

KL 발산(Kullback-Leibler divergence)은 두 확률 분포 간의 차이를 측정하는 지표로, 변분 자동 인코더(VAE)와 같은 생성 모델에서 핵심적인 역할을 한다. 특히 VAE의 목적 함수는 잠재 변수 분포 $q(z \mid x)$가 사전 분포 $p(z)$와 얼마나 차이나는지를 측정하는 KL 발산 항을 포함한다. 이 글에서는 가장 흔히 사용되는 두 정규분포 $N(\mu, \sigma^2)$와 표준정규분포 $N(0, 1)$ 사이의 KL 발산을 단계별로 유도한다.

두 확률 분포 $q(z)$와 $p(z)$ 사이의 KL 발산은 다음과 같이 정의된다.

\[KL(q(z \mid x) \parallel p(z)) = \int q(z \mid x) \log \frac{q(z \mid x)}{p(z)} \, dz\]

이 값은 항상 0 이상이며, 두 분포가 완전히 같을 때만 0이 된다. 여기서 $q(z \mid x)$는 x가 주어졌을 때 z가 따르는 분포를 의미하며, $p(z)$는 사전 분포를 나타낸다.

정규분포의 확률 밀도 함수

KL 발산 유도에 사용할 두 정규분포의 확률 밀도 함수는 다음과 같다.

데이터 x가 주어졌을 때 잠재 변수 z의 분포 $q(z \mid x) = N(\mu, \sigma^2)$: 이 분포는 인코더 네트워크가 입력 데이터 x를 받아 추정하며, 평균 $\mu$와 분산 $\sigma^2$가 x에 의해 결정된다. \(q(z \mid x) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right)\)

잠재 변수 z의 기본적인 분포 (사전 분포) $p(z) = N(0, 1)$: 학습 시작 전에 잠재 변수 z가 따를 것이라고 가정한 표준정규분포이다. 평균이 0, 분산이 1이다. \(p(z) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{z^2}{2}\right)\)

여기서 $\mu$와 $\sigma^2$는 인코더 네트워크의 출력으로 결정되는 값이다. z는 잠재 변수이며, x는 관측 데이터(예: 이미지)를 나타낸다. $q(z \mid x)$는 확률 밀도 함수로, z 값에 따라 확률 밀도를 반환하는 함수이다.

KL 발산 유도 단계

1. 적분식 설정

KL 발산 정의에 두 확률 밀도 함수를 대입한다.

\[KL(q(z \mid x) \parallel p(z)) = \int \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right) \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right)}{\frac{1}{\sqrt{2\pi}} \exp\left(-\frac{z^2}{2}\right)} \, dz\]

2. 로그 항 분해

로그 내부를 정리한다. $\log(A/B) = \log A - \log B$와 $\log(\exp(X)) = X$ 성질을 사용한다.

\[\log \frac{q(z \mid x)}{p(z)} = \log \left( \frac{\frac{1}{\sqrt{2\pi\sigma^2}}}{\frac{1}{\sqrt{2\pi}}} \cdot \frac{\exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right)}{\exp\left(-\frac{z^2}{2}\right)} \right)\] \[= \log \left( \frac{\sqrt{2\pi}}{\sqrt{2\pi\sigma^2}} \right) + \log \left( \exp\left(-\frac{(z-\mu)^2}{2\sigma^2}\right) \cdot \exp\left(\frac{z^2}{2}\right) \right)\] \[= \log \left( \frac{1}{\sqrt{\sigma^2}} \right) + \left(-\frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2}\right)\]

여기서 $\log \frac{1}{\sqrt{\sigma^2}} = \log \frac{1}{\sigma} = -\log \sigma$와 같이 $\sigma$에 대한 항이 분리된다.

괄호 안의 두 번째 항을 전개하고 z에 대해 정리한다.

\[-\frac{(z-\mu)^2}{2\sigma^2} + \frac{z^2}{2} = -\frac{z^2 - 2\mu z + \mu^2}{2\sigma^2} + \frac{z^2}{2}\] \[= z^2\left(\frac{\sigma^2 - 1}{2\sigma^2}\right) + z\left(\frac{\mu}{\sigma^2}\right) - \left(\frac{\mu^2}{2\sigma^2}\right)\]

3. 적분 분리

로그 항이 두 부분으로 나뉘었으므로, KL 발산 적분도 두 부분으로 분리하여 계산한다.

\[KL = \int q(z \mid x) \left( -\log \sigma \right) dz + \int q(z \mid x) \left( z^2\left(\frac{\sigma^2 - 1}{2\sigma^2}\right) + z\left(\frac{\mu}{\sigma^2}\right) - \left(\frac{\mu^2}{2\sigma^2}\right) \right) dz\]

4. 첫 번째 적분 계산

첫 번째 적분 항에서 $-\log \sigma$는 적분 변수 z와 무관한 상수이므로 적분 밖으로 꺼낼 수 있다.

\[\int q(z \mid x) \left( -\log \sigma \right) dz = -\log \sigma \int q(z \mid x) \, dz = -\log \sigma \cdot 1 = -\frac{1}{2}\log \sigma^2\]

여기서 $\int q(z \mid x) dz = 1$은 확률 밀도 함수의 전체 적분값이 1이라는 성질을 활용했다.

5. 두 번째 적분 계산

두 번째 적분은 기댓값 개념을 활용하여 계산한다. $q(z \mid x)$ 분포 하에서 어떤 함수 $g(z)$의 기댓값은 $E[g(z)] = \int g(z)q(z \mid x)dz$로 정의된다.

여기서 기댓값을 계산해야 하는 함수는 다음과 같다. \(g(z) = z^2\left(\frac{\sigma^2 - 1}{2\sigma^2}\right) + z\left(\frac{\mu}{\sigma^2}\right) - \left(\frac{\mu^2}{2\sigma^2}\right)\)

기댓값의 선형성에 따라 각 항의 기댓값을 계산하여 더한다. \(E[g(z)] = \left(\frac{\sigma^2 - 1}{2\sigma^2}\right)E[z^2] + \left(\frac{\mu}{\sigma^2}\right)E[z] - \left(\frac{\mu^2}{2\sigma^2}\right)E[1]\)

$z \sim N(\mu, \sigma^2)$일 때의 기댓값은 다음과 같다.

  • $E[z] = \mu$ (정규분포의 평균)
  • $E[z^2] = \mu^2 + \sigma^2$ (정규분포의 2차 모멘트)
  • $E[1] = 1$

$E[z]=\mu$와 $E[z^2]=\mu^2+\sigma^2$가 되는 이유를 살펴보자. $z = \mu + \sigma x$로 변수 치환하면, $x$는 표준정규분포 $N(0, 1)$을 따른다.

표준정규분포의 특성:

  • $E[x]=0$ (평균)
  • $E[x^2]=1$ (분산, $E[x^2] - (E[x])^2 = 1$)

이를 활용하면:

  • $E[z] = E[\mu + \sigma x] = \mu + \sigma E[x] = \mu + \sigma \cdot 0 = \mu$
  • $E[z^2] = E[(\mu + \sigma x)^2] = E[\mu^2 + 2\mu\sigma x + \sigma^2 x^2] = \mu^2 + 2\mu\sigma E[x] + \sigma^2 E[x^2] = \mu^2 + 2\mu\sigma \cdot 0 + \sigma^2 \cdot 1 = \mu^2 + \sigma^2$

이 값들을 $E[g(z)]$ 식에 대입하면:

\[E[g(z)] = \left(\frac{\sigma^2 - 1}{2\sigma^2}\right)(\mu^2 + \sigma^2) + \left(\frac{\mu}{\sigma^2}\right)(\mu) - \left(\frac{\mu^2}{2\sigma^2}\right)(1)\] \[= \frac{(\sigma^2 - 1)(\mu^2 + \sigma^2)}{2\sigma^2} + \frac{\mu^2}{\sigma^2} - \frac{\mu^2}{2\sigma^2}\]

식을 정리하면:

\[= \frac{\sigma^2\mu^2 + \sigma^4 - \mu^2 - \sigma^2 + 2\mu^2 - \mu^2}{2\sigma^2}\] \[= \frac{\sigma^2\mu^2 + \sigma^4 - \sigma^2}{2\sigma^2} = \frac{\sigma^2(\mu^2 + \sigma^2 - 1)}{2\sigma^2} = \frac{1}{2}(\mu^2 + \sigma^2 - 1)\]

6. 최종 결과 결합

단계 4와 단계 5의 결과를 합하여 최종 KL 발산 값을 얻는다.

\[KL = \left(-\frac{1}{2}\log \sigma^2\right) + \left(\frac{1}{2}(\mu^2 + \sigma^2 - 1)\right) = \frac{1}{2} (\mu^2 + \sigma^2 - \log \sigma^2 - 1)\]

최종 결과

두 정규분포 $N(\mu, \sigma^2)$와 $N(0, 1)$ 사이의 KL 발산은 다음과 같다.

\[KL(N(\mu, \sigma^2) \parallel N(0, 1)) = \frac{1}{2} \left( \mu^2 + \sigma^2 - \log \sigma^2 - 1 \right)\]

이 공식은 VAE 학습 시 재구성 오차와 함께 최적화되는 KL 손실 항으로 사용된다. 인코더 출력이 표준정규분포에서 크게 벗어나지 않도록 제약하는 역할을 한다.