Commit 8a57e24e authored by Steven Cordwell's avatar Steven Cordwell
Browse files

clean up and extend the tests

parent 99ac82ac
......@@ -5,35 +5,45 @@ Created on Sat Aug 24 15:04:06 2013
@author: steve
"""
from random import seed as randseed
import numpy as np
import mdptoolbox
from utils import SMALLNUM, P_forest, R_forest, P_small, R_small
from utils import SMALLNUM, P_forest, R_forest, P_forest_sparse
from utils import R_forest_sparse, P_small, R_small
def test_QLearning_small():
randseed(0)
np.random.seed(0)
a = mdptoolbox.mdp.QLearning(P_small, R_small, 0.9)
q = np.matrix("39.9336909966907 43.175433380901488; "
"36.943942243204454 35.42568055796341")
v = np.matrix("43.17543338090149, 36.943942243204454")
sdp = mdptoolbox.mdp.QLearning(P_small, R_small, 0.9)
q = np.matrix("33.330108655211646, 40.82109564847122; "
"34.37431040682546, 29.672368452303164")
v = np.matrix("40.82109564847122, 34.37431040682546")
p = np.matrix("1 0")
assert (np.absolute(a.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(a.V) - v) < SMALLNUM).all()
assert (np.array(a.policy) == p).all()
assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(sdp.policy) == p).all()
def test_QLearning_forest():
np.random.seed(0)
sdp = mdptoolbox.mdp.QLearning(P_forest, R_forest, 0.96)
q = np.matrix("11.198908998901134, 10.34652034142302; "
"10.74229967143465, 11.741057920409865; "
"2.8698000059458546, 12.259732864170232")
v = np.matrix("11.198908998901134, 11.741057920409865, 12.259732864170232")
p = np.matrix("0 1 1")
assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(sdp.policy) == p).all()
def test_QLearning_exampleForest():
randseed(0)
#FIXME: This is wrong as the number of states in this is util.STATES, not 3
def test_QLearning_forest_sparse():
np.random.seed(0)
a = mdptoolbox.mdp.QLearning(P_forest, R_forest, 0.9)
q = np.matrix("26.209597296761608, 18.108253687076136; "
"29.54356354184715, 18.116618509050486; "
"33.61440797109655, 25.1820819845856")
v = np.matrix("26.209597296761608, 29.54356354184715, 33.61440797109655")
p = np.matrix("0 0 0")
assert (np.absolute(a.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(a.V) - v) < SMALLNUM).all()
assert (np.array(a.policy) == p).all()
\ No newline at end of file
sdp = mdptoolbox.mdp.QLearning(P_forest_sparse, R_forest_sparse, 0.96)
q = np.matrix("11.198908998901134, 10.34652034142302; "
"10.74229967143465, 11.741057920409865; "
"2.8698000059458546, 12.259732864170232")
v = np.matrix("11.198908998901134, 11.741057920409865, 12.259732864170232")
p = np.matrix("0 1 1")
assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(sdp.policy) == p).all()
......@@ -9,20 +9,32 @@ import numpy as np
import mdptoolbox
from utils import SMALLNUM, P_forest, R_forest, P_small, R_small
from utils import SMALLNUM, STATES, P_forest, R_forest, P_forest_sparse
from utils import R_forest_sparse, P_rand, R_rand, P_rand_sparse, R_rand_sparse
from utils import P_small, R_small
def test_ValueIteration_boundIter():
inst = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01)
assert (inst.max_iter == 28)
def test_ValueIteration_iterate():
inst = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01)
def test_ValueIteration_small():
sdp = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01)
v = np.array((40.048625392716822, 33.65371175967546))
assert (np.absolute(np.array(inst.V) - v) < SMALLNUM).all()
assert (inst.policy == (1, 0))
assert (inst.iter == 26)
def test_ValueIteration_exampleForest():
a = mdptoolbox.mdp.ValueIteration(P_forest, R_forest, 0.96)
assert (a.policy == np.array([0, 0, 0])).all()
assert a.iter == 4
assert (sdp.max_iter == 28)
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (sdp.policy == (1, 0))
assert (sdp.iter == 26)
def test_ValueIteration_forest():
sdp = mdptoolbox.mdp.ValueIteration(P_forest, R_forest, 0.96)
assert (np.array(sdp.policy) == np.array([0, 0, 0])).all()
assert sdp.iter == 4
def test_ValueIteration_forest_sparse():
sdp = mdptoolbox.mdp.ValueIteration(P_forest_sparse, R_forest_sparse, 0.96)
assert (np.array(sdp.policy) == np.array([0] * STATES)).all()
assert sdp.iter == 14
def test_ValueIteration_rand():
sdp = mdptoolbox.mdp.ValueIteration(P_rand, R_rand, 0.9)
assert sdp.policy
def test_ValueIteration_rand_sparse():
sdp = mdptoolbox.mdp.ValueIteration(P_rand_sparse, R_rand_sparse, 0.9)
assert sdp.policy
......@@ -24,6 +24,12 @@ P_sparse[1] = sp.sparse.csr_matrix([[0, 1],[0.1, 0.9]])
P_forest, R_forest = mdptoolbox.example.forest()
P_forest_sparse, R_forest_sparse = mdptoolbox.example.forest(S=STATES,
is_sparse=True)
np.random.seed(0)
P_rand, R_rand = mdptoolbox.example.rand(STATES, ACTIONS)
P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS, is_sparse=True)
np.random.seed(0)
P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS,
is_sparse=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment