本文为稀土技术社区首发签约文章,30天内制止转载,30天后未获授权制止转载,侵权必究!

一、前语

对图画不了解的人经常梦想去除马赛克是能够完结的,严格意义来说这确实是无法完结的。而深度学习是呈现,让去除马赛克成为可能。

为了理解去除马赛克有多难,咱们需求知道马赛克是什么。观感上,马赛克便是方块感。当咱们观察图画像素时, 马赛克表现为下图的情况:

如何去除图片马赛克?

原图右下角有十字,而增加马赛克后右下角一片都变成了同一像素,如果咱们没保存原图,那么咱们无法复原,也不知道是否复原了原图。由于原图现已被破坏了,这也是为什么马赛克是不行修正的。

神经网络又是怎么让修正成为可能呢?其实无论什么办法的修正,都是一种估量,而不是真正的修正。神经网络去除马赛克的操作其实是生成马赛克那部分内容,然后代替马赛克,然后到达修正的作用。

这种修正并不是复原,而是想象。假如咱们对一张人脸打了马赛克,神经网络能够去除马赛克,但是去除后的人脸不再是本来那个人了。

二、完结原理

2.1 自编码器

图画修正的办法有很多,比方自编码器。自编码器是一种自监督模型,结构简略,不需求人为打标,收敛敏捷。其结构如图:

如何去除图片马赛克?

编码器部分便是用于下采样的卷积网络,编码器会把图片编码成一个向量,而解码器则运用转置卷积把编码向量上采样成和原图大小共同的图片,最终咱们把原图和生成成果的MSE作为丢失函数进行优化。当模型练习好后,就能够用编码器对图片进行编码。

2.2 自编码器去除马赛克

那自编码器和去除马赛克有什么联络呢?其实非常简略,便是原本咱们是输入原图,希望解码器能输出原图。这是出于咱们希望模型学习怎么编码图片的原图。而现在咱们想要模型去除马赛克,此时咱们要做的便是把马赛克图片作为输入,而原图作为输出,这样来练习就能够到达去除马赛克的作用了:

如何去除图片马赛克?

关于关于这种完结能够参考:/post/721068…

2.3 自编码器的问题

自编码器有个很明显的问题,便是图片通过编码器后会丢失信息,而解码器的成果天然也会存在一些问题。这样既达不到去除马赛克的功能,连复原的原图都有一些含糊。

这儿能够运用FPN的思想来改善,当自编码器加入FPN后,就得到了UNet网络结构

2.4 UNet网络

UNet结构和自编码器相似,是一个先下再上的结构。和自编码器不同的时,UNet会运用编码器的每个输出,将各个输出与解码器的输入进行concatenate,这样就能更好地保存原图信息。其结构如下图:

如何去除图片马赛克?

UNet原本是用于图画切割的网络,这儿咱们用它来去除马赛克。

在UNet中,有几个部分咱们分别来看看。

2.4.1 ConvBlock

在UNet中,有很多接连卷积的操作,这儿咱们作为一个Block(蓝色箭头),它能够完结为一个层,用PyTorch完结如下:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

这儿其实便是两次卷积操作,这儿的意图是提取当前感受野的特征。

2.4.2 ConvDown

通过接连卷积后,会运用卷积网络对图片进行下采样,这儿把stride设置为2即可让图片缩小为本来的1/2。咱们同样能够完结为层:

