强化学习Agent系列(二)——PyGame虚拟环境创建与Python 贪吃蛇Agent制作实战教学

文章目录

  • 一、前言
  • 二、gymnasium 简单虚拟环境创建
    • 1、gymnasium介绍
    • 2、gymnasium 贪吃蛇简单示例
  • 三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent
    • 1、虚拟环境
    • 2、虚拟环境注册
    • 3、训练程序
    • 4、模型测试
  • 三、卷积虚拟环境
    • 1、卷积神经网络虚拟环境
    • 2、训练代码

一、前言

大家好,未来的开发者们请上座
随着人工智能的发展,强化学习基本会再次来到人们眼前,遂想制作一下相关的教程。强化学习第一步基本离不开虚拟环境的搭建,下面用大家耳熟能详的贪吃蛇游戏为基础,制作一个Agent,完成对这个游戏的绝杀。
万里长城第二步:用python开发贪吃蛇智能体****加粗样式

二、gymnasium 简单虚拟环境创建

1、gymnasium介绍

gymnasium(此前称为gym)是一个由 OpenAI 开发的 Python 库,用于开发和比较强化学习算法。它提供了一组丰富的环境,模拟了各种任务,包括但不限于经典的控制问题、像素级游戏、机器人模拟等。

以下是gymnasium库的一些主要特点:

  1. 环境多样性:gymnasium包含了一系列不同的环境,每个环境都有其独特的观察空间(输入)和动作空间(输出)。这些环境涵盖了从简单的文本控制任务到复杂的三维视觉任务的广泛范围。

  2. 标准化API:gymnasium库提供了一个简单且统一的API来与这些环境交互。这使得研究人员和开发人员可以轻松地用相同的代码测试和比较不同的强化学习算法。

  3. 扩展性:用户可以创建自定义环境并将其集成到gymnasium框架中,这使得库能够适应各种不同的研究需求和应用场景。

  4. 评估标准:gymnasium环境通常包括预定义的评估标准,如累积回报或任务完成时间,这有助于在不同算法间进行公平的比较。

  5. 社区支持:由于gymnasium是由OpenAI推出并得到了强化学习社区的广泛支持,因此有大量的教程、论坛讨论和第三方资源可供学习和参考。

  6. 可视化和监控:gymnasium提供了工具来可视化智能体的性能,并允许监控和记录实验过程,便于分析和调试。

使用gymnasium的基本步骤通常包括:

导入gymnasium库。
创建一个环境实例。
初始化环境。
在一个循环中,根据当前观察值选择动作,执行动作,并接收环境的反馈(新的观察值、奖励、完成状态等)。
结束实验并关闭环境。

2、gymnasium 贪吃蛇简单示例

下面是贪吃蛇虚拟环境的一个简单的示例
在本次示例中,暂未进行任何训练。一切行为主要是从状态空间中随机抽取一个动作并执行。
下面是gymnasium创建的虚拟环境的三个核心函数介绍

  • reset(): 这个函数用于重置环境到初始状态,并返回初始状态的观测值。在开始每个新的episode时,通常会调用这个函数来初始化环境。

  • step(action): 这个函数用于让Agent在环境中执行一个动作(action),并返回四个值:观测值(observation),奖励(reward),是否终止(done),以及额外信息(info)。Agent根据环境返回的信息来决定下一步的动作。

  • render(): 这个函数用于在屏幕上渲染当前环境的状态,通常用于可视化环境以便观察Agent的行为。不是所有的环境都支持渲染,具体取决于环境的实现。

