VAE is a generative model that leverages Neural Network as function approximator to model a continuous latent variable with intractable posterior distribution. If you are interested in the theory of VAE I suggest to look at the original paper variable or this awesome tutorial by Carl Doersch. In this tutorial I aim to explain how to implement a VAE in Pytorch.
Similar to any other machine learning techniques we require four main blocks:
- data
- model
- training
- inference
Data
For data let’s use MNIST dataset. Pytorch vision module has an easy way to create training and test dataset for MNIST
from torchvision import datasets, transforms
# training
BATCH_SIZE = 100
trainset = datasets.MNIST('./data/', train=True, download=True,
transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
# test
testset = datasets.MNIST('./data/', train=False, download=False,
transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=2)
Before proceeding, let’s visualize some data. For that I am using torchvision.utils.make_grid
which creates a grid from multiple images:
def show_images(images):
images = torchvision.utils.make_grid(images)
show_image(images[0])
def show_image(img):
plt.imshow(img, cmap='gray')
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
show_images(images)
Network Architecture
Similar to deniosing auto encoder, VAE has an encoder and decoder.
1. Encoder
The encoder encodes an image to a varibale $z$ with normal distribution. For normal distribution we just need to approximate mean $m$ and standard deviation $s$. Therefore, the role of neural network is to learn a funcion from image to $m$ and $s$. This implicitly means we are learning a function from image to a probability distribution for $z$. We implement that function approximator using linear matrix and RELU nonlinearity:
self.fc1 = nn.Linear(784, 400)
self.fc2m = nn.Linear(400, latent_variable_dim) # use for mean
self.fc2s = nn.Linear(400, latent_variable_dim) # use for standard deviation
x = input.view(-1, 784) # input is a 28x28 mnist image
x = torch.relu(self.fc1(x))
log_s = self.fc2s(x)
m = self.fc2m(x)
Where latent_variable_dim
determines the dimension of normal distribution for latent variable $z$.
2. Decoder
The decoder gets the encoded value $z$, which in theory is reffered to as latent variable, and decodes that value to an image. Therefore, the role of decoder is to learn a function that maps a value of $z$ to a vector of 782
real values. Note that $z$ is in fact a random variable but here we just work with a realization (a.k.a a sampled value) of that random variable:
x = torch.relu(self.fc3(z))
x = torch.sigmoid(self.fc4(x))
Lets put everything together as a pytorch neural network module:
class VAE(nn.Module):
def __init__(self, latent_variable_dim):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc2m = nn.Linear(400, latent_variable_dim) # use for mean
self.fc2s = nn.Linear(400, latent_variable_dim) # use for standard deviation
self.fc3 = nn.Linear(latent_variable_dim, 400)
self.fc4 = nn.Linear(400, 784)
def reparameterize(self, log_var, mu):
s = torch.exp(0.5*log_var)
eps = torch.rand_like(s) # generate a iid standard normal same shape as s
return eps.mul(s).add_(mu)
def forward(self, input):
x = input.view(-1, 784)
x = torch.relu(self.fc1(x))
log_s = self.fc2s(x)
m = self.fc2m(x)
z = self.reparameterize(log_s, m)
x = self.decode(z)
return x, m, log_s
def decode(self, z):
x = torch.relu(self.fc3(z))
x = torch.sigmoid(self.fc4(x))
return x
What the whole network does is to encode an image and then reconstruct that image using its code. Therefore the goal is to make the reconstruction as close as possible to original image. This is part of loss function as I explain in next section.
Training
For doing training we need a loss function. VAE combines two type of losses
- A loss from reconstructing the image. This is simply a Cross Entropy (CE) or Mean Square Error (MSE) between decoded image and original image
- KL divergence: this loss function is for latent variable $z$. What we like to do is to make $P(z | input)$ as close as possible to standard normal (with mean zero and variance 1). Since $z$ has normal distribution with mean $m$ and variance $s$, i.e. $z~N(m, s)$ we can use this simple formula to calculate the loss function of $z$
In code, we can implement it like the following:
def loss(input_image, recon_image, mu, log_var):
CE = F.binary_cross_entropy(recon_image, input_image.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return KLD + CE
Note that for simplicity we keep log of variance instead of variance itself.
Now that we have the loss function training is very simple. We just need to read the data, run it through the network calculate the loss and do backprob. We use Adam
optimizer to do the backprob:
## train
vae = VAE(40)
train_loss = []
for epoch in range(5):
for i, data in enumerate(trainloader, 0):
images, labels = data
images = images.to(device)
optimizer.zero_grad()
recon_image, s, mu = vae(images)
l = loss(images, recon_image, mu, s)
l.backward()
train_loss.append(l.item() / len(images))
optimizer.step()
plt.plot(train_loss)
plt.show()
Testing
Testing is very similar to training. The only difference is that once in a while we like to look at some of the reconstructed images to see how the encoder and decoder are doing:
with torch.no_grad():
for i, data in enumerate(testloader, 0):
images, labels = data
images = images.to(device)
recon_image, s, mu = vae(images)
recon_image_ = recon_image.view(BATCH_SIZE, 1, 28, 28)
if i % 100 == 0:
show_images(recon_image_)
Effect of Dimension of latent variable
Before going to generation, let’s quickly look at the effect of dimension of $z$ (the latent varible) on training error, and image reconstruction.
By comparing Figure 2-(a) andFigure 2-(b) we can see the higher dimension of latent variable is helpful to generate sharper and more realistic MNIST images. When latent dimension is small the encoder has to try very hard to compact the information in a very low dimension vector which results in high lossy compression.
Just Generate
Now let’s say we like to generate some images. For that we just need to use the decoder part of the network and pass some realization of latent variable $z$:
with torch.no_grad():
z = [[0,0], [0,1], [1,1], [-1,-1], [-0.9, -0.9], [-0.5, -0.5]]
sample_images = vae.decode(torch.FloatTensor(z))
sample_images_ = sample_images.view(len(z), 1, 28,28)
print(sample_images_.size())
show_images(sample_images_)
As one can see, [-1, -1]
seems to be a code for handwriting 3
. Since we are mapping an image to a distribution not just one specific vector, a small ball around [-1, -1]
such as [-0.9, -0.9]
is also decoded to handwriting 3
. But as we are moving further away 3
becomes 8
which can be seen for vector [-0.5, -0.5]
Hope you enjoyed the tutorial. Please leave comments and feedback.