Commit f9228d24 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

ValueIterationGS is now usable

parent c6fd838d
...@@ -33,7 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -33,7 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
""" """
from numpy import absolute, array, diag, matrix, mean, mod, multiply, ndarray from numpy import absolute, array, diag, matrix, mean, mod, multiply, ndarray
from numpy import nonzero, ones, zeros from numpy import ones, zeros
from numpy.random import rand from numpy.random import rand
from math import ceil, log, sqrt from math import ceil, log, sqrt
from random import randint, random from random import randint, random
...@@ -807,7 +807,7 @@ class PolicyIteration(MDP): ...@@ -807,7 +807,7 @@ class PolicyIteration(MDP):
# the rows that use action a. .getA1() is used to make sure that # the rows that use action a. .getA1() is used to make sure that
# ind is a 1 dimensional vector # ind is a 1 dimensional vector
ind = nonzero(self.policy == aa)[0].getA1() ind = (self.policy == aa).nonzero()[0].getA1()
if ind.size > 0: # if no rows use action a, then no point continuing if ind.size > 0: # if no rows use action a, then no point continuing
Ppolicy[ind, :] = self.P[aa][ind, :] Ppolicy[ind, :] = self.P[aa][ind, :]
...@@ -860,7 +860,7 @@ class PolicyIteration(MDP): ...@@ -860,7 +860,7 @@ class PolicyIteration(MDP):
if V0 == 0: if V0 == 0:
policy_V = zeros((self.S, 1)) policy_V = zeros((self.S, 1))
else: else:
raise NotImplementedError("evalPolicyIterative: case V0 != 0 not implemented. Use V0=0 instead.") raise NotImplementedError("evalPolicyIterative: case V0 != 0 not implemented. Use default (V0=0) instead.")
policy_P, policy_R = self.computePpolicyPRpolicy() policy_P, policy_R = self.computePpolicyPRpolicy()
...@@ -1478,7 +1478,7 @@ class ValueIteration(MDP): ...@@ -1478,7 +1478,7 @@ class ValueIteration(MDP):
while not done: while not done:
self.iter = self.iter + 1 self.iter = self.iter + 1
Vprev = self.V Vprev = self.V.copy()
# Bellman Operator: compute policy and value functions # Bellman Operator: compute policy and value functions
self.policy, self.V = self.bellmanOperator() self.policy, self.V = self.bellmanOperator()
...@@ -1557,18 +1557,16 @@ class ValueIterationGS(ValueIteration): ...@@ -1557,18 +1557,16 @@ class ValueIterationGS(ValueIteration):
self.time = time() self.time = time()
#Q = array(())
while not done: while not done:
self.iter = self.iter + 1 self.iter = self.iter + 1
Vprev = self.V Vprev = self.V.copy()
Q = array(())
for s in range(self.S): for s in range(self.S):
Q = []
for a in range(self.A): for a in range(self.A):
Q[a] = self.R[s,a] + self.discount * self.P[a][s,:] * self.V Q.append(float(self.R[s, a] + self.discount * self.P[a][s, :] * self.V))
self.V[s] = max(Q) self.V[s] = max(Q)
variation = getSpan(self.V - Vprev) variation = getSpan(self.V - Vprev)
...@@ -1583,13 +1581,19 @@ class ValueIterationGS(ValueIteration): ...@@ -1583,13 +1581,19 @@ class ValueIterationGS(ValueIteration):
elif self.iter == self.max_iter: elif self.iter == self.max_iter:
done = True done = True
if self.verbose: if self.verbose:
print('MDP Toolbox : iterations stopped by maximum number of iteration condition') print('MDP Toolbox : iterations stopped by maximum number of iteration condition')
self.policy = []
for s in range(self.S): for s in range(self.S):
Q = zeros(self.A)
for a in range(self.A): for a in range(self.A):
Q[a] = self.R[s,a] + self.P[a][s,:] * self.discount * self.V Q[a] = self.R[s,a] + self.P[a][s,:] * self.discount * self.V
self.V[s], self.policy[s,1] = max(Q) self.V[s] = Q.max()
self.policy.append(int(Q.argmax()))
self.time = time() - self.time self.time = time() - self.time
self.V = tuple(array(self.V).reshape(self.S).tolist())
self.policy = tuple(self.policy)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from distutils.core import setup from distutils.core import setup
setup(name="PyMDPtoolbox", setup(name="PyMDPtoolbox",
version="0.7", version="0.8",
description="Python Markov Decision Problem Toolbox", description="Python Markov Decision Problem Toolbox",
author="Steven Cordwell", author="Steven Cordwell",
author_email="steven.cordwell@uqconnect.edu.au", author_email="steven.cordwell@uqconnect.edu.au",
......
...@@ -7,6 +7,7 @@ Created on Sun May 27 23:16:57 2012 ...@@ -7,6 +7,7 @@ Created on Sun May 27 23:16:57 2012
from mdp import check, checkSquareStochastic, exampleForest, exampleRand, MDP from mdp import check, checkSquareStochastic, exampleForest, exampleRand, MDP
from mdp import PolicyIteration, ValueIteration, ValueIterationGS from mdp import PolicyIteration, ValueIteration, ValueIterationGS
from numpy import absolute, array, eye, matrix, zeros from numpy import absolute, array, eye, matrix, zeros
from numpy.random import rand from numpy.random import rand
from scipy.sparse import eye as speye from scipy.sparse import eye as speye
...@@ -150,27 +151,27 @@ def test_exampleForest_check(): ...@@ -150,27 +151,27 @@ def test_exampleForest_check():
# exampleRand # exampleRand
P, R = exampleRand(STATES, ACTIONS) Pr, Rr = exampleRand(STATES, ACTIONS)
def test_exampleRand_dense_P_shape(): def test_exampleRand_dense_P_shape():
assert (P.shape == (ACTIONS, STATES, STATES)) assert (Pr.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_R_shape(): def test_exampleRand_dense_R_shape():
assert (R.shape == (ACTIONS, STATES, STATES)) assert (Rr.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_check(): def test_exampleRand_dense_check():
assert check(P, R) == None assert check(Pr, Rr) == None
P, R = exampleRand(STATES, ACTIONS, is_sparse=True) Prs, Rrs = exampleRand(STATES, ACTIONS, is_sparse=True)
def test_exampleRand_sparse_P_shape(): def test_exampleRand_sparse_P_shape():
assert (P.shape == (ACTIONS, )) assert (Prs.shape == (ACTIONS, ))
def test_exampleRand_sparse_R_shape(): def test_exampleRand_sparse_R_shape():
assert (R.shape == (ACTIONS, )) assert (Rrs.shape == (ACTIONS, ))
def test_exampleRand_sparse_check(): def test_exampleRand_sparse_check():
assert check(P, R) == None assert check(Prs, Rrs) == None
P = array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]]) P = array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]])
R = array([[5, 10], [-1, 2]]) R = array([[5, 10], [-1, 2]])
...@@ -220,7 +221,7 @@ def test_ValueIteration_boundIter(): ...@@ -220,7 +221,7 @@ def test_ValueIteration_boundIter():
def test_ValueIteration_iterate(): def test_ValueIteration_iterate():
inst = ValueIteration(P, R, 0.9, 0.01) inst = ValueIteration(P, R, 0.9, 0.01)
inst.iterate() inst.iterate()
assert (inst.value == (40.048625392716822, 33.65371175967546)) assert (inst.V == (40.048625392716822, 33.65371175967546))
assert (inst.policy == (1, 0)) assert (inst.policy == (1, 0))
assert (inst.iter == 26) assert (inst.iter == 26)
...@@ -255,9 +256,9 @@ def test_PolicyIteration_evalPolicyIterative_exampleForest(): ...@@ -255,9 +256,9 @@ def test_PolicyIteration_evalPolicyIterative_exampleForest():
v1 = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304') v1 = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304')
p = matrix('0; 1; 0') p = matrix('0; 1; 0')
a = PolicyIteration(Pf, Rf, 0.9) a = PolicyIteration(Pf, Rf, 0.9)
assert (absolute(a.value - v0) < SMALLNUM).all() assert (absolute(a.V - v0) < SMALLNUM).all()
a.evalPolicyIterative() a.evalPolicyIterative()
assert (absolute(a.value - v1) < SMALLNUM).all() assert (absolute(a.V - v1) < SMALLNUM).all()
assert (a.policy == p).all() assert (a.policy == p).all()
def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest(): def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest():
...@@ -267,15 +268,15 @@ def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest(): ...@@ -267,15 +268,15 @@ def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest():
a.evalPolicyIterative() a.evalPolicyIterative()
policy, value = a.bellmanOperator() policy, value = a.bellmanOperator()
assert (policy == p).all() assert (policy == p).all()
assert (absolute(a.value - v) < SMALLNUM).all() assert (absolute(a.V - v) < SMALLNUM).all()
def test_PolicyIteration_iterative_exampleForest(): def test_PolicyIteration_iterative_exampleForest():
a = PolicyIteration(Pf, Rf, 0.9, eval_type=1) a = PolicyIteration(Pf, Rf, 0.9, eval_type=1)
V = matrix('26.2439058351861 29.4839058351861 33.4839058351861') v = matrix('26.2439058351861 29.4839058351861 33.4839058351861')
p = matrix('0 0 0') p = matrix('0 0 0')
itr = 2 itr = 2
a.iterate() a.iterate()
assert (absolute(array(a.value) - V) < SMALLNUM).all() assert (absolute(array(a.V) - v) < SMALLNUM).all()
assert (array(a.policy) == p).all() assert (array(a.policy) == p).all()
assert a.iter == itr assert a.iter == itr
...@@ -283,15 +284,15 @@ def test_PolicyIteration_evalPolicyMatrix_exampleForest(): ...@@ -283,15 +284,15 @@ def test_PolicyIteration_evalPolicyMatrix_exampleForest():
v_pol = matrix('4.47513812154696; 5.02762430939227; 23.17243384704857') v_pol = matrix('4.47513812154696; 5.02762430939227; 23.17243384704857')
a = PolicyIteration(Pf, Rf, 0.9) a = PolicyIteration(Pf, Rf, 0.9)
a.evalPolicyMatrix() a.evalPolicyMatrix()
assert (absolute(a.value - v_pol) < SMALLNUM).all() assert (absolute(a.V - v_pol) < SMALLNUM).all()
def test_PolicyIteration_matrix_exampleForest(): def test_PolicyIteration_matrix_exampleForest():
a = PolicyIteration(Pf, Rf, 0.9) a = PolicyIteration(Pf, Rf, 0.9)
V = matrix('26.2440000000000 29.4840000000000 33.4840000000000') v = matrix('26.2440000000000 29.4840000000000 33.4840000000000')
p = matrix('0 0 0') p = matrix('0 0 0')
itr = 2 itr = 2
a.iterate() a.iterate()
assert (absolute(array(a.value) - V) < SMALLNUM).all() assert (absolute(array(a.V) - v) < SMALLNUM).all()
assert (array(a.policy) == p).all() assert (array(a.policy) == p).all()
assert a.iter == itr assert a.iter == itr
...@@ -300,10 +301,12 @@ def test_PolicyIteration_matrix_exampleForest(): ...@@ -300,10 +301,12 @@ def test_PolicyIteration_matrix_exampleForest():
def test_ValueIterationGS_exampleForest(): def test_ValueIterationGS_exampleForest():
a = ValueIterationGS(Pf, Rf, 0.9) a = ValueIterationGS(Pf, Rf, 0.9)
p = matrix('0 0 0') p = matrix('0 0 0')
v = matrix('25.5833879767579 28.8306546355469 32.8306546355469')
itr = 33 itr = 33
a.iterate() a.iterate()
assert (array(a.policy) == p).all() assert (array(a.policy) == p).all()
assert a.iter == itr assert a.iter == itr
assert (absolute(array(a.V) - v) < SMALLNUM).all()
#def test_JacksCarRental(): #def test_JacksCarRental():
# S = 21 ** 2 # S = 21 ** 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment