【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像

el/2024/3/2 12:03:49

GAN 生成动漫头像

  • 1. 获取数据
  • 2. 用GAN生成
    • 2.1 Generator
    • 2.2 Discriminator
    • 2.3 其它细节
    • 2.4 训练思路
  • 3. 全部代码
  • 4. 结果展示与分析
  • 小结

深度卷积生成对抗网络(DCGAN):Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

1. 获取数据

原来书里的下载链接即原来知乎的何之源分享链接失效了,一篇简书里有下载地址,一共是51223张图片,尺寸是96×96×3,总大小272 MB

利用python查看图片的大小:

法I:

from PIL import Image
image = Image.open(dir + path[0])
imgSize = image.size  #大小
print(imgSize)
(96, 96)

法II:

import cv2
img = cv2.imread(dir + path[0])
sp = img.shape
print(sp)
(96, 96, 3)

2. 用GAN生成

2.1 Generator

CONVTRANSPOSE2D

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros')

官方文档,这是逆卷积,关于卷积可看之前的CNN猫狗二分类

在由多个输入平面组成的输入图像上应用二维转置卷积算子

这个模块可以看作是Conv2d相对于其输入的梯度。它也被称为分式卷积或解卷积(虽然它不是实际的解卷操作)

  • stride控制交叉相关的步幅
  • padding控制两边隐含的零填充量,以便进行dilation * (kernel_size - 1) - padding。详情请看下面的说明
  • output_padding 控制添加到输出形状一侧的额外尺寸。详情请看下面的说明
  • dilation控制核点之间的间距,也就是所谓的à trous算法。它比较难描述,但这个链接有一个很好的可视化的扩张作用
  • groups控制输入和输出之间的连接,in_channels和out_channels都必须被分组所除

对于本实验:

  • 输入维度:noiseSize × 1 × 1
  • kernel_size=4, stride=1, padding=0,第一次变化后:(n_generator_feature * 8) × 4 × 4
  • 当kernel_size=4, stride=2, padding=1时,输入的宽高刚好是第一次的两倍,第二次变化后:(n_generator_feature * 4) × 8 × 8
  • 第三次变化后:(n_generator_feature * 2) × 16 × 16
  • 第四次变化后:n_generator_feature × 32 × 32
  • 最后一层采用kernel_size=5, stride=3, padding=1,为了将32 × 32变为96 × 96
  • 最后用Tanh将输出图片的像素归一化到-1~1,如果希望归一化到0~1,需要使用Sigmoid
Generator = NetGenerator()
x = torch.rand(1, noiseSize, 1, 1)
y = Generator(x)
print(y.size())
torch.Size([1, 3, 96, 96])

2.2 Discriminator

LeakyReLU

官方手册,inplace=True表示进行覆盖运算

Discriminator = NetDiscriminator()
x = torch.rand(1, 3, 96, 96)
y = Discriminator(x)
print(y.size())
torch.Size([1])

可以看出判别器和生成器的额网络几乎是对称的,从卷积核大小到padding、stride等设置,需要注意的是生成器的激活函数是ReLU,而判别器使用的是LeakyLeRU,二者并无本质区别,这里的选择更多的是经验总结。每一个样本经过判别器后,输出一个0~1的数,表示的是这个样本是真图片的概率

2.3 其它细节

torch.utils.data.DataLoader官方文档

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)中文文档
betas (Tuple[float, float], 可选) – 平滑常数:用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)

torch.nn.BCELoss():计算target 和output 间的二值交叉熵(Binary Cross Entropy)官方文档,计算公式可见此

关于(tqdm.tqdm)可见官方文档

2.4 训练思路

  • 训练判别器
    • 对于真图片,输出尽可能是1
    • 对于假图片,输出尽可能是0
  • 训练生成器
    • 对于假图片,输出尽可能是1

这里需要注意以下几点。

  • 训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
  • 在训练判别器时,需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
  • 在训练判别器时,需要反向传播两次,一次是希望把真图片判为1,一次是希望把假图片判为0。也可以将这两者的数据放到一个batch中,进行一次前向传播和一次反向传播即可。但是人们发现,在一个batch中只包含真图片或只包含假图片的做法最好。
  • 对于假图片,在训练判别器时,我们希望它输出0;而在训练生成器时,我们希望它输出1.因此可以看到一对看似矛盾的代码 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。其实这也很好理解,判别器希望能够把假图片判别为fake_label,而生成器则希望能把他判别为true_label,判别器和生成器互相对抗提升。

