20250220-代码笔记01-class CVRPEnv

文章目录

  • 前言
  • 一、def __init__(self, **env_params):
    • 函数功能
    • 函数代码
  • 二、use_saved_problems(self, filename, device)
    • 函数功能
    • 函数代码
  • 三、load_problems(self, batch_size, aug_factor=1)
    • 函数功能
    • 函数代码
    • use_saved_problems 与 load_problems 之间的关系
  • 四、reset(self)
    • 函数功能
    • 函数代码
  • 五、pre_step(self)
    • 函数功能
    • 函数代码
  • 六、step(self, selected)
    • 函数功能
    • 函数代码
  • 七、_get_travel_distance(self)
    • 函数功能
    • 问题
      • 什么是“滚动”?
    • 函数代码
  • 附件
    • 代码(全):CVRPEnv.py
    • 代码:一、def __init__(self, **env_params)


前言

对CVRPEnv.py中的类(class CVRPEnv)代码的学习。
代码地址如下:

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py


一、def init(self, **env_params):

函数功能

这段代码是CVRPEnv类的初始化方法,主要用于初始化与**车辆路径问题(CVRP)**环境相关的各个参数和变量。

参数思维导图链接
在这里插入图片描述

函数代码

    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']  #提取问题规模
        self.pomo_size = env_params['pomo_size']        #POMO 智能体数量

        self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例
        self.saved_depot_xy = None                      #配送中心(depot)的坐标
        self.saved_node_xy = None                       #节点(客户或城市)的坐标
        self.saved_node_demand = None                   #保存节点的需求量
        self.saved_index = None                         #保存节点的索引

        # Const @Load_Problem
        ####################################
        self.batch_size = None  
        self.BATCH_IDX = None   
        self.POMO_IDX = None    
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

        # regret
        ####################################
        self.mode = None
        self.last_current_node = None
        self.last_load = None
        self.regret_count = None

        self.regret_mask_matrix = None
        self.add_mask_matrix = None

        self.time_step=0 


二、use_saved_problems(self, filename, device)

函数功能

函数的功能是加载预先保存的问题实例,并将这些问题实例的数据保存到类的属性中。
它会从指定的文件中读取问题数据,包括配送中心的位置(depot_xy)节点的位置(node_xy)节点的需求量(node_demand),然后将这些数据存储在类的属性中,以供后续使用。

函数思维导图链接
在这里插入图片描述

函数代码

 def use_saved_problems(self, filename, device):                
        self.FLAG__use_saved_problems = True 

        loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例
        self.saved_depot_xy = loaded_dict['depot_xy']           #解析加载的数据
        self.saved_node_xy = loaded_dict['node_xy']             #
        self.saved_node_demand = loaded_dict['node_demand']
        self.saved_index = 0


三、load_problems(self, batch_size, aug_factor=1)

函数功能

该函数用于加载**车辆路径问题(CVRP)**实例,包括:

  1. 动态生成问题实例 或 从预加载数据中提取问题
  2. 数据增强
  3. 初始化索引和状态变量
  4. 存储到环境变量

工作方式

  • 如果 self.FLAG__use_saved_problemsTrue,则从通过 use_saved_problems 加载的预先保存的问题实例中提取数据(self.saved_depot_xy, self.saved_node_xy, self.saved_node_demand),并更新索引 self.saved_index
  • 如果 self.FLAG__use_saved_problemsFalse,则动态生成问题实例。使用 get_random_problems() 方法生成指定 batch_sizeproblem_size 的问题数据。
  • load_problems 还支持数据增强,通过指定 aug_factor 来增强生成的数据(目前仅支持 aug_factor=8),扩展批次数量并改变问题实例的坐标和需求。

函数功能思维导图链接
在这里插入图片描述

