test_QLearning.py 1.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 24 15:04:06 2013

@author: steve
"""

import numpy as np

import mdptoolbox

12
from utils import SMALLNUM, P_forest, R_forest, P_forest_sparse
Steven Cordwell's avatar
Steven Cordwell committed
13
from utils import R_forest_sparse, P_small, R_small, P_sparse
14 15 16

def test_QLearning_small():
    np.random.seed(0)
17
    sdp = mdptoolbox.mdp.QLearning(P_small, R_small, 0.9)
Steven Cordwell's avatar
Steven Cordwell committed
18
    sdp.run()
19 20 21
    q = np.matrix("33.330108655211646, 40.82109564847122; "
                  "34.37431040682546, 29.672368452303164")
    v = np.matrix("40.82109564847122, 34.37431040682546")
Steven Cordwell's avatar
Steven Cordwell committed
22
    p = (1, 0)
23 24
    assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
    assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
Steven Cordwell's avatar
Steven Cordwell committed
25 26 27 28 29 30 31 32 33 34 35 36 37
    assert sdp.policy == p

def test_QLearning_small_sparse():
    np.random.seed(0)
    sdp = mdptoolbox.mdp.QLearning(P_sparse, R_small, 0.9)
    sdp.run()
    q = np.matrix("33.330108655211646, 40.82109564847122; "
                  "34.37431040682546, 29.672368452303164")
    v = np.matrix("40.82109564847122, 34.37431040682546")
    p = (1, 0)
    assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
    assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
    assert sdp.policy == p
38 39 40 41

def test_QLearning_forest():
    np.random.seed(0)
    sdp = mdptoolbox.mdp.QLearning(P_forest, R_forest, 0.96)
Steven Cordwell's avatar
Steven Cordwell committed
42
    sdp.run()
43 44 45 46
    q = np.matrix("11.198908998901134, 10.34652034142302; "
                  "10.74229967143465, 11.741057920409865; "
                  "2.8698000059458546, 12.259732864170232")
    v = np.matrix("11.198908998901134, 11.741057920409865, 12.259732864170232")
Steven Cordwell's avatar
Steven Cordwell committed
47
    p = (0, 1, 1)
48 49
    assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
    assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
Steven Cordwell's avatar
Steven Cordwell committed
50
    assert sdp.policy == p
51

52
def test_QLearning_forest_sparse():
53
    np.random.seed(0)
54
    sdp = mdptoolbox.mdp.QLearning(P_forest_sparse, R_forest_sparse, 0.96)
Steven Cordwell's avatar
Steven Cordwell committed
55 56 57
    sdp.run()
    p = (0, 1, 1, 1, 1, 1, 0, 0, 0, 0)
    assert sdp.policy == p