Commit c82e6ae4 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

reduce some code duplication in unit tests

parent 7874abac
......@@ -1539,37 +1539,10 @@ class ValueIterationGS(ValueIteration):
def __init__(self, transitions, reward, discount, epsilon=0.01, max_iter=10, initial_value=0):
""""""
MDP.__init__(self, transitions, reward, discount, max_iter)
# initialization of optional arguments
if (initial_value == 0):
self.V = matrix(zeros((self.S, 1)))
else:
if (initial_value.size != self.S):
raise ValueError("The initial value must be length S")
self.V = matrix(initial_value)
if epsilon <= 0:
raise ValueError("epsilon must be greater than 0")
if discount == 1:
print('PyMDPtoolbox WARNING: check conditions of convergence.'
'With no discount, convergence is not always assumed.')
if (discount < 1):
# compute a bound for the number of iterations
self.boundIter(epsilon)
print('MDP Toolbox WARNING: max_iter is bounded by %s') % self.max_iter
# computation of threshold of variation for V for an epsilon-optimal policy
self.thresh = epsilon * (1 - self.discount) / self.discount
else: # discount == 1
# threshold of variation for V for an epsilon-optimal policy
self.thresh = epsilon
self.iter = 0
ValueIteration.__init__(self, transitions, reward, discount, epsilon, max_iter, initial_value)
def iterate(self):
""""""
V = self.V
done = False
......@@ -1578,17 +1551,21 @@ class ValueIterationGS(ValueIteration):
self.time = time()
#Q = array(())
while not done:
self.iter = self.iter + 1
Vprev = self.V
Q = array(())
for s in range(self.S):
for a in range(self.A):
Q[a] = self.R[s,a] + self.discount * self.P[a][s,:] * self.V
self.V[s] = max(Q)
variation = getSpan(V - Vprev)
variation = getSpan(self.V - Vprev)
if self.verbose:
print(" %s %s" % (self.iter, variation))
......
......@@ -6,7 +6,7 @@ Created on Sun May 27 23:16:57 2012
"""
from mdp import check, checkSquareStochastic, exampleForest, exampleRand, MDP
from mdp import PolicyIteration, ValueIteration
from mdp import PolicyIteration, ValueIteration, ValueIterationGS
from numpy import absolute, array, eye, matrix, zeros
from numpy.random import rand
from scipy.sparse import eye as speye
......@@ -129,16 +129,18 @@ def test_checkSquareStochastic_eye_sparse():
assert checkSquareStochastic(P) == None
# exampleForest
Pf, Rf = exampleForest()
def test_exampleForest_shape():
P, R = exampleForest()
assert (P == array([[[0.1, 0.9, 0.0],
def test_exampleForest_P_shape():
assert (Pf == array([[[0.1, 0.9, 0.0],
[0.1, 0.0, 0.9],
[0.1, 0.0, 0.9]],
[[1, 0, 0],
[1, 0, 0],
[1, 0, 0]]])).all()
assert (R == array([[0, 0],
def test_exampleForest_R_shape():
assert (Rf == array([[0, 0],
[0, 1],
[4, 2]])).all()
......@@ -148,22 +150,26 @@ def test_exampleForest_check():
# exampleRand
def test_exampleRand_dense_shape():
P, R = exampleRand(STATES, ACTIONS)
P, R = exampleRand(STATES, ACTIONS)
def test_exampleRand_dense_P_shape():
assert (P.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_R_shape():
assert (R.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_check():
P, R = exampleRand(STATES, ACTIONS)
assert check(P, R) == None
def test_exampleRand_sparse_shape():
P, R = exampleRand(STATES, ACTIONS, is_sparse=True)
P, R = exampleRand(STATES, ACTIONS, is_sparse=True)
def test_exampleRand_sparse_P_shape():
assert (P.shape == (ACTIONS, ))
def test_exampleRand_sparse_R_shape():
assert (R.shape == (ACTIONS, ))
def test_exampleRand_sparse_check():
P, R = exampleRand(STATES, ACTIONS, is_sparse=True)
assert check(P, R) == None
P = array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]])
......@@ -233,14 +239,12 @@ def test_PolicyIteration_init_policy0():
assert (a.policy == p).all()
def test_PolicyIteration_init_policy0_exampleForest():
P, R = exampleForest()
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
p = matrix('0; 1; 0')
assert (a.policy == p).all()
def test_PolicyIteration_computePpolicyPRpolicy_exampleForest():
P, R = exampleForest()
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
P1 = matrix('0.1 0.9 0; 1 0 0; 0.1 0 0.9')
R1 = matrix('0; 1; 4')
Ppolicy, Rpolicy = a.computePpolicyPRpolicy()
......@@ -248,29 +252,26 @@ def test_PolicyIteration_computePpolicyPRpolicy_exampleForest():
assert (absolute(Rpolicy - R1) < SMALLNUM).all()
def test_PolicyIteration_evalPolicyIterative_exampleForest():
P, R = exampleForest()
v0 = matrix('0; 0; 0')
v1 = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304')
p = matrix('0; 1; 0')
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
assert (absolute(a.value - v0) < SMALLNUM).all()
a.evalPolicyIterative()
assert (absolute(a.value - v1) < SMALLNUM).all()
assert (a.policy == p).all()
def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest():
P, R = exampleForest()
v = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304')
p = matrix('0; 0; 0')
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
a.evalPolicyIterative()
policy, value = a.bellmanOperator()
assert (policy == p).all()
assert (absolute(a.value - v) < SMALLNUM).all()
def test_PolicyIteration_iterative_exampleForest():
P, R = exampleForest()
a = PolicyIteration(P, R, 0.9, eval_type=1)
a = PolicyIteration(Pf, Rf, 0.9, eval_type=1)
V = matrix('26.2439058351861 29.4839058351861 33.4839058351861')
p = matrix('0 0 0')
itr = 2
......@@ -280,15 +281,13 @@ def test_PolicyIteration_iterative_exampleForest():
assert a.iter == itr
def test_PolicyIteration_evalPolicyMatrix_exampleForest():
P, R = exampleForest()
v_pol = matrix('4.47513812154696; 5.02762430939227; 23.17243384704857')
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
a.evalPolicyMatrix()
assert (absolute(a.value - v_pol) < SMALLNUM).all()
def test_PolicyIteration_matrix_exampleForest():
P, R = exampleForest()
a = PolicyIteration(P, R, 0.9)
a = PolicyIteration(Pf, Rf, 0.9)
V = matrix('26.2440000000000 29.4840000000000 33.4840000000000')
p = matrix('0 0 0')
itr = 2
......@@ -297,6 +296,8 @@ def test_PolicyIteration_matrix_exampleForest():
assert (array(a.policy) == p).all()
assert a.iter == itr
def test_ValueIterationGS():
#def test_JacksCarRental():
# S = 21 ** 2
# A = 11
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment