$$ \huge{\underline{\textbf{ Policy Iteration }}} $$


Implementation of Policy Iteration
from Sutton and Barto 2018, chapter 4.3.
Book available for free here


From Sutton and Barto (2018) Reinforcement Learning: An Introduction, chapter 4.3


In [1]:
# Naive implementation (for loops are slow), but matches the box
def policy_iter(env, gamma, theta):
    """Policy Iteration Algorithm
    
    Params:
        env - environment with following required memebers:
            env.nb_states - number of states
            env.nb_action - number of actions
            env.model     - prob-transitions and rewards for all states and actions, see note #1
        gamma (float) - discount factor
        theta (float) - termination condition
    """
    
    # 1. Initialization
    V = np.zeros(env.nb_states)
    pi = np.zeros(env.nb_states, dtype=int)  # greedy, always pick action 0
    
    while True:
    
        # 2. Policy Evaluation
        while True:
            delta = 0
            for s in range(env.nb_states):
                v = V[s]
                V[s] = sum_sr(env, V=V, s=s, a=pi[s], gamma=gamma)
                delta = max(delta, abs(v - V[s]))
            if delta < theta: break

        # 3. Policy Improvement
        policy_stable = True
        for s in range(env.nb_states):
            old_action = pi[s]
            pi[s] = np.argmax([sum_sr(env, V=V, s=s, a=a, gamma=gamma)  # list comprehension
                               for a in range(env.nb_actions)])
            if old_action != pi[s]: policy_stable = False
        if policy_stable: break
    
    return V, pi

Helper Functions:

In [2]:
def sum_sr(env, V, s, a, gamma):
    """Calc state-action value for state 's' and action 'a'"""
    tmp = 0  # state value for state s
    for p, s_, r, _ in env.model[s][a]:     # see note #1 !
        # p  - transition probability from (s,a) to (s')
        # s_ - next state (s')
        # r  - reward on transition from (s,a) to (s')
        tmp += p * (r + gamma * V[s_])
    return tmp

Note #1

env.model parameter is taken directly from OpenAI API for FrozenLake-v1 (where it is called env.P, see below). It is a nested structure which describes transition probabilities and expected rewards, for example:

>>> env.model[6][0]
[(0.3333333333333333, 2, 0.0, False),
 (0.3333333333333333, 5, 0.0, True),
 (0.3333333333333333, 10, 0.0, False)]

Has following meaning:

  • from state 6 and taking action 0, there is 0.33 probability transitioning to state 2, with reward 0.0, transition is non-terminal
  • from state 6 and taking action 0, there is 0.33 probability transitioning to state 5, with reward 0.0, transition is terminal, MDP ends
  • from state 6 and taking action 0, there is 0.33 probability transitioning to state 10, with reward 0.0, transition is non-terminal

See diagram

Solve FrozenLake-v0

Using OpenAI Gym FrozenLake-v0. See description here

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import gym
In [4]:
env = gym.make('FrozenLake-v0')
env.reset()
env.render()
SFFF
FHFH
FFFH
HFFG

Rename some members, but don't break stuff!

In [5]:
if not hasattr(env, 'nb_states'):  env.nb_states = env.env.nS
if not hasattr(env, 'nb_actions'): env.nb_actions = env.env.nA
if not hasattr(env, 'model'):      env.model = env.env.P

Perform policy iteration

In [6]:
V, pi = policy_iter(env, gamma=1.0, theta=1e-8)
print(V.reshape([4, -1]))
[[0.82352925 0.82352919 0.82352915 0.82352913]
 [0.82352926 0.         0.52941165 0.        ]
 [0.82352929 0.82352932 0.7647058  0.        ]
 [0.         0.88235288 0.94117644 0.        ]]

Show optimal policy, basically avoid holes in the lake

In [7]:
a2w = {0:'<', 1:'v', 2:'>', 3:'^'}
policy_arrows = np.array([a2w[x] for x in pi])
print(np.array(policy_arrows).reshape([-1, 4]))
[['<' '^' '^' '^']
 ['<' '<' '<' '<']
 ['^' 'v' '<' '<']
 ['<' '>' 'v' '<']]

Check correct

In [8]:
correct_V = np.array([[0.82352941, 0.82352941, 0.82352941, 0.82352941],
                      [0.82352941, 0.        , 0.52941176, 0.        ],
                      [0.82352941, 0.82352941, 0.76470588, 0.        ],
                      [0.        , 0.88235294, 0.94117647, 0.        ]])
correct_policy_arrows = np.array([['<', '^', '^', '^'],
                                  ['<', '<', '<', '<'],
                                  ['^', 'v', '<', '<'],
                                  ['<', '>', 'v', '<']])
assert np.allclose(V.reshape([4,-1]), correct_V)
assert np.alltrue(policy_arrows.reshape([4,-1]) == correct_policy_arrows)