使用自编码器进行图片检索

1 自编码器简介

自编码器最初提出的目的是为了学习到数据的主要规律,进行数据压缩,或特征提取,是一类无监督学习的算法。后续各种改进的算法,使得自编码器具有生成能力,例如变分自编码器(VAE)的图像生成能力。

本次放的代码都是线性版本,在我的仓库中还有cnn版本,毕竟线性版本只能处理一下这种mnist小数据集。地址:https://github.com/Guoxn1/ai。

2 线性自编码器

自动编码器是由线性层构成的,它看起来就像是一个普通的深度神经网络DNN。

特点如下:输出层的神经元数量往往与输入层的神经元数量一致、网络架构往往呈对称性,且中间结构简单、两边结构复杂。

v2-f0252095abc6b14b38cf184f65d5e0b0_r

从输入层开始压缩数据、直至架构中心的部分被称为编码器Encoder,编码器的职责是从原始数据中提取必要的信息,从原始数据中提纯出的信息被称之为编码Code或隐式表示。从编码开始拓展数据、直至输出层的部分被称为解码器Decoder,解码器的输出一般被称为重构数据,解码器的职责是将提取出的信息还原为原来的结构。

结合minst数据集,写一个用线性自编码器进行图片检索的demo。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torchvision import datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np


# 定义线性自编码器
class Line_coder(nn.Module):
def __init__(self):
super(Line_coder,self).__init__()

self.encoder = nn.Sequential(
nn.Linear(28*28,256),
nn.ReLU(True),
nn.Linear(256,64),
nn.ReLU(True),
nn.Linear(64,16),
nn.ReLU(True),
nn.Linear(16,3)
)

self.decoder = nn.Sequential(
nn.Linear(3,16),
nn.ReLU(True),
nn.Linear(16,64),
nn.ReLU(True),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,28*28),
nn.Tanh()
)

def forward(self,X):
X = self.encoder(X)
X = self.decoder(X)

return X


def data_load():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))
])

train_dataset = datasets.MNIST(root="../data",train=True,download=True,transform=transform)
test_dataset = datasets.MNIST(root="../data",download=True,train=False,transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=1,shuffle=True)

return train_loader,test_loader


def train(model,loss_fn,optim,epochs,train_loader,device):
model.train()
model.to(device)
for epoch in range(epochs):
for img,_ in train_loader:
# 扁平化 成为一维向量 以便输入到线性网络中
img = img.view(img.size(0),-1)
img = img.to(device)

output = model(img)
loss = loss_fn(output,img)

# 反向传播
optim.zero_grad()
loss.backward()
optim.step()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

def retriveed_images(query_image,train_loader,model,n,device):

query_image = query_image.view(query_image.size(0),-1).to(device)
query_feature = model.encoder(query_image)
distances = []
for img,_ in train_loader:
img = img.view(img.size(0),-1).to(device)
features = model.encoder(img)
dist = torch.norm(features-query_feature,dim=1)
distances.extend(list(zip(dist.cpu().detach().numpy(),img.cpu().detach().numpy())))

distances.sort(key=lambda x:x[0])
return [x[1] for x in distances[:n]]


def visualize_retrieval(query_image, retrieved_images):
plt.figure(figsize=(10, 2))

# 显示查询图片
plt.subplot(1, len(retrieved_images) + 1, 1)
plt.imshow(query_image.reshape(28, 28), cmap='gray')
plt.title('Query Image')
plt.axis('off')

# 显示检索到的图片
for i, img in enumerate(retrieved_images, 2):
plt.subplot(1, len(retrieved_images) + 1, i)
plt.imshow(img.reshape(28, 28), cmap='gray')
plt.title(f'Retrieved {i-1}')
plt.axis('off')

plt.show()

def test(model,test_loader,train_loader,device):
model.eval()
model.to(device)
for img,_ in test_loader:
query_image = img.view(img.size(0),-1).to(device)
break
retriveed_image = retriveed_images(query_image,train_loader,model,5,device)
visualize_retrieval(query_image.cpu().squeeze(), [img.squeeze() for img in retriveed_image])

def main():
device = "cuda"
model = Line_coder()
loss_fn = nn.MSELoss()
optim = torch.optim.Adam(model.parameters(),lr=1e-3)
epochs = 5
train_loader,test_loader = data_load()
train(model,loss_fn,optim,epochs,train_loader,device)
torch.save(model,"model.pth")

test(model,test_loader,train_loader,device)

main()

这段代码可以选出test_loader中的相似图片。

image-20231207111433820

3 卷积自编码器

将线性层替换为卷积层就是卷积自编码器。

