TFRecords详解

内容目录

  • TFRecords 是什么
  • 序列化(Serialization)
    • tf.data
  • 图像序列化(Serializing Images)
    • tf.Example
    • 函数封装
  • 小结

TFRecords 是什么

TPU拥有八个核心,充当八个独立的工作单元。我们可以通过将数据集分成多个文件或分片(shards),更有效地将数据传输给每个核心。这样,每个核心都可以在需要时获取数据的独立部分。

在TensorFlow中,用于分片的最方便的文件类型是TFRecord。TFRecord是一种包含字节串序列的二进制文件。数据在写入TFRecord之前需要被序列化(编码为字节串)。

在TensorFlow中,最方便的数据序列化方式是使用tf.Example封装数据。这是一种基于谷歌的protobufs的记录格式,但专为TensorFlow设计。它更或多或少地类似于带有一些类型注释的字典。

首先,我们将介绍如何使用TFRecords读取和写入数据。然后,我们将介绍如何使用tf.Example封装数据。

Protobufs(Protocol Buffers),也称为Protocol Buffers语言,是一种由Google开发的数据序列化格式。它可以用于结构化数据的序列化、反序列化以及跨不同平台和语言的数据交换。通过在一个结构体定义文件中定义数据结构,然后使用相应的编译器将其编译为特定语言的类,您可以方便地在不同的系统和编程语言之间共享和传输数据。

序列化(Serialization)

TFRecord是TensorFlow用于存储二进制数据的一种文件类型。TFRecord包含字节串序列。下面是一个非常简单的TFRecord示例:

import tensorflow as tf
import numpy as np

PATH = '/kaggle/working/data.tfrecord'

with tf.io.TFRecordWriter(path=PATH) as f:
    f.write(b'123') # write one record
    f.write(b'xyz314') # write another record

with open(PATH, 'rb') as f:
    print(f.read())

在这里插入图片描述

TFRecord是一系列字节,因此在将数据放入TFRecord之前,我们必须将数据转换为字节串。我们可以使用tf.io.serialize_tensor将张量转换为字节串使用tf.io.parse_tensor将其转换回张量。在解析字符串并将其再次转换为张量时,保持张量的数据类型(在这种情况下为tf.uint8)非常重要,因为您必须在解析过程中指定该数据类型。

x = tf.constant([[1, 2], [3, 4]], dtype=tf.uint8)
print('x:', x, '\n')

x_bytes = tf.io.serialize_tensor(x)
print('x_bytes:', x_bytes, '\n')

print('x:', tf.io.parse_tensor(x_bytes, out_type=tf.uint8))

在这里插入图片描述

tf.data

那么如何将数据集写入TFRecord呢?如果您的数据集由字节串组成,您可以使用data.TFRecordWriter。要再次读取数据集,可以使用data.TFRecordsDataset。

from tensorflow.data import Dataset, TFRecordDataset
from tensorflow.data.experimental import TFRecordWriter

# 创建一个小数据集
ds = Dataset.from_tensor_slices([b'abc', b'123'])

# 写入数据
writer = TFRecordWriter(PATH)
writer.write(ds)
    
# 读取数据集
ds_2 = TFRecordDataset(PATH)
for x in ds_2:
    print(x)

如果您的数据集由张量组成,请首先通过在数据集上映射tf.io.serialize_tensor来进行序列化。然后,在读取数据时,使用tf.io.parse_tensor来将字节串转换回张量。


features = tf.constant([
    [1, 2],
    [3, 4],
    [5, 6],
], dtype=tf.uint8)
ds = Dataset.from_tensor_slices(features)

# 对张量进行序列化操作
# 通过使用 `map` 函数,可以在数据集中的每个张量上应用 `tf.io.serialize_tensor` 进行序列化操作。
ds_bytes = ds.map(tf.io.serialize_tensor)

# 写入数据
writer = TFRecordWriter(PATH)
writer.write(ds_bytes)

# 读取数据(反序列化)
ds_bytes_2 = TFRecordDataset(PATH)
ds_2 = ds_2.map(lambda x: tf.io.parse_tensor(x, out_type=tf.uint8))

