Imitation Learning学习记录(理论例程)

前言

最近还是衔接着之前的学习记录,这次打算开始学习模仿学习的相关原理,参考的开源资料为

TeaPearce/Counter-Strike_Behavioural_Cloning: IEEE CoG & NeurIPS workshop paper ‘Counter-Strike Deathmatch with Large-Scale Behavioural Cloning’ (github.com)
[2104.04258] Counter-Strike Deathmatch with Large-Scale Behavioural Cloning (arxiv.org)

简单来说,行为克隆就是利用已有的人为示范数据作为输入来训练出一个策略,策略就会输出指定的动作,然而,行为克隆只能学习到专家的行为,而无法进行探索和自主学习。这意味着行为克隆的性能受限于专家的行为水平,并且可能无法适应新的、未在专家演示中出现过的情况。通过引入奖励函数,可以在行为克隆中加入一定的探索和自主学习能力。奖励函数可以根据当前状态和采取的动作来评估行为的好坏,并为模型提供反馈信号。通过优化奖励函数,可以使模型学习到更好的策略,并且能够适应新的情况和环境。奖励函数在行为克隆中起到了指导和调整模型学习的作用。其中,文章用到的例程的网络结构如下

请添加图片描述

本文打算从数据获取、模型训练、效果展示三个部分展开介绍

数据获取

在这个过程中作者使用了Game State Integration(GSI)技术来获取在线数据。通过GSI,作者可以从游戏中获取实时的游戏状态信息,包括玩家、队伍、武器、位置等各种数据。具体来说,作者可能使用了Valve提供的GSI接口来获取游戏状态信息。这些信息可以用于后续的行为克隆和分析工作。

这个过程中的核心代码如下

# now find the requried process and where two modules (dll files) are in RAM  
hwin_csgo = win32gui.FindWindow(0, ('counter-Strike: Global Offensive'))  
if(hwin_csgo):  
    pid=win32process.GetWindowThreadProcessId(hwin_csgo)  
    handle = pymem.Pymem()  
    handle.open_process_from_id(pid[1])  
    csgo_entry = handle.process_base  
else:  
    print('CSGO wasnt found')  
    os.system('pause')  
    sys.exit()  
  
# now find two dll files needed  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  
    tmp=next(list_of_modules)  
    # used to be client_panorama.dll, moved to client.dll during 2020  
    if(tmp.name=="client.dll"):  
        print('found client.dll')  
        off_clientdll=tmp.lpBaseOfDll  
        break  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  
    tmp=next(list_of_modules)  
    if(tmp.name=="engine.dll"):  
        print('found engine.dll')  
        off_enginedll=tmp.lpBaseOfDll  
        break

大致逻辑为:

  1. 查找CSGO进程:

    • 使用win32gui.FindWindow查找名为’counter-Strike: Global Offensive’的窗口句柄。
    • 如果找到窗口句柄,则通过win32process.GetWindowThreadProcessId获取与该窗口关联的进程ID。
    • 使用pymem.Pymem()创建一个进程内存访问对象,并通过open_process_from_id方法打开该进程。
    • 如果CSGO进程未找到,则打印消息并退出程序。
  2. 查找client.dll和engine.dll:

    • 使用handle.list_modules()获取进程中的所有模块列表。
    • 遍历模块列表,查找名为"client.dll"的模块,并获取动态链接库的基地址(lpBaseOfDll)。
    • 注意:这里使用了两次handle.list_modules()来分别查找两个DLL文件,但实际上你可以只调用一次并将结果存储在列表中,然后遍历这个列表来查找两个DLL。
    • 类似地,代码还查找名为"engine.dll"的模块,并获取其基地址。

找到窗口和动态链接库以后就可以开始录像并通过GSI或者RAM来访问键位等游戏信息,得到的数据类型大致为

  • frame_i_x: 图像信息
  • frame_i_xaux: 包含在前一个时间步骤中应用的动作,以及血量、弹药和团队。用于更好地帮助智能体寻找敌人以及适应当前情况
  • frame_i_y: 对应键盘以及鼠标的动作
  • frame_i_helperarr: 在格式kill_flag, death_flag中,每个变量都是二元变量,例如[[1,0]],意味着玩家击杀一次,但在该时间步内没有死亡

