找回密码
 会员注册
查看: 44|回复: 0

深度学习GAN生成对抗网络原理推导+代码实现(Python)

[复制链接]

2

主题

0

回帖

7

积分

新手上路

积分
7
发表于 2024-9-12 18:13:23 | 显示全部楼层 |阅读模式
1、前言本文将讲近些年来挺火的一个生成模型GAN生成对抗网络\boxed{\mathbf{GAN生成对抗网络}}GAN生成对抗网络​,其特殊的思路解法实在让人啧啧称奇。数学基础:【概率论与数理统计知识复习-哔哩哔哩】视频:【生成对抗网络GAN原理解析-哔哩哔哩】2、原理2.1、GAN的运行机理在传统的生成模型中,我们总是对我们的训练数据(或观测变量和隐变量)进行建模,得到概率分布,然后进行数据的生成。可GAN却不是这样,其利用神经网络这个函数逼近器,求解出了模型中概率分布的参数在不知道概率分布是什么的情况下\boxed{在不知道概率分布是什么的情况下}在不知道概率分布是什么的情况下​。其主要思想是,从一个简单的概率分布中采样,得到样本经过神经网络变换,得到一个新的样本,我们就假设这个样本就来自我们需要求解的概率分布中。然后用神经网络去辨别其是来自真实分布,还是我们要求解的概率分布。先来看模型图我们的训练数据xxx是来自真实分布对应图中P(data)\boxed{\mathbf{对应图中P(data)}}对应图中P(data)​,我们记作PdataP_{data}Pdata​,训练数据都是从PdataP_{data}Pdata​中采样得来(图中上半部分的x)。而我们从简单的概率分布中抽样P(z)P(z)P(z)如正态分布\boxed{\mathbf{如正态分布}}如正态分布​,让所得的样本经过一个神经网络G(z)G(z)G(z),得到一个新的样本xxx,这个样本就来自我们的需要求解的概率分布,我们记作PgP_{g}Pg​。然后将两个x给神经网络D(x)D(x)D(x)判断真伪,让它区分这个x是来自PdataP_{data}Pdata​还是PgP_gPg​,其输出样本来自PdataP_{data}Pdata​的概率。依据所得信息使用梯度下降更新神经网络参数,G(z)G(z)G(z)也是如此。而G(z)G(z)G(z)被称为生成器(用于生成样本)\boxed{\mathbf{(用于生成样本)}}(用于生成样本)​,D(x)D(x)D(x)被称为判别器用于判别样本真伪\boxed{\mathbf{用于判别样本真伪}}用于判别样本真伪​。2.2、目标函数损失函数来自判别器和生成器\boxed{\mathbf{损失函数来自判别器和生成器}}损失函数来自判别器和生成器​对于判别器\boxed{对于判别器}对于判别器​当样本来自PdataP_{data}Pdata​,我们要让所得的概率越大越好;当样本来自pgp_gpg​,我们要让其概率越小越好,即①max⁡DD(xi)②min⁡DD(G(zi))①\max\limits_{D}D(x_i)\\②\min\limits_{D}D(G(z_i))①Dmax​D(xi​)②Dmin​D(G(zi​))将最小化换成最大化max⁡D[1−D(G(zi))]\max\limits_{D}[1-D(G(z_i))]Dmax​[1−D(G(zi​))]所以单个样本判别器的损失函数可以写成max⁡D{D(xi)+[1−D(G(zi))]}\max\limits_{D}\left\{D(x_i)+[1-D(G(z_i))]\right\}Dmax​{D(xi​)+[1−D(G(zi​))]}对于所有样本N,我们希望均值最大max⁡D{1N∑i=1ND(xi)+1N∑i=1N[1−D(G(zi))]}\max_D\left\{\frac{1}{N}\sum\limits_{i=1}^ND(x_i)+\frac{1}{N}\sum\limits_{i=1}^N[1-D(G(z_i))]\right\}Dmax​{N1​i=1∑N​D(xi​)+N1​i=1∑N​[1−D(G(zi​))]}写成期望形式(并取log最大不改变最大值\boxed{不改变最大值}不改变最大值​),得到判别器的损失函数(x∼pdatax\simp_{data}x∼pdata​表示样本来自真实分布,PzP_zPz​表示正态分布)max⁡D{Ex∼pdata[log⁡D(x)]+Ez∼Pz[log⁡(1−D(G(z)))]}\boxed{\max\limits_{D}\left\{\mathbb{E}_{x\simp_{data}}\left[\logD(x)\right]+\mathbb{E}_{z\simP_z}\left[\log(1-D(G(z)))\right]\right\}}Dmax​{Ex∼pdata​​[logD(x)]+Ez∼Pz​​[log(1−D(G(z)))]}​接着,我们在上面讲到过,G(z)表示的是,采用一个z,经过一个神经网络,得到一个伪造出来的x。这个伪造的x服从分布PgP_gPg​。那么我们就可以把第二个期望改写成x的表达式,于是便可得到max⁡D{Ex∼pdata[log⁡D(x)]+Ex∼Pg[log⁡(1−D(x))]}\boxed{\max\limits_{D}\left\{\mathbb{E}_{x\simp_{data}}\left[\logD(x)\right]+\mathbb{E}_{x\simP_g}\left[\log(1-D(x))\right]\right\}}Dmax​{Ex∼pdata​​[logD(x)]+Ex∼Pg​​[log(1−D(x))]}​对于生成器\boxed{对于生成器}对于生成器​它希望生成的样本让判别器判别为真的概率越大越好,所以直接设计成(将最大写成最小)min⁡GEx∼Pg[log⁡(1−D(x))]\boxed{\min\limits_{G}\mathbb{E}_{x\simP_g}\left[\log(1-D(x))\right]}Gmin​Ex∼Pg​​[log(1−D(x))]​所以最终的目标函数可以写成min⁡Gmax⁡D{Ex∼pdata[log⁡D(x)]+Ex∼Pg[log⁡(1−D(x))]}\min\limits_{G}\max\limits_{D}\left\{\mathbb{E}_{x\simp_{data}}\left[\logD(x)\right]+\mathbb{E}_{x\simP_g}\left[\log(1-D(x))\right]\right\}Gmin​Dmax​{Ex∼pdata​​[logD(x)]+Ex∼Pg​​[log(1−D(x))]}3、最优解求解得到了目标函数,我们很显然还需要证明其存在最优解。并且最优解的PgP_gPg​是否和PdataP_{data}Pdata​无限接近先求里层(关于D求最大)\boxed{先求里层(关于D求最大)}先求里层(关于D求最大)​Ex∼pdata[log⁡D(x)]+Ex∼Pg[log⁡(1−D(x))]=∫xlog⁡D(x)Pdata(x)dx+∫xlog⁡(1−D(x))Pg(x)dx=∫x[log⁡D(x)Pdata(x)+log⁡(1−D(x))Pg(x)]dx==Ex∼pdata[logD(x)]+Ex∼Pg[log(1−D(x))]∫xlogD(x)Pdata(x)dx+∫xlog(1−D(x))Pg(x)dx∫x[logD(x)Pdata(x)+log(1−D(x))Pg(x)]dxEx∼pdata[log⁡D(x)]+Ex∼Pg[log⁡(1−D(x))]=∫xlog⁡D(x)Pdata(x)dx+∫xlog⁡(1−D(x))Pg(x)dx=∫x[log⁡D(x)Pdata(x)+log⁡(1−D(x))Pg(x)]dx==​Ex∼pdata​​[logD(x)]+Ex∼Pg​​[log(1−D(x))]∫x​logD(x)Pdata​(x)dx+∫x​log(1−D(x))Pg​(x)dx∫x​[logD(x)Pdata​(x)+log(1−D(x))Pg​(x)]dx​要求积分最大,就是要求里面的每一个最大max⁡D[log⁡D(x)Pdata(x)+log⁡(1−D(x))Pg(x)]\max_D\left[{\logD(x)P_{data}(x)+\log(1-D(x))P_g(x)}\right]Dmax​[logD(x)Pdata​(x)+log(1−D(x))Pg​(x)]求导∂∂DlogD(x)Pdata(x)+log⁡(1−D(x))Pg(x)=1D(x)Pdata(x)−11−D(x)Pg(x)=∂∂DlogD(x)Pdata(x)+log(1−D(x))Pg(x)1D(x)Pdata(x)−11−D(x)Pg(x)∂∂DlogD(x)Pdata(x)+log⁡(1−D(x))Pg(x)=1D(x)Pdata(x)−11−D(x)Pg(x)=​∂D∂​logD(x)Pdata​(x)+log(1−D(x))Pg​(x)D(x)1​Pdata​(x)−1−D(x)1​Pg​(x)​整理得D(x)=Pdata(x)Pg(x)+Pdata(x)\boxed{D(x)=\frac{P_{data}(x)}{P_{g}(x)+P_{data}(x)}}D(x)=Pg​(x)+Pdata​(x)Pdata​(x)​​将其代入目标函数,并且关于外层G求最小\boxed{将其代入目标函数,并且关于外层G求最小}将其代入目标函数,并且关于外层G求最小​min⁡G∫x[log⁡Pdata(x)Pg(x)+Pdata(x)Pdata(x)+log⁡(1−Pdata(x)Pg(x)+Pdata(x))Pg(x)]dx=min⁡G[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2∗12)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2∗12)Pg(x)dx]=min⁡G[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+∫log⁡12Pdata(x)dx+∫log⁡12Pg(x)dx]=min⁡G[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡12∫Pdata(x)dx+log⁡12∫Pg(x)dx]=min⁡G[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡12+log⁡12]=min⁡G[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡14]=min⁡GKL(Pdata(x)∣∣Pdata(x)+Pg(x)2)+KL(Pg(x)∣∣Pdata(x)+Pg(x)2)−log⁡4======minG∫x[logPdata(x)Pg(x)+Pdata(x)Pdata(x)+log(1−Pdata(x)Pg(x)+Pdata(x))Pg(x)]dxminG⎡⎣∫xlog⎛⎝Pdata(x)Pg(x)+Pdata(x)2∗12⎞⎠Pdata(x)dx+∫xlog⎛⎝Pg(x)Pg(x)+Pdata(x)2∗12⎞⎠Pg(x)dx⎤⎦minG⎡⎣∫xlog⎛⎝Pdata(x)Pg(x)+Pdata(x)2⎞⎠Pdata(x)dx+∫xlog⎛⎝Pg(x)Pg(x)+Pdata(x)2⎞⎠Pg(x)dx+∫log12Pdata(x)dx+∫log12Pg(x)dx⎤⎦minG⎡⎣∫xlog⎛⎝Pdata(x)Pg(x)+Pdata(x)2⎞⎠Pdata(x)dx+∫xlog⎛⎝Pg(x)Pg(x)+Pdata(x)2⎞⎠Pg(x)dx+log12∫Pdata(x)dx+log12∫Pg(x)dx⎤⎦minG⎡⎣∫xlog⎛⎝Pdata(x)Pg(x)+Pdata(x)2⎞⎠Pdata(x)dx+∫xlog⎛⎝Pg(x)Pg(x)+Pdata(x)2⎞⎠Pg(x)dx+log12+log12⎤⎦minG⎡⎣∫xlog⎛⎝Pdata(x)Pg(x)+Pdata(x)2⎞⎠Pdata(x)dx+∫xlog⎛⎝Pg(x)Pg(x)+Pdata(x)2⎞⎠Pg(x)dx+log14⎤⎦minGKL(Pdata(x)||Pdata(x)+Pg(x)2)+KL(Pg(x)||Pdata(x)+Pg(x)2)−log4minG∫x[log⁡Pdata(x)Pg(x)+Pdata(x)Pdata(x)+log⁡(1−Pdata(x)Pg(x)+Pdata(x))Pg(x)]dx=minG[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2∗12)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2∗12)Pg(x)dx]=minG[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+∫log⁡12Pdata(x)dx+∫log⁡12Pg(x)dx]=minG[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡12∫Pdata(x)dx+log⁡12∫Pg(x)dx]=minG[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡12+log⁡12]=minG[∫xlog⁡(Pdata(x)Pg(x)+Pdata(x)2)Pdata(x)dx+∫xlog⁡(Pg(x)Pg(x)+Pdata(x)2)Pg(x)dx+log⁡14]=minGKL(Pdata(x)||Pdata(x)+Pg(x)2)+KL(Pg(x)||Pdata(x)+Pg(x)2)−log⁡4======​Gmin​∫x​[logPg​(x)+Pdata​(x)Pdata​(x)​Pdata​(x)+log(1−Pg​(x)+Pdata​(x)Pdata​(x)​)Pg​(x)]dxGmin​[∫x​log(2Pg​(x)+Pdata​(x)​Pdata​(x)​∗21​)Pdata​(x)dx+∫x​log(2Pg​(x)+Pdata​(x)​Pg​(x)​∗21​)Pg​(x)dx]Gmin​[∫x​log(2Pg​(x)+Pdata​(x)​Pdata​(x)​)Pdata​(x)dx+∫x​log(2Pg​(x)+Pdata​(x)​Pg​(x)​)Pg​(x)dx+∫log21​Pdata​(x)dx+∫log21​Pg​(x)dx]Gmin​[∫x​log(2Pg​(x)+Pdata​(x)​Pdata​(x)​)Pdata​(x)dx+∫x​log(2Pg​(x)+Pdata​(x)​Pg​(x)​)Pg​(x)dx+log21​∫Pdata​(x)dx+log21​∫Pg​(x)dx]Gmin​[∫x​log(2Pg​(x)+Pdata​(x)​Pdata​(x)​)Pdata​(x)dx+∫x​log(2Pg​(x)+Pdata​(x)​Pg​(x)​)Pg​(x)dx+log21​+log21​]Gmin​[∫x​log(2Pg​(x)+Pdata​(x)​Pdata​(x)​)Pdata​(x)dx+∫x​log(2Pg​(x)+Pdata​(x)​Pg​(x)​)Pg​(x)dx+log41​]Gmin​KL(Pdata​(x)∣∣2Pdata​(x)+Pg​(x)​)+KL(Pg​(x)∣∣2Pdata​(x)+Pg​(x)​)−log4​KL(p∣∣q)=∫xplog⁡pqdxKL(p||q)=\int_xp\log\frac{p}{q}dxKL(p∣∣q)=∫x​plogqp​dx,KL散度是衡量概率分布ppp和qqq的相似程度,其大于等于0,当其相似程度一样时,则散度为0,也就是我们要求的最小值。小补充\boxed{小补充}小补充​2JS(Pdata(x)∣∣Pg(x))=KL(Pdata(x)∣∣Pdata(x)+Pg(x)2)+KL(Pg(x)∣∣Pdata(x)+Pg(x)2)\boxed{\mathbf{2JS\left(P_{data}(x)||P_g(x)\right)=KL\left(P_{data}(x)||\frac{P_{data}(x)+P_{g}(x)}{2}\right)+KL\left(P_{g}(x)||\frac{P_{data}(x)+P_{g}(x)}{2}\right)}}2JS(Pdata​(x)∣∣Pg​(x))=KL(Pdata​(x)∣∣2Pdata​(x)+Pg​(x)​)+KL(Pg​(x)∣∣2Pdata​(x)+Pg​(x)​)​JS(p∣∣q)JS(p||q)JS(p∣∣q)被称为JS散度,其仍然是大于等于0的。所以是一样的。所以Pdata(x)=Pg(x)+Pdata2→Pdata=Pg(x)P_{data}(x)=\frac{P_g(x)+P_{data}}{2}\rightarrowP_{data}=P_g(x)Pdata​(x)=2Pg​(x)+Pdata​​→Pdata​=Pg​(x)由此可见,目标函数最优值能够让Pg逼近Pdata\boxed{\mathbb{由此可见,目标函数最优值能够让P_g逼近P_{data}}}由此可见,目标函数最优值能够让Pg​逼近Pdata​​,并且当其相等时,有D(x)=Pdata(x)Pg(x)+Pdata(x)=12\boxed{D(x)=\frac{P_{data}(x)}{P_{g}(x)+P_{data}(x)}}=\frac{1}{2}D(x)=Pg​(x)+Pdata​(x)Pdata​(x)​​=21​也就是判别器再也无法判断出样本是来自PdataP_{data}Pdata​还是PgP_gPg​4、代码实现结果如下​效果一般,在其他变种优化有很多比这个好的,感兴趣的读者自行查阅。importtorchfromtorchvision.datasetsimportMNISTfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderfromtqdmimporttqdmimportmatplotlib.pyplotaspltclassGenerate_Model(torch.nn.Module):'''生成器'''def__init__(self):super().__init__()self.fc=torch.nn.Sequential(torch.nn.Linear(in_features=128,out_features=256),torch.nn.Tanh(),torch.nn.Linear(in_features=256,out_features=512),torch.nn.ReLU(),torch.nn.Linear(in_features=512,out_features=784),torch.nn.Tanh())defforward(self,x):x=self.fc(x)returnxclassDistinguish_Model(torch.nn.Module):'''判别器'''def__init__(self):super().__init__()self.fc=torch.nn.Sequential(torch.nn.Linear(in_features=784,out_features=512),torch.nn.Tanh(),torch.nn.Linear(in_features=512,out_features=256),torch.nn.Tanh(),torch.nn.Linear(in_features=256,out_features=128),torch.nn.Tanh(),torch.nn.Linear(in_features=128,out_features=1),torch.nn.Sigmoid())defforward(self,x):x=self.fc(x)returnxdeftrain():device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")#判断是否存在可用GPUtransformer=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=0.5,std=0.5)])#图片标准化train_data=MNIST("./data",transform=transformer,download=True)#载入图片dataloader=DataLoader(train_data,batch_size=64,num_workers=4,shuffle=True)#将图片放入数据加载器D=Distinguish_Model().to(device)#实例化判别器G=Generate_Model().to(device)#实例化生成器D_optim=torch.optim.Adam(D.parameters(),lr=1e-4)#为判别器设置优化器G_optim=torch.optim.Adam(G.parameters(),lr=1e-4)#为生成器设置优化器loss_fn=torch.nn.BCELoss()#损失函数epochs=100#迭代100次forepochinrange(epochs):dis_loss_all=0#记录判别器损失损失gen_loss_all=0#记录生成器损失loader_len=len(dataloader)#数据加载器长度forstep,dataintqdm(enumerate(dataloader),desc="第{}轮".format(epoch),total=loader_len):#先计算判别器损失sample,label=data#获取样本,舍弃标签sample=sample.reshape(-1,784).to(device)#重塑图片sample_shape=sample.shape[0]#获取批次数量#从正态分布中抽样sample_z=torch.normal(0,1,size=(sample_shape,128),device=device)Dis_true=D(sample)#判别器判别真样本true_loss=loss_fn(Dis_true,torch.ones_like(Dis_true))#计算损失fake_sample=G(sample_z)#生成器通过正态分布抽样生成数据Dis_fake=D(fake_sample.detach())#判别器判别伪样本fake_loss=loss_fn(Dis_fake,torch.zeros_like(Dis_fake))#计算损失Dis_loss=true_loss+fake_loss#真假加起来D_optim.zero_grad()Dis_loss.backward()#反向传播D_optim.step()#生成器损失Dis_G=D(fake_sample)#判别器判别G_loss=loss_fn(Dis_G,torch.ones_like(Dis_G))#计算损失G_optim.zero_grad()G_loss.backward()#反向传播G_optim.step()withtorch.no_grad():dis_loss_all+=Dis_loss#判别器累加损失gen_loss_all+=G_loss#生成器累加损失withtorch.no_grad():dis_loss_all=dis_loss_all/loader_lengen_loss_all=gen_loss_all/loader_lenprint("判别器损失为:{}".format(dis_loss_all))print("生成器损失为:{}".format(gen_loss_all))torch.save(G,"./model/G.pth")#保存模型torch.save(D,"./model/D.pth")#保存模型if__name__=='__main__':#train()#训练模型model_G=torch.load("./model/G.pth",map_location=torch.device("cpu"))#载入模型fake_z=torch.normal(0,1,size=(10,128))#抽样数据result=model_G(fake_z).reshape(-1,28,28)#生成数据result=result.detach().numpy()#绘制foriinrange(10):plt.subplot(2,5,i+1)plt.imshow(result[i])plt.gray()plt.show()1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161175、结束以上,就是GAN生成对抗网络的全部内容了,如有问题,还望指出。阿里嘎多
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2024-12-26 02:07 , Processed in 0.434157 second(s), 26 queries .

Powered by Discuz! X3.5

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表