DH. AI

VAE (Auto-Encoding Variational Bayes) 코드 구현 본문

[딥러닝]/[Generative model]

VAE (Auto-Encoding Variational Bayes) 코드 구현

도환 2023. 3. 13. 17:29

데이터 셋은 MNIST를 이용한다.

전체 코드

VAE 논문 리뷰

VAE의 Encoder

x를 입력받아 간단한 Linear Layer를 거친 후 z, mu, logvar를 내보낸다. z는 Decoder에 들어갈 값이고, mu와 logvar는 Regularization term을 계산할 때 쓰인다.  logvar :  log(σ²)

 

def reparameterization(mu, logvar):
    std = torch.exp(logvar/2) # logvar :  log(σ²)
    eps = torch.randn_like(std)
    return mu + eps * std


class Encoder(nn.Module):
    def __init__(self, x_dim=img_size**2, h_dim=hidden_dim, z_dim=latent_dim):
        super(Encoder, self).__init__()

        # 1st hidden layer
        self.fc1 = nn.Sequential(
            nn.Linear(x_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )

        # 2nd hidden layer
        self.fc2 = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )

        # output layer
        self.mu = nn.Linear(h_dim, z_dim)
        
        self.logvar = nn.Linear(h_dim, z_dim)

    def forward(self, x):
        x = self.fc2(self.fc1(x))

        mu = F.relu(self.mu(x))
        logvar = F.relu(self.logvar(x))

        z = reparameterization(mu, logvar)
        return z, mu, logvar

 

VAE의 Decoder

Encoder에서 출력된 z를 입력받아 x_reconst를 출력한다. x_reconst는 reconstruction Error를 계산할 때 사용된다. 디코더의 확률분포 가정을 어떻게 하느냐에 따라서 마지막 활성화 함수 부분이나 출력 부분이 달라질 수 있다.

 

class Decoder(nn.Module):
    def __init__(self, x_dim=img_size**2, h_dim=hidden_dim, z_dim=latent_dim):
        super(Decoder, self).__init__()

        # 1st hidden layer
        self.fc1 = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=0.2),
        )

        # 2nd hidden layer
        self.fc2 = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )

        # output layer
        self.fc3 = nn.Linear(h_dim, x_dim)

    def forward(self, z):
        z = self.fc2(self.fc1(z))
        x_reconst = F.sigmoid(self.fc3(z)) 
        
        # 디코더를 Bernoulli(p)로 가정했기 때문에 디코더의 출력은 p이다.
        # 출력값을 0~1으로 만들기 위해서 시그모이드를 적용
        # 만약 디코더의 분포를 정규분포로 가정한다면 디코더의 출력은 정규분포의 파라미터인
        # μ와 σ가 될 것이고 마지막 Layer의 Activation도 Relu로 구현했을 것이다.
        return x_reconst

 

 

출처 : Smart Design Lab 강남우 교수님의 강의

아래 Train 코드에서 loss 부분은 위 그림을 참고하자. - 수식 유도

Train

 

for epoch in range(n_epochs):
    train_loss = 0
    for i, (x, _) in enumerate(train_dataloader):
        # forward
        x = x.view(-1, img_size**2)
        x = x.to(device)
        z, mu, logvar = encoder(x)
        x_reconst = decoder(z)

        # compute reconstruction loss and KL divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')
        # reduction : sum -> 출력값의 합을 내보냄
        
        # Regularization
        kl_div = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)

        # backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Loss를 Reconstruction error와 Regularization으로 정의하고 학습을 진행한다.

 

Decoder Reconstruction

 

Reconstruction 결과

홀수 column : 원래 Test이미지

짝수 column : Reconstrunction 결과

 

결과가 조금 blur 한데, 이 결과에 대한 해석으로 GAN과 VAE의 차이점을 비교한 해석을 보자.


GAN도 생성모델이지만 blur한 현상은 잘 나타나지 않는다. Gan은 Discriminator에 이미지 전체가 입력되고, 이 이미지가 진짜인지 가짜인지 즉 0 또는 1의 값이 출력되는 방면, VAE는 픽셀마다의 Reconstrunction Error를 모두 계산하여 더한다.

따라서 VAE는 픽셀별 error 값을 평균적으로 줄이는 방향으로 학습하게 되어 결과가 blur 하게 나타난다는 해석이다.

VAE와 달리 GAN은 이미지 전체를 보고 진짜 이미지라고 판별되도록 학습되는 Adversarial Loss를 쓴다.

 

이론만 공부했을때는 모델 흐름에 대한 완전한 이해가 부족했는데 역시 코드를 한번 보니 이해가 잘 된다! 

 

 

 

 

hong_journey 님의 블로그를 참고하였습니다.