其中,具体的键位信息如下:

# how many slots were used for each action type?  
n_keys = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r  
n_clicks = 2 # number of mouse buttons, left, right  
n_mouse_x = len(mouse_x_possibles) # number of outputs on mouse x axis  
n_mouse_y = len(mouse_y_possibles) # number of outputs on mouse y axis  
n_extras = 3 # number of extra aux inputs, eg health, ammo, team. others could be weapon, kills, deaths  
aux_input_length = n_keys+n_clicks+1+1+n_extras # aux uses continuous input for mouse this is multiplied by ACTIONS_PREV elsewhere

一个帧所包含的具体信息值如下:

请添加图片描述
请添加图片描述

模型训练

网络结构

输入先进入一个预训练好的EfficientNetB0模型,该模型在ImageNet数据集上进行了训练。并加上了时间序列信息,接下来将提取好的特征输入进一个带有时序信息的ConvLSTM网络

base_model = EfficientNetB0(weights='imagenet',input_shape=(input_shape[1:]),include_top=False,drop_connect_rate=0.2)
if 'drop' in model_name:  
    if 'big' in model_name:  
        x = ConvLSTM2D(filters=512,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)  
    else:  
        x = ConvLSTM2D(filters=256,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)

输出的信息为

# set up outputs, sepearate outputs will allow seperate losses to be applied  
output_1 = TimeDistributed(Dense(n_keys, activation='sigmoid'))(dense_5)  
output_2 = TimeDistributed(Dense(n_clicks, activation='sigmoid'))(dense_5)  
output_3 = TimeDistributed(Dense(n_mouse_x, activation='softmax'))(dense_5) # softmax since mouse is mutually exclusive  
output_4 = TimeDistributed(Dense(n_mouse_y, activation='softmax'))(dense_5)   
output_5 = TimeDistributed(Dense(1, activation='linear'))(dense_5)   
# output_all = concatenate([output_1,output_2,output_3,output_4], axis=-1)  
output_all = concatenate([output_1,output_2,output_3,output_4,output_5], axis=-1)