接下来就是一些可视化的代码。每次可视化使用的噪声都是固定的fix_noises,因为这样便于我们比较对于相同的输入,生成器生成的图片是如何一步步提升的。另外,由于我们对输入的图片进行了归一化处理(-1~1),在可视化时则需要将它还原成原来的scale(0~1)

3. 全部代码

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = '/mnt/Data1/ysc/GAN'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 256
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)real_image = real_image.cuda()if (i + 1) % d_every == 0:optimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i))# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0))i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.show()if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布if torch.cuda.is_available() == True:print('Cuda is available!')Generator.cuda()Discriminator.cuda()criterion.cuda()true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()fix_noises, noises = fix_noises.cuda(), noises.cuda()plot_epoch = [1,5,10,20,50,100,199]for i in range(200):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

4. 结果展示与分析

在第1,5,10,20,50,100,199分别打印结果如下所示,这里第0代没有打印:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 刚开始训练的图像比较模糊(1个epoch),但是可以看出图像已经有面部轮廓
  • 继续训练数个epoch之后,生成的图多了很多细节信息,包括头发、颜色等,但是总体还是模糊
  • 训练数个epoch之后,细节继续完善,包括头发的纹理、眼睛的细节等,但还是有不少涂抹的痕迹
  • 训练数个epoch时,已经能看出明显的面部轮廓和细节,但还是有涂抹现象,并且有些细节不够合理,例如眼睛一大一小,面部轮廓扭曲严重
  • 当训练到最大epoch会后,图片的细节已经十分完善,线条更加流畅,轮廓更清晰,虽然还有一些不合理之处,但是已经有不少图片能够以假乱真了

类似的生成动漫头像的项目还有《用DRGAN生成高清的动漫头像》,效果很好,但遗憾的是,由于论文中使用的数据涉及版权问题,未能公开。这篇论文主要改进包括使用了更高质量的图片和更深、更复杂的模型

GAN可以应用到不同的生成图片场景中,只要将训练图片改成其他类型的图片即可,例如LSUN房客图片集、MNIST手写数据集或CIFAR10数据集等。事实上,上述模型还有很大的改进空间。在这里,我们使用的全卷积网络只有四层,模型比较浅,而在ResNet的论文发表之后,也有不少研究者尝试在GAN的网络结构中引入Residual Block结构,并取得了不错的视觉效果。感兴趣可以尝试将示例代码中的单层卷积改为Residual Block,相信可以取得不错的效果

今年来,GAN的一个重大突破在于理论研究。论文《Towards Principled Methods for Training Generative Adversarial Networks》从理论的角度分析了GAN为何难以训练,作者随后在另一篇论文《Wasserstein GAN》中针对性地提出了一个更好的解决方案。但是这篇论文在部分技术细节上的实现过于随意,所以随后又有人有针对性地提出了《Improved Training of Wasserstein GANs》,更好地训练WGAN。后面两篇论文分别用PyTorch和TensorFlow实现,代码可以在GitHub上搜索到。笔者当初也尝试用100行左右的代码实现了Wasserstein GAN,该兴趣可以去了解

随着GAN研究的逐渐成熟,人们也尝试把GAN用于工业实际问题之中,而在众多相关论文中,最令人深刻的就是《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》,论文中提出了一种新的GAN结构称为CycleGAN。CycleGAN利用GAN实现风格迁移、黑白图像彩色化,以及马和斑马互相转化等,效果十分出众。论文的作者用PyTorch实现了所有的代码,并开源在GitHub上,感兴趣可以自行查阅

小结

GAN生成的结果还是比较理想吧,就是一个简单的GAN的结构,其中的Generator的反卷积与Discriminator的卷积可以琢磨一下,模型只是简单的训练,并没有保存和测试


http://www.ngui.cc/el/3526008.html

相关文章

【语音信号处理】2语音信号实践——LSTM(hidden、output)、Attention、语音可视化

语音信号处理 深度学习1. LSTM-hidden 实现细节2. LSTM-output 实现细节3. Attention4. 语音可视化5. 全部代码小结1. LSTM-hidden 实现细节 关于class torch.utils.data.Dataset官方文档&#xff0c; 当ATCH_SIZE 128&#xff0c;HIDDEN_SIZE 64&#xff0c;最大迭代次数…

【NLP】文献翻译4——CH-SIMS:中文多模态情感分析数据集与细粒度的模态注释

CH-SIMS: A Chinese Multimodal Sentiment Analysis Dataset with Fine-grained Annotations of Modality摘要1. 介绍2. 相关工作2.1 多模态数据集2.2 多模态情感分析2.3 多任务学习3. CH-SIMS 数据集3.1 数据获取3.2 标注3.3 特征提取4. 多模式多任务学习框架4.1 单模态子网4.…

【NLP】文献翻译5——用自我监督的多任务学习学习特定模式的表征,用于多模态情感分析

Learning Modality-Specific Representations with Self-Supervised Multi-Task Learning for Multimodal Sentiment Analysis摘要1. 介绍2. 相关工作2.1 多模态情感分析2.2 Transformer and BERT2.3 多任务学习3. 方法论3.1 任务设定3.2 总体架构3.3 ULGM3.4 优化目标4. 实验环…

【PyTorch】13 Image Caption:让神经网络看图讲故事

图像描述1、数据集获取2、文本数据处理3、图像数据处理4、训练5、全部代码6、总结1、数据集获取 数据来自&#xff1a;AI challenger 2017 图像描述数据集 百度网盘: https://pan.baidu.com/s/1g1XaPKzNvOurH9M44p1qrw 提取码: bag3 这里由于原训练集太大&#xff0c;这里仅使…

【PyTorch】14 AI艺术家:神经网络风格迁移

风格迁移 Style Transfer1、数据集2、原理简介3、用Pytorch实现风格迁移4、结果展示5、全部代码小结详细可参考此CSDN 1、数据集 使用COCO数据集&#xff0c;官方网站点此&#xff0c;下载点此&#xff0c;共13.5GB&#xff0c;82783张图片 2、原理简介 风格迁移分为两类&a…

【20210906】让实验室服务器运行本地python代码

从零开始配置实验室电脑的python环境1. 电脑信息2. 电脑环境配置&#xff08;1&#xff09;Pycharm&#xff08;2&#xff09;anaconda&#xff08;3&#xff09;配置Anacondapycharm环境3. 服务器环境配置小结在实验室刚刚装好的DELL电脑&#xff0c;设备规格&#xff1a;Vost…

【20210910】让实验室服务器在Anaconda环境运行本地python代码

从零开始配置服务器的python环境1. 下载Anaconda Linux2. Pycharm3. 配置服务器上Python环境4. tmux应用5. Anaconda环境小结1. 下载Anaconda Linux 可以查看服务器的Linux版本&#xff1a; cat /proc/versionLinux version 5.11.13-arch1-1 (linuxarchlinux) (gcc (GCC) 10.…

【20210919】LaTex入门:overleaf使用

overleaf在线编辑Latex1. 使用overleaf2. 一些问题小结1. 使用overleaf 2. 一些问题 overleaf官网 首先注册一下 上传模板编译报错&#xff1a; 解决办法&#xff1a; Select “menu” – “Compiler” – “XeLatex”.Compiled again, successfully. 模板感觉太复杂了&…

【20210916】GMM入门

高斯混合模型GMM&#xff08;Gaussian Mixture Model&#xff09;1. 模型介绍2. 极大似然估计MLE&#xff08;Maximum Likelihood Estimate&#xff09;3. EM求解&#xff08;Expectation-Maximization Algorithm&#xff09;(1) EM算法&#xff08;期望最大算法&#xff09;公…

【20210920】HMM入门

隐马尔可夫模型 Hidden Markov Model1. 马尔可夫过程简介2. {A、B、π\piπ}3. 引入α,β\alpha,\betaα,β便于Evaluate4. EM算法参数学习小结本文参考的视频链接 首先要知道什么式序列&#xff08;Series&#xff09;&#xff0c;什么是集合&#xff08;Set&#xff09; 时…