函数代码

 def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        #加载问题实例
        if not self.FLAG__use_saved_problems:
            #动态生成模式
            depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)
        else:
            #预加载模式,从保存的实例数据中提取问题
            depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]
            node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]
            node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]
            self.saved_index += batch_size

        #数据增强
        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                depot_xy = augment_xy_data_by_8_fold(depot_xy)
                node_xy = augment_xy_data_by_8_fold(node_xy)
                node_demand = node_demand.repeat(8, 1)
            else:
                raise NotImplementedError
            
        #合并配送中心和节点数据
        self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        # shape: (batch, problem+1, 2)
        depot_demand = torch.zeros(size=(self.batch_size, 1))
        # shape: (batch, 1)
        self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        # shape: (batch, problem+1)

        #初始化批量索引和 POMO 索引
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

        #更新重置状态和步骤状态
        self.reset_state.depot_xy = depot_xy
        self.reset_state.node_xy = node_xy
        self.reset_state.node_demand = node_demand

        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.POMO_IDX = self.POMO_IDX

use_saved_problems 与 load_problems 之间的关系

  • use_saved_problems 作为数据加载的前置条件

    • use_saved_problems 主要负责加载已经保存好的问题实例文件(比如一个torch.save()保存的文件),并将这些数据存储到环境中的特定变量中(例如 self.saved_depot_xyself.saved_node_xy)。

    • 一旦执行了use_saved_problems,它设置了 self.FLAG__use_saved_problems = True,这意味着在后续的操作中,环境会从保存的数据中加载问题实例,而不是重新生成问题。

    • 但是use_saved_problems 本身并不负责加载具体的问题实例数据它只是为后续的加载操作(如 load_problems)提供了指示标志

  • load_problems使用 use_saved_problems 加载的数据:

    • load_problems执行数据加载和问题生成的主函数,它根据 self.FLAG__use_saved_problems 的值,决定是从保存的数据中提取问题实例,还是生成新的随机问题实例。
    • self.FLAG__use_saved_problems = True 时,load_problems 会从 self.saved_depot_xyself.saved_node_xyself.saved_node_demand 等变量中读取数据,并根据需要为每个批次的问题实例做进一步处理(如索引的更新、数据增强等)。
    • 如果 self.FLAG__use_saved_problems = False,则 load_problems 会使用 get_random_problems() 来动态生成问题数据。

四、reset(self)

函数功能

reset 函数的主要目的是将环境的状态变量重置为初始值,通常在每个新的训练回合或实验开始时调用。该函数确保环境处于一个已知的初始状态,以便智能体能够从一个干净的状态开始进行决策和学习。

函数参数思维导图
在这里插入图片描述

函数代码

 def reset(self):
        #重置选择计数
        self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)

        #重置当前节点
        self.current_node = None
        # shape: (batch, pomo)  

        #重置已选择的节点列表
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        #初始化是否在配送中心
        self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
        # shape: (batch, pomo)

        # 初始化负载
        self.load = torch.ones(size=(self.batch_size, self.pomo_size))
        # shape: (batch, pomo)

        #初始化访问掩码
        self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
        self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')
        # shape: (batch, pomo, problem+1)

        #初始化负无穷掩码
        self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
        self.ninf_mask[:, :, self.problem_size+1] = float('-inf')
        # shape: (batch, pomo, problem+1)

        #初始化完成状态
        self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
        # shape: (batch, pomo)

        #初始化其他状态变量
        self.regret_count = torch.zeros((self.batch_size, self.pomo_size))
        self.mode = torch.full((self.batch_size, self.pomo_size), 0)
        self.last_current_node = None
        self.last_load = None
        self.time_step=0

        reward = None
        done = False
        return self.reset_state, reward, done


五、pre_step(self)

函数功能

pre_step 函数是环境中的一个预处理步骤,用于在每个时间步之前设置必要的状态信息。
通常,在强化学习环境中,每个时间步会根据当前状态和动作进行更新,pre_step 函数则为每个时间步提供所需的状态,供后续的决策和学习过程使用。

函数功能思维导图
在这里插入图片描述

函数代码

    def pre_step(self):
        #重置 selected_count
        self.step_state.selected_count = 0
        #复制当前负载
        self.step_state.load = self.load
        #设置当前节点
        self.step_state.current_node = self.current_node
        #更新掩码状态
        self.step_state.ninf_mask = self.ninf_mask
        
        #返回步骤状态、奖励和完成标志
        reward = None
        done = False
        return self.step_state, reward, done


六、step(self, selected)

函数功能

这个函数的主要功能是在每个时间步(step)中更新智能体的状态,执行任务、处理负载、选择节点等,最终返回当前的状态、奖励和是否完成任务的标志。

函数功能与参数的思维导图链接

在这里插入图片描述

函数代码

def step(self, selected):
        # selected.shape: (batch, pomo)

        #时间步数控制
        if self.time_step<4:

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)

            #特定时间步的操作
            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
            
            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask


        #时间步大于等于 4 的复杂操作
        else:
            #动作模式分类
            action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
            action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
            action2_bool_index = self.mode == 1
            action3_bool_index = self.mode == 2
            
            action1_index = torch.nonzero(action1_bool_index)
            action2_index = torch.nonzero(action2_bool_index)

            action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

            #更新选择计数
            self.selected_count = self.selected_count+1
            #后悔模式
            self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

            #节点更新
            self.last_is_depot = (self.last_current_node == 0)

            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            self.last_current_node = self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

            #更新已选择节点列表
            self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

            #更新负载
            self.at_the_depot = (selected == 0)
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            # shape: (batch, pomo, problem+1)
            _3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
            #扩展需求列表 demand_list 
            demand_list = torch.cat((demand_list, _3), dim=2)
            gathering_index = selected[:, :, None]
            # shape: (batch, pomo, 1)
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            _1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
            self.last_load= self.load.clone()
            # shape: (batch, pomo)
            self.load -= selected_demand
            self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记
            self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
            self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0


            # 更新负无穷掩码
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            # shape: (batch, pomo, problem+1)
            self.ninf_mask[demand_too_large] = float('-inf')

            # 更新完成状态
            # 检查哪些智能体已经完成所有节点的访问。
            # 更新完成标记 self.finished。
            newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
            # shape: (batch, pomo)
            self.finished = self.finished + newly_finished
            # shape: (batch, pomo)

            #更新模式
            self.mode[action1_bool_index] = 1
            self.mode[action2_bool_index] = 2
            self.mode[action3_bool_index] = 0
            self.mode[self.finished] = 4

            # 更新完成后的掩码调整
            self.ninf_mask[:, :, 0][self.finished] = 0
            self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')

            # 更新步骤状态
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask



        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done


七、_get_travel_distance(self)

函数功能

_get_travel_distance 函数的主要功能是计算每个智能体(POMO智能体)在每个时间步所选择的节点之间的旅行距离。

函数参数和流程图链接

在这里插入图片描述

问题

什么是“滚动”?

“滚动”是对张量或数组进行操作的一种方式,它通过沿特定维度(通常是时间维度)移动元素,从而生成一个新的数组或张量。

例子
设我们有一个一维张量表示时间步的节点选择情况:

tensor = torch.tensor([1, 2, 3, 4, 5])

如果我们对这个张量进行滚动操作,沿着时间维度向右滚动1步:

rolled_tensor = tensor.roll(dims=0, shifts=1)

这时,rolled_tensor 将变成:

tensor([5, 1, 2, 3, 4])

函数代码

  def _get_travel_distance(self):

        m1 = (self.selected_node_list==self.problem_size+1)
        m2 = (m1.roll(dims=2, shifts=-1) | m1)
        m3 = m1.roll(dims=2, shifts=1)
        m4 = ~(m2|m3)

        selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)
        selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)

        self.regret_mask_matrix = m1
        self.add_mask_matrix = (~m2)

        travel_distances = torch.zeros((self.batch_size, self.pomo_size))

        for t in range(self.selected_node_list.shape[2]):
            add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()
            add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()

            travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()

            travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()



        return travel_distances



附件

代码(全):CVRPEnv.py

返回:前言

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPEnv.py


from dataclasses import dataclass
import torch

from CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold


@dataclass
class Reset_State:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None      #表示批次的索引
    POMO_IDX: torch.Tensor = None       #表示 POMO 算法中的多智能体索引
    # shape: (batch, pomo)
    selected_count: int = None          #表示当前已经选中的节点数量
    load: torch.Tensor = None           #表示当前负载状态
    # shape: (batch, pomo)
    current_node: torch.Tensor = None   #表示当前正在访问的节点编号
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None      #表示负无穷掩码
    # shape: (batch, pomo, problem+1)


class CVRPEnv:               
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']  #提取问题规模
        self.pomo_size = env_params['pomo_size']        #POMO 智能体数量

        self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例
        self.saved_depot_xy = None                      #配送中心(depot)的坐标
        self.saved_node_xy = None                       #节点(客户或城市)的坐标
        self.saved_node_demand = None                   #保存节点的需求量
        self.saved_index = None                         #保存节点的索引

        # Const @Load_Problem
        ####################################
        self.batch_size = None  
        self.BATCH_IDX = None   
        self.POMO_IDX = None    
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

        # regret
        ####################################
        self.mode = None
        self.last_current_node = None
        self.last_load = None
        self.regret_count = None

        self.regret_mask_matrix = None
        self.add_mask_matrix = None

        self.time_step=0

    #加载保存的问题实例数据 
    def use_saved_problems(self, filename, device):                
        self.FLAG__use_saved_problems = True 

        loaded_dict = torch.load(filename, map_location=device) #加载保存的问题实例
        self.saved_depot_xy = loaded_dict['depot_xy']           #解析加载的数据
        self.saved_node_xy = loaded_dict['node_xy']             #
        self.saved_node_demand = loaded_dict['node_demand']
        self.saved_index = 0

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        #加载问题实例
        if not self.FLAG__use_saved_problems:
            #动态生成模式
            depot_xy, node_xy, node_demand = get_random_problems(batch_size, self.problem_size)
        else:
            #预加载模式,从保存的实例数据中提取问题
            depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index+batch_size]
            node_xy = self.saved_node_xy[self.saved_index:self.saved_index+batch_size]
            node_demand = self.saved_node_demand[self.saved_index:self.saved_index+batch_size]
            self.saved_index += batch_size

        #数据增强
        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                depot_xy = augment_xy_data_by_8_fold(depot_xy)
                node_xy = augment_xy_data_by_8_fold(node_xy)
                node_demand = node_demand.repeat(8, 1)
            else:
                raise NotImplementedError
            
        #合并配送中心和节点数据
        self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        # shape: (batch, problem+1, 2)
        depot_demand = torch.zeros(size=(self.batch_size, 1))
        # shape: (batch, 1)
        self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        # shape: (batch, problem+1)

        #初始化批量索引和 POMO 索引
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

        #更新重置状态和步骤状态
        self.reset_state.depot_xy = depot_xy
        self.reset_state.node_xy = node_xy
        self.reset_state.node_demand = node_demand

        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.POMO_IDX = self.POMO_IDX

    def reset(self):
        #重置选择计数
        self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)

        #重置当前节点
        self.current_node = None
        # shape: (batch, pomo)  

        #重置已选择的节点列表
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        #初始化是否在配送中心
        self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
        # shape: (batch, pomo)

        # 初始化负载
        self.load = torch.ones(size=(self.batch_size, self.pomo_size))
        # shape: (batch, pomo)

        #初始化访问掩码
        self.visited_ninf_flag = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
        self.visited_ninf_flag[:, :, self.problem_size+1] = float('-inf')
        # shape: (batch, pomo, problem+1)

        #初始化负无穷掩码
        self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size+2))
        self.ninf_mask[:, :, self.problem_size+1] = float('-inf')
        # shape: (batch, pomo, problem+1)

        #初始化完成状态
        self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)
        # shape: (batch, pomo)

        #初始化其他状态变量
        self.regret_count = torch.zeros((self.batch_size, self.pomo_size))
        self.mode = torch.full((self.batch_size, self.pomo_size), 0)
        self.last_current_node = None
        self.last_load = None
        self.time_step=0

        reward = None
        done = False
        return self.reset_state, reward, done

    def pre_step(self):
        #重置 selected_count
        self.step_state.selected_count = 0
        #复制当前负载
        self.step_state.load = self.load
        #设置当前节点
        self.step_state.current_node = self.current_node
        #更新掩码状态
        self.step_state.ninf_mask = self.ninf_mask
        
        #返回步骤状态、奖励和完成标志
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        #时间步数控制
        if self.time_step<4:

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)

            #特定时间步的操作
            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
            
            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask


        #时间步大于等于 4 的复杂操作
        else:
            #动作模式分类
            action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
            action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
            action2_bool_index = self.mode == 1
            action3_bool_index = self.mode == 2
            
            action1_index = torch.nonzero(action1_bool_index)
            action2_index = torch.nonzero(action2_bool_index)

            action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

            #更新选择计数
            self.selected_count = self.selected_count+1
            #后悔模式
            self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

            #节点更新
            self.last_is_depot = (self.last_current_node == 0)

            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            self.last_current_node = self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

            #更新已选择节点列表
            self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

            #更新负载
            self.at_the_depot = (selected == 0)
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            # shape: (batch, pomo, problem+1)
            _3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
            #扩展需求列表 demand_list 
            demand_list = torch.cat((demand_list, _3), dim=2)
            gathering_index = selected[:, :, None]
            # shape: (batch, pomo, 1)
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            _1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
            self.last_load= self.load.clone()
            # shape: (batch, pomo)
            self.load -= selected_demand
            self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记
            self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
            self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0


            # 更新负无穷掩码
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            # shape: (batch, pomo, problem+1)
            self.ninf_mask[demand_too_large] = float('-inf')

            # 更新完成状态
            # 检查哪些智能体已经完成所有节点的访问。
            # 更新完成标记 self.finished。
            newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
            # shape: (batch, pomo)
            self.finished = self.finished + newly_finished
            # shape: (batch, pomo)

            #更新模式
            self.mode[action1_bool_index] = 1
            self.mode[action2_bool_index] = 2
            self.mode[action3_bool_index] = 0
            self.mode[self.finished] = 4

            # 更新完成后的掩码调整
            self.ninf_mask[:, :, 0][self.finished] = 0
            self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')

            # 更新步骤状态
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask



        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done

    def _get_travel_distance(self):

        m1 = (self.selected_node_list==self.problem_size+1)
        m2 = (m1.roll(dims=2, shifts=-1) | m1)
        m3 = m1.roll(dims=2, shifts=1)
        m4 = ~(m2|m3)

        selected_node_list_right = self.selected_node_list.roll(dims=2, shifts=1)
        selected_node_list_right2 = self.selected_node_list.roll(dims=2, shifts=3)

        self.regret_mask_matrix = m1
        self.add_mask_matrix = (~m2)

        travel_distances = torch.zeros((self.batch_size, self.pomo_size))

        for t in range(self.selected_node_list.shape[2]):
            add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()
            add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()

            travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.depot_node_xy[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.depot_node_xy[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()

            travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.depot_node_xy[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.depot_node_xy[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()



        return travel_distances




代码:一、def init(self, **env_params)

    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']  #提取问题规模
        self.pomo_size = env_params['pomo_size']        #POMO 智能体数量

        self.FLAG__use_saved_problems = False           #设置是否使用保存的问题实例
        self.saved_depot_xy = None                      #配送中心(depot)的坐标
        self.saved_node_xy = None                       #节点(客户或城市)的坐标
        self.saved_node_demand = None                   #保存节点的需求量
        self.saved_index = None                         #保存节点的索引

        # Const @Load_Problem
        ####################################
        self.batch_size = None  
        self.BATCH_IDX = None   
        self.POMO_IDX = None    
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

        # regret
        ####################################
        self.mode = None
        self.last_current_node = None
        self.last_load = None
        self.regret_count = None

        self.regret_mask_matrix = None
        self.add_mask_matrix = None

        self.time_step=0 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/975136.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

041集——封装之:新建图层(CAD—C#二次开发入门)

如图所示&#xff1a;增加一个图层“新图层”&#xff0c;颜色为红&#xff08;1&#xff09;&#xff0c;当图层颜色定义为黄&#xff08;2&#xff09;时&#xff0c;直接覆盖之前图层颜色&#xff0c;图层名不变。 代码如下&#xff1a; /// </summary>/// <param …

极简入门,本地部署dify低代码平台构建AI Agent大模型全流程(使用教程、微案例、配置详解、架构图解析)

文章目录 一、环境搭建1.1 安装VMware-workstationCentOS7.91.2 安装宝塔1.3 安装docker及改镜像、安装dify1.4 配置模型供应商 二、dify快速上手体验2.1 知识库2.2 微案例&#xff1a;基于知识库的助手 三、dify知识库配置详解3.1 分片策略3.2 父子分段3.3 索引方法3.4 检索结…

STM32-心知天气项目

一、项目需求 使用 ESP8266 通过 HTTP 获取天气数据&#xff08;心知天气&#xff09;&#xff0c;并显示在 OLED 屏幕上。 按键 1 &#xff1a;循环切换今天 / 明天 / 后天天气数据&#xff1b; 按键 2 &#xff1a;更新天气。 二、项目框图 三、cjson作用 https://gi…

ROS2 应用:按键控制 MoveIt2 中 Panda 机械臂关节位置

视频讲解 ROS2 应用&#xff1a;按键控制 MoveIt2 中 Panda 机械臂关节位置 创建 ROS 2 包 进入工作空间的 src 目录&#xff0c;然后创建一个新的 Python 包&#xff1a; ros2 pkg create --build-type ament_python panda_joint_control --dependencies rclpy control_msgs…

初学者如何设置以及使用富文本编辑器[eclipse版]

手把手教你设置富文本编辑器 参考来源&#xff1a;UEditor Docs 初学者按我的步骤来就可以啦 一、设置ueditor编辑器 1.提取文件[文章最底部有链接提取方式] 2.解压文件并放到自己项目中&#xff0c;在WebContent目录下&#xff1a; 3. 修改jar包位置路径 到--> 注意&a…

springboot系列十四: 注入Servlet, Filter, Listener + 内置Tomcat配置和切换 + 数据库操作

文章目录 注入Servlet, Filter, Listener官方文档基本介绍使用注解方式注入使用RegistrationBean方法注入DispatcherServlet详解 内置Tomcat配置和切换基本介绍内置Tomcat配置通过application.yml完成配置通过类配置 切换Undertow 数据库操作 JdbcHikariDataSource需求分析应用…

【数据结构初阶第十五节】堆的应用(堆排序 + Top-K问题)

必须有为成功付出代价的决心&#xff0c;然后想办法付出这个代价。云边有个稻草人-CSDN博客 对于本节我们要提前掌握前一节课堆的相关实现才能学好本次的知识&#xff0c;一定要多画图多敲代码看看实现的效果是啥&#xff08;Crazy&#xff01;&#xff09;开始吧&#xff01; …

deepseek自动化代码生成

使用流程 效果第一步&#xff1a;注册生成各种大模型的API第二步&#xff1a;注册成功后生成API第三步&#xff1a;下载vscode在vscode中下载agent&#xff0c;这里推荐使用cline 第四步&#xff1a;安装完成后&#xff0c;设置模型信息第一步选择API provider&#xff1a; Ope…

Scrapy:Downloader下载器设计详解

Scrapy下载器设计详解 1. 整体架构 Scrapy的下载器(Downloader)是整个爬虫框架的核心组件之一&#xff0c;负责处理所有网络请求的下载工作。它的主要职责是&#xff1a; 管理并发请求实现请求调度处理下载延迟维护下载槽(Slot) 官方文档&#xff1a;Settings中的Downloader配…

【IO】java IO流的类型及IO模型

文章目录 分类字节流输入流输出流 字符流输入流输出流 字节缓冲流字符缓冲流4中常见的IO模型BIO&#xff08;同步阻塞模型&#xff09;同步非阻塞模型NIO&#xff08;多路复用模型&#xff09;AIO异步 分类 根据数据流向分为&#xff1a;输入流、输出流&#xff08;以内存为中…

计算机视觉:主流数据集整理

第一章&#xff1a;计算机视觉中图像的基础认知 第二章&#xff1a;计算机视觉&#xff1a;卷积神经网络(CNN)基本概念(一) 第三章&#xff1a;计算机视觉&#xff1a;卷积神经网络(CNN)基本概念(二) 第四章&#xff1a;搭建一个经典的LeNet5神经网络(附代码) 第五章&#xff1…

八股文实战之JUC:静态方法的锁和普通方法的锁

1、对于staic同步方法锁住的是class类模板&#xff08;Class对象&#xff09; 对象是线程&#xff08;调用者&#xff09; 调用者只有获取资源的锁才能调用 2、普通同步方法 锁住的资源是class对象 对象是线程&#xff08;调用者&#xff09;即&#xff1a; 静态同步方法&a…

EasyRTC:基于WebRTC与P2P技术,开启智能硬件音视频交互的全新时代

在数字化浪潮的席卷下&#xff0c;智能硬件已成为我们日常生活的重要组成部分&#xff0c;从智能家居到智能穿戴&#xff0c;从工业物联网到远程协作&#xff0c;设备间的互联互通已成为不可或缺的趋势。然而&#xff0c;高效、低延迟且稳定的音视频交互一直是智能硬件领域亟待…

VSCode - VSCode 切换自动换行

VSCode 自动换行 1、基本介绍 在 VSCode 中&#xff0c;启用自动换行可以让长行代码自动折行显示&#xff0c;避免水平滚动条频繁使用&#xff0c;提升代码阅读体验 如果禁用自动换行&#xff0c;长行代码就需要手动结合水平滚动条来阅读 2、演示 启用自动换行 禁用自动换…

编程小白冲Kaggle每日打卡(12)--kaggle学堂:<机器学习简介>模型如何工作

Kaggle官方课程链接&#xff1a;How Models Work 本专栏旨在Kaggle官方课程的汉化&#xff0c;让大家更方便地看懂。 How Models Work 第一步&#xff0c;如果你是机器学习的新手。 Introduction 我们将从概述机器学习模型的工作原理和使用方法开始。如果你以前做过统计建模…

IDEA安装deepseek最新教程2025

IDEA引入DeepSeek 将 IntelliJ IDEA&#xff08;JetBrains 开发的 Java 集成开发环境&#xff09;与 DeepSeek&#xff08;深度求索的技术能力&#xff09;结合&#xff0c;通常涉及利用 AI 技术增强开发效率或扩展 IDE 功能,安装完成后&#xff0c;结合 IntelliJ IDEA 的开发…

安科瑞能源物联网平台助力企业实现绿色低碳转型

安科瑞顾强 随着全球能源结构的转型和“双碳”目标的推进&#xff0c;能源管理正朝着智能化、数字化的方向快速发展。安科瑞电气股份有限公司推出的微电网智慧能源管理平台&#xff08;EMS 3.0&#xff09;&#xff0c;正是这一趋势下的创新解决方案。该平台集成了物联网&…

Ansible 学习笔记

这里写自定义目录标题 基本架构文件结构安装查看版本 Ansible 配置相关文件主机清单写法 基本架构 Ansible 是基于Python实现的&#xff0c;默认使用22端口&#xff0c; 文件结构 安装 查看用什么语言写的用一下命令 查看版本 Ansible 配置相关文件 主机清单写法

android,flutter 混合开发,pigeon通信,传参

文章目录 app效果native和flutter通信的基础知识1. 编解码器 一致性和完整性&#xff0c;安全性&#xff0c;性能优化2. android代码3. dart代码 1. 创建flutter_module2.修改 Android 项目的 settings.gradle&#xff0c;添加 Flutter module3. 在 Android app 的 build.gradl…

怎么在Github上readme文件里面怎么插入图片?

环境&#xff1a; Github 问题描述&#xff1a; 怎么在Github上readme文件里面怎么插入图片&#xff1f; https://github.com/latiaoge/AI-Sphere-Butler/tree/master 解决方案&#xff1a; 1.相对路径引用 上传图片到仓库 将图片文件&#xff08;如 .png/.jpg&#xff…