目录
1.wgan产生背景
(1)超参数敏感
(2)模型崩塌
2.wgan主要解决的问题
- 引入了一种新的分布距离度量方法:wasserstein距离,也称为(earth-mover distance)简称em距离,表示从一个分布变换到另一个分布的最小代价。
- 定义了一种称为wasserstein gan的gan形式,该形式使em距离的合理有效近似最小化,并且本文从理论上证明了相应的优化问题是合理的。
-
wgan解决了gans的主要训练问题。特别是,训练wgan不需要维护在鉴别器和生成器的训练中保持谨慎的平衡,并且也不需要对网络架构进行仔细的设计。模式在gans中典型的下降现象也显著减少。wgan最引人注目的实际好处之一是能够通过训练鉴别器进行运算来连续地估计em距离。绘制这些学习曲线不仅对调试和超参数搜索,但也与观察到的样品质量。
3.不同距离的度量方式
(1)方式一
(2)方式二
(3)方式三
(4)方式四
4.wgan原理
(1)p和q分布下的距离计算
(2)em距离转换优化目标推导
(3)判别器和生成器的优化目标
5.wgan训练算法
具体实现代码如下:
for epoch in range(num_epochs):
for batch_idx,(data,_) in enumerate(dataloader):
data = data.to(device)
cur_batch_size = data.shape[0]
#train: critic : max[critic(real)] - e[critic(fake)]
loss_critic = 0
for _ in range(critic_iterations):
noise = torch.randn(size = (cur_batch_size,z_dim,1,1),device=device)
fake_img = gen(noise)
#使用reshape主要是将最后的维度从[1,1,1,1]=>[1]
critic_real = critic(data).reshape(-1)
critic_fake = critic(fake_img).reshape(-1)
loss_critic = (torch.mean(critic_real)- torch.mean(critic_fake))
opt_critic.zero_grad()
loss_critic.backward(retain_graph=true)
opt_critic.step()
#clip critic weight between -0.01 , 0.01
for p in critic.parameters():
p.data.clamp_(-weight_clip,weight_clip)
#将维度从[1,1,1,1]=>[1]
gen_fake = critic(fake_img).reshape(-1)
#max e[critic(gen_fake)] <-> min -e[critic(gen_fake)]
loss_gen = -torch.mean(gen_fake)
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
6.wgan网络结构
7.数据集下载
8.wgan代码实现
9.mainwindow窗口显示生成器生成的图片
10.模型下载
参考文章:
发表评论