本文为稀土技术社区首发签约文章,14天内制止转载,14天后未获授权制止转载,侵权必究!

作者简介:秃头小苏,致力于用最通俗的言语描绘问题

往期回忆:对立生成网络GAN系列——GAN原理及手写数字生成小事例   对立生成网络GAN系列——DCGAN简介及人脸图画生成事例   对立生成网络GAN系列——CycleGAN简介及图片春冬改换事例

近期方针:写好专栏的每一篇文章

支撑小苏:点赞、保藏⭐、留言

对立生成网络GAN系列——AnoGAN原理及缺点检测实战

写在前面

  跟着深度学习的开展,现已有许多学者将深度学习应用到物体瑕疵检测中,如列车钢轨的缺点检测、医学影像中各种疾病的检测。可是瑕疵检测使命几乎都存在一个共同的难题——缺点数据太少了。咱们运用这些稀少的缺点数据很难运用深度学习练习一个抱负的模型,往往都需求进行数据扩充,即经过某些手法添加咱们的缺点数据。 【数据扩充咱们感兴趣自己去了解下,GAN网络也是完成数据扩充的干流手法】 上面说到的办法是根据缺点数据来练习的,是有监督的学习,学者们在绵长的研究中,考虑能不能运用一种无监督的办法来完成缺点检测呢?所以啊,AnoGAN就横空出世了,它不需求缺点数据进行练习,而仅运用正常数据练习模型,关于AnoGAN的细节后文详细介绍。

关于GAN网络,我现已介绍了几篇,如下:

  • [1]对立生成网络GAN系列——GAN原理及手写数字生成小事例
  • [2]对立生成网络GAN系列——DCGAN简介及人脸图画生成事例
  • [3]对立生成网络GAN系列——CycleGAN原理

  在阅览本文之前主张咱们对GAN有必定的了解,能够参阅[1]和[2],关于[3]感兴趣的能够看看,本篇文章用不到[3]相关知识。

  准备好了嘛,咱们开端发车了喔。

