High-Dimensional Continuous Control Using Generalized Advantage Estimation
Abstract
Two main challenges of policy gradient:
- Large number of samples (Due to the high variance)
- The difficulty of obtaining stable and steady improvement despite the non-stationarity of the incoming data
Solutions:
- Using value functions to reduce variance at the cost of some bias
- Using region optimization procedure
Introduction
A key source of difficulty is the long time delay between actions and their effect on rewards, which called credit assignment problem or distal reward problem.
High variance: using more sample
High bias: more pernicious. Bias can cause the algorithm to fail to converge, or to converge to a poor solution that is not even a local optimum
Propose Generalized Advantage Estimation(GAE) significantly reduce variance while maintaining a tolerable level of bias.
Preliminaries
Policy gradient method: maximize the expected total reward \(g:=\nabla_\theta \mathbb{E}\left[\sum_{t=0}^{\infty} r_t\right]\).
Gradient estimation: \[ g=\mathbb{E}\left[\sum_{t=0}^{\infty} \Psi_t \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)\right] \] where \(\Psi_t\) may be one of the following:
- \(\sum_{t=0}^{\infty} r_t\): total reward of the trajectory
- \(\sum_{t^{\prime}=t}^{\infty} r_{t^{\prime}}\): reward following action \(a_t\)
- \(\sum_{t^{\prime}=t}^{\infty} r_{t^{\prime}}-b\left(s_t\right)\): baselined version of previous formula
- \(Q^\pi\left(s_t, a_t\right)\): state-action value function
- \(A^\pi\left(s_t, a_t\right)\): advantage function (lowest variance)
- \(r_t+V^\pi\left(s_{t+1}\right)-V^\pi\left(s_t\right)\): TD residual
where $$ \[\begin{align} V^\pi\left(s_t\right):=&\mathbb{E}_{\substack{s_{t+1: \infty} \\ a_{t: \infty}}}\left[\sum_{l=0}^{\infty} r_{t+l}\right] \\ Q^\pi\left(s_t, a_t\right):=&\mathbb{E}_{\substack{s_{t+1: \infty} \\ a_{t+1: \infty}}}\left[\sum_{l=0}^{\infty} r_{t+l}\right]\\ A^\pi\left(s_t, a_t\right):=&Q^\pi\left(s_t, a_t\right)-V^\pi\left(s_t\right), (Advantage function) \end{align}\] $$ Discounted of formulations of MDPs:
Value function $$ \[\begin{aligned} V^{\pi, \gamma}\left(s_t\right) & :=\mathbb{E} s_{\substack{t+1: \infty \\ a_{t: \infty}}},\left[\sum_{l=0}^{\infty} \gamma^l r_{t+l}\right] \\ Q^{\pi, \gamma}\left(s_t, a_t\right)&:=\mathbb{E} s_{\substack{t+1: \infty \\ a_{t+1: \infty}}}\left[\sum_{l=0}^{\infty} \gamma^l r_{t+l}\right] \\ A^{\pi, \gamma}\left(s_t, a_t\right) & :=Q^{\pi, \gamma}\left(s_t, a_t\right)-V^{\pi, \gamma}\left(s_t\right) \end{aligned}\]\[ Discounted approximation \] g^:=_{}$$ We need to obtain an unbiased estimate of \(g^\gamma\), which is a biased estimate of the policy gradient of undiscounted MDPs.
Definition 1. The estimator \(\hat A\) is \(\gamma\)-just if \[ \underset{\substack{a_{0: \infty}}}{\mathbb{E} s_{0: \infty} \\ }\left[\hat{A}_t\left(s_{0: \infty}, a_{0: \infty}\right) \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)\right]=\mathbb{E}_{\substack{s_0: \infty \\ a_{0: \infty}}}\left[A^{\pi, \gamma}\left(s_t, a_t\right) \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)\right] \] then \[ \mathbb{E}_{\substack{s_{0: \infty} \\ a_{0: \infty}}}\left[\sum_{t=0}^{\infty} \hat{A}_t\left(s_{0: \infty}, a_{0: \infty}\right) \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right)\right]=g^\gamma \] Proposition 1. Suppose that \(\hat A_t\) can be written in the form \(\hat A_t(s_{0:\infin}, a_{0:\infin})=Q_t(s_{0:\infin}, a_{0:\infin})-b_t(s_{0:t}, a_{0:t-1})\) such that for all \((s_t,a_t)\), \(\mathbb{E}_{s_{t+1: \infty}, a_{t+1: \infty} \mid s_t, a_t}\left[Q_t\left(s_{t: \infty}, a_{t: \infty}\right)\right]=Q^{\pi, \gamma}\left(s_t, a_t\right)\). Then \(\hat A\) is \(\gamma\)-just.
The following expression are \(\gamma\)-just advantage estimators for \(\hat A_t\):
- \(\sum_{l=0}^{\infty} \gamma^l r_{t+l}\)
- \(A^{\pi, \gamma}\left(s_t, a_t\right)\)
- \(Q^{\pi, \gamma}\left(s_t, a_t\right)\)
- \(r_t+\gamma V^{\pi, \gamma}\left(s_{t+1}\right)-V^{\pi, \gamma}\left(s_t\right)\)
Advantage Function Estimation
An accurate estimate \(\hat A_t\) of \(A^{\pi, \gamma}(s_t, a_t)\), which used to construct a policy gradient estimator: \[ \hat{g}=\frac{1}{N} \sum_{n=1}^N \sum_{t=0}^{\infty} \hat{A}_t^n \nabla_\theta \log \pi_\theta\left(a_t^n \mid s_t^n\right) \] where n indexes over a batch of episodes.
GAE
Let \(V\) be an approximate value function
Define \(\delta_t^V=r_t+\gamma V\left(s_{t+1}\right)-V\left(s_t\right)\) , the TD residual of V, an estimate of the advantage of the action \(a_t\)
In fact, if \(V=V^{\pi, \gamma}\), then it is a \(\gamma\)-just advantage estimator of \(A^{\pi, \gamma}\) \[ \begin{aligned} \mathbb{E}_{s_{t+1}}\left[\delta_t^{V^{\pi, \gamma}}\right] & =\mathbb{E}_{s_{t+1}}\left[r_t+\gamma V^{\pi, \gamma}\left(s_{t+1}\right)-V^{\pi, \gamma}\left(s_t\right)\right] \\ & =\mathbb{E}_{s_{t+1}}\left[Q^{\pi, \gamma}\left(s_t, a_t\right)-V^{\pi, \gamma}\left(s_t\right)\right]=A^{\pi, \gamma}\left(s_t, a_t\right) \end{aligned} \] However,it is only \(\gamma\)-just for \(V=V^{\pi,\gamma}\), otherwise it will yield biased estimate.
Notation \(\hat A_t^{(k)}\)
the sum of \(k\) of these \(\delta\) terms \[ \begin{aligned} &\begin{array}{ll} \hat{A}_t^{(1)}:=\delta_t^V & =-V\left(s_t\right)+r_t+\gamma V\left(s_{t+1}\right) \\ \hat{A}_t^{(2)}:=\delta_t^V+\gamma \delta_{t+1}^V & =-V\left(s_t\right)+r_t+\gamma r_{t+1}+\gamma^2 V\left(s_{t+2}\right) \\ \hat{A}_t^{(3)}:=\delta_t^V+\gamma \delta_{t+1}^V+\gamma^2 \delta_{t+2}^V & =-V\left(s_t\right)+r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+\gamma^3 V\left(s_{t+3}\right) \end{array}\\ &\hat{A}_t^{(k)}:=\sum_{l=0}^{k-1} \gamma^l \delta_{t+l}^V=-V\left(s_t\right)+r_t+\gamma r_{t+1}+\cdots+\gamma^{k-1} r_{t+k-1}+\gamma^k V\left(s_{t+k}\right) \end{aligned} \] If \(k\rightarrow \infin\) \[ \hat{A}_t^{(\infty)}=\sum_{l=0}^{\infty} \gamma^l \delta_{t+l}^V=-V\left(s_t\right)+\sum_{l=0}^{\infty} \gamma^l r_{t+l} \] Definition \(GAE(\gamma, \lambda)\) \[ \begin{aligned} \hat{A}_t^{\mathrm{GAE}(\gamma, \lambda)}:= & (1-\lambda)\left(\hat{A}_t^{(1)}+\lambda \hat{A}_t^{(2)}+\lambda^2 \hat{A}_t^{(3)}+\ldots\right) \\ = & (1-\lambda)\left(\delta_t^V+\lambda\left(\delta_t^V+\gamma \delta_{t+1}^V\right)+\lambda^2\left(\delta_t^V+\gamma \delta_{t+1}^V+\gamma^2 \delta_{t+2}^V\right)+\ldots\right) \\ = & (1-\lambda)\left(\delta_t^V\left(1+\lambda+\lambda^2+\ldots\right)+\gamma \delta_{t+1}^V\left(\lambda+\lambda^2+\lambda^3+\ldots\right)\right. \\ & \left.\quad+\gamma^2 \delta_{t+2}^V\left(\lambda^2+\lambda^3+\lambda^4+\ldots\right)+\ldots\right) \\ = & (1-\lambda)\left(\delta_t^V\left(\frac{1}{1-\lambda}\right)+\gamma \delta_{t+1}^V\left(\frac{\lambda}{1-\lambda}\right)+\gamma^2 \delta_{t+2}^V\left(\frac{\lambda^2}{1-\lambda}\right)+\ldots\right) \\ = & \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}^V \end{aligned} \] Special cases \[ \begin{aligned} & \operatorname{GAE}(\gamma, 0): \quad \hat{A}_t:=\delta_t \quad=r_t+\gamma V\left(s_{t+1}\right)-V\left(s_t\right) \\ & \operatorname{GAE}(\gamma, 1): \quad \hat{A}_t:=\sum_{l=0}^{\infty} \gamma^l \delta_{t+l}=\sum_{l=0}^{\infty} \gamma^l r_{t+l}-V\left(s_t\right) \\ & \end{aligned} \] \(GAE(\gamma,1)\) is \(\gamma\)-just, but it has high variance due to the sum of terms.
\(GAE(\gamma, 0)\) is \(\gamma\)-just for \(V=V^{\pi,\gamma}\),otherwise induces bias but lower variance.
\(0<\lambda<1\) makes a compromise between bias and variance.
Finally, \[ g^\gamma \approx \mathbb{E}\left[\sum_{t=0}^{\infty} \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) \hat{A}_t^{\mathrm{GAE}(\gamma, \lambda)}\right]=\mathbb{E}\left[\sum_{t=0}^{\infty} \nabla_\theta \log \pi_\theta\left(a_t \mid s_t\right) \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}^V\right] \] where equality holds when \(\lambda = 1\).
Result

Code
1 | def generalized_advantage_estimate( |
Reference
- https://zhuanlan.zhihu.com/p/45107835
- https://zhuanlan.zhihu.com/p/577598804
- https://towardsdatascience.com/generalized-advantage-estimate-maths-and-code-b5d5bd3ce737
- https://arxiv.org/abs/1506.02438