变分自编码(Variational Auto-Encoder)基本理论与模型实现

变分自编码器(VAE)原理介绍

为什么必须要搞清楚VAE的原理?因为它可以看做是扩散模型(diffusion model)的一个前置方法,或者反过来说,扩散模型可以看做是一个加强版的VAE(从VAE加级联hierarchy到HVAE,再加入马尔可夫性到MHVAE,再加入几个约束:隐变量尺寸和原图一致,编码过程不是网络而是高斯加噪,以及最终处理成高斯分布,就得到了Variational Diffusion Model)。为了从目的论的角度理解扩散模型的设计,需要先了解VAE的基本原理和实现方法。

VAE模型要解决的问题:给定一个数据分布的一些样本,学到这个分布,从而可以用来生成同分布(但是不仅限于已有样本)的新数据。被归为生成模型。

基本思路:基于隐变量(latent space)的先验假设,即数据分布可以用隐变量来生成,只要我们找到一个合适的隐变量分布,以及映射函数,就可以通过从隐空间采样的方式,生成目标分布的样本。

名称释义

  1. variational:变分法,主要来自于计算优化目标时采用了ELBO(证据下界,或称为变分下界VLB),通过优化ELBO逐渐逼近优化目标
  2. auto-encoder:形式上看,模型结构类似auto-encoder模型,即先通过一个encoder模型压缩输入信息,对其进行编码,然后通过一个decoder将编码结果解码到目标域内。

与普通的auto-encoder的区别:普通的AE通常用来做一些明确的pair数据的映射学习任务,比如语义分割、去噪SR等,特点是可以直接学习一对一的映射即可;但VAE是一个生成模型,它如果直接输入x编码得到z,然后映射回x,可想而知模型大概率会学到一个脉冲函数(Dirac delta function),最终没法泛化生成具有多样性的样本,这是我们不希望的。VAE的主要核心和技巧就在如何解决这个问题。

关键公式推导

  1. 如何将预测的问题通过变分法转换成编解码过程?

这里的不等关系的推导来自于琴生不等式,其中左边大于右边的根本原因来自于之间的KL divergence。

  1. 目标函数如何分解成重构损失(reconstruction term)和分布匹配损失(prior matching term)?

VAE的具体实现方式

就是标准高斯, 也是高斯分布,因此只需要预测两个值,即均值和方差即可。

这样一来,上述目标函数中的重构就可以通过再隐变量分布上采样来计算(Monte Carlo estimate),而KL散度可以解析计算(这时高斯分布的便捷性体现出来了,因为两个高斯分布的KL散度直接用就能算出来)

那么这个过程如何用网络实现呢?首先,编码器将输入的训练数据样本映射得到均值和方差两个向量,然后从它们所构造的高斯分布中采样一个,然后将输入到解码器,得到生成的结果,与原始的计算重构损失。

注意到一个问题:这里涉及到了一个采样的操作,这个过程不同于计算分布函数(本质上就是网络模拟映射),是不可微的,为了让网络可以训练,VAE采用了重参数化trick(reparameterization trick),利用高斯分布的性质,要想从采样一个样本,等价于从中采样一个,然后通过: 计算出一个,这个就是服从的一个样本。

极简版VAE代码示例

网络实现代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import torch
import torch.nn as nn

class ConvBNReLU(nn.Module):
def __init__(self, in_nc, out_nc, k=3, s=2, nlayer=1):
super().__init__()
p = (k - 1) // 2
module_list = list()
for i in range(nlayer):
mid_nc = out_nc if i == nlayer - 1 else in_nc
cur_s = s if i == nlayer - 1 else 1
module_list.extend([
nn.Conv2d(in_nc, mid_nc, kernel_size=k, stride=cur_s, padding=p),
nn.BatchNorm2d(mid_nc),
nn.ReLU(inplace=True)
])
self.block = nn.Sequential(*module_list)

def forward(self, x):
out = self.block(x)
return out

class DeconvBNReLU(nn.Module):
def __init__(self, in_nc, out_nc, k=2, s=2, nlayer=1):
super().__init__()
p = (k - 1) // 2
module_list = list()
for i in range(nlayer):
if i == 0:
module_list.extend([
nn.ConvTranspose2d(in_nc, out_nc, kernel_size=k, stride=s, padding=p),
nn.BatchNorm2d(out_nc),
nn.ReLU(inplace=True)
])
else:
module_list.extend([
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_nc),
nn.ReLU(inplace=True)
])
self.block = nn.Sequential(*module_list)

def forward(self, x):
out = self.block(x)
return out

class MLP(nn.Module):
def __init__(self, in_nc, mid_nc, out_nc, nlayer=2, is_flatten=True):
super().__init__()
mlp_blocks = list()
for i in range(nlayer):
cur_in = in_nc if i == 0 else mid_nc
cur_out = out_nc if i == nlayer - 1 else mid_nc
mlp_blocks.append(nn.Linear(cur_in, cur_out))
if i < nlayer - 1:
mlp_blocks.append(nn.BatchNorm1d(cur_out))
mlp_blocks.append(nn.ReLU(cur_out))
self.mlp = nn.Sequential(*mlp_blocks)
self.is_flatten = is_flatten

def forward(self, x):
if self.is_flatten:
vec = torch.flatten(x, start_dim=1)
else:
vec = x
out = self.mlp(vec)
return out