AnoGAN 原理详解✨✨✨

  首要咱们来看看AnoGAN的全称,即Anomaly Detection with Generative Adversarial Networks,中文是指运用生成对立网络完成反常检测。这篇论文解决的是医学影像中疾病的检测,因为对医学相关内容不了解,本文将完全将该算法从论文中剥离,只介绍算法原理,而不结合论文进行讲述。想要了解论文概况的能够点击☞☞☞查看。

  接下来就随我一起来看看AnoGAN的原理。其实AnoGAN的原理是很简单的,可是我看网上的材料总是说的摸棱两可,我认为首要原因有两点:其一是没有把AnoGAN的原理分步来叙说,其二是有专家视角,它们认为咱们都应该了解,但这对于新手来说了解也的确是有必定难度的。

  在介绍AnoGAN的详细原理时,我先来谈谈AnoGAN的起点,这十分重要,咱们好好感触。咱们知道,DCGAN是将一个噪声或者说一个潜在变量映射成一张图片,在咱们练习DCGAN时,都是运用某一种数据进行的,如[2]中运用的数据都是人脸,那么这些数据都是正常数据,咱们从一个潜在变量经DCGAN后生成的图片应该也都是正常图画。AnoGAN的主意便是我能否将一张图片M映射成某个潜在变量呢,这其实是较难做到的。可是咱们能够在某个空间不断的查找一个潜在变量,使得这个潜在变量生成的图片与图片M尽可能挨近。这便是AnoGAN的起点,咱们可能还不了解这么做的意义,下文为咱们详细介绍。☘☘☘

  AnoGAN其实是分两个阶段进行的,首要是练习阶段,然后是测验阶段,咱们一点点来看:

  • 练习阶段

      练习阶段仅运用正常的数据练习对立生成网络。如咱们运用手写数字中的数字8作为本阶段的数据进行练习,那么8便是正常数据。练习完毕后咱们输入一个向量z,生成网络会将z变成8。不知道咱们有没有发现其实这阶段便是[2]中的DCGAN呢? 【留意:练习阶段现已练习好GAN网络,后面的测验阶段GAN网络的权重是不在改换的】

  • 测验阶段

      在练习阶段咱们现已练习好了一个GAN网络,在这一阶段咱们便是要运用练习好的网络来进行缺点检测。如现在咱们有一个数据6,此为缺点数据 【练习时运用8进行练习,这儿的6即为缺点数据】 。现在咱们要做的便是查找一个潜在变量并让其生成的图片与图片6尽可能挨近,详细完成如下:首要咱们会界说一个潜在变量z,然后经过刚刚练习的好的生成网络,得到假图画G(z),接着G(z)和缺点数据6核算丢失,这时候丢失往往会比较大,咱们不断的更新z值,会使丢失不断的减少,在程序中咱们能够设置更新z的次数,如更新500次后中止,此刻咱们认为将如今的潜在变量z送入生成网络得到的假图画现已和图片6十分像了,所以咱们将z再次送入生成网络,得到G(z)。【注:因为潜在变量z送入的网络是生成图片8的,虽然经过查找使G(z)和6尽可能相像,但仍是存在必定距离,即它们的丢失较大】 终究咱们就能够核算G(z)和图片6的丢失,记为loss1,并将这个丢失作为判别是否有缺点的重要依据。怎样作为判别是否有缺点的重要依据呢?我再举个例子咱们就了解了,现在在测验阶段咱们传入的不是缺点数据,而是正常的数据8,此刻应用相同的办法查找潜在变量z,然后将终究的z送入生成网络,得到G(z),终究核算G(z)和图片8的丢失。 【注:因为潜在变量z送入的网络是生成图片8的,所以终究生成的G(z)能够和数据8很像,即它们的丢失较小】 经过以上剖析, 咱们能够发现当咱们在测验阶段传入缺点图片时终究的丢失大,传入正常图片时的丢失小,这时候咱们就能够设置一个适宜的阈值来判别图画是否有缺点了。 这一段是整个AnoGAN的要点,咱们多考虑考虑,相信你能够了解。我也画了一个此进程的流程图,咱们能够参阅一下,如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战


  读了上文,是不是对AnoGAN大致进程有了必定了解了呢!我觉得咱们练习阶段肯定是没问题的啦,便是一个DCGAN网络,不清楚这个的话主张阅览[2]了解DCGAN网络。测验阶段的难点就在于咱们如何界说丢失函数来更新z值,咱们直接来看论文中此部分的丢失,首要分为两部分,分别是Residual Loss和Discrimination Loss,它们界说如下:

  • Residual Loss

                 R(z)=∑∣x−G(z)∣{\rm{R}}(z) = \sum {|x – G(z)|}

  上式z表明潜在变量,G(z)表明生成的假图画,x表明输入的测验图片。上式表明生成的假图画和输入图片之间的距离。如果生成的图片越挨近x,则R(z)越小。

  • Discrimination Loss

                 D(z)=∑∣f(x)−f(G(z))∣D(z) = \sum {|f(x) – f(G(z))|}

  上式z表明潜在变量,G(z)表明生成的假图画,x表明输入的测验图片。f()表明将经过判别器,然后取判别器某一层的输出成果。 【注:这儿运用的并非判别器的终究输出,而是判别器某层的输出,关于这一点,会在代码讲解时介绍】 这儿能够把判别器当作一个特征提取网络,咱们将生成的假图片和测验图片都输入判别器,看它们提取到特征的差异。相同,如果生成的图片越挨近x,则D(z)越小。

  求得R(z)和D(z)后,咱们界说它们的线性组合作为终究的丢失,如下:

                 Loss(z)=(1−)R(z)+D(z)Loss(z)=(1-\lambda)R(z)+\lambda D(z)

通常,咱们取=0.1\lambda =0.1


  到这儿,AnoGAN的理论部分都介绍完了喔!!!不知道你了解了多少呢?如果觉得有些当地了解还差点儿意思的话,就来看看下面的代码吧,这回对你了解AnoGAN十分有协助。

AnoGAN代码实战

  如果咱们和我相同找过AnoGAN代码的话,可能就会和我有相同的感触,那便是太乱了。怎样说呢,我认为从原理上来说,应该很好完成AnoGAN,可是我看Github上的代码写的挺杂乱,不是很好了解,有的乃至起着AnoGAN的姓名,完成的却是一个简单的DCGAN网络,着实让人有些无语。所以我计划依照自己的思路来完成一个AnoGAN,奈何却呈现了各式各样的Bug,正当我灰心丧气时,看到了一篇外文的博客,写的十分对我的胃口,所以依照它的思路完成了AnoGAN。这儿我仍是想感概一下,我发现许多外文的博客的确写的十分美丽,我想这是值得咱们学习的当地!!!

