简介

提出了自注意力生成对抗网络(SAGAN),该网络答应对图画生成进行注意力驱动的远间隔依赖性建模。传统卷积 GAN 仅依据低分辨率特征图中的空间部分点生成高分辨率细节。在 SAGAN 中,可以运用来自一切特征方位的信息来生成细节。此外,判别器可以查看与图画的远处的特征是否一致。此外,最近的研究标明,约束生成器会影响GAN性能。利用这一见地,将谱归一化应用于 GAN 生成器,并发现这安稳了练习。所提出的 SAGAN 比之前体现更好的作业,在具有挑战性的 ImageNet 数据集上,将发布的 Inception 得分从 36.8 提高到 52.52,并将 Frchet Inception 间隔从 27.62 降低到 18.65。注意力层的可视化显现,生成器利用了与目标形状相对应的邻域,而不是固定形状的部分区域。

引进自注意力

GAN 系列——SAGAN

在生成器和判别器进行卷积的进程中,由表的试验标明,将自注意力加在中高维的特征图上,获得的作用比较好。进程:经过 1×1 的卷积映射到三个空间,

GAN 系列——SAGAN

丢失

GAN 系列——SAGAN

安稳的策略

  1. 在生成器和判别器中运用谱归一化

  2. 生成器和判别器的学习率设置不一样,分别是 0.0001 和 0.0004

  3. 亚当

法典

谱归一化

defl2normalize(v,eps=1e-12):
returnv/(v.norm()+eps)
classSpectralNorm(nn.Module):
def__init__(self,module,name='weight',power_iterations=1):
super(SpectralNorm,self).__init__()
self.module=module
self.name=name
self.power_iterations=power_iterations
ifnotself._made_params():
self._make_params()
def_update_u_v(self):
u=getattr(self.module,self.name+"_u")
v=getattr(self.module,self.name+"_v")
w=getattr(self.module,self.name+"_bar")
height=w.data.shape[0]
for_inrange(self.power_iterations):
v.data=l2normalize(torch.mv(torch.t(w.view(height,-1).data),u.data))
u.data=l2normalize(torch.mv(w.view(height,-1).data,v.data))
#sigma=torch.dot(u.data,torch.mv(w.view(height,-1).data,v.data))
sigma=u.dot(w.view(height,-1).mv(v))
setattr(self.module,self.name,w/sigma.expand_as(w))
def_made_params(self):
try:
u=getattr(self.module,self.name+"_u")
v=getattr(self.module,self.name+"_v")
w=getattr(self.module,self.name+"_bar")
returnTrue
exceptAttributeError:
returnFalse
def_make_params(self):
w=getattr(self.module,self.name)
height=w.data.shape[0]
width=w.view(height,-1).data.shape[1]
u=Parameter(w.data.new(height).normal_(0,1),requires_grad=False)
v=Parameter(w.data.new(width).normal_(0,1),requires_grad=False)
u.data=l2normalize(u.data)
v.data=l2normalize(v.data)
w_bar=Parameter(w.data)
delself.module._parameters[self.name]
self.module.register_parameter(self.name+"_u",u)
self.module.register_parameter(self.name+"_v",v)
self.module.register_parameter(self.name+"_bar",w_bar)
defforward(self,*args):
self._update_u_v()
returnself.module.forward(*args)

自注意力

classSelf_Attn(nn.Module):
def__init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in=in_dim
self.activation=activation