class Encoder(nn.Module):
def __init__(self, in_nc=1, in_size=(28, 28),
base_ch=8, num_layer=2,
mlp_nc=128, latent_dim=64):
super().__init__()
in_h, in_w = in_size
self.out_size = (base_ch * (2 ** (num_layer - 1)), in_h // 2 ** num_layer, in_w // 2 ** num_layer)
mlp_in_nc = self.out_size[0] * self.out_size[1] * self.out_size[2]
basic_blocks = list()
for i in range(num_layer):
cur_in = in_nc if i == 0 else base_ch * (2 ** (i - 1))
cur_out = base_ch * (2 ** i)
basic_blocks.append(
ConvBNReLU(cur_in, cur_out, k=3, s=2, nlayer=3)
)
self.encode_block = nn.Sequential(*basic_blocks)
self.calc_mean = MLP(mlp_in_nc, mlp_nc, latent_dim, is_flatten=True)
self.calc_logvar = MLP(mlp_in_nc, mlp_nc, latent_dim, is_flatten=True)

def forward(self, x):
feat = self.encode_block(x)
mean = self.calc_mean(feat)
logvar = self.calc_logvar(feat)
return mean, logvar


class Decoder(nn.Module):
def __init__(self, in_size=(64, 7, 7), out_nc=1,
base_ch=8, num_layer=2,
mlp_nc=128, latent_dim=64):
super().__init__()
mlp_out_nc = in_size[0] * in_size[1] * in_size[2]
self.in_size = in_size
self.mlp = MLP(latent_dim, mlp_nc, mlp_out_nc, nlayer=2, is_flatten=False)
basic_blocks = list()
for i in range(num_layer):
cur_in = base_ch * (2 ** (num_layer - 1 - i))
cur_out = out_nc if i == num_layer - 1 else base_ch * (2 ** (num_layer - 2 - i))
basic_blocks.append(
DeconvBNReLU(cur_in, cur_out, k=2, s=2, nlayer=3)
)
self.decode_block = nn.Sequential(*basic_blocks)

def forward(self, x):
feat = self.mlp(x)
feat = feat.view(-1, *self.in_size)
out = self.decode_block(feat)
return out


class VariationalAutoEncoder(nn.Module):
def __init__(self, in_nc=1, in_size=(28, 28),
base_ch=8, num_layer_enc=2, num_layer_dec=2,
mlp_nc=128, latent_dim=64, decoder_in_size=(16, 7, 7)):
super().__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(in_nc, in_size, base_ch, num_layer_enc, mlp_nc, latent_dim)
self.decoder = Decoder(decoder_in_size, in_nc, base_ch, num_layer_dec, mlp_nc, latent_dim)

def reparam(self, mu, logvar):
std = torch.exp(logvar / 2)
eps = torch.randn_like(std)
return mu + std * eps

def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparam(mu, logvar)
output = self.decoder(z)
return output, mu, logvar

def sample(self, n_samples):
z = torch.randn(n_samples, self.latent_dim)
samples = self.decoder(z)
return samples


if __name__ == "__main__":
vae = VariationalAutoEncoder()
dummy_input = torch.randn(4, 1, 28, 28)
out = vae(dummy_input)
print(out.size())

训练过程采用简单的MNIST数据集,不做引导地随机生成手写数字,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import cv2

from autoencoder import VariationalAutoEncoder

# parameter settings #
transform = transforms.Compose([
transforms.ToTensor(),
])
batch_size = 8
num_epochs = 10
learning_rate = 1e-3
gen_num_samples = 16
# parameter settings #

def VAE_loss(x, recon, mu, logvar, lamda=10.0):
reconstruction_term = torch.mean((x - recon) ** 2)
prior_matching_term = 0.5 * torch.mean(-1 - logvar + torch.exp(logvar) + mu ** 2)
loss = reconstruction_term * lamda + prior_matching_term
return loss

train_dataset = torchvision.datasets.MNIST(root='data', transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

vae_model = VariationalAutoEncoder()
optimizer = Adam(vae_model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
for step, batch_data in enumerate(train_loader):
inputs, _ = batch_data
optimizer.zero_grad()
out, mu, logvar = vae_model(inputs)
loss = VAE_loss(inputs, out, mu, logvar)
loss.backward()
optimizer.step()

if step % 100 == 0:
print(f'Epoch {epoch}/{num_epochs}, Iter {step}/{len(train_loader)}, Loss: {loss.item()}')

samples = vae_model.sample(gen_num_samples)
samples = torch.clamp(samples, 0, 1)
vis = make_grid(samples, nrow=int(np.sqrt(gen_num_samples)))
vis = np.array(torchvision.transforms.ToPILImage()(vis))
cv2.imwrite(f'vis_{epoch}.png', vis)

训练结果如下(epoch=6):

VAE生成结果

下面实验验证一下loss中的两个项目的作用,首先,只采用reconstruction term,不加入prior matching term:

只用重构误差损失的生成结果

从效果上来看,更加清晰不模糊,但是随机高斯分布上sample过decoder的结果没有明确的数字的语义。这个很好理解,因为我们没有约束中间的隐变量是标准高斯,因此从标准高斯采样自然是没有意义的。

然后,只采用prior matching term,不做重构损失的约束,效果如下:

只采用分布约束损失的生成结果

可以看出,模型完全没有学到编解码关系,相当于只约束了编码的分布,没有与原图对应,因此也就没有语义。

进一步思考:为什么不直接约束映射结果为标准高斯,然后解码结果为原图,而非要将每个训练集样本映射到某个mu和sigma的高斯分布,然后再让它逼近标准高斯?这样的结果不还是映射到一个标准高斯么?

对此的一点个人理解:实际上,VAE最终的结果就是希望映射到标准高斯,这样就可以直接采样生成了。但是训练过程为什么要迂回一步呢?我们设想如果直接约束latent为标准高斯,那应该如何施加这个目标呢?因为每个样本对应的都是标准高斯的一个sample,如果要从sample约束分布是很困难的,因为我们sample这个过程不能用映射的格式写出来。但是如果不用映射,真的去标准高斯采样的话,那么有无法区分那个映射到哪个了。VAE的方法则巧妙解决了这个问题,它通过重参数化过程真的将采样用解析的映射表示出来了,那么如何将采样结果与原始输入建立关联呢?自然就是将映射的值放到这个解析映射的参数上(即重参数化trick中的 中的)。于是 就成了要映射的目标值,而采样过程也被转化成了一个可微分的过程,最后只要能让映射的接近0和1,最终模型就可以从标准高斯分布采样生成。