代码下载地址✨✨✨

  本次我将源码上传到我的Github了,咱们能够阅览README文件了解代码的运用,Github地址如下:

AnoGAN-pytorch完成

  我认为你阅览README文件后现已对这个项目的结构有所了解,我在下文也会帮咱们剖析剖析源码,但更多的时刻咱们应该自己着手去亲身调试,这样你会有不相同的收成。

数据读取✨✨✨

  本次运用的数据为mnist手写数字数据集,咱们下载的是.csv格局的数据,这种格局方便读取。读取数据代码如下:

## 读取练习集数据  (60000,785)
train = pd.read_csv(".\data\mnist_train.csv",dtype = np.float32)
## 读取测验集数据  (10000,785)
test = pd.read_csv(".\data\mnist_test.csv",dtype = np.float32)

  咱们能够来看一下mnist数据集的格局是怎样的,先来看看train中的内容,如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  train的shape为(60000,785),其表明练习会集共有60000个数据,即60000张手写数字的图片,每个数据都有785个值。咱们来剖析一下这785个数值的意义,第一个数值为标签label,表明其表明哪个手写数字,后784个数值为对应数字每个像素的值,手写数字图片巨细为2828,故总共有784个像素值。

  解说完练习集数据的意义,那测验集也是相同的啦,只不过数据较少,只有10000条数据,test的内容如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  咱们需求留意的是,上述的练习集和测验会集的数据咱们今天并不会全部用到。咱们取练习会集的前400个标签为7或8的数据作为AnoGAN的练习集,即7、8都为正常数据。取测验集前600个标签为2、7、8作为测验数据,即测验会集有正常数据(7、8)和反常数据(2),相关代码如下:

# 查询练习数据中标签为7、8的数据,并取前400个
train = train.query("label in [7.0, 8.0]").head(400)
​
# 查询练习数据中标签为7、8的数据,并取前400个
test = test.query("label in [2.0, 7.0, 8.0]").head(600)

  能够看看此刻的train和test的成果:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  在AnoGAN中,咱们是无监督的学习,因此是不需求标签的,经过以下代码去除train和test中的标签:

# 取除标签后的784列数据
train = train.iloc[:,1:].values.astype('float32')
test = test.iloc[:,1:].values.astype('float32')

  去除标签后train和test的成果如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  能够看出,此刻train和test中现已没有了label类,它们的第二个维度也从785变成了784。

  终究,咱们将train和test reshape成图片的格局,即2828,代码如下:

# train:(400,784)-->(400,28,28)
# test:(600,784)-->(600,28,28)
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)

  此刻,train和test的维度产生改换,如下图所示:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  至此,咱们的数据读取部分就为咱们介绍完了,是不是发现挺简单的呢,加油吧!!!

模型建立

模型建立真滴很简单!!!咱们之间看代码吧。

生成模型建立

"""界说生成器网络结构"""
class Generator(nn.Module):
​
 def __init__(self):
  super(Generator, self).__init__()
​
  def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.ReLU(inplace=True), bn=True):
    seq = []
    seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
    if bn is True:
     seq += [nn.BatchNorm2d(out_channel)]
    seq += [activation]
​
    return nn.Sequential(*seq)
​
  seq = []
  seq += [CBA(20, 64*8, stride=1, padding=0)]
  seq += [CBA(64*8, 64*4)]
  seq += [CBA(64*4, 64*2)]
  seq += [CBA(64*2, 64)]
  seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
​
  self.generator_network = nn.Sequential(*seq)
​
 def forward(self, z):
   out = self.generator_network(z)
​
   return out

  为了协助咱们了解,我绘制 了生成网络的结构图,如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

判别模型建立

"""界说判别器网络结构"""
class Discriminator(nn.Module):
​
 def __init__(self):
  super(Discriminator, self).__init__()
​
  def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1, activation=nn.LeakyReLU(0.1, inplace=True)):
    seq = []
    seq += [nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)]
    seq += [nn.BatchNorm2d(out_channel)]
    seq += [activation]