self.query_conv=nn.Conv2d(in_channels=in_dim,out_channels=in_dim//8,kernel_size=1)
self.key_conv=nn.Conv2d(in_channels=in_dim,out_channels=in_dim//8,kernel_size=1)
self.value_conv=nn.Conv2d(in_channels=in_dim,out_channels=in_dim,kernel_size=1)
self.gamma=nn.Parameter(torch.zeros(1))
self.softmax=nn.Softmax(dim=-1)#
defforward(self,x):
"""
inputs:
x:inputfeaturemaps(BXCXWXH)
returns:
out:selfattentionvalue+inputfeature
attention:BXNXN(NisWidth*Height)
"""
m_batchsize,C,width,height=x.size()
proj_query=self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)#[B,N,C//8]
proj_key=self.key_conv(x).view(m_batchsize,-1,width*height)#[B,C//8,N]
energy=torch.bmm(proj_query,proj_key)#transposecheck
attention=self.softmax(energy)#[B,N,N]
proj_value=self.value_conv(x).view(m_batchsize,-1,width*height)#[B,C,N]
out=torch.bmm(proj_value,attention.permute(0,2,1))#[B,C,N]
out=out.view(m_batchsize,C,width,height)#[B,C,W,H]

out=self.gamma*out+x
returnout,attention

生成器

classGenerator(nn.Module):
def__init__(self,batch_size,image_size=64,z_dim=100,conv_dim=64):
super(Generator,self).__init__()
self.imsize=image_size
layer1=[]
layer2=[]
layer3=[]
last=[]
repeat_num=int(np.log2(self.imsize))-3#3
mult=2**repeat_num#8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim,conv_dim*mult,4)))
layer1.append(nn.BatchNorm2d(conv_dim*mult))
layer1.append(nn.ReLU())
curr_dim=conv_dim*mult#64*8
layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim,int(curr_dim/2),4,2,1)))
layer2.append(nn.BatchNorm2d(int(curr_dim/2)))
layer2.append(nn.ReLU())
curr_dim=int(curr_dim/2)#64*4
layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim,int(curr_dim/2),4,2,1)))
layer3.append(nn.BatchNorm2d(int(curr_dim/2)))
layer3.append(nn.ReLU())
ifself.imsize==64:
layer4=[]
curr_dim=int(curr_dim/2)#64*2
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim,int(curr_dim/2),4,2,1)))
layer4.append(nn.BatchNorm2d(int(curr_dim/2)))
layer4.append(nn.ReLU())
self.l4=nn.Sequential(*layer4)
curr_dim=int(curr_dim/2)#64
self.l1=nn.Sequential(*layer1)
self.l2=nn.Sequential(*layer2)
self.l3=nn.Sequential(*layer3)
last.append(nn.ConvTranspose2d(curr_dim,3,4,2,1))
last.append(nn.Tanh())
self.last=nn.Sequential(*last)
self.attn1=Self_Attn(128,'relu')
self.attn2=Self_Attn(64,'relu')
defforward(self,z):
z=z.view(z.size(0),z.size(1),1,1)#[N,100,1,1]
out=self.l1(z)#[N,64x8,4,4]
out=self.l2(out)#[N,64x4,8,8]
out=self.l3(out)#[N,64x2,16,16]
out,p1=self.attn1(out)
out=self.l4(out)#[N,64,32,32]
out,p2=self.attn2(out)
out=self.last(out)#[N,3,64,64]
returnout,p1,p2

判别器

classDiscriminator(nn.Module):
def__init__(self,batch_size=64,image_size=64,conv_dim=64):
super(Discriminator,self).__init__()
self.imsize=image_size
layer1=[]
layer2=[]
layer3=[]
last=[]
layer1.append(SpectralNorm(nn.Conv2d(3,conv_dim,4,2,1)))
layer1.append(nn.LeakyReLU(0.1))
curr_dim=conv_dim
layer2.append(SpectralNorm(nn.Conv2d(curr_dim,curr_dim*2,4,2,1)))
layer2.append(nn.LeakyReLU(0.1))
curr_dim=curr_dim*2
layer3.append(SpectralNorm(nn.Conv2d(curr_dim,curr_dim*2,4,2,1)))
layer3.append(nn.LeakyReLU(0.1))
curr_dim=curr_dim*2
ifself.imsize==64:
layer4=[]
layer4.append(SpectralNorm(nn.Conv2d(curr_dim,curr_dim*2,4,2,1)))
layer4.append(nn.LeakyReLU(0.1))
self.l4=nn.Sequential(*layer4)
curr_dim=curr_dim*2
self.l1=nn.Sequential(*layer1)
self.l2=nn.Sequential(*layer2)
self.l3=nn.Sequential(*layer3)
last.append(nn.Conv2d(curr_dim,1,4))
self.last=nn.Sequential(*last)
self.attn1=Self_Attn(256,'relu')
self.attn2=Self_Attn(512,'relu')
defforward(self,x):#[B,3,64,64]
out=self.l1(x)#[B,64,32,32]
out=self.l2(out)#[B,128,16,16]
out=self.l3(out)#[B,256,8,8]
out,p1=self.attn1(out)
out=self.l4(out)#[B,512,4,4]
out,p2=self.attn2(out)
out=self.last(out)#[B,1,1,1]
returnout.squeeze(),p1,p2

丢失

d_out_real,dr1,dr2=self.D(real_images)
d_loss_real=torch.nn.ReLU()(1.0-d_out_real).mean()
z=tensor2var(torch.randn(real_images.size(0),self.z_dim))
fake_images,gf1,gf2=self.G(z)
d_out_fake,df1,df2=self.D(fake_images)
d_loss_fake=torch.nn.ReLU()(1.0+d_out_fake).mean()
d_loss=d_loss_real+d_loss_fake
z=tensor2var(torch.randn(real_images.size(0),self.z_dim))
fake_images,_,_=self.G(z)
g_out_fake,_,_=self.D(fake_images)
g_loss_fake=-g_out_fake.mean()

参考链接:

github.com/heykeetae/S…

arxiv.org/abs/1805.08…


ONE MORE THING

咪豆AI圈(Meedo)针对当前人工智能领域职业入门成本较高、碎片化信息严重、资源链接不足等痛点问题,致力于打造人工智能领域的全资源、深内容、广链接三位一体的在线科研社区平台,提供AI导航网、AI版知乎,AI知识树和AI圈子等服务,欢迎AI未来儿一起来探究(www.meedo.top/)