【PyTorch】生成对抗网络

生成对抗网络是什么

概念

Generative Adversarial Nets,简称GAN
GAN:生成对抗网络 —— 一种可以生成特定分布数据的模型
《Generative Adversarial Nets》 Ian J Goodfellow-2014

GAN网络结构

Recent Progress on Generative Adversarial Networks (GANs): A Survey
在这里插入图片描述

How Generative Adversarial Networks and Their Variants Work: An Overview
在这里插入图片描述

Generative Adversarial Networks_ A Survey and Taxonomy

在这里插入图片描述

GAN的训练

训练目的

  1. 对于D:对真样本输出高概率
  2. 对于G:输出使D会给出高概率的数据

GAN 的训练和监督学习训练模式的差异

在监督学习的训练模式中,训练数经过模型得到输出值,然后使用损失函数计算输出值与标签之间的差异,根据差异值进行反向传播,更新模型的参数,如下图所示。
在这里插入图片描述
在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。如下图所示。
在这里插入图片描述

具体训练过程

step1:训练D
输入:真实数据加G生成的假数据
输出:二分类概率

step2:训练G
输入:随机噪声z
输出:分类概率——D(G(z))

在这里插入图片描述

DCGAN

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
在这里插入图片描述

Discriminator:卷积结构的模型
Generator:卷积结构的模型

DCGAN 的定义如下:

from collections import OrderedDict
import torch
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)


class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=128):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

    def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/886728.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

爬虫——同步与异步加载

一、同步加载 同步模式--阻塞模式(就是会阻止你浏览器的一个后续加载)停止了后续的解析 因此停止了后续的文件加载(图像) 比如hifini音乐网站 二、异步加载 异步加载--xhr(重点) 比如腾讯新闻,腾讯招聘等 三、同…

系统规划与管理——1信息系统综合知识(3)

文章目录 1.3 信息系统1.3.1 信息系统定义1.3.2 信息系统的生命周期1.3.3 信息系统常用的开发方法 1.3 信息系统 1.3.1 信息系统定义 信息系统是一种以处理信息为目的的专门的系统类型。信息系统可以是手工的,也可以是计算机化的。计算机化的信息系统的组成部件包…

【D3.js in Action 3 精译_025】3.4 让 D3 数据适应屏幕(中)—— 线性比例尺的用法

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一部分 D3.js 基础知识 第一章 D3.js 简介(已完结) 1.1 何为 D3.js?1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践(上)1.3 数据可…

HTML:相关概念以及标签

目录 什么是网页? 什么是HTML语言? 语法规范 HTML基本结构标签 DOCTYPE,lang以及字符集 HTML常用标签 5>图像标签(重要) 除此之外还有几个调整图片属性的标签 图像标签总结 什么是网页? 我们平时使用电脑和手机都是离不开网站和网页的,那么什么是网页呢?什么又是网…

cocotb报错收集

1、原因是定义测试类的时候,idle_inserter的名字不一样 函数修正后 函数修正前

电脑显示mfc140u.dll丢失怎么办,分享4个有效的解决方法

1. mfc140u.dll 简介 1.1 定义与作用 mfc140u.dll 是 Microsoft Foundation Class (MFC) 库中的一个动态链接库文件,它是 MFC 库在 Unicode 版本中的一个特定实现。MFC 是微软为 Windows 平台开发的一套 C 类库,封装了众多 Windows API 函数&#xff0…

定时器定时中断定时器外部中断

基础背景:TIM定时中断-CSDN博客 TIM的函数 // 恢复缺省设置 void TIM_DeInit(TIM_TypeDef* TIMx); // 时基单元初始化,第一个参数TIMx选择某个定时器,第二个参数是结构体,包含了配置时基单元的一些参数。 void TIM_TimeBaseInit…

了解华为计算产品线,昇腾的业务都有哪些?

🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ 随着 ChatGPT 的现象级爆红,它引领了 AI 大模型时代的深刻变革,进而造成 AI 算力资源日益紧缺。与此同时,中美贸易战的持续也使得 AI 算力国产化适配成为必然趋势。 …

golang grpc初体验

grpc 是一个高性能、开源和通用的 RPC 框架,面向服务端和移动端,基于 HTTP/2 设计。目前支持c、java和go,分别是grpc、grpc-java、grpc-go,目前c版本支持c、c、node.js、ruby、python、objective-c、php和c#。grpc官网 grpc-go P…

Visual Studio 字体与主题推荐

个人推荐,仅供参考: 主题:One Monokai VS Theme 链接:One Monokai VS Theme - Visual Studio Marketplacehttps://marketplace.visualstudio.com/items?itemNameazemoh.onemonokai 效果: 字体:JetBrain…

[RabbitMQ] Spring Boot整合RabbitMQ

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

Scrapy 爬虫的大模型支持

使用 Scrapy 时,你可以轻松使用大型语言模型 (LLM) 来自动化或增强你的 Web 解析。 有多种使用 LLM 来帮助进行 Web 抓取的方法。在本指南中,我们将在每个页面上调用一个 LLM,从中抽取我们定义的一组属性,而无需编写任何选择器或…

C++和OpenGL实现3D游戏编程【连载13】——多重纹理混合详解

🔥C++和OpenGL实现3D游戏编程【目录】 1、本节要实现的内容 前面说过纹理贴图能够大幅提升游戏画面质量,但纹理贴图是没有叠加的。在一些游戏场景中,要求将非常不同的多个纹理(如泥泞的褐色地面、绿草植密布的地面、碎石遍布的地面)叠加(混合)起来显示,实现纹理间能够…

多区域OSPF路由协议

前言 之前也有过关于OSPF路由协议的博客,但都不是很满意,不是很完整。现在也是听老师讲解完OSPF路由协议,感触良多,所以这里重新整理一遍。这次应该是会满意的 一些相关概念 链路状态 链路指路由器上的一个接口,链路状…

【社保通-注册安全分析报告-滑动验证加载不正常导致安全隐患】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

新手教学系列——爬虫异步并发注意事项

引言 爬虫是网络数据采集中不可或缺的工具,很多程序员在入门时会遇到这样的问题:为什么我的爬虫这么慢?尤其在面对大量数据时,单线程爬虫的速度可能让人捶胸顿足。随着爬虫规模的增大,异步并发成为了提高爬取效率的关键。然而,异步并发并不像表面看起来那么简单,如果没…

初识Linux · 进程替换

目录 前言: 1 直接看代码和现象 2 解释原理 3 将代码改成多进程版本 4 认识所有函数并使用 前言: 由前面的章节学习,我们已经了解了进程状态,进程终止以及进程等待,今天,我们学习进程替换。进程替换我…

Python:import语句的使用(详细解析)(一)

相关阅读 Pythonhttps://blog.csdn.net/weixin_45791458/category_12403403.html?spm1001.2014.3001.5482 import语句是Python中一个很重要的机制,允许在一个文件中访问另一个文件的函数、类、变量等,本文就将进行详细介绍。 在具体谈论import语句前&a…

hbuilderx+uniapp+Android宠物用品商城领养服务系统的设计与实现 微信小程序沙箱支付

目录 项目介绍支持以下技术栈:具体实现截图HBuilderXuniappmysql数据库与主流编程语言java类核心代码部分展示登录的业务流程的顺序是:数据库设计性能分析操作可行性技术可行性系统安全性数据完整性软件测试详细视频演示源码获取方式 项目介绍 顾客 领养…