内容导读:TorchVision Transforms API 扩展晋级,现已支撑方针检测、实例及语义切割以及视频类使命。新 API 尚处于测验阶段,开发者能够试用体验。
本文首发自微信公众号:PyTorch 开发者社区
TorchVision 现已针对 Transforms API 进行了扩展, 详细如下:
-
除用于图画分类外,现在还能够用其进行方针检测、实例及语义切割以及视频分类等使命;
-
支撑从 TorchVision 直接导入 SoTA 数据增强,如 MixUp、 CutMix、Large Scale Jitter 以及 SimpleCopyPaste。
-
支撑运用全新的 functional transforms 转化视频、Bounding box 以及切割掩码 (Segmentation Mask)。
Transforms 当前的局限性
稳定版 TorchVision Transforms API,也也便是咱们常说的 Transforms V1,只支撑单个图画,因而,只适用于分类使命:
from torchvision import transforms
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs = trans(imgs)
上述办法不支撑需求运用 Label 的方针检测、切割或分类 Transforms, 如 MixUp 及 cutMix。这使分类以外的计算机视觉使命都不能用 Transforms API 履行必要的扩展。一起,这也加大了用 TorchVision 原语训练高精度模型的难度。
为了克服这个局限性,TorchVision 在其 reference script 中供给了自定义完成, 用于演示一切使命中的增强是如何履行的。
虽然这种做法使得开发者能够训练出高精度的分类、方针检测及切割模型,但做法比较粗糙,TorchVision 二进制文件中还是不能导入 Transforms。
全新的 Transforms API
Transforms V2 API 支撑视频、bounding box、label 以及切割掩码, 这意味着它为许多计算机视觉使命供给了本地支撑。新的解决计划是一种更为直接的替代计划:
from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)
全新的 Transform Class 无需强制履行特定的次序或结构,就能够接纳恣意数量的输入:
# Already supported:
trans(imgs) # Image Classification
trans(videos) # Video Tasks
trans(imgs_or_videos, labels) # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels) # Object Detection
trans(imgs, bboxes, masks, labels) # Instance Segmentation
trans(imgs, masks) # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels}) # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints) # Keypoint Detection
trans(stereo_images, disparities, masks) # Depth Perception
trans(image1, image2, optical_flows, masks) # Optical Flow
functional API 已经更新,支撑一切输入必要的 signal processing kernel,如 resizing, cropping, affine transforms, padding 等:
from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])
API 运用 Tensor subclassing 来包装输入,附加有用的元数据,并 dispatch 到正确的内核。 利用 TorchData Data Pipe 的 Datasets V2 相关作业完成后,就不再需求手动包装输入了。目前,用户能够经过以下办法手动包装输入:
from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])
除新 API 之外,PyTorch 官方还为 SoTA 研讨中用到的一些数据增强供给了重要完成,如 MixUp、 CutMix、Large Scale Jitter、 SimpleCopyPaste、AutoAugmentation 办法以及一些新的 Geometric、Colour 和 Type Conversion transforms。
该 API 继续支撑 single image 或 batched input image 的 PIL 和 Tensor 后端,并在 functional API 上保留了 JIT-scriptability。这使得图画映射得以从 uint8 延迟到 float, 带来了性能的进一步提升。
它目前能够在 TorchVision 的原型区域 (prototype area) 中运用,并且支撑从 nightly build 版别中导入。经验证,新 API 与从前完成的准确性一致。
当前的局限性
functional API (kernel) 仍然保持 JIT-scriptable 及 fully-BC,Transform Class 供给了相同的接口,却无法运用脚本。
这是因为 Transform Class 运用的是张量子类 (Tensor Subclassing),且接纳恣意数量的输入,这是 JIT 所不支撑的。该局限将在后续版别中不断优化。
一个端到端示
以下是一个新 API 示例,它能够一起运用 PIL 图画和张量。
测验图片:
代码示例:
import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
[[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
[148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
[422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
[435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
[469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
[452, 39, 463, 63], [424, 38, 429, 50]],
format=features.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()
—— 完 ——