代码部分有一点需要注意,在解码器中式中转置卷积进行维度扩张。

线性编码器会损失很多信息,所以改为卷积自编码器,对图像处理效果可能会更好。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torchvision import datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np


# 定义线性自编码器
class conv_coder(nn.Module):
def __init__(self):
super(conv_coder,self).__init__()

self.encoder = nn.Sequential(
nn.Conv2d(1,8,kernel_size=3,padding=1), # [1,28,28] -> [8,28,28]
nn.MaxPool2d(kernel_size=2,stride=2), # [8,28,28] -> [8,14,14]
nn.ReLU(),
nn.Conv2d(8,16,kernel_size=3,padding=1), #[8,14,14] -> [16,14,14]
nn.MaxPool2d(kernel_size=2,stride=2), #[16,14,14] -> [16,7,7]
nn.ReLU(),
nn.Conv2d(16,4,kernel_size=3,padding=1), #[16,7,7] -> [4,7,7]
nn.ReLU(),
nn.Flatten(),
nn.Linear(4*7*7,32),
nn.ReLU(),
nn.Linear(32,3)
)

self.decoder = nn.Sequential(
nn.Linear(3,32),
nn.ReLU(),
nn.Linear(32,4*7*7),
nn.ReLU(),
nn.Unflatten(1,(4,7,7)),
nn.ReLU(),
nn.ConvTranspose2d(4,16,kernel_size=3,padding=1),
nn.ReLU(),
# 最大池化
nn.ConvTranspose2d(16,8,kernel_size=3,padding=1,stride=2,output_padding=1),
nn.ReLU(),
# 最大池化
nn.ConvTranspose2d(8,1,kernel_size=3,padding=1,stride=2,output_padding=1),
nn.Tanh()
)

def forward(self,X):
X = self.encoder(X)
X = self.decoder(X)

return X


def data_load():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))
])

train_dataset = datasets.MNIST(root="../data",train=True,download=True,transform=transform)
test_dataset = datasets.MNIST(root="../data",download=True,train=False,transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=1,shuffle=True)

return train_loader,test_loader


def train(model,loss_fn,optim,epochs,train_loader,device):
model.train()
model.to(device)
for epoch in range(epochs):
for img,_ in train_loader:
img = img.to(device)

output = model(img)

loss = loss_fn(output,img)

# 反向传播
optim.zero_grad()
loss.backward()
optim.step()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

def retriveed_images(query_image,train_loader,model,n,device):

query_image = query_image.to(device)
query_feature = model.encoder(query_image)
distances = []
for img,_ in train_loader:
img = img.to(device)
features = model.encoder(img)
dist = torch.norm(features-query_feature,dim=1)
distances.extend(list(zip(dist.cpu().detach().numpy(),img.cpu().detach().numpy())))

distances.sort(key=lambda x:x[0])
return [x[1] for x in distances[:n]]


def visualize_retrieval(query_image, retrieved_images):
plt.figure(figsize=(10, 2))

# 显示查询图片
plt.subplot(1, len(retrieved_images) + 1, 1)
plt.imshow(query_image.reshape(28, 28), cmap='gray')
plt.title('Query Image')
plt.axis('off')

# 显示检索到的图片
for i, img in enumerate(retrieved_images, 2):
plt.subplot(1, len(retrieved_images) + 1, i)
plt.imshow(img.reshape(28, 28), cmap='gray')
plt.title(f'Retrieved {i-1}')
plt.axis('off')

plt.show()

def test(model,test_loader,train_loader,device):
model.eval()
model.to(device)
for img,_ in test_loader:
query_image = img.to(device)
break
retriveed_image = retriveed_images(query_image,train_loader,model,5,device)
visualize_retrieval(query_image.cpu().squeeze(), [img.squeeze() for img in retriveed_image])

def main():
device = "cuda"
model = conv_coder()
loss_fn = nn.MSELoss()
optim = torch.optim.Adam(model.parameters(),lr=1e-3)
epochs = 5
train_loader,test_loader = data_load()
train(model,loss_fn,optim,epochs,train_loader,device)
torch.save(model,"model_cnn.pth")

test(model,test_loader,train_loader,device)

main()

输出如下:

image-20231207134937945

还是出现了一定的误判,但是整体还是可以的。可能是数据集太简单,过拟合了。

4 变分自动编码器

对于基本自编码器来说,只能够对原始数据进行压缩,不具备生成能力,也就是我们给解码器任意数据作为输入,解码器能够给我们生成我们想要的东西。主要原因是,基本自编码器给定一张图片生成原始图片,从输入到输出都是确定的,没有任何随机的成分,为了使模型表现很好,在不断的迭代训练中,编码器的输出也就是解码器的输入会趋于确定,这样才能让解码器能生成与输入数据更接近的数据,以使损失变得更小。但是这就与生成器的初衷有悖了。

