留意:本文运用tensorflow1.x版本进行演示

运用本地Jupyter Notebook搭载TensorFlow相关库进行操作

1. 读取TFRecords文件

其实读取TFRecords文件大体思路与常规文件读取思路(结构行列、读取、解码、批处理行列)比较一致。但是,仍是有一点不相同,在解码操作之前,需求解析Example操作(由于TFRecords比其他文件多了个Example结构),TFRecords文件读取步骤如下所示:

  • 结构文件名行列
  • 读取
  • 解析Example
    • tf.parse_single_example()
      • tf.FixedLenFeature(shape, dtype)
  • 解码
  • 结构批处理行列

接下来,咱们将对TFRecords文件读取中用到的函数进行详细说明:

  • tf.parse_single_example(serialized, features=None, name=None)

    • 用来解析一个单一的Example原型
    • serialized:标量字符串Tensor,一个序列化的Example
    • features:dict字典数据,键为读取的姓名,值为FixedLenFeature
    • return:回来一个键值对组成的字典,键为读取的姓名。想拿到解析后的example数据,需求经过字典形式访问。
  • tf.FixedLenFeature(shape, dtype)

    • 这个函数和上一个函数其实是嵌套运用的,上一个函数中的features参数中的一部分(字典中值的部分)需求用本函数填充
    • shape:输入数据的形状,一般不指定即为空列表
    • dtype:输入数据类型,与存储进文件的类型要相同
    • 类型只能是float32,int64,string

2. 代码演示

导入所需模块,由于本地下载的是Tensorflow2.x版本,想运转Tensorflow1的语法,需求敞开兼容模型,以支撑Tensorflow1语法正常运转。

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os

从读取TFRecords文件的视点,进行函数定义,对已保存到本地的TFRecords文件进行读取。详细代码如下所示:

  • 首先在函数中需求结构文件名行列,经过变量file_queue接纳

  • 然后,运用tf.TFRecordReader()读取器,运用该读取器下面的read办法进行文件读取,运用变量key与value承受元组

  • 接下来以上述介绍的API进行Example解析,能够将中间结果image与label打印出来看看

  • 别忘记敞开会话tf.Session()才干看到详细的值

    • 会话中tf.train.Coordinator()敞开线程
    • sess.run()运转一下用以检查详细的值
    • 收回资源,收回线程
  • 然后是解码操作,咱们能够将其解码成uint8

  • 打印出的是一维数组,咱们需求进行图画调整将其调整成32323

  • 终究,将其放入批处理行列。

class Cifar():
    def __init__(self):
        # 设置图画巨细
        self.height = 32
        self.width = 32
        self.channel = 3
        # 设置图画字节数
        self.image = self.height * self.width * self.channel
        self.label = 1
        self.sample = self.image + self.label
    def read_tfrecords(self):
        """
        读取tfrecords文件
        """
        # 1. 结构文件名行列
        file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
        # 2. 读取与解码
        # 2.1 读取
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)
        # 2.2 解析example
        feature = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64)
        })
        image = feature["image"]
        label = feature["label"]
        print("read_tf_image:\n", image)
        print("read_tf_label:\n", label)
        # 2.3 解码
        image_decoded = tf.decode_raw(image, tf.uint8)
        print("image_decoded:\n", image_decoded)
        # 图画形状调整
        image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
        # 3. 结构批处理行列
        image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=5, capacity=100)
        print("image_batch:\n", image_batch)
        print("label_batch:\n", label_batch)
        # 敞开会话
        with tf.Session() as sess:
            # 敞开线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            image_value, label_value, image_decoded, image_batch, label_batch = sess.run([image, label, image_decoded, image_batch, label_batch])
            print("image_value:\n", image_value)
            print("label_value:\n", label_value)
            # 收回资源
            coord.request_stop()
            coord.join(threads)
        return None
cifar = Cifar()
cifar.read_tfrecords()

部分读取结果如下图所示:

【深度学习】TensorFlow:TFRecords文件读取
本文正在参加「金石方案 . 瓜分6万现金大奖」