Commit 8a57e24e authored by Steven Cordwell's avatar Steven Cordwell
Browse files

clean up and extend the tests

parent 99ac82ac
...@@ -5,35 +5,45 @@ Created on Sat Aug 24 15:04:06 2013 ...@@ -5,35 +5,45 @@ Created on Sat Aug 24 15:04:06 2013
@author: steve @author: steve
""" """
from random import seed as randseed
import numpy as np import numpy as np
import mdptoolbox import mdptoolbox
from utils import SMALLNUM, P_forest, R_forest, P_small, R_small from utils import SMALLNUM, P_forest, R_forest, P_forest_sparse
from utils import R_forest_sparse, P_small, R_small
def test_QLearning_small(): def test_QLearning_small():
randseed(0)
np.random.seed(0) np.random.seed(0)
a = mdptoolbox.mdp.QLearning(P_small, R_small, 0.9) sdp = mdptoolbox.mdp.QLearning(P_small, R_small, 0.9)
q = np.matrix("39.9336909966907 43.175433380901488; " q = np.matrix("33.330108655211646, 40.82109564847122; "
"36.943942243204454 35.42568055796341") "34.37431040682546, 29.672368452303164")
v = np.matrix("43.17543338090149, 36.943942243204454") v = np.matrix("40.82109564847122, 34.37431040682546")
p = np.matrix("1 0") p = np.matrix("1 0")
assert (np.absolute(a.Q - q) < SMALLNUM).all() assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(a.V) - v) < SMALLNUM).all() assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(a.policy) == p).all() assert (np.array(sdp.policy) == p).all()
def test_QLearning_forest():
np.random.seed(0)
sdp = mdptoolbox.mdp.QLearning(P_forest, R_forest, 0.96)
q = np.matrix("11.198908998901134, 10.34652034142302; "
"10.74229967143465, 11.741057920409865; "
"2.8698000059458546, 12.259732864170232")
v = np.matrix("11.198908998901134, 11.741057920409865, 12.259732864170232")
p = np.matrix("0 1 1")
assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(sdp.policy) == p).all()
def test_QLearning_exampleForest(): #FIXME: This is wrong as the number of states in this is util.STATES, not 3
randseed(0) def test_QLearning_forest_sparse():
np.random.seed(0) np.random.seed(0)
a = mdptoolbox.mdp.QLearning(P_forest, R_forest, 0.9) sdp = mdptoolbox.mdp.QLearning(P_forest_sparse, R_forest_sparse, 0.96)
q = np.matrix("26.209597296761608, 18.108253687076136; " q = np.matrix("11.198908998901134, 10.34652034142302; "
"29.54356354184715, 18.116618509050486; " "10.74229967143465, 11.741057920409865; "
"33.61440797109655, 25.1820819845856") "2.8698000059458546, 12.259732864170232")
v = np.matrix("26.209597296761608, 29.54356354184715, 33.61440797109655") v = np.matrix("11.198908998901134, 11.741057920409865, 12.259732864170232")
p = np.matrix("0 0 0") p = np.matrix("0 1 1")
assert (np.absolute(a.Q - q) < SMALLNUM).all() assert (np.absolute(sdp.Q - q) < SMALLNUM).all()
assert (np.absolute(np.array(a.V) - v) < SMALLNUM).all() assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (np.array(a.policy) == p).all() assert (np.array(sdp.policy) == p).all()
\ No newline at end of file
...@@ -9,20 +9,32 @@ import numpy as np ...@@ -9,20 +9,32 @@ import numpy as np
import mdptoolbox import mdptoolbox
from utils import SMALLNUM, P_forest, R_forest, P_small, R_small from utils import SMALLNUM, STATES, P_forest, R_forest, P_forest_sparse
from utils import R_forest_sparse, P_rand, R_rand, P_rand_sparse, R_rand_sparse
from utils import P_small, R_small
def test_ValueIteration_boundIter(): def test_ValueIteration_small():
inst = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01) sdp = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01)
assert (inst.max_iter == 28)
def test_ValueIteration_iterate():
inst = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9, 0.01)
v = np.array((40.048625392716822, 33.65371175967546)) v = np.array((40.048625392716822, 33.65371175967546))
assert (np.absolute(np.array(inst.V) - v) < SMALLNUM).all() assert (sdp.max_iter == 28)
assert (inst.policy == (1, 0)) assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert (inst.iter == 26) assert (sdp.policy == (1, 0))
assert (sdp.iter == 26)
def test_ValueIteration_exampleForest():
a = mdptoolbox.mdp.ValueIteration(P_forest, R_forest, 0.96) def test_ValueIteration_forest():
assert (a.policy == np.array([0, 0, 0])).all() sdp = mdptoolbox.mdp.ValueIteration(P_forest, R_forest, 0.96)
assert a.iter == 4 assert (np.array(sdp.policy) == np.array([0, 0, 0])).all()
assert sdp.iter == 4
def test_ValueIteration_forest_sparse():
sdp = mdptoolbox.mdp.ValueIteration(P_forest_sparse, R_forest_sparse, 0.96)
assert (np.array(sdp.policy) == np.array([0] * STATES)).all()
assert sdp.iter == 14
def test_ValueIteration_rand():
sdp = mdptoolbox.mdp.ValueIteration(P_rand, R_rand, 0.9)
assert sdp.policy
def test_ValueIteration_rand_sparse():
sdp = mdptoolbox.mdp.ValueIteration(P_rand_sparse, R_rand_sparse, 0.9)
assert sdp.policy
...@@ -24,6 +24,12 @@ P_sparse[1] = sp.sparse.csr_matrix([[0, 1],[0.1, 0.9]]) ...@@ -24,6 +24,12 @@ P_sparse[1] = sp.sparse.csr_matrix([[0, 1],[0.1, 0.9]])
P_forest, R_forest = mdptoolbox.example.forest() P_forest, R_forest = mdptoolbox.example.forest()
P_forest_sparse, R_forest_sparse = mdptoolbox.example.forest(S=STATES,
is_sparse=True)
np.random.seed(0)
P_rand, R_rand = mdptoolbox.example.rand(STATES, ACTIONS) P_rand, R_rand = mdptoolbox.example.rand(STATES, ACTIONS)
P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS, is_sparse=True) np.random.seed(0)
P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS,
is_sparse=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment