乐高小人分类项目

数据来源

 LEGO Minifigures | Kaggle

建立文件目录

BASE_DIR = 'lego/star-wars-images/'
names = [
    'YODA', 'LUKE SKYWALKER', 'R2-D2', 'MACE WINDU', 'GENERAL GRIEVOUS'
]
tf.random.set_seed(1)

# Read information about dataset
if not os.path.isdir(BASE_DIR + 'train/'):
    for name in names:
        os.makedirs(BASE_DIR + 'train/' + name)
        os.makedirs(BASE_DIR + 'test/' + name)
        os.makedirs(BASE_DIR + 'val/' + name)

 

划分数据集

# Total number of classes in the dataset
orig_folders = ['0001/', '0002/', '0003/', '0004/', '0005/']
for folder_idx, folder in enumerate(orig_folders):
    files = os.listdir(BASE_DIR + folder)
    number_of_images = len([name for name in files])
    n_train = int((number_of_images * 0.6) + 0.5)
    n_valid = int((number_of_images * 0.25) + 0.5)
    n_test = number_of_images - n_train - n_valid
    print(number_of_images, n_train, n_valid, n_test)
    for idx, file in enumerate(files):
        file_name = BASE_DIR + folder + file
        if idx < n_train:
            shutil.move(file_name, BASE_DIR + 'train/' + names[folder_idx])
        elif idx < n_train + n_valid:
            shutil.move(file_name, BASE_DIR + 'val/' + names[folder_idx])
        else:
            shutil.move(file_name, BASE_DIR + 'test/' + names[folder_idx])

训练数据生成

train_gen = ImageDataGenerator(rescale=1./255)
val_gen = ImageDataGenerator(rescale=1./255)
test_gen = ImageDataGenerator(rescale=1./255)

train_batches = train_gen.flow_from_directory(
    'lego/star-wars-images/train',
    target_size=(256, 256),
    class_mode='sparse',
    batch_size=4,
    shuffle=True,
    color_mode='rgb',
    classes=names,
)

 查看其中一批次的样本

train_batch = train_batches[0]
test_batch = train_batches[0]
print(train_batch[0].shape)
print(train_batch[1])
print(test_batch[0].shape)
print(test_batch[1])


def show(batch, pre_labels=None):
    plt.figure(figsize=(10, 10))
    for i in range(4):
        plt.subplot(2, 2, i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(batch[0][i], cmap=plt.cm.binary)
        # extra index
        lbl = names[int(batch[1][i])]
        if pre_labels is not None:
            lbl += '/Pred:' + names[int(pre_labels[i])]
        plt.xlabel(lbl)
    plt.show()


show(test_batch)

模型建立

model = keras.Sequential([
    layers.Conv2D(32, (3, 3), strides=(1, 1), padding='valid', activation='relu', input_shape=(256, 256, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(5),
])
# print(model.summary())

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(0.001),
    metrics=['accuracy']
)

model.fit(train_batches, validation_data=val_batches, epochs=30, verbose=2)
model.save('lego_model.h5')

 绘制 loss 和 acc 

# plot loss and acc
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='valid loss')
plt.grid()
plt.legend(fontsize=15)

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='train acc')
plt.plot(history.history['val_accuracy'], label='valid acc')
plt.grid()
plt.legend(fontsize=15);

 模型评估和预测

model.evaluate(test_batches, verbose=2)

predictions = model.predict(test_batches)
predictions = tf.nn.softmax(predictions)
labels = np.argmax(predictions, axis=1)
print(test_batches[0][1])
print(labels[0:4])

show(test_batches[0], labels[0:4])

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/667434.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GPT-4o:新一代人工智能技术的全方位解析引言

目录 &#x1f40b;引言 &#x1f40b;梳理 GPT 各版本之间的内容 &#x1f988;GPT-1&#xff1a;开创性的起点 &#x1f988; GPT-2&#xff1a;参数规模的大幅提升 &#x1f988; GPT-3&#xff1a;参数爆炸与多任务学习 &#x1f988;GPT-4&#xff1a;进一步提升的智…

嵌入式模块学习小记(未分类)

L298N电机驱动板模块 Output A&#xff1a;接DC 电机 1 或步进电机的 A和 A-&#xff1b; Output B&#xff1a;接DC 电机 2 或步进电机的 B和 B-&#xff1b; 5V Enable&#xff1a;如果使用输入电源大于12V的电源&#xff0c;请将跳线帽移除。输入电源小于12V时短接可以提…

【Python面试50题】

1. **基础概念** 1. Python 是解释型还是编译型语言&#xff1f; 2. 什么是 Python 的 GIL&#xff08;全局解释器锁&#xff09;&#xff1f; 3. 如何理解 Python 中的可变与不可变数据类型&#xff1f; 4. 解释一下 Python 中的 pass 语句。 5. Python 中的列…

让低代码平台插上AI的翅膀 - 记开源驰骋AI平台升级

让低代码系统插上AI的翅膀——驰骋低代码开发平台引领新时代 在当今日新月异的科技世界中&#xff0c;人工智能&#xff08;AI&#xff09;已经成为各个行业不可或缺的一部分。从制造业的自动化生产到金融行业的智能风控&#xff0c;再到医疗领域的精准诊断&#xff0c;AI技术…

FPGA-ARM架构与分类

ARM架构&#xff0c;曾称进阶精简指令集机器&#xff08;Advanced RISC Machine&#xff09;更早称作Acorn RISC Machine&#xff0c;是一个32位精简指令集&#xff08;RISC&#xff09;处理器架构。 主要是根据FPGA zynq-7000的芯片编写的知识思维导图总结,废话不多说自取吧 …

GPT LoRA 大模型微调,生成猫耳娘