# They are the same!
for x in ds:
    print(x)
print()
for x in ds_2:
    print(x)

在这里插入图片描述

# 简化
def parse_serialized(serialized):
    return tf.io.parse_tensor(serialized, out_type=tf.uint8)  # 修改 out_type 根据您的张量数据类型

ds_3 = TFRecordDataset(PATH)

ds_3 = ds_3.map(parse_serialized)

for x in ds_3:
    print(x) #结果和上面一致

图像序列化(Serializing Images)

对图像进行序列化有多种方法:

  • 使用tf.io.serialize_tensor进行原始编码,使用tf.io.parse_tensor进行解码。
  • 使用tf.io.encode_jpeg进行JPEG编码,使用tf.io.decode_jpeg或tf.io.decode_and_crop_jpeg进行解码。
  • 使用tf.io.encode_png进行PNG编码,使用tf.io.decode_png进行解码。

只需确保使用与您选择的编码器相对应的解码器。通常,在使用TPU时,使用JPEG编码对图像进行编码是一个不错的选择,因为这可以对数据进行一定程度的压缩,从而可能提高数据传输速度。

from sklearn.datasets import load_sample_image
import matplotlib.pyplot as plt

# Load numpy array
image_raw = load_sample_image('flower.jpg')
print("Type {} with dtype {}".format(type(image_raw), image_raw.dtype))
plt.imshow(image_raw)
plt.title("Numpy")
plt.show()

在这里插入图片描述

from IPython.display import Image

# jpeg encode / decode
image_jpeg = tf.io.encode_jpeg(image_raw)
print("Type {} with dtype {}".format(type(image_jpeg), image_jpeg.dtype)) 
print("Sample: {}".format(image_jpeg.numpy()[:25])) #显示前25个编码后的字节
Image(image_jpeg.numpy())

在这里插入图片描述

image_raw_2 = tf.io.decode_jpeg(image_jpeg)

print("Type {} with dtype {}".format(type(image_raw_2), image_raw_2.dtype))
plt.imshow(image_raw_2)
plt.title("Numpy")
plt.show()

在这里插入图片描述

tf.Example

如果您有结构化数据,比如成对的图像和标签,该怎么办?TensorFlow还包括用于结构化数据的API,即tf.Example。它们基于谷歌的Protocol Buffers。

一个单独的Example旨在表示数据集中的一个实例,比如一个(图像、标签)对。每个Example都有Features,这被描述为特征名称和值的字典。一个值可以是BytesList、FloatList或Int64List,每个值都包装为单独的Feature。没有用于张量的值类型;相反,使用tf.io.serialize_tensor对张量进行序列化,通过numpy方法获取字节串,并将其编码为BytesList。

以下是我们如何对带有标签的图像数据进行编码的示例:

from tensorflow.train import BytesList, FloatList, Int64List
from tensorflow.train import Example, Features, Feature

# The Data
image = tf.constant([ # this could also be a numpy array
    [0, 1, 2],
    [3, 4, 5],
    [6, 7, 8],
])
label = 0
class_name = "Class A"


# Wrap with Feature as a BytesList, FloatList, or Int64List
image_feature = Feature(
    bytes_list=BytesList(value=[
        tf.io.serialize_tensor(image).numpy(),
    ])
)
label_feature = Feature(
    int64_list=Int64List(value=[label]),
)
class_name_feature = Feature(
    bytes_list=BytesList(value=[
        class_name.encode()
    ])
)


# Create a Features dictionary
features = Features(feature={
    'image': image_feature,
    'label': label_feature,
    'class_name': class_name_feature,
})

# Wrap with Example
example = Example(features=features)

print(example)

在这里插入图片描述
查看标签内容
![[Pasted image 20230810140233.png]]![[Pasted image 20230810140309.png]]

一旦所有内容都被编码为一个示例(Example),可以使用SerializeToString方法将其序列化。
![[Pasted image 20230810140347.png]]

函数封装