损失函数

  1. 键盘按键损失(loss1a, loss1b, loss1c, loss1d
    • loss1a:计算 WASD 键(通常用于游戏中的移动)的二进制交叉熵损失。
    • loss1b:计算空格键的二进制交叉熵损失。
    • loss1c:计算重新加载键(如游戏中的“R”键)的二进制交叉熵损失。
    • loss1d(注释掉的部分):原本可能用于计算其他键盘按键的损失,但在提供的代码中,它被重新定义为武器切换键(1, 2, 3)的损失。
  2. 鼠标点击损失(loss2a, loss2b
    • loss2a:计算鼠标左键点击的二进制交叉熵损失。
    • loss2b:计算鼠标右键点击的二进制交叉熵损失(如果n_clicks大于1的话)。
  3. 鼠标移动损失(loss3, loss4
    • loss3:计算鼠标在 X 轴上的移动损失。由于鼠标移动是互斥的(即鼠标不能同时处于多个位置),因此使用了分类交叉熵损失(categorical_crossentropy)。
    • loss4:计算鼠标在 Y 轴上的移动损失,同样使用了分类交叉熵损失。

除此之外,还有一个loss_crit损失函数,

loss_crit = 10*losses.MSE(y_true[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  
                   + GAMMA*y_pred[:,1:,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  
                   ,y_pred[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1])

这是一个基于时序差分(Temporal Difference, TD)的均方误差(Mean Squared Error, MSE)损失函数,用于强化学习中的值函数逼近。它计算了当前时间步的奖励(或值)与下一个时间步的预测奖励(或值)之和(经过折扣因子 GAMMA 调整后)与当前时间步的预测奖励(或值)之间的均方误差。这种损失函数允许神经网络学习如何根据当前状态和环境信息来预测未来的奖励或值,从而优化策略或值函数。在这个特定的实现中,损失还乘以了一个系数(如10),可能是为了调整该损失在总损失中的相对权重。

奖励函数如下,奖励为 R(杀敌数,死亡数,子弹数)

reward_i = kill_i - 0.5*dead_i - 0.01*shoot_i # this is reward function  
y[i,j,-2:] = (reward_i,0.) # 0. is a placeholder for original advantage

效果展示

通过e2e.yml文件配置虚拟环境,更改了游戏内的窗口分辨率,设置了一些其他的参数,运行dm_run_agent.py以后在自己的电脑上成功复现

请添加图片描述

总结

本次基于Counter-Strike Deathmatch with Large-Scale Behavioural Cloning这个开源项目系统地学习了一下行为克隆的基本流程,从数据采集、模型训练以及损失函数的定义到最终复现,拓宽了我对RL的认知,在日后也能够更好地迁移到Robotic,逻辑如下:

  1. 数据收集:首先,需要收集人类专家在特定任务中的行为数据。这些数据通常包括机器人所处的状态(如位置、姿态、环境信息等)以及对应的人类专家在该状态下所采取的动作(如移动方向、操作指令等)。这些数据构成了行为克隆算法的训练集。
  2. 模型训练:使用收集到的数据训练一个模型,如神经网络模型。这个模型将学习从状态到动作的映射关系,即根据机器人当前的状态预测应该执行的动作。在训练过程中,模型会不断优化其参数,以最小化预测动作与真实动作之间的差异。
  3. 模型部署:训练好的模型可以部署到机器人上,用于指导机器人的行为。当机器人遇到新的状态时,它会将当前状态输入到模型中,模型会输出一个预测的动作。机器人将根据这个预测的动作来执行相应的操作。
  4. 反馈与调整:在机器人执行动作的过程中,可以通过收集反馈信息来进一步调整模型。例如,可以观察机器人执行动作后的效果,如果效果不理想,则可以收集新的数据并重新训练模型,以提高其性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/616842.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

现代制造之3D打印技术进行零件加工

现代制造 有现代技术支撑的制造业,即无论是制造还是服务行业,添了现代两个字不过是因为有了现代科学技术的支撑,如发达的通信方式,不断发展的互联网,信息化程度加强了,因此可以为这两个行业增加了不少优势…

简单易懂的Java Queue入门教程!

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

600/天,海外项目值班,接不接?

朋友介绍了一个海外项目,广告系统短期维护,刚上线需要维护14天也就是2个星期,费用单价600/天,主要工作内容:北京晚上12点-早上8点值班,如果有问题及时响应并修复。 如果我年轻10岁,这个项目我倒…

【网站项目】SpringBoot803房屋租赁管理系统

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

C++入门指南(上)

目录 ​编辑 一、祖师爷画像 二、什么是C 三、C发展史 四、C在工作领域的应用 1. 操作系统以及大型系统软件开发 2. 服务器端开发 3. 游戏开发 4. 嵌入式和物联网领域 5. 数字图像处理 6. 人工智能 7. 分布式应用 五、如何快速上手C 一、祖师爷画像 本贾尼斯特劳斯…

docker修改默认安装路径

docker安装之后默认在 /etc/docker 在/etc/docker 文件下有一个daemon -json 没有就新增 {"registry-mirrors": ["https://kfwkfulq.mirror.aliyuncs.com","https://2lqq34jg.mirror.aliyuncs.com","https://pee6w651.mirror.aliyuncs.c…

续篇——源码部署LAMP环境上线项目——禅道项目

上篇:LNMP环境部署WordPress——使用源码包安装方式部署环境-CSDN博客 目录 一.前提准备 1. 名词区别 2. 下载项目软件包 3. 上传项目源码到虚拟机并解压 二.安装Apache 1. 环境清理 2.关闭Nginx 3. 下载Apache 4. 下载APR组件 4.1 安装apr 4.2 安装apr-util组件 5…

算法学习012-不同路径 c++动态规划算法实现 中小学算法思维学习 信奥算法解析

目录 C不同路径 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、运行结果 五、考点分析 六、推荐资料 C不同路径 一、题目要求 1、编程实现 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” &#xff09…

德国储能项目锂电池储能集装箱突发火灾:安全挑战再引关注

2024年4月27日,德国尼尔莫尔商业区的一起锂电池储能集装箱火灾事件引起了全球关注。这起事故不仅导致两名消防员在救援过程中受伤,更暴露了储能系统在安全领域亟待解决的重要问题。 根据德国消防队的出警记录,火灾发生在晚上9点前不久。消防人…

在Linux操作系统中LVM逻辑券管理指令

1.PV物理券相关指令 1.查看机器中的PV pvscan 命令 这个叫做/dev/sda2 的PV,被加入到了名叫centos的卷组中,并且这个券组的大小是小于19.51GB 2.创建物理券 pvcreate 磁盘/分区名称 pvcreate /dev/sdc 3.删除物理券 pvremove 磁盘/分区名称 2.…

5.10.3 使用 Transformer 进行端到端对象检测(DETR)

框架的主要成分称为 DEtection TRansformer 或 DETR,是基于集合的全局损失,它通过二分匹配强制进行独特的预测,以及 Transformer 编码器-解码器架构。 DETR 会推理对象与全局图像上下文的关系,以直接并行输出最终的预测集。 1. …

欢乐钓鱼大师自动钓鱼,游戏辅助!

在探索《欢乐钓鱼大师》的世界时,一项备受关注的功能是陀螺仪模式。这是一种利用手机陀螺仪传感器来增强游戏体验的功能,通过模拟真实的钓鱼动作,让玩家更深入地沉浸在游戏的世界中,感受到更加逼真的钓鱼体验。在本篇攻略中&#…

【全开源】JAVA同城组局同城找搭子系统源码支持微信小程序微信公众号H5 APP

让你周末不孤单 发布活动:用户可以发布自己想要进行的活动,包括活动类型、时间、地点等信息,方便其他用户查找和参与。搜索搭档:用户可以根据活动类型、时间、地点等信息,搜索附近的搭档,快速找到志同道合…

Github 2024-05-12 php开源项目日报 Top10

根据Github Trendings的统计,今日(2024-05-12统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量PHP项目10Filament: 加速Laravel开发的完美起点 创建周期:1410 天开发语言:PHP协议类型:MIT LicenseStar数量:12228 个Fork数量:1990 次关…

工程师工具箱系列(2)hasor

文章目录 工程师工具箱系列(2)hasor简介特点环境准备引入依赖数据库脚本文件配置Hasor配置 运行测试小结 工程师工具箱系列(2)hasor 简介 Hasor有着自己的独立的生命周期与Spring的不同,是一套完整的体系,提供了注入DataQL、Dataway、hasor-web等等&am…

半小时搞懂STM32面经知识点——IIC

1.IIC 1.1什么是IIC? 同步半双工通信协议,适用于小数据和短距离传输。 1.2 IIC需要几条线? IIC总共有2条通信总线(SDA,SCL),SCL为时钟同步线,用于主机和从机间数据同步操作;SDA为…

SSM【Spring SpringMVC Mybatis】—— Spring(一)

目录 1、初识Spring 1.1 Spring简介 1.2 搭建Spring框架步骤 1.3 Spring特性 1.5 bean标签详解 2、SpringIOC底层实现 2.1 BeanFactory与ApplicationContexet 2.2 图解IOC类的结构 3、Spring依赖注入数值问题【重点】 3.1 字面量数值 3.2 CDATA区 3.3 外部已声明be…

JVM之运行时数据区

Java虚拟机在运行时管理的内存区域被称为运行时数据区。 程序计数器: 也叫pc寄存器,每个线程会通过程序计数器记录当前要执行的字节码指令的地址。程序计数器在运行时是不会发生内存溢出的,因为每个线程只存储一个固定长度的内存地址。 JAVA虚…

Electron学习笔记(四)

文章目录 相关笔记笔记说明 六、数据1、使用本地文件持久化数据(1) 用户数据目录(2) 读写本地文件(3) 第三方库 2、读写受限访问的 Cookie3、清空浏览器缓存 相关笔记 Electron学习笔记(一)Electron学习笔记(二)Electron学习笔记…

文献阅读——ESG的说与做

The cost of hypocrisy: Does corporate ESG decoupling reduce labor investment efficiency? 主要学习借鉴ESG decoupling 如何构造的!!! 1.引言 在碳峰值和碳中和目标的背景下,尽管上市公司ESG信息披露率不断上升&#xff0c…