往期热门专栏回顾 专栏描述Java项目实战介绍Java组件安装、使用&#xff1b;手写框架等Aws服务器实战Aws Linux服务器上操作nginx、git、JDK、VueJava微服务实战Java 微服务实战&#xff0c;Spring Cloud Netflix套件、Spring Cloud Alibaba套件、Seata、gateway、shadingjdbc…

Windows环境安装redis

1、下载redis https://github.com/tporadowski/redis/releases 2、解压 .zip 3、更改文件名 更改文件名称为&#xff1a;redis 4、将本地解压后的redis&#xff0c;作为本地服务器下的应用服务 从redis文件路径下&#xff0c;执行cmd .\redis-server --service-install re…

使用wireshark分析tcp握手过程

开启抓包 tcpdump -i any host 127.0.0.1 and port 123 -w tcp_capture.pcap 使用telnet模拟tcp连接 telnet 127.0.0.1 123 如果地址无法连接&#xff0c;则会一直重试SYN包&#xff0c;各个平台SYN重试间隔并不一致&#xff0c;如下&#xff1a; 异常站点抓包展示&#xff…

word中设置页眉,首页不设置

在设计文档时&#xff0c;有时候会给文档设置页眉&#xff0c;但是一设置&#xff0c;就是每页都会同时设置&#xff0c;大部分都不需要首页设置&#xff0c;那咋么解决呢&#xff0c;请看以下的解说&#xff0c;Come On&#xff01;&#xff01;&#xff01; 1、首先点击头部…

基于SSM的“基于Apriori算法的网络书城”的设计与实现(源码+数据库+文档)

基于SSM的“基于Apriori算法的网络书城”的设计与实现&#xff08;源码数据库文档) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SSM 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 网站功能展示图 首页 商品分类 热销 新品 我的订单 个…

组装电脑(使用老机箱)

昨天同事拿来一台联想 ThinkCentre M6210t的台式机&#xff0c;说计算机实在是太慢了&#xff0c;在只保留主机箱想升级一下。   她拿来了配件&#xff0c;有电源、主板、CPU、CPU风扇、内存条、机箱风扇、硬盘&#xff1a;   主板&#xff1a;华硕 Prime H610M-K D4&#…

FPGA高端项目:FPGA解码MIPI视频+图像缩放+视频拼接,基于MIPI CSI-2 RX Subsystem架构实现,提供4套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐我这里已有的 MIPI 编解码方案本方案在Xilinx Artix7-35T上解码MIPI视频的应用本方案在Xilinx Artix7-100T上解码MIPI视频的应用本方案在Xilinx Kintex7上解码MIPI视频的应用本方案在Xilinx Zynq7000上解码MIPI视频的应用本方案在…

【云原生 | 60】Docker中通过docker-compose部署kafka集群

&#x1f341;博主简介&#xff1a; &#x1f3c5;云计算领域优质创作者 &#x1f3c5;2022年CSDN新星计划python赛道第一名 &#x1f3c5;2022年CSDN原力计划优质作者 &#x1f3c5;阿里云ACE认证高级工程师 &#x1f3c5;阿里云开发者社区专…

基于WIN2016搭建MS2016 ALWAYS ON域控故障转移群集

基于WIN2016搭建MS2016 ALWAYS ON域控故障转移群集 一、前言1、Always On简介2、AD DC域控简介 二、部署实施1、部署环境简介2、搭建流程简介3、域控服务器安装及群集节点加域3.1、安装域控&#xff0c;安装同时会安装DNS系统3.2、执行安装&#xff0c;完成后重启服务器3.3、将…

哇塞!数字营销竟是企业增长的魔法棒!

​嘿&#xff0c;朋友们&#xff01;你们有没有发现“蚓链数字营销”就像一根神奇的魔法棒&#xff0c;为企业带来了超乎想象的市场影响力&#xff01; 首先&#xff0c;蚓链数字营销能够利用互联网和数字技术&#xff0c;精准地定位目标用户群体。比如&#xff0c;通过搜索引擎…

Java整合EasyExcel实战——3(上下列相同合并单元格策略)

参考&#xff1a;https://juejin.cn/post/7322156759443095561?searchId202405262043517631094B7CCB463FDA06https://juejin.cn/post/7322156759443095561?searchId202405262043517631094B7CCB463FDA06 准备条件 依赖 <dependency><groupId>com.alibaba</gr…

数据分析案例一使用Python进行红酒与白酒数据数据分析

源码和数据集链接 以红葡萄酒为例 有两个样本: winequality-red.csv:红葡萄酒样本 winequality-white.csv:白葡萄酒样本 每个样本都有得分从1到10的质量评分&#xff0c;以及若干理化检验的结果 #理化性质字段名称1固定酸度fixed acidity2挥发性酸度volatile acidity3柠檬酸…

【SpringBoot】SpringBoot整合JWT

目录 先说token单点登录&#xff08;SSO&#xff09;简介原理单点登录的优势单点登录流程分布式单点登录方式方式一&#xff1a;session广播机制实现方式二&#xff1a;使用cookieredis实现。方式三&#xff1a;token认证 JWT数字签名JWT的作用JWT和传统Session1、无状态&#…

【Linux 网络】网络基础(三)(其他重要协议或技术:DNS、ICMP、NAT)

一、DNS&#xff08;Domain Name System&#xff09; DNS 是一整套从域名映射到 IP 的系统。 1、DNS 背景 TCP/IP 中使用 IP 地址和端口号来确定网络上的一台主机的一个程序&#xff0c;但是 IP 地址不方便记忆。于是人们发明了一种叫主机名的东西&#xff0c;是一个字符串&…

【Python】解决Python报错:AttributeError: ‘NoneType‘ object has no attribute ‘xxx‘

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…