Commit c82e6ae4 by Steven Cordwell

### reduce some code duplication in unit tests

parent 7874abac
 ... ... @@ -1539,37 +1539,10 @@ class ValueIterationGS(ValueIteration): def __init__(self, transitions, reward, discount, epsilon=0.01, max_iter=10, initial_value=0): """""" MDP.__init__(self, transitions, reward, discount, max_iter) # initialization of optional arguments if (initial_value == 0): self.V = matrix(zeros((self.S, 1))) else: if (initial_value.size != self.S): raise ValueError("The initial value must be length S") self.V = matrix(initial_value) if epsilon <= 0: raise ValueError("epsilon must be greater than 0") if discount == 1: print('PyMDPtoolbox WARNING: check conditions of convergence.' 'With no discount, convergence is not always assumed.') if (discount < 1): # compute a bound for the number of iterations self.boundIter(epsilon) print('MDP Toolbox WARNING: max_iter is bounded by %s') % self.max_iter # computation of threshold of variation for V for an epsilon-optimal policy self.thresh = epsilon * (1 - self.discount) / self.discount else: # discount == 1 # threshold of variation for V for an epsilon-optimal policy self.thresh = epsilon self.iter = 0 ValueIteration.__init__(self, transitions, reward, discount, epsilon, max_iter, initial_value) def iterate(self): """""" V = self.V done = False ... ... @@ -1578,17 +1551,21 @@ class ValueIterationGS(ValueIteration): self.time = time() #Q = array(()) while not done: self.iter = self.iter + 1 Vprev = self.V Q = array(()) for s in range(self.S): for a in range(self.A): Q[a] = self.R[s,a] + self.discount * self.P[a][s,:] * self.V self.V[s] = max(Q) variation = getSpan(V - Vprev) variation = getSpan(self.V - Vprev) if self.verbose: print(" %s %s" % (self.iter, variation)) ... ...
 ... ... @@ -6,7 +6,7 @@ Created on Sun May 27 23:16:57 2012 """ from mdp import check, checkSquareStochastic, exampleForest, exampleRand, MDP from mdp import PolicyIteration, ValueIteration from mdp import PolicyIteration, ValueIteration, ValueIterationGS from numpy import absolute, array, eye, matrix, zeros from numpy.random import rand from scipy.sparse import eye as speye ... ... @@ -129,16 +129,18 @@ def test_checkSquareStochastic_eye_sparse(): assert checkSquareStochastic(P) == None # exampleForest Pf, Rf = exampleForest() def test_exampleForest_shape(): P, R = exampleForest() assert (P == array([[[0.1, 0.9, 0.0], def test_exampleForest_P_shape(): assert (Pf == array([[[0.1, 0.9, 0.0], [0.1, 0.0, 0.9], [0.1, 0.0, 0.9]], [[1, 0, 0], [1, 0, 0], [1, 0, 0]]])).all() assert (R == array([[0, 0], def test_exampleForest_R_shape(): assert (Rf == array([[0, 0], [0, 1], [4, 2]])).all() ... ... @@ -148,22 +150,26 @@ def test_exampleForest_check(): # exampleRand def test_exampleRand_dense_shape(): P, R = exampleRand(STATES, ACTIONS) P, R = exampleRand(STATES, ACTIONS) def test_exampleRand_dense_P_shape(): assert (P.shape == (ACTIONS, STATES, STATES)) def test_exampleRand_dense_R_shape(): assert (R.shape == (ACTIONS, STATES, STATES)) def test_exampleRand_dense_check(): P, R = exampleRand(STATES, ACTIONS) assert check(P, R) == None def test_exampleRand_sparse_shape(): P, R = exampleRand(STATES, ACTIONS, is_sparse=True) P, R = exampleRand(STATES, ACTIONS, is_sparse=True) def test_exampleRand_sparse_P_shape(): assert (P.shape == (ACTIONS, )) def test_exampleRand_sparse_R_shape(): assert (R.shape == (ACTIONS, )) def test_exampleRand_sparse_check(): P, R = exampleRand(STATES, ACTIONS, is_sparse=True) assert check(P, R) == None P = array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]]) ... ... @@ -233,14 +239,12 @@ def test_PolicyIteration_init_policy0(): assert (a.policy == p).all() def test_PolicyIteration_init_policy0_exampleForest(): P, R = exampleForest() a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) p = matrix('0; 1; 0') assert (a.policy == p).all() def test_PolicyIteration_computePpolicyPRpolicy_exampleForest(): P, R = exampleForest() a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) P1 = matrix('0.1 0.9 0; 1 0 0; 0.1 0 0.9') R1 = matrix('0; 1; 4') Ppolicy, Rpolicy = a.computePpolicyPRpolicy() ... ... @@ -248,29 +252,26 @@ def test_PolicyIteration_computePpolicyPRpolicy_exampleForest(): assert (absolute(Rpolicy - R1) < SMALLNUM).all() def test_PolicyIteration_evalPolicyIterative_exampleForest(): P, R = exampleForest() v0 = matrix('0; 0; 0') v1 = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304') p = matrix('0; 1; 0') a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) assert (absolute(a.value - v0) < SMALLNUM).all() a.evalPolicyIterative() assert (absolute(a.value - v1) < SMALLNUM).all() assert (a.policy == p).all() def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest(): P, R = exampleForest() v = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304') p = matrix('0; 0; 0') a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) a.evalPolicyIterative() policy, value = a.bellmanOperator() assert (policy == p).all() assert (absolute(a.value - v) < SMALLNUM).all() def test_PolicyIteration_iterative_exampleForest(): P, R = exampleForest() a = PolicyIteration(P, R, 0.9, eval_type=1) a = PolicyIteration(Pf, Rf, 0.9, eval_type=1) V = matrix('26.2439058351861 29.4839058351861 33.4839058351861') p = matrix('0 0 0') itr = 2 ... ... @@ -280,15 +281,13 @@ def test_PolicyIteration_iterative_exampleForest(): assert a.iter == itr def test_PolicyIteration_evalPolicyMatrix_exampleForest(): P, R = exampleForest() v_pol = matrix('4.47513812154696; 5.02762430939227; 23.17243384704857') a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) a.evalPolicyMatrix() assert (absolute(a.value - v_pol) < SMALLNUM).all() def test_PolicyIteration_matrix_exampleForest(): P, R = exampleForest() a = PolicyIteration(P, R, 0.9) a = PolicyIteration(Pf, Rf, 0.9) V = matrix('26.2440000000000 29.4840000000000 33.4840000000000') p = matrix('0 0 0') itr = 2 ... ... @@ -297,6 +296,8 @@ def test_PolicyIteration_matrix_exampleForest(): assert (array(a.policy) == p).all() assert a.iter == itr def test_ValueIterationGS(): #def test_JacksCarRental(): # S = 21 ** 2 # A = 11 ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!