下面是具体的示例代码

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
class SnakeEnv(gym.Env):
    def __init__(self):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH=240
        self.SCREEN_HEIGHT=240
        #蛇的方块大小
        self.snakeCell=10
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH,self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')
        self.action_space=gym.spaces.Discrete(4) #动作空间为4
        self.observation_space=gym.spaces.Box(low=0,high=7,shape=(self.SCREEN_WIDTH,self.SCREEN_HEIGHT),dtype=np.uint8)
    
    # 重启
    def reset(self):
        """
        重置蛇和食物的位置
        """
        # 蛇的初始位置
        self.snake_head=[100,50]
        self.snake_body=[[100,50],[100-self.snakeCell,50],[100-self.snakeCell*2,50]]
        self.len=3

        # 食物的初始位置
        self.food_pos=[random.randint(1,self.SCREEN_WIDTH//10-1)*10,random.randint(1,self.SCREEN_HEIGHT//10-1)*10]

        return self._get_observation()

    # 根据当前状态 和action 执行动作
    def step(self,action):
        # 定义动作到方向的映射
        directionDict={'LEFT':[1,0],'RIGHT':[-1,0],'UP':[0,-1],'DOWN':[0,1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }
        directionTarget=action_to_direction[action]
        nextPosDelay=np.array(directionDict[directionTarget])*self.snakeCell #加的位置
        self.snake_head=list(np.array(self.snake_body[0])+nextPosDelay)

        if self.snake_head in self.snake_body:
           return self._get_observation(), 0, True, False, {}
        self.snake_body.insert(0,self.snake_head)

        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_head:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10)) * 10]
            self.len+=1

        # 弹出
        while self.len<len(self.snake_body):
            self.snake_body.pop()
        # 奖励
        reward,done=self._get_reward()
        truncated = True

        return self._get_observation(), reward, truncated, done, {}

    # 渲染
    def render(self,mode="human"):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE= (255,255,255)
        GREEN = (0,255,0)
        RED = (255,0,0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物
        for pos in self.snake_body:
            pygame.draw.rect(screen,GREEN,pygame.Rect(pos[0],pos[1],self.snakeCell,self.snakeCell))
        pygame.draw.rect(screen,RED,pygame.Rect(self.food_pos[0],self.food_pos[1],self.snakeCell,self.snakeCell))

        pygame.display.update()

    # 获取奖励
    def _get_reward(self):

        # 计算奖励
        reward = 0
        done = False

        # 检查蛇是否吃到食物
        if self.snake_head:
            reward+=10

        # 检查蛇是否撞到墙壁或自身
        head=self.snake_head
        if head[0]<0 or head[0]>self.SCREEN_WIDTH-10 or head[1]<0 or head[1]>self.SCREEN_HEIGHT-10:
            reward = -10
            done = True

        return reward,done

    # 获取当前观察空间
    def _get_observation(self):
        # 获取窗口内容作为观察值
        observation = pygame.display.get_surface()
        # 将观察值调整为指定的宽度和高度
        # observation = pygame.transform.scale(observation, (self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        return observation

def main():
    snakeEnv=Snake()
    snakeEnv.reset()
    done=False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action= snakeEnv.action_space.sample()
        screen, reward, truncated, done,_=snakeEnv.step(action)
        snakeEnv.render()
        time.sleep(0.03)
if __name__=="__main__":
    main()

程序运行截图:
在这里插入图片描述

三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent

在上一步中,你已经创建出你要的虚拟环境了,现在让我们在这个创建好的环境中进行训练吧!

1、虚拟环境

SnakeEnv2.py

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optional


class SnakeEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }

    def __init__(self, render_mode="human"):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH = 100
        self.SCREEN_HEIGHT = 100
        # 蛇的方块大小
        self.snakeCell = 10
        # 游戏速度
        self.speed = 12
        self.clock = pygame.time.Clock()
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')

        self.render_mode = render_mode
        self.action_space = gym.spaces.Discrete(4)  # 动作空间为4
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT),
                                                dtype=np.float32)
        # 初始化蛇和食物的位置等属性
        # ...
        self.num_timesepts = 0

    # 重启
    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """
        重置蛇和食物的位置
        """
        super().reset(seed=seed)
        self.curStep = 0  # 步数
        # 蛇的初始位置
        self.snake_next = [60, 50]
        self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]
        self.len = 3

        # 食物的初始位置
        self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,
                         random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]
        info = {}

        #
        return self._get_observation(), info

    # 根据当前状态 和action 执行动作
    def step(self, action):
        self.num_timesepts += 1  # 步骤统计加1
        # 定义动作到方向的映射
        directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }

        directionTarget = action_to_direction[action]
        nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置
        # 输入action 获取到 snake_next 下一步
        self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)
        if self.snake_next == self.snake_body[1]:
            return self._get_observation(), -0.5, False, False, {}

        # 奖励
        reward, terminated = self._get_reward()
        
        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_next:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]
            self.len += 1




        truncated = False
        info = {}
        # if self.render_mode == "human":
        if self.render_mode == "human" and self.num_timesepts % 5000 > 4000 and self.num_timesepts > 10000:
            self.render()
            for event in pygame.event.get():
                if event == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
        
        if not terminated:
            self.snake_body.insert(0, self.snake_next)
            
        # 弹出
        while self.len < len(self.snake_body):
            self.snake_body.pop()
        return self._get_observation(), reward,terminated,  truncated,{}

    # 渲染
    def render(self):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE = (255, 255, 255)
        GREEN = (0, 255, 0)
        RED = (255, 0, 0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物

        snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)
        for i in range(len(self.snake_body)):
            pos = self.snake_body[len(self.snake_body) - i - 1]
            color = [int(round(component * snakecolor[i])) for component in GREEN]
            pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))
        pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))

        pygame.display.update()

        self.clock.tick(self.speed)
    def GetDic(self,p1,p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))
    # 获取奖励
    def _get_reward(self):
    
        # 计算奖励
        self.curStep += 1  # 步数
        reward = 0
        terminated = False
        flag=0
        # 正向激励
        # 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励
        if self.snake_next == self.food_pos:
            reward += 500 + pow(5, self.len)
            self.curStep = 0
            #print(reward)
        # 负向激励
        # 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励
        head = self.snake_next
        if head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[
            1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:
            reward -= 100 / self.len
            terminated = True
            self.curStep = 0
        
        # 摸鱼步数超过一定值就开始负向奖励
        if self.curStep > 100 * self.len:
            reward -=  1 / self.len

      
      
        # 中向激励
        if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):
            reward += 2 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_body[0],self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.
        else:
            reward -= 1 / self.len
        #print(reward * 0.3)
        if reward<0:
            #print(reward * 0.2)
            pass
     
        #print(reward* 0.2)
        return reward * 0.2, terminated

    # 获取当前观察空间
    def _get_observation(self):
        # 返回观察空间,也就是一个二维数组
        obs = np.zeros((self.SCREEN_WIDTH, self.SCREEN_HEIGHT), dtype=np.float32)
        obs[tuple(np.transpose(self.snake_body))] = np.linspace(0.8, 0.2, len(self.snake_body), dtype=np.float32)
        obs[tuple(self.snake_body[0])] = 1.0
        obs[tuple(self.food_pos)] = -1.0
        return obs


