Commit f9228d24 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

ValueIterationGS is now usable

parent c6fd838d
......@@ -33,7 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from numpy import absolute, array, diag, matrix, mean, mod, multiply, ndarray
from numpy import nonzero, ones, zeros
from numpy import ones, zeros
from numpy.random import rand
from math import ceil, log, sqrt
from random import randint, random
......@@ -807,7 +807,7 @@ class PolicyIteration(MDP):
# the rows that use action a. .getA1() is used to make sure that
# ind is a 1 dimensional vector
ind = nonzero(self.policy == aa)[0].getA1()
ind = (self.policy == aa).nonzero()[0].getA1()
if ind.size > 0: # if no rows use action a, then no point continuing
Ppolicy[ind, :] = self.P[aa][ind, :]
......@@ -860,7 +860,7 @@ class PolicyIteration(MDP):
if V0 == 0:
policy_V = zeros((self.S, 1))
else:
raise NotImplementedError("evalPolicyIterative: case V0 != 0 not implemented. Use V0=0 instead.")
raise NotImplementedError("evalPolicyIterative: case V0 != 0 not implemented. Use default (V0=0) instead.")
policy_P, policy_R = self.computePpolicyPRpolicy()
......@@ -1478,7 +1478,7 @@ class ValueIteration(MDP):
while not done:
self.iter = self.iter + 1
Vprev = self.V
Vprev = self.V.copy()
# Bellman Operator: compute policy and value functions
self.policy, self.V = self.bellmanOperator()
......@@ -1557,18 +1557,16 @@ class ValueIterationGS(ValueIteration):
self.time = time()
#Q = array(())
while not done:
self.iter = self.iter + 1
Vprev = self.V
Q = array(())
Vprev = self.V.copy()
for s in range(self.S):
Q = []
for a in range(self.A):
Q[a] = self.R[s,a] + self.discount * self.P[a][s,:] * self.V
Q.append(float(self.R[s, a] + self.discount * self.P[a][s, :] * self.V))
self.V[s] = max(Q)
variation = getSpan(self.V - Vprev)
......@@ -1583,13 +1581,19 @@ class ValueIterationGS(ValueIteration):
elif self.iter == self.max_iter:
done = True
if self.verbose:
if self.verbose:
print('MDP Toolbox : iterations stopped by maximum number of iteration condition')
self.policy = []
for s in range(self.S):
Q = zeros(self.A)
for a in range(self.A):
Q[a] = self.R[s,a] + self.P[a][s,:] * self.discount * self.V
self.V[s], self.policy[s,1] = max(Q)
self.V[s] = Q.max()
self.policy.append(int(Q.argmax()))
self.time = time() - self.time
self.V = tuple(array(self.V).reshape(self.S).tolist())
self.policy = tuple(self.policy)
......@@ -3,7 +3,7 @@
from distutils.core import setup
setup(name="PyMDPtoolbox",
version="0.7",
version="0.8",
description="Python Markov Decision Problem Toolbox",
author="Steven Cordwell",
author_email="steven.cordwell@uqconnect.edu.au",
......
......@@ -7,6 +7,7 @@ Created on Sun May 27 23:16:57 2012
from mdp import check, checkSquareStochastic, exampleForest, exampleRand, MDP
from mdp import PolicyIteration, ValueIteration, ValueIterationGS
from numpy import absolute, array, eye, matrix, zeros
from numpy.random import rand
from scipy.sparse import eye as speye
......@@ -150,27 +151,27 @@ def test_exampleForest_check():
# exampleRand
P, R = exampleRand(STATES, ACTIONS)
Pr, Rr = exampleRand(STATES, ACTIONS)
def test_exampleRand_dense_P_shape():
assert (P.shape == (ACTIONS, STATES, STATES))
assert (Pr.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_R_shape():
assert (R.shape == (ACTIONS, STATES, STATES))
assert (Rr.shape == (ACTIONS, STATES, STATES))
def test_exampleRand_dense_check():
assert check(P, R) == None
assert check(Pr, Rr) == None
P, R = exampleRand(STATES, ACTIONS, is_sparse=True)
Prs, Rrs = exampleRand(STATES, ACTIONS, is_sparse=True)
def test_exampleRand_sparse_P_shape():
assert (P.shape == (ACTIONS, ))
assert (Prs.shape == (ACTIONS, ))
def test_exampleRand_sparse_R_shape():
assert (R.shape == (ACTIONS, ))
assert (Rrs.shape == (ACTIONS, ))
def test_exampleRand_sparse_check():
assert check(P, R) == None
assert check(Prs, Rrs) == None
P = array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]])
R = array([[5, 10], [-1, 2]])
......@@ -220,7 +221,7 @@ def test_ValueIteration_boundIter():
def test_ValueIteration_iterate():
inst = ValueIteration(P, R, 0.9, 0.01)
inst.iterate()
assert (inst.value == (40.048625392716822, 33.65371175967546))
assert (inst.V == (40.048625392716822, 33.65371175967546))
assert (inst.policy == (1, 0))
assert (inst.iter == 26)
......@@ -255,9 +256,9 @@ def test_PolicyIteration_evalPolicyIterative_exampleForest():
v1 = matrix('4.47504640074458; 5.02753258879703; 23.17234211944304')
p = matrix('0; 1; 0')
a = PolicyIteration(Pf, Rf, 0.9)
assert (absolute(a.value - v0) < SMALLNUM).all()
assert (absolute(a.V - v0) < SMALLNUM).all()
a.evalPolicyIterative()
assert (absolute(a.value - v1) < SMALLNUM).all()
assert (absolute(a.V - v1) < SMALLNUM).all()
assert (a.policy == p).all()
def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest():
......@@ -267,15 +268,15 @@ def test_PolicyIteration_evalPolicyIterative_bellmanOperator_exampleForest():
a.evalPolicyIterative()
policy, value = a.bellmanOperator()
assert (policy == p).all()
assert (absolute(a.value - v) < SMALLNUM).all()
assert (absolute(a.V - v) < SMALLNUM).all()
def test_PolicyIteration_iterative_exampleForest():
a = PolicyIteration(Pf, Rf, 0.9, eval_type=1)
V = matrix('26.2439058351861 29.4839058351861 33.4839058351861')
v = matrix('26.2439058351861 29.4839058351861 33.4839058351861')
p = matrix('0 0 0')
itr = 2
a.iterate()
assert (absolute(array(a.value) - V) < SMALLNUM).all()
assert (absolute(array(a.V) - v) < SMALLNUM).all()
assert (array(a.policy) == p).all()
assert a.iter == itr
......@@ -283,15 +284,15 @@ def test_PolicyIteration_evalPolicyMatrix_exampleForest():
v_pol = matrix('4.47513812154696; 5.02762430939227; 23.17243384704857')
a = PolicyIteration(Pf, Rf, 0.9)
a.evalPolicyMatrix()
assert (absolute(a.value - v_pol) < SMALLNUM).all()
assert (absolute(a.V - v_pol) < SMALLNUM).all()
def test_PolicyIteration_matrix_exampleForest():
a = PolicyIteration(Pf, Rf, 0.9)
V = matrix('26.2440000000000 29.4840000000000 33.4840000000000')
v = matrix('26.2440000000000 29.4840000000000 33.4840000000000')
p = matrix('0 0 0')
itr = 2
a.iterate()
assert (absolute(array(a.value) - V) < SMALLNUM).all()
assert (absolute(array(a.V) - v) < SMALLNUM).all()
assert (array(a.policy) == p).all()
assert a.iter == itr
......@@ -300,10 +301,12 @@ def test_PolicyIteration_matrix_exampleForest():
def test_ValueIterationGS_exampleForest():
a = ValueIterationGS(Pf, Rf, 0.9)
p = matrix('0 0 0')
v = matrix('25.5833879767579 28.8306546355469 32.8306546355469')
itr = 33
a.iterate()
assert (array(a.policy) == p).all()
assert a.iter == itr
assert (absolute(array(a.V) - v) < SMALLNUM).all()
#def test_JacksCarRental():
# S = 21 ** 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment