参考:王树森《强化学习》书籍、课程、代码
1、基本概念
折扣回报:
U
t
=
R
t
+
γ
⋅
R
t
+
1
+
γ
2
⋅
R
t
+
2
+
⋯
+
γ
n
−
t
⋅
R
n
.
U_t=R_t+\gamma\cdot R_{t+1}+\gamma^2\cdot R_{t+2}+\cdots+\gamma^{n-t}\cdot R_n.
Ut=Rt+γ⋅Rt+1+γ2⋅Rt+2+⋯+γn−t⋅Rn.
动作价值函数:
Q
π
(
s
t
,
a
t
)
=
E
[
U
t
∣
S
t
=
s
t
,
A
t
=
a
t
]
,
Q_\pi(s_t,a_t)=\mathbb{E}\Big[U_t\Big|S_t=s_t,A_t=a_t\Big],
Qπ(st,at)=E[Ut
St=st,At=at],
最大化动作价值函数<消除
π
\pi
π的影响>:
Q
⋆
(
s
t
,
a
t
)
=
max
π
Q
π
(
s
t
,
a
t
)
,
∀
s
t
∈
S
,
a
t
∈
A
.
Q^\star(s_t,a_t)=\max_{\pi}Q_\pi(s_t,a_t),\quad\forall s_t\in\mathcal{S},\quad a_t\in\mathcal{A}.
Q⋆(st,at)=πmaxQπ(st,at),∀st∈S,at∈A.
2、Q-learning算法
Observe a transition
(
s
t
,
a
t
,
r
t
,
s
t
+
1
)
.
TD target:
y
t
=
r
t
+
γ
⋅
max
a
Q
⋆
(
s
t
+
1
,
a
)
.
TD error:
δ
t
=
Q
⋆
(
s
t
,
a
t
)
−
y
t
.
Update:
Q
⋆
(
s
t
,
a
t
)
←
Q
⋆
(
s
t
,
a
t
)
−
α
⋅
δ
t
.
\begin{aligned} &\text{Observe a transition }(s_t,{a_t},r_t,s_{t+1}). \\ &\text{TD target: }y_t=r_t+\gamma\cdot\max_{{a}}Q^\star(s_{t+1},{a}). \\ &\text{TD error: }\delta_t=Q^\star(s_t,{a_t})-y_t. \\ &\text{Update:}\quad Q^\star(s_t,{a_t})\leftarrow Q^\star(s_t,{a_t})-\alpha\cdot\delta_t. \end{aligned}
Observe a transition (st,at,rt,st+1).TD target: yt=rt+γ⋅amaxQ⋆(st+1,a).TD error: δt=Q⋆(st,at)−yt.Update:Q⋆(st,at)←Q⋆(st,at)−α⋅δt.
3、DQN算法
Observe a transition
(
s
t
,
a
t
,
r
t
,
s
t
+
1
)
.
TD target:
y
t
=
r
t
+
γ
⋅
max
a
Q
(
s
t
+
1
,
a
;
w
)
.
TD error:
δ
t
=
Q
(
s
t
,
a
t
;
w
)
−
y
t
.
Update:
w
←
w
−
α
⋅
δ
t
⋅
∂
Q
(
s
t
,
a
t
;
w
)
∂
w
.
\begin{aligned} &\text{Observe a transition }(s_t,{a_t},r_t,s_{t+1}). \\ &\text{TD target: }y_{t}=r_{t}+\gamma\cdot\max_{a}Q(s_{t+1},a;\mathbf{w}). \\ &\text{TD error: }\delta_t=Q(s_t,{a_t};\mathbf{w})-y_t. \\ &\text{Update: }\mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_t\cdot\frac{\partial Q(s_t,a_t;\mathbf{w})}{\partial\mathbf{w}}. \end{aligned}
Observe a transition (st,at,rt,st+1).TD target: yt=rt+γ⋅amaxQ(st+1,a;w).TD error: δt=Q(st,at;w)−yt.Update: w←w−α⋅δt⋅∂w∂Q(st,at;w).
注:DQN使用神经网络近似最大化动作价值函数;
4、经验回放:Experience Replay(训练DQN的一种策略)
优点:可以重复利用离线经验数据;连续的经验具有相关性,经验回放可以在离线经验BUFFER随机抽样,减少相关性;
超参数:Replay Buffer的长度;
∙
Find w by minimizing
L
(
w
)
=
1
T
∑
t
=
1
T
δ
t
2
2
.
∙
Stochastic gradient descent (SGD):
∙
Randomly sample a transition,
(
s
i
,
a
i
,
r
i
,
s
i
+
1
)
,
from the buffer
∙
Compute TD error,
δ
i
.
∙
Stochastic gradient: g
i
=
∂
δ
i
2
/
2
∂
w
=
δ
i
⋅
∂
Q
(
s
i
,
a
i
;
w
)
∂
w
∙
SGD: w
←
w
−
α
⋅
g
i
.
\begin{aligned} &\bullet\text{ Find w by minimizing }L(\mathbf{w})=\frac{1}{T}\sum_{t=1}^{T}\frac{\delta_{t}^{2}}{2}. \\ &\bullet\text{ Stochastic gradient descent (SGD):} \\ &\bullet\text{ Randomly sample a transition, }(s_i,a_i,r_i,s_{i+1}),\text{from the buffer} \\ &\bullet\text{ Compute TD error, }\delta_i. \\ &\bullet\text{ Stochastic gradient: g}_{i}=\frac{\partial\delta_{i}^{2}/2}{\partial \mathbf{w}}=\delta_{i}\cdot\frac{\partial Q(s_{i},a_{i};\mathbf{w})}{\partial\mathbf{w}} \\ &\bullet\text{ SGD: w}\leftarrow\mathbf{w}-\alpha\cdot\mathbf{g}_i. \end{aligned}
∙ Find w by minimizing L(w)=T1t=1∑T2δt2.∙ Stochastic gradient descent (SGD):∙ Randomly sample a transition, (si,ai,ri,si+1),from the buffer∙ Compute TD error, δi.∙ Stochastic gradient: gi=∂w∂δi2/2=δi⋅∂w∂Q(si,ai;w)∙ SGD: w←w−α⋅gi.
注:实践中通常使用minibatch SGD,每次抽取多个经验,计算小批量随机梯度;
**5、使用经验回放训练的DQN的算法代码实现**
"""4.3节DQN算法实现。
"""
import argparse
from collections import defaultdict
import os
import random
from dataclasses import dataclass, field
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNet(nn.Module):
"""QNet.
Input: feature
Output: num_act of values
"""
def __init__(self, dim_state, num_action):
super().__init__()
self.fc1 = nn.Linear(dim_state, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, num_action)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class DQN:
def __init__(self, dim_state=None, num_action=None, discount=0.9):
self.discount = discount
self.Q = QNet(dim_state, num_action)
self.target_Q = QNet(dim_state, num_action)
self.target_Q.load_state_dict(self.Q.state_dict())
def get_action(self, state):
qvals = self.Q(state)
return qvals.argmax()
def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
# 计算s_batch,a_batch对应的值。
qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
# 使用target Q网络计算next_s_batch对应的值。
next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1)
# 使用MSE计算loss。
loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
return loss
def soft_update(target, source, tau=0.01):
"""
update target by target = tau * source + (1 - tau) * target.
"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
@dataclass
class ReplayBuffer:
maxsize: int
size: int = 0
state: list = field(default_factory=list)
action: list = field(default_factory=list)
next_state: list = field(default_factory=list)
reward: list = field(default_factory=list)
done: list = field(default_factory=list)
def push(self, state, action, reward, done, next_state):
"""
:param state: 状态
:param action: 动作
:param reward: 奖励
:param done:
:param next_state:下一个状态
:return:
"""
if self.size < self.maxsize:
self.state.append(state)
self.action.append(action)
self.reward.append(reward)
self.done.append(done)
self.next_state.append(next_state)
else:
position = self.size % self.maxsize
self.state[position] = state
self.action[position] = action
self.reward[position] = reward
self.done[position] = done
self.next_state[position] = next_state
self.size += 1
def sample(self, n):
total_number = self.size if self.size < self.maxsize else self.maxsize
indices = np.random.randint(total_number, size=n)
state = [self.state[i] for i in indices]
action = [self.action[i] for i in indices]
reward = [self.reward[i] for i in indices]
done = [self.done[i] for i in indices]
next_state = [self.next_state[i] for i in indices]
return state, action, reward, done, next_state
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if not args.no_cuda:
torch.cuda.manual_seed(args.seed)
def train(args, env, agent):
replay_buffer = ReplayBuffer(10_000)
optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr)
optimizer.zero_grad()
epsilon = 1
epsilon_max = 1
epsilon_min = 0.1
episode_reward = 0
episode_length = 0
max_episode_reward = -float("inf")
log = defaultdict(list)
log["loss"].append(0)
agent.Q.train()
state, _ = env.reset(seed=args.seed)
for i in range(args.max_steps):
if np.random.rand() < epsilon or i < args.warmup_steps:
action = env.action_space.sample()
else:
action = agent.get_action(torch.from_numpy(state))
action = action.item()
env.render()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward
episode_length += 1
replay_buffer.push(state, action, reward, done, next_state)
state = next_state
if done is True:
log["episode_reward"].append(episode_reward)
log["episode_length"].append(episode_length)
print(
f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log['loss'][-1]:.1e}, epsilon={epsilon:.3f}")
# 如果得分更高,保存模型。
if episode_reward > max_episode_reward:
save_path = os.path.join(args.output_dir, "model.bin")
torch.save(agent.Q.state_dict(), save_path)
max_episode_reward = episode_reward
episode_reward = 0
episode_length = 0
epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, 1e-1)
state, _ = env.reset()
if i > args.warmup_steps:
bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)
loss = agent.compute_loss(bs, ba, br, bd, bns)
loss.backward()
optimizer.step()
optimizer.zero_grad()
log["loss"].append(loss.item())
soft_update(agent.target_Q, agent.Q)
# 3. 画图。
plt.plot(log["loss"])
plt.yscale("log")
plt.savefig(f"{args.output_dir}/loss.png", bbox_inches="tight")
plt.close()
plt.plot(np.cumsum(log["episode_length"]), log["episode_reward"])
plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")
plt.close()
def eval(args, env, agent):
agent = DQN(args.dim_state, args.num_action)
model_path = os.path.join(args.output_dir, "model.bin")
agent.Q.load_state_dict(torch.load(model_path))
episode_length = 0
episode_reward = 0
state, _ = env.reset()
for i in range(5000):
episode_length += 1
action = agent.get_action(torch.from_numpy(state)).item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
env.render()
episode_reward += reward
state = next_state
if done is True:
print(f"episode reward={episode_reward}, episode length{episode_length}")
state, _ = env.reset()
episode_length = 0
episode_reward = 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
parser.add_argument("--max_steps", default=50000, type=int, help="Maximum steps for interaction.")
parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--warmup_steps", default=5000, type=int, help="Warmup steps without training.")
parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
parser.add_argument("--epsilon_decay", default=1 / 1000, type=float, help="Epsilon-greedy algorithm decay coefficient.")
parser.add_argument("--do_train", action="store_true", help="Train policy.")
parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
args = parser.parse_args()
args.do_train = True
args.do_eval = True
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
env = gym.make(args.env, render_mode="human")
set_seed(args)
agent = DQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount)
agent.Q.to(args.device)
agent.target_Q.to(args.device)
if args.do_train:
train(args, env, agent)
if args.do_eval:
eval(args, env, agent)
if __name__ == "__main__":
main()
显示: