机器学习-基于attention机制来实现对Image Caption图像描述实验

机器学习-基于attention机制来实现对Image Caption图像描述实验

实验目的

基于attention机制来实现对Image Caption图像描述

实验内容

1.了解一下RNN的Encoder-Decoder结构

在最原始的RNN结构中,输入序列和输出序列必须是严格等长的。但在机器翻译等任务中,源语言句子的长度和目标语言句子的长度往往不同,因此我们需要将原始序列映射为一个不同长度的序列。Encoder-Decoder模型就解决了这样一个长度不一致的映射问题。

1

2.模型架构训练

在Image Caption输入的图像代替了之前机器翻译中的输入的单词序列,图像是一系列的像素值,我们需要从使用图像特征提取常用的CNN从图像中提取出相应的视觉特征,然后使用Decoder将该特征解码成输出序列,下图是论文的网络结构,特征提取采用的是CNN,Decoder部分,将RNN换成了性能更好的LSTM,输入还是word embedding,每步的输出是单词表中所有单词的概率。

实验数据和程序清单

import json
 
# 加载数据集标注
with open("annotations/captions_train2014.json", "r") as f:
    annotations = json.load(f)
 
# 提取图像文件名和描述
image_path_to_caption = {}
for val in annotations["annotations"]:
    caption = f"<start> {val['caption']} <end>"
    image_path = "train2014/" + "COCO_train2014_" + "%012d.jpg" % (val["image_id"])
    if image_path in image_path_to_caption:
        image_path_to_caption[image_path].append(caption)
    else:
        image_path_to_caption[image_path] = [caption]
 
image_paths = list(image_path_to_caption.keys())
归一化处理
import tensorflow as tf
 
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (299, 299))
    img = tf.keras.applications.inception_v3.preprocess_input(img)
return img, image_path
模型构建
from tensorflow.keras.applications import InceptionV3
 
encoder = InceptionV3(weights="imagenet", include_top=False)
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from tensorflow.keras.models import Model
 
embedding_dim = 256
vocab_size = 10000  # 您可以根据需要调整词汇表大小
max_length = 40  # 您可以根据需要调整最大描述长度
 
# 解码器输入
input_caption = Input(shape=(max_length,))
embedding = Embedding(vocab_size, embedding_dim)(input_caption)
lstm_output = LSTM(256)(embedding)
output_caption = Dense(vocab_size, activation="softmax")(lstm_output)
 
# 定义解码器模型
decoder = Model(inputs=input_caption, outputs=output_caption)
模型训练
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
 
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
 
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
 
return tf.reduce_mean(loss_)
@tf.function
def train_step(img_tensor, target):
    loss = 0
    hidden = decoder.reset_state(batch_size=target.shape[0])
 
    dec_input = tf.expand_dims([tokenizer.word_index["<start>"]] * target.shape[0], 1)
 
    with tf.GradientTape() as tape:
        features = encoder(img_tensor)
        for i in range(1, target.shape[1]):
            predictions = decoder([features, hidden, dec_input])
            loss += loss_function(target[:, i], predictions)
 
            dec_input = tf.expand_dims(target[:, i], 1)
 
    total_loss = loss / int(target.shape[1])
 
    trainable_variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))
 
return loss, total_loss
import time
 
epochs = 10
batch_size = 64
buffer_size = 1000
 
dataset = tf.data.Dataset.from_tensor_slices((image_paths, captions))
dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
 
for epoch in range(epochs):
    start = time.time()
    total_loss = 0
 
    for (batch, (img_tensor, target)) in enumerate(dataset):
        batch_loss, t_loss = train_step(img_tensor, target)
        total_loss += t_loss
 
        if batch % 100 == 0:
            print(f"Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy() / int(target.shape[1]):.4f}")
 
    print(f"Epoch {epoch+1} Loss {total_loss/len(image_paths):.6f}")
    print(f"Time taken for 1 epoch: {time.time() - start:.2f} sec\n")

可视化:


import matplotlib.pyplot as plt
import numpy as np