def make_example(image, label, class_name):
    image_feature = Feature(
        bytes_list=BytesList(value=[
            tf.io.serialize_tensor(image).numpy(),
        ])
    )
    label_feature = Feature(
        int64_list=Int64List(value=[
            label,
        ])
    )
    class_name_feature = Feature(
        bytes_list=BytesList(value=[
            class_name.encode(),
        ])
    )

    features = Features(feature={
        'image': image_feature,
        'label': label_feature,
        'class_name': class_name_feature,
    })
    
    example = Example(features=features)
    
    return example.SerializeToString()

函数使用如下:

example = make_example(
    image=np.array([[1, 2], [3, 4]]),
    label=1,
    class_name="Class B",
)

print(example)

![[Pasted image 20230810140530.png]]

小结

整个过程可能如下所示:

  1. 使用tf.data.Dataset构建数据集。您可以使用from_generatorfrom_tensor_slices方法。
  2. 通过使用make_example遍历数据集来序列化数据集。
  3. 使用io.TFRecordWriterdata.TFRecordWriter将数据集写入TFRecords。

然而,请注意,如果要在数据集的map方法中使用make_example之类的函数,您需要首先使用tf.py_function对其进行包装,因为TensorFlow以图模式执行数据集变换。您可以编写类似以下的代码:

ds_bytes = ds.map(lambda image, label: tf.py_function(func=make_example, inp=[image, label], Tout=tf.string))

其他资料
API文档tf.data.Dataset | TensorFlow v2.13.0。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/73367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

初始多线程

目录 认识线程 线程是什么: 线程与进程的区别 Java中的线程和操作系统线程的关系 创建线程 继承Thread类 实现Runnable接口 其他变形 Thread类及其常见方法 Thread的常见构造方法 Thread类的几个常见属性 Thread类常用的方法 启动一个线程-start() 中断…

[保研/考研机试] KY109 Zero-complexity Transposition 上海交通大学复试上机题 C++实现

描述: You are given a sequence of integer numbers. Zero-complexity transposition of the sequence is the reverse of this sequence. Your task is to write a program that prints zero-complexity transposition of the given sequence. 输入描述&#xf…

Docker的基本概念及镜像加速器的配置

1.Docker的概念 由于代码运行环境不同,代码运行会出现水土不服的情况。运用docker容器会把环境进行打包,避免水土不服。docker是一种容器技术,它解决软件跨环境迁移的问题。 2,安装Docker 3.Docker架构 4.Docker镜像加速器的配…

11、Nvidia显卡驱动、CUDA、cuDNN、Anaconda及Tensorflow Pytorch版本

Nvidia显卡驱动、CUDA、cuDNN、Anaconda及Tensorflow-GPU版本 一、确定版本关系二、安装过程1.安装显卡驱动2、安装CUDA3、安装cudnn4、安装TensorFlow5、安装pytorch 三、卸载 一、确定版本关系 TensorFlow Pytorch推出cuda和cudnn的版本,cuda版本推出驱动可选版本…

【boost网络库从青铜到王者】第二篇:asio网络编程中的socket的监听和连接

文章目录 1、网络编程基本流程2、终端节点endpoint的创建2.1、客户端终端节点endpoint的创建2.2、服务器终端节点endpoint的创建 3、服务器与客户端通信套接字socket的创建4、服务器监听套接字socket的创建5、绑定accpet监听套接字6、客户端连接指定的端点7、服务器接收连接8、…

网工最常犯的9大错误,越早知道越吃香

下午好,我的网工朋友 我们常说,人要学会避免错误,尤其是对在职场生活的打工人来说,更是如此。 学生时代,我们通过错题本收集错误,提高刷题正确率和分数,但到了职场,因为没有量化的…

WebAPIs 第四天

1.日期对象 2.节点操作 3.M端事件 4.JS插件 一.日期对象 实例化时间对象方法时间戳 日期对象:用来表示时间的对象 作用:可以得到当前系统时间 1.1 实例化 ① 概念:在代码中发现了new关键字时,一般将这个操作称为实例化 …

九、多态(2)

本章概要 构造器和多态 构造器调用顺序继承和清理构造器内部多态方法的行为 协变返回类型使用继承设计 替代 vs 扩展向下转型与运行时类型信息 构造器和多态 通常,构造器不同于其他类型的方法。在涉及多态时也是如此。尽管构造器不具有多态性(事实上…

DoIP学习笔记系列:(五)“安全认证”的.dll从何而来?

文章目录 1. “安全认证”的.dll从何而来?1.1 .dll文件base1.2 增加客户需求算法传送门 DoIP学习笔记系列:导航篇 1. “安全认证”的.dll从何而来? 无论是用CANoe还是VFlash,亦或是编辑cdd文件,都需要加载一个与$27服务相关的.dll(Windows的动态库文件),这个文件是从哪…

vscode配置vue用户代码片段

打开vscode软件 选中左下角的设置按钮,再点击用户代码片段(如图) 再选择vue.json文件/新建全局代码片段(如图) 进行相关配置(如下代码) {"Vue2 quickly build template": {&q…

Unity UI.Image 六边形+流光 Shader

效果图 参考代码 Shader"Custom/HexFlowImage" {Properties{[PerRendererData] _MainTex ("Sprite Texture", 2D) "white" {}_Color ("Tint", Color) (1,1,1,1)_StencilComp ("Stencil Comparison", Float) 8_Stencil (…

SQL | 注释

2-注释 2.1-单行注释 select prod_name -- 这是一条行内注释 from products; 使用两个连字符(-- ) 放在行内,两个连字符后的内容即为注释内容。 # 这是一条注释 select prod_name from products; 这种注释方式可能有些数据库不支持,所以使用前应该…

shiro框架基本概念介绍

目录 什么是Shiro: Shiro的核心功能包括: Shiro主要组件及相互作用: Shiro 认证过程: Shiro 授权过程: 资料获取方法 什么是Shiro: Shiro 是一个强大灵活的开源安全框架,可以完全处理身份验证、授权、加密和会话…

funbox3靶场渗透笔记

funbox3靶场渗透笔记 靶机地址 https://download.vulnhub.com/funbox/Funbox3.ova 信息收集 fscan找主机ip192.168.177.199 .\fscan64.exe -h 192.168.177.0/24___ _/ _ \ ___ ___ _ __ __ _ ___| | __/ /_\/____/ __|/ __| __/ _ |/ …

单源最短路的扩展应用

选择最佳线路 有一天,琪琪想乘坐公交车去拜访她的一位朋友。 由于琪琪非常容易晕车,所以她想尽快到达朋友家。 现在给定你一张城市交通路线图,上面包含城市的公交站台以及公交线路的具体分布。 已知城市中共包含 n 个车站(编号…

Azure概念介绍

云计算定义 云计算是一种使用网络进行存储和处理数据的计算方式。它通过将数据和应用程序存储在云端服务器上,使用户能够通过互联网访问和使用这些资源,而无需依赖于本地硬件和软件。 发展历史 云计算的概念最早可以追溯到20世纪60年代的时候&#x…

Stable Diffusion WebUI安装和使用教程(Windows)

目录 下载Stable Diffusion WebUI运行安装程序,双击webui.bat界面启动插件安装(github)模型下载(有些需要魔法)安装过程遇到的大坑总结参考的博客 整个过程坑巨多,我花了一个晚上的时间才全部搞定,本教程针对有编程基础…

网络设备(防火墙、路由器、交换机)日志分析监控

外围网络设备(如防火墙、路由器、交换机等)是关键组件,因为它们控制进出公司网络的流量。因此,监视这些设备的活动有助于 IT 管理员解决操作问题,并保护网络免受攻击者的攻击。通过收集和分析这些设备的日志来监控这些…

苹果Mac像Windows一样使用

一、将磁盘访问设置的像Windows一样: 1.1、点击任务栏第一个按钮打开“访达”,点击菜单栏上的访达-偏好设置: 1.2、勾选“硬盘”,这样macOS的桌面上就会显示一个本地磁盘,之后重命名为磁盘根,相当于window…

部署Springboot项目注意事项

步骤 步骤 1:将数据库内容在云服务器上的数据库部署一份 我使用mariadb;会出现一些不兼容现象;我们需要把默认值删掉 2:配置文件你得修改地方 a:linux是磁盘区分(像我自己项目用来储存验证码的文件我们得换这个配置;…