​
    return nn.Sequential(*seq)
​
  seq = []
  seq += [CBA(1, 64)]
  seq += [CBA(64, 64*2)]
  seq += [CBA(64*2, 64*4)]
  seq += [CBA(64*4, 64*8)]
  self.feature_network = nn.Sequential(*seq)
​
  self.critic_network = nn.Conv2d(64*8, 1, kernel_size=4, stride=1)
​
 def forward(self, x):
   out = self.feature_network(x)
​
   feature = out
   feature = feature.view(feature.size(0), -1)
​
   out = self.critic_network(out)
​
   return out, feature
​

  相同,为了方便咱们了解,我也绘制了判别网络的结构图,如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  这儿咱们需求稍稍留意一下,判别网络有两个输出,一个是终究的输出,还有一个是第四个CBA BLOCK提取到的特征,这个在理论部分介绍丢失函数时有提及。

模型练习

数据集加载

class image_data_set(Dataset):
  def __init__(self, data):
    self.images = data[:,:,:,None]
    self.transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),
      transforms.Normalize((0.1307,), (0.3081,))
     ])
​
  def __len__(self):
    return len(self.images)
​
  def __getitem__(self, idx):
    return self.transform(self.images[idx])
    
 # 加载练习数据
 train_set = image_data_set(train)
 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

  这部分不难,但我提醒咱们留意一下这句:transforms.Resize(64, interpolation=InterpolationMode.BICUBIC),即咱们采用插值算法将原来2828巨细的图片上采样成了6464巨细。 【感兴趣的这儿也能够不对其进行上采样,这样的话咱们需求修正一下上节的模型,能够试试作用喔】

加载模型、界说优化器、丢失函数等参数

# 指定设备
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# batch_size默许128
batch_size = args.batch_size
# 加载模型
G = Generator().to(device)
D = Discriminator().to(device)
​
# 练习形式
G.train()
D.train()
​
# 设置优化器
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.9))
​
# 界说丢失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')

练习GAN网络

"""
练习
"""# 开端练习
for epoch in range(args.epochs):
  # 界说初始丢失
  log_g_loss, log_d_loss = 0.0, 0.0
  for images in train_loader:
    images = images.to(device)
​
    ## 练习判别器 Discriminator
    # 界说真标签(全1)和假标签(全0)  维度:(batch_size)
    label_real = torch.full((images.size(0),), 1.0).to(device)
    label_fake = torch.full((images.size(0),), 0.0).to(device)
​
    # 界说潜在变量z   维度:(batch_size,20,1,1)
    z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
    # 潜在变量喂入生成网络--->fake_images:(batch_size,1,64,64)
    fake_images = G(z)
​
    # 真图画和假图画送入判别网络,得到d_out_real、d_out_fake  维度:都为(batch_size,1,1,1)
    d_out_real, _ = D(images)
    d_out_fake, _ = D(fake_images)
​
    # 丢失核算
    d_loss_real = criterion(d_out_real.view(-1), label_real)
    d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
    d_loss = d_loss_real + d_loss_fake
​
    # 差错反向传达,更新丢失
    optimizerD.zero_grad()
    d_loss.backward()
    optimizerD.step()
​
    ## 练习生成器 Generator
    # 界说潜在变量z   维度:(batch_size,20,1,1)
    z = torch.randn(images.size(0), 20).to(device).view(images.size(0), 20, 1, 1).to(device)
    fake_images = G(z)
​
    # 假图画喂入判别器,得到d_out_fake  维度:(batch_size,1,1,1)
    d_out_fake, _ = D(fake_images)
​
    # 丢失核算
    g_loss = criterion(d_out_fake.view(-1), label_real)
​
    # 差错反向传达,更新丢失
    optimizerG.zero_grad()
    g_loss.backward()
    optimizerG.step()
​
    ## 累计一个epoch的丢失,判别器丢失和生成器丢失分别存放到log_d_loss、log_g_loss中
    log_d_loss += d_loss.item()
    log_g_loss += g_loss.item()
​
  ## 打印丢失
  print(f'epoch {epoch}, D_Loss:{log_d_loss / 128:.4f}, G_Loss:{log_g_loss / 128:.4f}')