def plot_attention(image_path, result, attention_plot):
   img = plt.imread(image_path)

   fig = plt.figure(figsize=(10, 10))
   len_result = len(result)
   for i in range(len_result):
       temp_att = np.resize(attention_plot[i], (8, 8))
       grid_size = max(np.ceil(len_result / 2), 2)
       ax = fig.add_subplot(grid_size, grid_size, i + 1)
       ax.set_title(result[i])
       imgplot = ax.imshow(img)
       ax.imshow(temp_att, cmap="gray", alpha=0.6, extent=imgplot.get_extent())

   plt.tight_layout()
   plt.show()

plot_attention(image_path, result, attention_plot)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/297253.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机毕业设计-----SSM宠物商城带后台管理系统

项目介绍 该项目为前后台项目&#xff0c;分为普通用户与管理员两种角色&#xff0c;前台普通用户登录&#xff0c;后台管理员登录&#xff1b; 用户角色包含以下功能&#xff1a; 加入购物车,发表留言,提交订单,查看订单信息,会员注册,登录页面等功能。 管理员角色包含以下…

python爬虫,简单的requests的get请求,百度搜索实例

1、百度搜索实例 import requests url = https://www.baidu.com/s? # key_word = 迪丽热巴 key_word = input(输入搜索内容:) headers = {User-Agent: Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36 } …

DS|静态查找

题目一&#xff1a;DS静态查找 -- 顺序查找 题目描述&#xff1a; 给出一个队列和要查找的数值&#xff0c;找出数值在队列中的位置&#xff0c;队列位置从1开始 要求使用带哨兵的顺序查找算法 输入要求&#xff1a; 第一行输入n&#xff0c;表示队列有n个数据 第二行输入…

PySimpleGUI图形界面实例|PDF表格转换Excel文件

实例要求&#xff1a; 使用PySimpleGUI做一个把单位考勤系统导出的pdf文件合并输出Excel的应用&#xff0c;故事出自&#xff1a;https://hannyang.blog.csdn.net/article/details/135395946 当时时间紧&#xff0c;没有好好做界面且输出csv文件了事。今天趁周六休息&#xf…

linux中最常用的用户信息命令

文章目录 linux中最常用的用户信息命令还有谁 last语法一般使用方法查看最近登陆的三个用户省略hostname显示最后一列显示主机IP地址 我是谁 whoami谁&#xff1f;who默认使用系统的运行时间显示表头信息显示登录的人员及总数 什么&#xff1f;谁&#xff1f;w (who & what…

Demo:基于elementplus的弹窗嵌套表单进行二次封装

基于elementplus的弹窗嵌套表单进行二次封装 所见即所得&#xff1a;简单封装方便工作 ProForm.vue代码&#xff1a; <!--* Author: 忆往昔* LastEditTime: 2024-01-6 14:36:00* email: 15871856064163.com --> <template><div class"penk-form-contain…

【SpringBoot实战专题】「开发实战系列」全方位攻克你的技术盲区之Spring定义Jackson转换Null的方法和实现案例

Spring自动定义Jackson转换Null得方法 背景MessageConverter 使用Jackson原生方式处理空字段&#xff08;次重点方案&#xff09;ObjectMapper的配置选项通过使用注解的方式 MappingJackson2HttpMessageConverter&#xff08;重点方案&#xff09;创建MappingJackson2HttpMessa…

WebStorm 创建一个Vue项目

一、下载并安装WebStorm 步骤一 步骤二 选择激活方式 激活码&#xff1a; I2A0QUY8VU-eyJsaWNlbnNlSWQiOiJJMkEwUVVZOFZVIiwibGljZW5zZWVOYW1lIjoiVU5JVkVSU0lEQURFIEVTVEFEVUFMIERFIENBTVBJTkFTIiwiYXNzaWduZWVOYW1lIjoiVGFvYmFv77yaSkVU5YWo5a625qG25rAIOa0uW3peS9nOWup…

Easycode模板,基于官方提供的Mybatis-plus模板改造