对于VAE来说,编码器的输入是原始数据X,但解码器的输入不是编码器的输出了,而是从满足一定分布中随机抽样出的Z。因此当变分自动编码器被训练好之后,我们可以只取架构中的解码器来使用:只要对解码器输入满足特定分布的随机数Z,解码器就可以生成像从原始数据X中抽样出来的数据,如此就能够实现图像生成。许多论文已经证明,变分自动编码器的生成能力足以与一些生成对抗网络分庭抗礼,但这一架构在生成领域的局限也很明显:与GAN一样,变分自动编码器能够获得的信息只有随机数Z,因此在面临复杂数据时架构会显得有些弱小。

4.1 基本架构

与普通自动编码器一样,变分自动编码器有编码器Encoder与解码器Decoder两大部分组成,原始图像从编码器输入,经编码器后形成隐式表示(Latent Representation),之后隐式表示被输入到解码器、再复原回原始输入的结构。然而,与普通Autoencoders不同的是,变分自用编码器的Encoder与Decoder在数据流上并不是相连的,我们不会直接将Encoder编码后的结果传递给Decoder,而是要使得隐式表示满足既定分布。

①首先,变分自动编码器中的编码器会尽量将样本 X 所携带的所有特征信息的分布转码成类高斯分布。 ②编码器需要输出该类高斯分布的均值 u 与标准差 a 作为编码器的输出。 ③以编码器生成的均值u 与标准差 a 为基础构建高斯分布。 ④从构建的高斯分布中随机采样出一个数值 Z ,将该数值输入解码器。 ⑤解码器基于 Z进行解码,并最终输出与样本的原始特征结构一致的数据,作为VAE的输出X1 。

img

根据以上流程,变分自动编码器的Encoder在输出时,并不会直接输出原始数据的隐式表示,而是会输出从原始数据提炼出的均值 u 和标准差 a 。之后,我们需要建立均值为 u 、标准差为 a 的正态分布,并从该正态分布中抽样出隐式表示z,再将隐式表示z输入到Decoder中进行解码。对隐式表示z而言,它传递给Decoder的就不是原始数据的信息,而只是与原始数据同均值、同标准差的分布中的信息了。

4.2 正向传播、损失函数、重参数化

m个样本,每个样本有5个特征。

正向传播:

img

img

当前的均值和标准差不是真实数据的统计量,而是通过Encoder推断出的、当前样本数据可能服从的任意分布中的属性。我们可以令Encoder的输出层存在3个神经元,这样Encoder就会对每一个样本推断出三对不同的均值和标准差。这个行为相当于对样本数据所属的原始分布进行估计,但给出了三个可能的答案。因此现在,在每个样本下,我们就可以基于三个均值和标准差的组合生成三个不同的正态分布了。

每个样本对应了3个正态分布,而3个正态分布中可以分别抽取出三个数字z,此时每个隐式表示z就是一个形如(m,3)的矩阵。将这一矩阵放入Decoder,则Decoder的输入层也需要有三个神经元。此时,我们的隐式空间就是(m,3)。

对任意的自动编码器而言,隐式空间越大,隐式表示z所携带的信息自然也会越多,自动编码器的表现就可能变得更好,因此在实际使用变分自动编码器的过程中,一个样本上至少都会生成10~100组均值和标准差,隐式表示z的结构一般也是较高维的矩阵。

损失函数:

img

那么现在我们让编码器输出的概率分布和我们的先验的概率分布一样,这样就能够完成我们的生成任务。当然为了保证输出的精度,还需要让模型的输出 X1与模型的输入 X 存在一定的制约关系。

常用KL散度来表示两个分布之间的差异,KL散度越小,分布越接近。

img

由于这个抽样流程的存在,架构中的数据流是断裂的,因此反向传播无法进行。反向传播要求每一层数据之间必有函数关系,而抽样流程不是一个函数关系,因此无法被反向传播。为了解决这一问题,变分自动编码器的原始论文提出了重参数化技巧,这一技巧可以帮助我们在抽样的同时建立Z u a 之间的函数关系,这样就可以令反向传播顺利进行了。

也就是不对mu,sigma的正态分布进行采样,而是对0,1的标准正态分布进行采样,再由 采样值*sigma+mu 来获取z,此时z显然满足mu,sigma的正态分布,且z并非由于采样得到,可由上述公式进行求导和梯度反向传播。

img

4.3 代码实现

以mnist手写数据集为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch

torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 200

device="cuda"



class VariationalEncoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalEncoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)
self.linear3 = nn.Linear(512, latent_dims)

self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc # hack to get sampling on the GPU
self.N.scale = self.N.scale
self.kl = 0

def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = F.relu(self.linear1(x))
mu = self.linear2(x).to(device)
sigma = torch.exp(self.linear3(x)).to(device)
sam = self.N.sample(mu.shape).to(device)
z = mu + sigma * sam
self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()
return z


class Decoder(nn.Module):
def __init__(self, latent_dims):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 512)
self.linear2 = nn.Linear(512, 784)

def forward(self, z):
z = F.relu(self.linear1(z))
z = torch.sigmoid(self.linear2(z))
return z.reshape((-1, 1, 28, 28))


class VariationalAutoencoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalAutoencoder, self).__init__()
self.encoder = VariationalEncoder(latent_dims).to(device)
self.decoder = Decoder(latent_dims).to(device)

def forward(self, x):
z = self.encoder(x)
return self.decoder(z)


def train(autoencoder, data, epochs=20):
device="cuda"
autoencoder = autoencoder.to(device)
opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, _ in data:
x = x.to(device) # GPU

x_hat = autoencoder(x)
loss = ((x - x_hat) ** 2).sum() + autoencoder.encoder.kl
opt.zero_grad()
loss.backward()
opt.step()
print(epoch)
return autoencoder


def plot_latent(variational_autoencoder, data, num_batches=100):
for i, (x, y) in enumerate(data):
z = variational_autoencoder.encoder(x.to(device))
z = z.to('cpu').detach().numpy()
plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
if i > num_batches:
plt.colorbar()
break


def plot_reconstructed(autoencoder, r0=(-5, 10), r1=(-10, 5), n=12):
w = 28
img = np.zeros((n * w, n * w))
for i, y in enumerate(np.linspace(*r1, n)):
for j, x in enumerate(np.linspace(*r0, n)):
z = torch.Tensor([[x, y]]).to(device)
x_hat = autoencoder.decoder(z)
x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
img[(n - 1 - i) * w:(n - 1 - i + 1) * w, j * w:(j + 1) * w] = x_hat
plt.imshow(img, extent=[*r0, *r1])


def interpolate(autoencoder, x_1, x_2, n=12):
z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)
z = torch.stack([z_1 + (z_2 - z_1) * t for t in np.linspace(0, 1, n)])
interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()

w = 28
img = np.zeros((w, n * w))
for i, x_hat in enumerate(interpolate_list):
img[:, i * w:(i + 1) * w] = x_hat.reshape(28, 28)
plt.imshow(img)
plt.xticks([])
plt.yticks([])


if __name__ == "__main__":
# device="cpu"
latent_dims = 2
vae = VariationalAutoencoder(latent_dims) # GPU
data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data',
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=128,
shuffle=True)

vae = train(vae, data,epochs=10)
#plot_latent(vae, data)
#plt.show()
plot_reconstructed(vae, r0=(-3, 3), r1=(-3, 3))
plt.show()
x, y = next(iter(data)) # hack to grab a batch
x_1 = x[y == 1][1].to(device) # find a 1
x_2 = x[y == 0][1].to(device) # find a 0

interpolate(vae, x_1, x_2, n=20)
plt.show()
image-20231207213052683

生成出的2的图像,怎么感觉还不如上面的第一个2。

image-20231207214351902

篇幅受限,就不展示cnn_vae了,但是效果可以展示如下,可以看到,效果要好一点。都是根据已有的图片生成的噢。

image-20231207222459181

5 CVAE 条件变分自编码器

我们上面是可以指定生成什么数字的,主要的原理就是给一个相关的图,比如给一个2的图,生成一个2,但是这有很多不便。

论文中标准结构的CVAE,encoder的输入变为原始的手写数字图像和数字类别信息的拼接,输出不变。decoder的输入变为正态分布采样z与数字类别信息的拼接,输出不变。

将数字类别信息作为条件加入到encoder和decoder的输入中,由此来指定数字类别生成对应的数字图片。

img

数字类别输入encoder和decoder前需要经过onehot encoding(独热编码),将0-9这个十类别的单个数字变为10维向量,该向量仅在数字值对应的位置上取1,其他位置取0。(例如:数字0对应的向量就是[1,0,0,0,0,0,0,0,0,0],数字1对应的向量是[0,1,0,0,0,0,0,0,0,0],等等)。

该行为的目的是把0-9这10个数字对应的数据无关化(互相垂直),独立化(平等对待,忽视其值的大小)。例如,onehot encoding前数字0到1,0与9的距离不同,但对于分类任务,所有数字是平等的,onehot后的十维向量,0到1的距离和0到9的距离相同,而且0-9这10个十维向量互相垂直,线性无关,谁也无法用剩余的来表示。

另一个角度,也可将这个10维向量看作是该手写数字图像对应的0-9这十种数字的概率,由于该数字图像的事实结果是0-9的某一个,所以对应维度的概率值为1,其他的维度为0。

除此之外,训练模型的目标与VAE无差别。

训练完成使用时,CVAE的decoder输入为n个标准正态分布的采样值,及指定的数字类别,输出为指定数字类别对应的手写数字图像。

这种加入条件的思路在wavenet,tacotron等声音模型中常使用,把人物的声音特征作为条件加入,可生成相同内容,但不同人声的语音。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch

torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 200

device="cuda"



class VariationalEncoder(nn.Module):
def __init__(self, input_size,hidden_size,latent_dims):
super(VariationalEncoder, self).__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, latent_dims)
self.linear3 = nn.Linear(hidden_size, latent_dims)

self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc # hack to get sampling on the GPU
self.N.scale = self.N.scale
self.kl = 0

def forward(self, x,c):
x = torch.flatten(x, start_dim=1)
x = torch.cat((x,c),dim=1)
x = F.relu(self.linear1(x))
mu = self.linear2(x).to(device)
sigma = torch.exp(self.linear3(x)).to(device)
sam = self.N.sample(mu.shape).to(device)
z = mu + sigma * sam
self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()
return z


class Decoder(nn.Module):
def __init__(self, latent_dims,hidden_size,output_size):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, hidden_size)
self.linear2 = nn.Linear(hidden_size, output_size)

def forward(self, z,c):
z = torch.cat((z,c),dim=1)
z = F.relu(self.linear1(z))
z = torch.sigmoid(self.linear2(z))
return z.reshape((-1, 1, 28, 28))


class VariationalAutoencoder(nn.Module):
def __init__(self, input_size,output_size,hidden_size,latent_dims,condition_size):
super(VariationalAutoencoder, self).__init__()
self.encoder = VariationalEncoder(input_size+condition_size,hidden_size,latent_dims).to(device)
self.decoder = Decoder(latent_dims+condition_size,hidden_size,output_size).to(device)

def forward(self, x,c):

z = self.encoder(x,c)

return self.decoder(z,c)


def train(autoencoder, data, epochs=10):
device="cuda"
autoencoder = autoencoder.to(device)
opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
y = F.one_hot(y.to(device),condition_size)
x_hat = autoencoder(x,y)
loss = ((x - x_hat) ** 2).sum() + autoencoder.encoder.kl
opt.zero_grad()
loss.backward()
opt.step()
print(epoch)
return autoencoder



def plot_pre_imgs(cvae,latent_dims):

sample = torch.randn(1,latent_dims).to(device)
fig, ax = plt.subplots(2, 5, figsize=(10, 4))
n = 0
for i in range(2):
for j in range(5):
i_number = n*torch.ones(1).long().to(device)
condit = F.one_hot(i_number,condition_size)#将数字进行onehot encoding
gen = cvae.decoder(sample,condit)[0].view(28,28)#生成
n = n+1
ax[i, j].imshow(gen.cpu().detach().numpy(), cmap='gray')
ax[i, j].axis('off')

plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()


if __name__ == "__main__":
# device="cpu"
condition_size=10
latent_dims = 8
input_size=28*28
output_size= 28*28
hidden_size=512
cvae = VariationalAutoencoder(input_size,output_size,hidden_size,latent_dims,condition_size) # GPU
data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data',
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=128,
shuffle=True)

cvae = train(cvae, data,epochs=20)
plot_pre_imgs(cvae,latent_dims)
plt.show()
torch.save(cvae,"cvae.pth")
image-20231208210525790

以上图片全是生成的,可见效果还是不错的,但是5和7的生成略微像6和9,可能是过拟合,也可能是模型的问题,但总体还是不错的。

还有一种代码是encoder不给标签,decoder给标签,感觉效果应该一般,没进行code,感觉效果会差。

上面是线性的条件变分自编码器,这边改成卷积的再生成一遍,篇幅原因,代码都在我的仓库中,效果如下:

image-20231208215549456

全是生成的噢,可见,效果还是不错的,记住了空间信息,5和7的问题明显改善。

看到后面还有其他更有力的生成模型,之后会继续更新,这篇没还完噢。


使用自编码器进行图片检索
http://example.com/2023/12/08/使用自编码器进行图片检索/
作者
Guoxin
发布于
2023年12月8日
许可协议