拓展神经网络八股(入门级)

自制数据集

minst等数据集是别人打包好的,如果是本领域的数据集。自制数据集。

替换

把图片路径和标签文件输入到函数里,并返回输入特征和标签

只需要把图片灰度值数据拼接到特征列表,标签添加到标签列表,提取操作函数如下:

def generateds(path, txt):
    f = open(txt, 'r')
    contents = f.readlines() #读取所有行
    f.close()
    x, y_ = [], []
    for content in contents:
        value = content.split()
        img_path = path + value[0]#找到图片索引路径
        img = Image.open(img_path) #图片打开
        img = np.array(img.convert('L')) # 图片变为8位灰度的npy格式的数据集                    
        img = img / 255.
        x.append(img)
        y_.append(value[1])
        print('loading:' + content) # 打印状态提示
    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_astype(np.int64)
    return x, y_

 完整代码

import tensorflow as tf
from PIL import Image
import numpy as np
import os

train_path = './fashion_image_label/fashion_train_jpg_60000/'
train_txt = './fashion_image_label/fashion_train_jpg_60000.txt'
x_train_savepath = './fashion_image_label/fashion_x_train.npy'
y_train_savepath = './fashion_image_label/fahion_y_train.npy'

test_path = './fashion_image_label/fashion_test_jpg_10000/'
test_txt = './fashion_image_label/fashion_test_jpg_10000.txt'
x_test_savepath = './fashion_image_label/fashion_x_test.npy'
y_test_savepath = './fashion_image_label/fashion_y_test.npy'


def generateds(path, txt):
    f = open(txt, 'r')
    contents = f.readlines()  # 按行读取
    f.close()
    x, y_ = [], []
    for content in contents:
        value = content.split()  # 以空格分开,存入数组
        img_path = path + value[0]
        img = Image.open(img_path)
        img = np.array(img.convert('L'))
        img = img / 255.
        x.append(img)
        y_.append(value[1])
        print('loading : ' + content)

    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_.astype(np.int64)
    return x, y_


if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(
        x_test_savepath) and os.path.exists(y_test_savepath):
    print('-------------Load Datasets-----------------')
    x_train_save = np.load(x_train_savepath)
    y_train = np.load(y_train_savepath)
    x_test_save = np.load(x_test_savepath)
    y_test = np.load(y_test_savepath)
    x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))
    x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:
    print('-------------Generate Datasets-----------------')
    x_train, y_train = generateds(train_path, train_txt)
    x_test, y_test = generateds(test_path, test_txt)

    print('-------------Save Datasets-----------------')
    x_train_save = np.reshape(x_train, (len(x_train), -1))
    x_test_save = np.reshape(x_test, (len(x_test), -1))
    np.save(x_train_savepath, x_train_save)
    np.save(y_train_savepath, y_train)
    np.save(x_test_savepath, x_test_save)
    np.save(y_test_savepath, y_test)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

数据增强

如果数据量过少,模型见识不足。增加数据,提高泛化力。

用来应对因为拍照角度不同引起的图片变形

image_gen_train=tf,keras.preprocessing,image.ImageDataGenneratorP(...)

image_gen)train,fit(x_train)

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,使数据和网络结构匹配

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机45度旋转
    width_shift_range=.15,  # 宽度偏移
    height_shift_range=.15,  # 高度偏移
    horizontal_flip=True,  # 水平翻转
    zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)
model.summary()

 因为是标准MINST数据集,因此在准确度上看不出来,需要在具体应用中才能体现

断点续训

实时保存最优模型

 保存模型参数可以使用tensorflow提供的ModelCheckpoint(filepath=checkpoint_save,

                              save_weight_only,sabe_best_only)

参数提取

获取各层网络最优参数,可以在各个平台实现应用

model.trainable_variables 返回模型中可训练参数

acc/loss可视化

查看训练效果

history=model.fit()

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt

np.set_printoptions(threshold=np.inf)

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1) 画出第一列
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2) #画出第二列
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

应用程序

给图识物

给出一张图片,输出预测结果

1.复现模型 Sequential加载模型

2.加载参数 load_weights(model_save_path)

3.预测结果

我们需要对颜色取反,我们的训练图片是黑底白字

减少了背景噪声的影响

from PIL import Image
import numpy as np
import tensorflow as tf

type = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

model_save_path = './checkpoint/fashion.ckpt'
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
                                        
model.load_weights(model_save_path)

preNum = int(input("input the number of test pictures:"))
for i in range(preNum):
    image_path = input("the path of test picture:")
    img = Image.open(image_path)
    img=img.resize((28,28),Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))
    img_arr = 255 - img_arr  #每个像素点= 255 - 各自点当前灰度值
    img_arr = img_arr/255.0
    x_predict = img_arr[tf.newaxis,...]

    result = model.predict(x_predict)
    pred=tf.argmax(result, axis=1)
    print('\n')
    print(type[int(pred)])


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/788387.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

25.无源蜂鸣器驱动设计

相对于有源蜂鸣器,无源蜂鸣器的成本更低,声音频率可控。而有源蜂鸣器因其内部 自带振荡源,只要加上适当的直流电源即可发声,程序控制较为方便。 (1)设计定义:设计一个无源蜂鸣器的驱动程序&…

基于泰坦尼克号生还数据进行 Spark 分析

基于泰坦尼克号生还数据进行 Spark 分析 在这篇博客中,我们将展示如何使用 Apache Spark 分析著名的泰坦尼克号数据集。通过这篇教程,您将学习如何处理数据、分析乘客的生还情况,并生成有价值的统计信息。 数据解析 • PassengerId &#…

ctfshow-web入门-文件上传(web164、web165)图片二次渲染绕过

web164 和 web165 的利用点都是二次渲染,一个是 png,一个是 jpg 目录 1、web164 2、web165 二次渲染: 网站服务器会对上传的图片进行二次处理,对文件内容进行替换更新,根据原有图片生成一个新的图片,这样…

EasyCVR视频汇聚平台:存储系统怎么选?分布式存储vs.集中式存储的区别在哪?

在当今的数字化时代,安防监控已成为维护社会秩序和公共安全的重要手段。随着监控设备的普及和监控数据的不断增加,如何高效、安全地存储和管理这些视频数据,成为了安防行业面临的重要挑战。EasyCVR视频存储系统凭借其卓越的性能和灵活的架构&…

综合安全防护

题目 1,DMZ区内的服务器,办公区仅能在办公时间内(9:00-18:00)可以访问,生产区的设备全天可以访问. 2,生产区不允许访问互联网,办公区和游客区允许访问互联网 3,办公区设备10.0.2.10不允许访问DMz区的FTP服务器和HTTP服务器,仅能ping通10.0.3.10 4,办公区分为市场部和研发部,研…

pnpm workspace使用教程【Monorepo项目】

目录 前言一、pnpm简介特点:对比 二、 创建项目添加文件 pnpm-workspace.yaml目录结构pnpm workspace: 协议修改配置文件执行 安装 三、命令解析执行包命令所有包操作命令 四、实例代码 前言 前面两篇,我们讲了 yarn workspace 和 lerna , …

局域网远程共享桌面如何实现

在局域网内实现远程共享桌面,可以通过以下几种方法: 一、使用Windows自带的远程桌面功能: 首先,在需要被控制的电脑上右键点击“此电脑”,选择“属性”。 进入计算机属性界面后,点击“高级系统设置”&am…

记录excel表生成一列按七天一个周期的方法

使用excel生成每七天一个周期的列。如下图所示: 针对第一列的生成办法,使用如下函数: TEXT(DATE(2024,1,1)(ROW()-2)*7,"yyyy/m/d")&" - "&TEXT(DATE(2024,1,1)(ROW()-1)*7-1,"yyyy/m/d") 特此记录。…

一文实践强化学习训练游戏ai--doom枪战游戏实践

一文实践强化学习训练游戏ai–doom枪战游戏实践 上次文章写道下载doom的环境并尝试了简单的操作,这次让我们来进行对象化和训练、验证,如果你有基础,可以直接阅读本文,不然请你先阅读Doom基础知识,其中包含了下载、动作…

android CameraX构建相机拍照

Android CameraX 是一个 Jetpack 支持库,旨在简化相机应用的开发工作。它提供了一致且易用的API接口,适用于大多数Android设备,并可向后兼容至Android 5.0(API级别21)。 CameraX解决了在多种设备上实现相机功能时所遇…

14-56 剑和诗人30 - IaC、PaC 和 OaC 在云成功中的作用

介绍 随着各大企业在 2024 年加速采用云计算,基础设施即代码 (IaC)、策略即代码 (PaC) 和优化即代码 (OaC) 已成为成功实现云迁移、IT 现代化和业务转型的关键功能。 让我在云计划的背景下全面了解这些代码功能的当前状态。我们将研究现代云基础设施趋势、IaC、Pa…

java:获取当前的日期和时间

// 获取当前的日期和时间LocalDateTime now LocalDateTime.now();// 定义日期时间格式化器DateTimeFormatter formatter DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");// 格式化日期时间String formattedDateTime now.format(formatter);// 打印结果Syste…

【数据结构和算法的概念等】

目录 一、数据结构1、数据结构的基本概念2、数据结构的三要素2.1 数据的逻辑结构2.2 数据的存储(物理)结构2.3 数据的运算 二、算法1、算法概念2、算法的特性及特点3、算法分析 一、数据结构 1、数据结构的基本概念 数据: 是所有能输入到计…

前端八股文 对事件循环的理解

对事件循环的理解 思维导图 图示 实际案例的执行过程 总结

能源电子领域2区SCI,版面稀缺,即将截稿,无版面费!

【SciencePub学术】今天小编给大家推荐1本能源电子领域的SCI!影响因子1.0-2.0之间,最重要的是审稿周期较短,对急投的学者较为友好! 能源电子类SCI 01 / 期刊概况 【期刊简介】IF:1.0-2.0,JCR2区&#xf…

【C++】C++入门基础--引用,inline,nullptr

文章目录 前言一、引用?1.1 引用的概念和定义1.2 引用的特性1.3 引用的使用1.4 const引用(常引用)1.5 指针和引用的关系 二、inline2.1inline概念和定义2.2 inline使用2.3 inline注意事项 三.nullptr总结 前言 上一篇文章我们介绍了C中的命名…

枚举对象序列化规则(将Java枚举转换为JSON字符串的步骤)

文章目录 引言I 案例分析1.1 接口签名计算1.2 请求对象1.3 枚举对象序列化II 在JSON中以枚举的code值来表示枚举的实现方式2.1 自定义toString方法返回code引言 在Java中,每个对象都有一个toString方法,用于返回该对象的字符串表示。默认情况下,Enum类的toString方法返回的…

C语言笔记30 •单链表经典算法OJ题-2.移除链表元素•

移除链表元素 1.问题 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 2.代码实现&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <stdlib.h&g…

【RHCE】转发服务器实验

1.在本地主机上操作 2.在客户端操作设置主机的IP地址为dns 3.测试,客户机是否能ping通

特征及特征选择

1、特征&#xff08;Feature&#xff09;是什么&#xff1f; 特征是数据集中的一个可量化的属性或变量&#xff0c;用于描述数据点的特性。 特征可以是连续的数值&#xff0c;如身高、体重等&#xff0c;也可以是离散的类别&#xff0c;如性别、种族等。 常见的特征有边缘、角、…