kok202
pyTorch - Mnist VAE

2019. 1. 24. 21:53[정리] 직무별 개념 정리/딥러닝

import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.autograd as autograd

import torch.optim as optim

import numpy as np

from torch.autograd import Variable

from torchvision import datasets

from torchvision import transforms

from torchvision.utils import save_image






'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# Mnist 데이터 로딩

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

transform = transforms.Compose([transforms.ToTensor()])

mnist = datasets.MNIST(root='../DATA_MNIST/',

                       train=True,

                       transform=transform,

                       download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

                                          batch_size=100,

                                          shuffle=True)






'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# VAE에 인코더 디코더 설정

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

class VAE(nn.Module):

    def __init__(self):

        super(VAE, self).__init__()

        self.layer_encoder1    = nn.Linear(784, 128)

        self.layer_encoder2_mu = nn.Linear(128, 100)

        self.layer_encoder2_va = nn.Linear(128, 100)     # z Dimension is 100

        self.layer_decoder1    = nn.Linear(100, 128)

        self.layer_decoder2    = nn.Linear(128, 784)

        self.relu    = nn.ReLU()

        self.sigmoid = nn.Sigmoid()

        self.z_dimension = 100

        self.batch_size = 100


    def Encoder(self, input_x):

        z = self.layer_encoder1(input_x)

        z = self.relu(z)

        z_mu = self.layer_encoder2_mu(z)

        z_mu = self.relu(z_mu)

        z_va = self.layer_encoder2_va(z)

        z_va = self.relu(z_va)

        return z_mu, z_va           # return z_mu, z_logvar


    def Decoder(self, input_z):

        x = self.layer_decoder1(input_z)

        x = self.relu(x)

        x = self.layer_decoder2(x)

        x = self.sigmoid(x)

        return x


    def forward(self, input_x):

        z_mu, z_va = self.Encoder(input_x)

        z_epsilon = Variable(torch.randn(self.batch_size, self.z_dimension))

        z_sample = z_mu + torch.exp(z_va / 2) * z_epsilon

        output = self.Decoder(z_sample)

        return output, z_mu, z_va






'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# 옵티마이저와 로스 함수 정의

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

model = VAE()

reconstruction_criterion = nn.BCELoss()

reconstruction_criterion.size_average = False

optimizer = optim.Adam(model.parameters(), lr=1e-3)







'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# Variable, denorm 정의

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

def to_var(x):

    if torch.cuda.is_available():

        x = x.cuda()

    return Variable(x)


def denorm(x):

    out = (x + 1) / 2

    return out.clamp(0, 1)






'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# 학습

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

model.train()

epoch_before = -1

for epoch in range(10):

    for i, (images, _) in enumerate(data_loader):

        batch_size = images.size(0)


        ''''''''''''''' Train model '''''''''''''''

        x = to_var(images.view(batch_size, -1))

        x_, z_mu, z_va = model(x)


        optimizer.zero_grad()

        reconstruction_loss = reconstruction_criterion(x_, x) / batch_size

        regularization_loss = torch.mean(0.5 * torch.sum(torch.exp(z_va) + z_mu ** 2 - 1 - z_va, 1))

        loss = reconstruction_loss + regularization_loss

        loss.backward()

        optimizer.step()






        '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

        # 학습 확인

        '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

        if epoch_before != epoch:

            print('Epoch [%d], Data visualization'

                  % (epoch))

            model.eval()

            epoch_before = epoch

            z_sample = Variable(torch.randn(batch_size, 100))

            images = model.Decoder(z_sample)

            images = images.view(images.size(0), 1, 28, 28)

            save_image(images.data, './data_MNIST_Create/model%d.png' % (epoch_before + 1))

            torch.save(model.state_dict(), './save_MNIST/model%d.pkl' % (epoch_before + 1))

            model.train()

        if (i + 1) % 300 == 0:

            print('Epoch [%d], '

                  'recon_loss: %.4f, '

                  'regul_loss: %.4f, '

                  'total_loss: %.4f, '

                  % (epoch,

                     reconstruction_loss.data[0],

                     regularization_loss.data[0],

                     loss.data[0]))






'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

# 학습 모델 저장

'''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

torch.save(model.state_dict(), './save_MNIST/model%d.pkl' % (epoch_before + 1))






'[정리] 직무별 개념 정리 > 딥러닝' 카테고리의 다른 글

CNN Calculater  (0) 2019.01.24
KLDivergence  (0) 2019.01.24
Tensorflow - Mnist CNN 분류  (0) 2019.01.24