class ConvDown(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 2, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

这儿只有一个卷积,而且stride被设置为了2。

2.4.3 ConvUp

接下来是解码器部分,这儿多了一个上选用的操作,咱们能够用转置卷积完结,代码如下:

class ConvUp(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(channels, channels // 2, 2, 2),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

上面是层能够把图片尺寸扩大为2倍,一起把特征图数量缩小到1/2。这儿缩小特征图的操作是为了concatenate操作,后面详细说。

三、完好完结

首要,导入需求用的模块:

import os
import random
import torch
from torch import nn
from torch import optim
from torch.utils import data
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw, ImageFilter
from torchvision.utils import make_grid

下面开始具体完结。

3.1 创立Dataset

首要创立本次任务需求的数据集,分布大致相同的图片即可,代码如下:

class ReConstructionDataset(data.Dataset):
    def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
        self.image_size = image_size
        # 图画预处理
        self.trans = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # 保持一切图片的途径
        self.image_paths = []
        # 读取根目录,把一切图片途径放入image_paths
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                self.image_paths.append(os.path.join(root, file))
    def __getitem__(self, item):
        # 读取图片,并预处理
        image = Image.open(self.image_paths[item])
        return self.trans(self.create_blur(image)), self.trans(image)
    def __len__(self):
        return len(self.image_paths)
    @staticmethod
    def create_blur(image, return_mask=False, box_size=200):
        mask = Image.new('L', image.size, 255)
        draw = ImageDraw.Draw(mask)
        upper_left_corner = (random.randint(0, image.size[0] - box_size), random.randint(0, image.size[1] - box_size))
        lower_right_corner = (upper_left_corner[0] + box_size, upper_left_corner[1] + box_size)
        draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
        masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
        if return_mask:
            return masked_image, mask
        else:
            return masked_image

Dataset的完结与以往根本共同,完结init、getitem、len办法,这儿咱们还完结了一个create_blur办法,该办法用于生成矩形马赛克(实际上是高斯含糊)。下面是create_blur办法生成的图片:

如何去除图片马赛克?

3.2 网络构建

这儿咱们需求运用前面的几个子单元,先完结编码器,代码如下:

class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.blk0 = ConvBlock(3, 64)
        self.down0 = ConvDown(64)
        self.blk1 = ConvBlock(64, 128)
        self.down1 = ConvDown(128)
        self.blk2 = ConvBlock(128, 256)
        self.down2 = ConvDown(256)
        self.blk3 = ConvBlock(256, 512)
        self.down3 = ConvDown(512)
        self.blk4 = ConvBlock(512, 1024)
    def forward(self, inputs):
        f0 = self.blk0(inputs)
        d0 = self.down0(f0)
        f1 = self.blk1(d0)
        d1 = self.down1(f1)
        f2 = self.blk2(d1)
        d2 = self.down2(f2)
        f3 = self.blk3(d2)
        d3 = self.down3(f3)
        f4 = self.blk4(d3)
        return f0, f1, f2, f3, f4

这儿便是ConvBlok和ConvDown的n次组合,最终会得到一个102444的特征图。在forward中,咱们返回了5个ConvBlok返回的成果,由于在解码器中咱们需求全部运用。

接下来是解码器部分,这儿与编码器相反,代码如下:

class UNetDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.up3 = ConvUp(1024)
        self.blk3 = ConvBlock(1024, 512)
        self.up2 = ConvUp(512)
        self.blk2 = ConvBlock(512, 256)
        self.up1 = ConvUp(256)
        self.blk1 = ConvBlock(256, 128)
        self.up0 = ConvUp(128)
        self.blk0 = ConvBlock(128, 64)
        self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)
    def forward(self, inputs):
        f0, f1, f2, f3, f4 = inputs
        u3 = self.up3(f4)
        df2 = self.blk3(torch.concat((f3, u3), dim=1))
        u2 = self.up2(df2)
        df1 = self.blk2(torch.concat((f2, u2), dim=1))
        u1 = self.up1(df1)
        df0 = self.blk1(torch.concat((f1, u1), dim=1))
        u0 = self.up0(df0)
        f = self.blk0(torch.concat((f0, u0), dim=1))
        return torch.tanh(self.last_conv(f))

解码器的inputs为编码器的5组特征图,在forward时需求与上采样成果concatenate。

最终,整个网络组合起来,代码如下:

class ReConstructionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = UNetEncoder()
        self.decoder = UNetDecoder()
    def forward(self, inputs):
        fs = self.encoder(inputs)
        return self.decoder(fs)

3.3 网络练习

现在各个部分都完结了,能够开始练习网络:

device = "cuda" if torch.cuda.is_available() else "cpu"
def train(model, dataloader, optimizer, criterion, epochs):
    model = model.to(device)
    for epoch in range(epochs):
        for iter, (masked_images, images) in enumerate(dataloader):
            masked_images, images = masked_images.to(device), images.to(device)
            outputs = model(masked_images)
            loss = criterion(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (iter + 1) % 100 == 1:
                print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
                with torch.no_grad():
                    outputs = make_grid(outputs)
                    img = outputs.cpu().numpy().transpose(1, 2, 0)
                    plt.imshow(img)
                    plt.show()
        torch.save(model.state_dict(), '../outputs/reconstruction.pth')
if __name__ == '__main__':
    dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64)
    unet = ReConstructionNetwork()
    optimizer = optim.Adam(auto_encoder.parameters(), lr=0.0002)
    criterion = nn.MSELoss()
    train(unet, dataloader, optimizer, criterion, 20)

练习完结后,就能够用来去除马赛克了,代码如下:


dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
    masked_images, images = masked_images.to(device), images.to(device)
    with torch.no_grad():
        outputs = unet(masked_images)
        outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
        outputs = make_grid(outputs)
        img = outputs.cpu().numpy().transpose(1, 2, 0)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        Image.fromarray(img).show()

下面是生成成果。左侧为原图,中间为增加马赛克后的图片,右侧则是去除马赛克后的成果:

如何去除图片马赛克?

全体来说作用比较不错。本文的办法不只能够用来去除马赛克,还能够完结图画重构。比方老化的图片、被墨汁污染的图片等,都能够用本文的办法完结重构。别的,本文的数据有限,完结作用并不通用,有需求的读者能够移步CodeFormer项目:github.com/sczhou/Code…