目录结构 模板亮点 1、接口类默认继承实体类 实体类不做任何修改保证类与表统一 2、实体类涵盖多种注解 日期格式编码、Long类型转String、字段自动填充 3、自带insertOrUpdateBatch方法 导入方式 {"author" : "Wsong","version" : "1.2.8…

SpringBoot+RocketMQ集群(dledger)部署完整学习笔记

文章目录 前言一、单台集群部署二、多台集群部署1.修改配置2.dashboard修改 三、整合springboot1.引入pom和修改yml2.编写消费者3.编写生产者4.测试效果 总结 前言 RocketMQ集群方式有好几种 官网地址 https://rocketmq.apache.org/zh/docs/4.x/deployment/01deploy 2m-2s-asy…

ATTCK视角下的信息收集:主机发现

目录 1、利用协议主动探测主机存活 利用ICMP发现主机 利用ARP发现主机 利用NetBIOS协议发现主机 利用TCP/UDP发现主机 利用DNS协议发现主机 利用PRC协议发现主机程序 2、被动主机存活检测 利用Browser主机探测存活主机 利用ip段探测主机存活 利用net命令探测主机存活…

unity C# 中通俗易懂LINQ使用案例

文章目录 1. 从数组或列表中查询元素**&#xff1a;2. **排序与分组**&#xff1a;3. **连接多个数据源**&#xff1a;4. **聚合操作**&#xff1a;5. **分页查询**&#xff1a;6. **多条件查询**&#xff1a;7. **转换和投影&#xff08;Select&#xff09;**&#xff1a;8. *…

jdbc源码研究

JDBC介绍 JDBC&#xff08;Java Data Base Connectivity,java数据库连接&#xff09;是一种用于执行SQL语句的Java API&#xff0c;可以为多种关系数据库提供统一访问&#xff0c;它由一组用Java语言编写的类和接口组成。 开发者不必为每家数据通信协议的不同而疲于奔命&#…

竞赛保研 基于深度学习的人脸专注度检测计算系统 - opencv python cnn

文章目录 1 前言2 相关技术2.1CNN简介2.2 人脸识别算法2.3专注检测原理2.4 OpenCV 3 功能介绍3.1人脸录入功能3.2 人脸识别3.3 人脸专注度检测3.4 识别记录 4 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于深度学习的人脸专注度…

Flask 会员列表展示

感谢编程浪子师傅的源码信息分享 web/controllers/member/Member.py # -*- coding: utf-8 -*- from flask import Blueprint,request,redirect,jsonify from common.libs.Helper import ops_render,iPagination,getCurrentDate,getDictFilterField,selectFilterObj from comm…

【uniapp】APP打包上架应用商-注意事项

初雪云-uniapp启动图自定义生成&#xff08;支持一键生成storyboard&#xff09; 一、修改App端上传图片/视频 uni.uploadFile let thatthis; uni.chooseImage({count: 1,sourceType: [camera,album],sizeType: [compressed, original],success: rey > {uni.showLoading({ t…

Linux操作系统——进程控制(一) 进程创建和进程终止

进程创建 fork函数 在linux中fork函数时非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程为子进程&#xff0c;而原进程为父进程。 #include <unistd.h> pid_t fork(void); 返回值&#xff1a;自进程中返回0&#xff0c;父进程返回子进程id&#xff…

BetaFlight开源代码之电流校准

BetaFlight开源代码之电流校准 1. 源由2. 分析2.1 常规逻辑2.2 数据流2.3 采样电路2.3.1 采样实现2.3.2 采样原理2.3.3 Layout参考2.3.4 INA169芯片2.3.5 INA169 Near-Zero Vsense 3. 原理4. 示例4.1 实测&转换数据4.2 线性拟合-小电流4.3 线性拟合-大电流4.4 大电流/小电流…

HDMI彩条显示实验与方块移动实验

一、HDMI接口简介 一种数字音视频接口标准&#xff0c;提供高质量的数字音视频传输&#xff0c;同时支持多通道音频、高分辨率视频和其他数据传输功能。提供更高的数据传输带宽&#xff08;带宽&#xff1a;1s内传输多少比特数据&#xff09; 数字传输&#xff1a; HDMI是一种全…

【VRTK】启用多种VR设备的Passthrough功能

【背景】 透视可以让VR头盔展现AR能力,通过VRTK,可以快速实现多种设备平台可用的透视功能。包括主流的Oculus,Pico等。整个不成不需要自己写代码。 【操作】 针对WaveXR,点击场景中的CameraRigsWaveXR-》WaveRig-》Camera Offset-》Main Camera,追加一个新组件,名为Und…