当前位置: 代码网 > 科技>人工智能>神经网络 > WGAN基本原理及Pytorch实现WGAN

WGAN基本原理及Pytorch实现WGAN

2024年08月05日 神经网络 我要评论
WGAN基本原理及Pytorch实现WGAN

目录

1.wgan产生背景

(1)超参数敏感

(2)模型崩塌

2.wgan主要解决的问题

3.不同距离的度量方式

(1)方式一

(2)方式二

(3)方式三

(4)方式四

4.wgan原理

(1)p和q分布下的距离计算 

(2)em距离转换优化目标推导

(3)判别器和生成器的优化目标

5.wgan训练算法 

6.wgan网络结构

7.数据集下载

8.wgan代码实现 

9.mainwindow窗口显示生成器生成的图片

10.模型下载 


gan原理及pytorch框架实现gan(比较容易理解)

pytorch框架实现dcgan(比较容易理解)

cyclegan的基本原理以及pytorch框架实现

1.wgan产生背景

(1)超参数敏感

(2)模型崩塌

2.wgan主要解决的问题

  • 引入了一种新的分布距离度量方法:wasserstein距离,也称为(earth-mover distance)简称em距离,表示从一个分布变换到另一个分布的最小代价。
  • 定义了一种称为wasserstein gan的gan形式,该形式使em距离的合理有效近似最小化,并且本文从理论上证明了相应的优化问题是合理的。
  • wgan解决了gans的主要训练问题。特别是,训练wgan不需要维护在鉴别器和生成器的训练中保持谨慎的平衡,并且也不需要对网络架构进行仔细的设计。模式在gans中典型的下降现象也显著减少。wgan最引人注目的实际好处之一是能够通过训练鉴别器进行运算来连续地估计em距离。绘制这些学习曲线不仅对调试和超参数搜索,但也与观察到的样品质量。

3.不同距离的度量方式

(1)方式一

(2)方式二

 

(3)方式三

 

(4)方式四

 

4.wgan原理

(1)p和q分布下的距离计算 

 

(2)em距离转换优化目标推导

 

(3)判别器和生成器的优化目标

 

 

5.wgan训练算法 

具体实现代码如下:

for epoch in range(num_epochs):
    for batch_idx,(data,_) in enumerate(dataloader):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        #train: critic : max[critic(real)] - e[critic(fake)]
        loss_critic = 0
        for _ in range(critic_iterations):
            noise = torch.randn(size = (cur_batch_size,z_dim,1,1),device=device)
            fake_img = gen(noise)
            #使用reshape主要是将最后的维度从[1,1,1,1]=>[1]
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake_img).reshape(-1)

            loss_critic = (torch.mean(critic_real)- torch.mean(critic_fake))
            opt_critic.zero_grad()
            loss_critic.backward(retain_graph=true)
            opt_critic.step()

            #clip critic weight between -0.01 , 0.01
            for p in critic.parameters():
                p.data.clamp_(-weight_clip,weight_clip)

        #将维度从[1,1,1,1]=>[1]
        gen_fake = critic(fake_img).reshape(-1)
        #max e[critic(gen_fake)] <-> min -e[critic(gen_fake)]
        loss_gen = -torch.mean(gen_fake)
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

6.wgan网络结构

 

7.数据集下载

 

8.wgan代码实现 

9.mainwindow窗口显示生成器生成的图片

 

10.模型下载 

参考文章:

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com