​
​
​
​
## 展示生成器存储的图片,存放在result文件夹下的G_out.jpg
z = torch.randn(8, 20).to(device).view(8, 20, 1, 1).to(device)
fake_images = G(z)
torchvision.utils.save_image(fake_images,f"result\G_out.jpg")

  这部分便是练习一个DCGAN网络,到目前为止其实也都能够认为是DCGAN的内容。咱们能够来看一下输出的G_out.jpg图片:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  这儿咱们能够看到练习是有了作用的,但会发现不是特别好。我剖析有两点原因,其一是咱们的模型不好,且GAN本身就简单呈现形式崩溃的问题;其二是咱们的数据选取的少,在数据读取时练习集咱们只取了前400个数据,但实际上咱们总共能够取12116个,咱们能够测验添加数据,我想数据多了后作用肯定比这个好,咱们快去试试吧!!!

缺点检测✨✨✨

  这部分才是AnoGAN的要点,首要咱们先界说丢失的核算,如下:

## 界说缺点核算的得分
def anomaly_score(input_image, fake_image, D):
# Residual loss 核算
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1, 2, 3))
​
# Discrimination loss 核算
_, real_feature = D(input_image)
_, fake_feature = D(fake_image)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), (1))
​
# 结合Residual loss和Discrimination loss核算每张图画的丢失
total_loss_by_image = 0.9 * residual_loss + 0.1 * discrimination_loss
# 核算总丢失,行将一个batch的丢失相加
total_loss = total_loss_by_image.sum()
​
return total_loss, total_loss_by_image, residual_loss

  咱们能够对比一下理论部分丢失函数的介绍,看看是不是相同的呢。

  接着咱们就需求不断的查找潜在变量z了,使其与输入图片尽可能挨近,代码如下:

# 加载测验数据
test_set = image_data_set(test)
test_loader = DataLoader(test_set, batch_size=5, shuffle=False)
input_images = next(iter(test_loader)).to(device)
​
# 界说潜在变量z  维度:(5,20,1,1)
z = torch.randn(5, 20).to(device).view(5, 20, 1, 1)
# z的requires_grad参数设置成Ture,让z能够更新
z.requires_grad = True# 界说优化器
z_optimizer = torch.optim.Adam([z], lr=1e-3)
​
# 查找z
for epoch in range(5000):
  fake_images = G(z)
  loss, _, _ = anomaly_score(input_images, fake_images, D)
​
  z_optimizer.zero_grad()
  loss.backward()
  z_optimizer.step()
​
  if epoch % 1000 == 0:
  print(f'epoch: {epoch}, loss: {loss:.0f}')
​

  履行完上述代码后,咱们得到了一个较抱负的潜在变量,这时候再用z来生成图片,并根据生成图片和输入图片来核算丢失,同时,咱们也保存了输入图片和生成图片,并打印了它们之前的丢失,相关代码如下:

  fake_images = G(z)
​
  _, total_loss_by_image, _ = anomaly_score(input_images, fake_images, D)
​
  print(total_loss_by_image.cpu().detach().numpy())
​
  torchvision.utils.save_image(input_images, f"result/Nomal.jpg")
  torchvision.utils.save_image(fake_images, f"result/ANomal.jpg")

  咱们能够来看看终究的成果哦,如下:

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战

  能够看到,当输入图画为2时(此为缺点),生成的图画也是8,它们的丢失最高为464040.44。这时候如果咱们设置一个阈值为430000,高于这个阈值的即为反常图片,低于这个阈值的即为正常图片,那么咱们是不是就能够经过AnoGAN来完成缺点的检测了呢!!!

总结

  到这儿,AnoGAN的所有内容就介绍完了,咱们好好感触感触它的思维,其实是很简单的,可是又十分巧妙。终究我不知道咱们有没有发现AnoGAN一个十分明显的缺点,那便是咱们每次在判别反常时要不断的查找潜在变量z,这是十分耗时的。而许多使命对时刻的要求仍是很高的,所以AnoGAN还有许多能够改善的当地,后续博文我会带咱们继续学习GAN网络在缺点检测中的应用,咱们下期见。

参阅文献

AnoGAN论文

AnoGAN|GAN做图画反常检测的奠基石

GAN 运用 Pytorch 进行反常检测的办法

深度学习论文笔记(反常检测)

如若文章对你有所协助,那就

        

对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战