Sitemap
Heartbeat

Comet is a machine learning platform helping data scientists, ML engineers, and deep learning engineers build better models faster

Follow publication

Exploring Variational Autoencoders (VAEs) for Image Compression

9 min readMay 31, 2023

--

Black and white grid of MNIST handwritten numbers with a distorted lens
photo credit: Tensorflow.org

Introduction

What are Variational Autoencoders (VAEs)?

Flow chart demonstrating the difference between a deterministic autoencoder and a probabilistic variational autoencoder
source: Understanding Variational Autoencoders (VAEs) by Joseph Rocca

Real-time model analysis allows your team to track, monitor, and adjust models already in production. Learn more lessons from the field with Comet experts.

Implementing VAE for Image Compression

Requirements

Implementation

Install and import libraries

!pip install comet_ml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.autograd import Variable
from comet_ml import Experiment
from comet_ml import Experiment

experiment = Experiment(api_key="your_api_key", project_name="vae-project")

Load and Preprocess Data

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = MNIST(root='./data', train=True, transform=transform, download=True)
test_data = MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

Define the Model

class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(7 * 7 * 64, 256)
self.fc21 = nn.Linear(256, 128)
self.fc22 = nn.Linear(256, 128)

def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.relu(self.conv2(x))
x = x.view(-1, 7 * 7 * 64)
x = nn.functional.relu(self.fc1(x))
mu = self.fc21(x)
logvar = self.fc22(x)
return mu, logvar

class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 7 * 7 * 64)
self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1)

def forward(self, z):
z = nn.functional.relu(self.fc1(z))
z = nn.functional.relu(self.fc2(z))
z = z.view(-1, 64, 7, 7)
z = nn.functional.relu(self.conv1(z))
z = nn.functional.tanh(self.conv2(z))
return z

class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()

def forward(self, x):
mu, logvar = self.encoder(x)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
recon_x = self.decoder(z)
return recon_x, mu, logvar

Define the loss function

def vae_loss(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy_with_logits(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

Initialize the model and define the optimizer

# Define hyperparameters
latent_size = 20
input_channels = 1
output_channels = 1
hidden_dims = [32, 64, 128, 256]

# Initialize model and move it to device
model = VAE(latent_size, input_channels, output_channels, hidden_dims).to(device)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

Train the model

for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
current_loss = train_loss / (batch_idx + 1)
print('Epoch: {} [{}/{} ({:.0f}%)]\tTraining Loss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
current_loss))
experiment.log_metric('train_loss', current_loss, step=(epoch * len(train_loader) + batch_idx))

Evaluate the model

model.eval()
test_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += vae_loss(recon_batch, data, mu, logvar).item()

test_loss /= len(test_loader.dataset)
print('Test Loss: {:.6f}'.format(test_loss))
experiment.log_metric('test_loss', test_loss)

# visualize the reconstructed images
num_images = 8
for i in range(num_images):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(data[i][0], cmap='gray')
axs[0].set_title('Original Image')
axs[1].imshow(recon_batch[i][0].cpu().detach().numpy(), cmap='gray')
axs[1].set_title('Reconstructed Image')
experiment.log_figure(figure_name=f"Reconstructed Image {i+1}", figure=plt)
A black and white grid showing the original handwritten numbers 1 and 9 from the MNIST data set, and their reconstruction images, as generated from the variational autoencoder
reconstructed images generated by the VAE

Conclusion

Resources

--

--

Heartbeat
Heartbeat

Published in Heartbeat

Comet is a machine learning platform helping data scientists, ML engineers, and deep learning engineers build better models faster

Boluwatife Victor O.
Boluwatife Victor O.

Written by Boluwatife Victor O.

A technical writer who embodies the finesse of Art, science, & persuasion. Expert in cryptocurrency/web3, AI, & software development.

No responses yet

Write a response