def main():
    snakeEnv = SnakeEnv()
    snakeEnv.reset()
    done = False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action = snakeEnv.action_space.sample()
        screen, reward, truncated, done, _ = snakeEnv.step(action)
        snakeEnv.render()


if __name__ == "__main__":
    main()

2、虚拟环境注册

打开当前项目的site-packages
在这里插入图片描述
找到gymnasium
将其SnakeEnv2.py放置如下,并在init.py中添加调用注册函数
在这里插入图片描述
到这里就注册完毕,可以进行训练了

3、训练程序

snake_train.py 具体代码如下

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler

# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnv-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-2, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

model = PPO("MlpPolicy", env, verbose=1, device="cuda",

        n_steps=2048,
        batch_size=512,
        n_epochs=4,
        gamma=0.94,
            learning_rate=lr_schedule,
            clip_range=clip_range_schedule,
            )
# 4、训练模型:

# Set the save directory
num=1
save_dir="trained_models_mlp"
while True:
    save_dir = "trained_models_mlp_{}".format(num)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        break
    else:
        num +=1

checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")

# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)


model.learn(
    total_timesteps=int(200000),
    callback=[checkpoint_callback]
)


# Restore stdout
sys.stdout = original_stdout

# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))

4、模型测试

对训练好的模型进行测试,可以用如下代码

import time
import random
from sb3_contrib import MaskablePPO
from stable_baselines3 import PPO
from snakecnn23 import SnakeEnv
import pygame

MODEL_PATH=r'H:\AILab\RL\Snaker2\trained_models_cnn\ppo_snake_final'
# Load the trained model
model = MaskablePPO.load(MODEL_PATH)

snakeEnv = SnakeEnv()
for i in range(10):
    obs,info=snakeEnv.reset()
    terminated = False

    while not terminated:
        # 获取事件
        for event in pygame.event.get():
            if event == pygame.QUIT:
                pygame.quit()
        # 从动作空间随机获取一个动作
        action ,_=  model.predict(obs, action_masks=snakeEnv.get_action_mask())

        prev_mask = snakeEnv.get_action_mask()
        action_value=int(action.item())

        obs, reward,  terminated, truncated, _ = snakeEnv.step(action_value)
        snakeEnv.render()

三、卷积虚拟环境

上面的是基于多层感知机,上限有限,可能效果不是很好,可以对其进行一点点改进
核心是修改
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)
和 _get_observation() 观察空间
改完这两个其他的基本不用变

1、卷积神经网络虚拟环境

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optional


class SnakeEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }

    def __init__(self, render_mode="human"):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH = 84
        self.SCREEN_HEIGHT = 84
        # 蛇的方块大小
        self.snakeCell = 7
        # 游戏速度
        self.speed = 12
        self.clock = pygame.time.Clock()
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')

        self.render_mode = render_mode
        self.action_space = gym.spaces.Discrete(4)  # 动作空间为4
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),
                                                dtype=np.uint8)
        # 初始化蛇和食物的位置等属性
        # ...
        self.num_timesepts = 0

    # 重启
    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """
        重置蛇和食物的位置
        """
        super().reset(seed=seed)
        self.curStep = 0  # 步数
        # 蛇的初始位置
        self.snake_next = [60, 50]
        self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]
        self.len = 3

        # 食物的初始位置
        self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,
                         random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]
        info = {}

        #
        return self._get_observation(), info

    # 根据当前状态 和action 执行动作
    def step(self, action):
        self.num_timesepts += 1  # 步骤统计加1
        # 定义动作到方向的映射
        directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }

        directionTarget = action_to_direction[action]
        nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置
        # 输入action 获取到 snake_next 下一步
        self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)
        if self.snake_next == self.snake_body[1]:
            return self._get_observation(), -0.5, False, False, {}

        # 奖励
        reward, terminated = self._get_reward(action)
        
        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_next:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]
            self.len += 1




        truncated = False
        info = {}
        # if self.render_mode == "human":
        if self.render_mode == "human" and self.num_timesepts % 5000 > 4500 and self.num_timesepts > 10000:
            self.render()
            for event in pygame.event.get():
                if event == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
        
        if not terminated:
            self.snake_body.insert(0, self.snake_next)
            
        # 弹出
        while self.len < len(self.snake_body):
            self.snake_body.pop()
        return self._get_observation(), reward,terminated,  truncated,{}

    # 渲染
    def render(self):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE = (255, 255, 255)
        GREEN = (0, 255, 0)
        RED = (255, 0, 0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物

        snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)
        for i in range(len(self.snake_body)):
            pos = self.snake_body[len(self.snake_body) - i - 1]
            color = [int(round(component * snakecolor[i])) for component in GREEN]
            pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))
        pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))

        pygame.display.update()

        self.clock.tick(self.speed)
    def GetDic(self,p1,p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))
    # 获取奖励
    def _get_reward(self,action):
    
        # 计算奖励
        self.curStep += 1  # 步数
        reward = 0
        terminated = False
        flag=0
        # 正向激励
        # 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励
        if self.snake_next == self.food_pos:
            reward += 400 + pow(5, self.len)
            self.curStep = 0
           # print(reward)
            #print(action)
           # print(self.snake_body,self.food_pos,self.snake_next)
        # 负向激励
        # 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励
        head = self.snake_next
        if head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[
            1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:
            reward -= 200 / self.len
            terminated = True
            self.curStep = 0
        
        # 摸鱼步数超过一定值就开始负向奖励
        if self.curStep > 250 * self.len:
            reward -=  1 / self.len

      
      
        # 中向激励
        if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):
            reward += 4 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_next,self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.
        elif self.curStep>50 and self.GetDic(self.snake_next,self.food_pos)>= self.GetDic(self.snake_body[0],self.food_pos):
            reward -= 2 / self.len * self.GetDic(self.snake_next,self.food_pos) /self.SCREEN_WIDTH
        #print(reward * 0.3)
        if reward<0:
            #print(reward * 0.2)
            pass
     
       # print(reward* 0.2)
        return reward * 0.2, terminated

    # 获取当前观察空间
    def _get_observation(self):
        obs = np.zeros((self.SCREEN_WIDTH//self.snakeCell, self.SCREEN_HEIGHT//self.snakeCell), dtype=np.uint8)

        # Set the snake body to gray with linearly decreasing intensity from head to tail.
        newsnake=np.array(self.snake_body)//7
        obs[tuple(np.transpose(newsnake))] = np.linspace(200, 50, len(newsnake), dtype=np.uint8)

        # Stack single layer into 3-channel-image.
        obs = np.stack((obs, obs, obs), axis=-1)

        # Set the snake head to green and the tail to blue
        obs[tuple(newsnake[0])] = [0, 255, 0]
        obs[tuple(newsnake[-1])] = [255, 0, 0]

        # Set the food to red
        obs[np.array(self.food_pos)//7] = [0, 0, 255]

        # Enlarge the observation to 84x84
        obs = np.repeat(np.repeat(obs, self.snakeCell, axis=0), self.snakeCell, axis=1)

        return obs
        


def main():
    snakeEnv = SnakeEnv()
    snakeEnv.reset()
    done = False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action = snakeEnv.action_space.sample()
        screen, reward, truncated, done, _ = snakeEnv.step(action)
        snakeEnv.render()


if __name__ == "__main__":
    main()

2、训练代码

核心是修改算法名,之前用MlpPolicy,现在改为CnnPolicy
其余不变

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler

from stable_baselines3 import PPO
def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnvcnn-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-3, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

model = PPO(  "CnnPolicy",
            env,
            device="cuda",
            verbose=1,
            n_steps=2048,
            batch_size=512,
            n_epochs=4,
            gamma=0.94,
            learning_rate=lr_schedule,
            clip_range=clip_range_schedule,
            tensorboard_log=LOG_DIR
            )
# 4、训练模型:

# Set the save directory
num=1
save_dir="trained_models_cnn"
while True:
    save_dir = "trained_models_cnn_{}".format(num)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        break
    else:
        num +=1

checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")

# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)


model.learn(
    total_timesteps=int(200000),
    callback=[checkpoint_callback]
)


# Restore stdout
sys.stdout = original_stdout

# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))


# 5、测试训练好的模型:
obs = env.reset()

for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    observation, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated:
        obs = env.reset()
env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/417931.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

300分钟吃透分布式缓存(拉钩教育总结)

开篇寄语 开篇寄语&#xff1a;缓存&#xff0c;你真的用对了吗&#xff1f; 你好&#xff0c;我是你的缓存老师陈波&#xff0c;可能大家对我的网名 fishermen 会更熟悉。 我是资深老码农一枚&#xff0c;经历了新浪微博从起步到当前月活数亿用户的大型互联网系统的技术演进…

NebulaGraph入门

感谢阅读 官方文档链接NebulaGraph简介nGQLnGQL简介占位标识符和占位符值注释实列大小写区分关键字 基本概念以及相关代码实现补充说明图空间语法以及列子创建克隆官方示例代码(创建并克隆)USE语句指定图空间时查看所有SPACESPACE详情CLEAR SPACE删库跑路&#xff08;看玩笑的说…

C语言:字符函数 字符串函数 内存函数

C语言&#xff1a;字符函数 & 字符串函数 & 内存函数 字符函数字符分类函数字符转换函数tolowertoupper 字符串函数strlenstrcpystrcatstrcmpstrstrstrtok 内存函数memcpymemmovememsetmemcmp 字符函数 顾名思义&#xff0c;字符函数就是作用于字符的函数&#xff0c;…

3dgs学习(二)—— 3d高斯与协方差矩阵及其几何意义

协方差矩阵与3d高斯 3d高斯与椭球与协方差矩阵 3d高斯&#xff0c;及3维空间内的正态分布。 通过一元正态分布的坐标系图像不难想象&#xff0c;3维空间中的正态分布点集中在一片椭球空间中&#xff0c;各方向长轴取决于各方向正态分布的方差。 而协方差矩阵通过计算多元之…

好物周刊#42:国产项目管理软件

https://github.com/cunyu1943 村雨遥的好物周刊&#xff0c;记录每周看到的有价值的信息&#xff0c;主要针对计算机领域&#xff0c;每周五发布。 一、项目 1. 菠萝博客 基于 Java 的菠萝博客系统&#xff0c;简单易部署&#xff0c;精致主题&#xff0c;贴心服务&#xf…

本届挑战赛亚军方案:面向微服务架构系统中无标注、多模态运维数据的异常检测、根因定位与可解释性分析

CheerX团队来自于南瑞研究院系统平台研发中心&#xff0c;中心主要从事NUSP电力自动化通用软件平台的关键技术研究与软件研发。 选题分析 图1 研究现状 本次CheerX团队的选题紧密贴合了目前的运维现状。实际运维中存在多种问题导致运维系统的不可用。比如故障发生时&#xff…

【常用】【测速】ptflops库---速度FPS、参数Params、计算复杂度Flops

一、常用名字 中文名字 英文名字 简称 单位 模型参数量 number of parameters. param. (单位B M) 计算复杂度 computational…

【Spring Cloud 进阶】OpenFeign 底层原理解析

参考文章 万字33张图探秘OpenFeign核心架构原理 | 三友SpringCloud OpenFeign源码详细解析Java 代理机制 OpenFeign 是一个精彩的使用动态代理技术的典型案例&#xff0c;通过分析其底层实现原理&#xff0c;我们可以对动态代理技术有进一步的理解。 目录 1. Feign 与 OpenFeig…

VUE3:统计分析页面布局+自适应页面参考

一、布局 <template><div class"container1"><div class"form white"><el-form :inline"true" :rules"rules" :model"queryParams" label-width"80px" ref"querParmRef"><e…

力扣递归:路径总和

思路&#xff1a;此题思路为递归实现&#xff0c;递归思路为&#xff1a;在每层递归的过程中将各个节点的数据记录下来&#xff0c;不断将减少目标数据的值准备进行判断&#xff0c;当进行到叶子节点时要进行判断 /*** Definition for a binary tree node.* struct TreeNode {…

OJ_二叉树最短路径长度

题干 C实现 #define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<vector> using namespace std;struct TreeNode {int num;TreeNode* left;TreeNode* right;TreeNode* parent; };void createTree(vector<TreeNode*>& nodeArr, int n) {for (i…

2000-2022年上市公司绿色专利申请占比/数据

2000-2022年上市公司绿色专利申请占比数据 1、时间&#xff1a;2000-2022年 2、来源&#xff1a;国家知识产权局、WIPO绿色专利清单 3、指标&#xff1a;年份、股票代码、股票简称、行业代码、省份、城市、区县、行政区划代码、城市代码、区县代码、首次上市年份、上市状态、…

又降价啦!2024年阿里云核心产品价格全线下调,最高幅度达55%

2024年3月1日开始&#xff0c;阿里云将开启新一轮的降价政策&#xff0c;核心产品价格全线下调&#xff0c;平均降幅20%&#xff0c;最高幅度达55%&#xff0c;阿里云希望通过此次大规模降价&#xff0c;让更多企业和开发者用上先进的公共云服务&#xff0c;加速云计算在中国各…

深度学习 精选笔记(8)梯度消失和梯度爆炸

学习参考&#xff1a; 动手学深度学习2.0Deep-Learning-with-TensorFlow-bookpytorchlightning ①如有冒犯、请联系侵删。 ②已写完的笔记文章会不定时一直修订修改(删、改、增)&#xff0c;以达到集多方教程的精华于一文的目的。 ③非常推荐上面&#xff08;学习参考&#x…

备战蓝桥杯Day19 - 堆排序基础知识

一、每日一题 - 填充 详细题解 s input() # 输入字符串 n len(s) # 定义字符的长度 judge ["00", "11", "0?", "1?", "?0", "?1", "??"] # 把所有的情况一一列举出来 count 0 # 设置计数…

【Python】PyGameUI控件

哈里前段时间写了一个windows平板上自娱自乐&#xff08;春节和家人一起玩&#xff09;基于pygame的大富翁游戏。 pygame没有按钮之类的UI控件&#xff0c;写起来不怎么顺手。就自己写一个简单的框架。 仓库地址 哈里PygameUi: pygame ui封装自用 (gitee.com) 使用示例 示…

Tomcat 软件和配置文件 基本介绍

一 &#xff0c;web知识 简介 &#xff08;一&#xff09;web技术 1&#xff0c;http协议和 B/S &#xff08;Browser/Server&#xff09;结构 最早出现了CGl (Common Gateway Interface)通用网关接口&#xff0c;通过浏览器中输入URL直接映射到一个服务器端的脚本程序执行&…

从单体服务到微服务:多模式 Web 应用开发记录<三>预初始化属性

相关文章&#xff1a; 多模式 Web 应用开发记录<一>背景&全局变量优化多模式 Web 应用开发记录<二>自己动手写一个 Struts 开头先看一个简单的例子&#xff0c;这是 ftl 文件的一个表单&#xff1a; <form id"validateForm" action"#&quo…

【程序员是如何看待“祖传代码”的?】《代码的遗产:探索程序员眼中的“祖传代码”》

程序员是如何看待“祖传代码”的&#xff1f; 在程序员的世界里&#xff0c;代码不仅仅是构建软件的基石&#xff0c;它们也承载着历史、智慧和技术的演变。在我的编程生涯中&#xff0c;我遇到过许多神奇而独特的“祖传代码”&#xff0c;这些代码如同古老的魔法书&#xff0…

【C语言】三子棋

前言&#xff1a; 三子棋是一种民间传统游戏&#xff0c;又叫九宫棋、圈圈叉叉棋、一条龙、井字棋等。游戏规则是双方对战&#xff0c;双方依次在9宫格棋盘上摆放棋子&#xff0c;率先将自己的三个棋子走成一条线就视为胜利。但因棋盘太小&#xff0c;三子棋在很多时候会出现和…