TensorFlow2从磁盘读取图片数据集的示例(tf.data.Dataset.list_files)

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.resnet import ResNet50
from pathlib import Path
import numpy as np

#数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = Path(os.path.join(base_dir,'train'))
file_pattern = os.path.join(train_dir,'*/*.jpg')
image_count = len(list(train_dir.glob('*/*.jpg')))

list_ds = tf.data.Dataset.list_files(file_pattern,shuffle = False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5):
  print(f.numpy())
  
class_names = np.array(sorted([item.name for item in train_dir.glob('*') ]))
print(class_names)

val_size = int(image_count * 0.2)
train_data = list_ds.skip(val_size)
validation_data = list_ds.take(val_size)
print(tf.data.experimental.cardinality(train_data).numpy())
print(tf.data.experimental.cardinality(validation_data).numpy())


def get_label(file_path):
  parts = tf.strings.split(file_path, os.path.sep)
  one_hot = parts[-2] == class_names
  return tf.argmax(one_hot)

def decode_img(img):
  img = tf.io.decode_jpeg(img, channels=3)
  return tf.image.resize(img, [64, 64])

def process_path(file_path):
  label = get_label(file_path)
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

train_data = train_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
validation_data = validation_data.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)

for image, label in train_data.take(2):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(4)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

train_data = configure_for_performance(train_data)
validation_data = configure_for_performance(validation_data)


save_model_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model_resnet50_cats_and_dogs.h5', save_freq='epoch')

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(64, 64, 3))
base_model.trainable = True
    
model = tf.keras.models.Sequential([
    base_model,
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(l=0.01)),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',optimizer = Adam(lr=1e-3),metrics = ['acc'])

history = model.fit(train_data.repeat(),steps_per_epoch=100,epochs=50,validation_data=validation_data.repeat(),validation_steps=50,verbose=1,callbacks = [save_model_cb])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/104577.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

外包干了3个月,技术退步明显。。。。。

先说一下自己的情况,本科生生,19年通过校招进入广州某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测…

最简单一个跳板机实现

一、背景 生产环境有很多服务器需要让开发人员登录上去做一些基本的操作,比如查看日志什么的,一般小厂对安全要求不是很高,就会直接把ROOT账号给到所有开发人员,但这样当员工离职后,需要去修改ROOT密码是一件很麻烦的…

nodejs使用es-batis

使用方法 创建连接 因为它只支持非连接池所以每次都要创建连接 let dao new MySqlDaoContext({charset: "utf8",host: "localhost",user: "root",password: "root",database: "test",});await dao.initialize();dao in…

0022Java程序设计-ssm微信小程序社区互助平台

文章目录 **摘要**目录系统设计开发环境 摘要 首先,论文一开始便是清楚的论述了小程序的研究内容。其次剖析系统需求分析,弄明白“做什么”,分析包括业务分析、业务流程分析、用例分析,更进一步明确系统的需求。然后在明白了小程序的需求基础上,需要进一步地设计系…

微信批量添加好友,让你的人脉迅速增长

在这个数字化时代,微信作为中国最流行的社交平台之一,已经成为了人们生活中不可或缺的一部分。它的广泛使用为我们提供了无限的社交可能性。你是否曾为了扩大人脉圈子而犯愁?今天,我将向你揭示一个高效添加微信好友的秘密武器&…

(免费领源码)java#Springboot#mysql装修选购网站99192-计算机毕业设计项目选题推荐

摘 要 随着科学技术,计算机迅速的发展。在如今的社会中,市场上涌现出越来越多的新型的产品,人们有了不同种类的选择拥有产品的方式,而电子商务就是随着人们的需求和网络的发展涌动出的产物,电子商务网站是建立在企业与…

Oracle通过透明网关查询SQL Server 报错ORA-00904

Oracle通过透明网关查询SQL Server 报错ORA-00904 问题描述: 只有全表扫描SELECT * 时SQL语句可以正常执行 添加WHERE条件或指定列名查询,查询语句就报错 问题原因: 字段大小写和SQLSERVER中定义的不一致导致查询异常 解决办法: 给…

Java中JVM、JRE和JDK三者有什么区别和联系?

任何语言或者软件的运行都需要环境。就像人要生活在空气中,鱼要活在水中,喜阴植物就不能放在阳光下暴晒一样,任何对象个体的存在都离不开其所需要的环境,编程语言亦是一样的。 java 语言的开发运行,也离不开 Java 语言…

windows如何查看电脑IP地址

介绍两种查看电脑IP的方式 一、第一种方式 1、在电脑左下角搜索网络连接 2、鼠标右键你目前连接的网络(wifi就点wifi 网线就点以太网);选择里面的状态。 3、点击详细信息,这里的IPv4地址就是你电脑的IP。 二、第二种 1、win…

出差学小白知识No5:|Ubuntu上关联GitLab账号并下载项目(ssh key配置)

1 注冊自己的gitlab账户 有手就行 2 ubuntu安装git ,并查看版本 sudo apt-get install git git --version 3 vim ~/.ssh/config Host gitlab.example.com User your_username Port 22 IdentityFile ~/.ssh/id_rsa PreferredAuthentications publickey 替换gitl…

CC001:CC照片建模

摘要:CC照片建模原理是通过从图像中提取特征点和特征描述符,然后根据特征点的匹配来计算相机的位姿,从而生成三维点云数据。最后,借助网格重建和纹理映射的方法,将点云转换为带有纹理的三维网格模型。 实验数据&#x…

每日一题 2520. 统计能整除数字的位数(简单)

简单题频率好高,预测一波明天困难 class Solution:def countDigits(self, num: int) -> int:ans 0for i in str(num):if num % int(i) 0:ans 1return ans

DC-7 靶机

DC_7 信息搜集 存活检测 详细扫描 后台网页扫描 网页信息搜集 搜索相关信息 在配置中发现了用户名密码字样 $username "dc7user"; $password "MdR3xOgB7#dW";ssh 登录 尝试使用获取的账密进行登录 网页登录失败 尝试 ssh 登录 成功登录 登陆今后提…

mac安装jdk

1、下载jdk(我的电脑要下载arm版,截图不对) Java Downloads | Oraclehttps://www.oracle.com/java/technologies/downloads/#jdk17-mac 2、双击安装

性能测试用例和测试结果

性能测试用例和测试结果 一 核心业务功能的TPS测试1.1 登录接口测试用例1.2 进入首页接口测试用例1.3 添加购物车接口测试用例1.4 结算和下订单接口测试用例1.5 系统资源使用率1.6 单接口测试中一个测试的各个成员接口要单独做性能统计 二 业务流程(多接口组合&…

HarmonyOS 快速入门TypeScript

1.什么是TypeScript,它和JavaScript,ArkTs有什么区别 ArkTS是HarmonyOS优选的主力应用开发语言。它在TypeScript(简称TS)的基础上,匹配ArkUI框架,扩展了声明式UI、状态管理等相应的能力,让开发…

GoLong的学习之路(二)语法之基本数据类型

书接上回:我在GoLong的学习之路(一)中在常量最后说了iota的作用。今天这里我在介绍一下我学习Go语言中基本数据类型。 文章目录 Go中的基本数据类型整型特殊整型数字字面语法 浮点型复数布尔值字符串字符串转义符多行字符字符串的常用操作&am…

汉威科技光纤预警系统,守护油气长输管道“大动脉”

石油、天然气早已成为城市生活中不可或缺的能源。广大车主能快速地加上汽油,千家万户能方便地用上天然气,得益于我国庞大的石油、天然气输送基础设施网络。 我国油气分布西多东少、北多南少,要想把千里、乃至万里之外的石油、天然气输送到中部…

Pytorch整体工作流程代码详解(新手入门)

一、前言 本文详细介绍Pytorch的基本工作流程及代码,以及如何在GPU上训练模型(如下图所示)包括数据准备、模型搭建、模型训练、评估及模型的保存和载入。 适用读者:有一定的Python和机器学习基础的深度学习/Pytorch初学者。 本文…

【FPGA零基础学习之旅#17】搭建串口收发与储存双口RAM系统

🎉欢迎来到FPGA专栏~搭建串口收发与储存双口RAM系统 ☆* o(≧▽≦)o *☆嗨~我是小夏与酒🍹 ✨博客主页:小夏与酒的博客 🎈该系列文章专栏:FPGA学习之旅 文章作者技术和水平有限,如果文中出现错误&#xff0…