Commit ec07c4d8 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

made some if statements clearer

parent 66b34533
...@@ -122,10 +122,10 @@ def check(P, R): ...@@ -122,10 +122,10 @@ def check(P, R):
# and each element containing an SxS array. An AxSxS array will be # and each element containing an SxS array. An AxSxS array will be
# be converted to an object array. A numpy object array is similar to a # be converted to an object array. A numpy object array is similar to a
# MATLAB cell array. # MATLAB cell array.
if (not type(P) is ndarray): if type(P) != ndarray:
raise TypeError(mdperr["P_type"]) raise TypeError(mdperr["P_type"])
# also check R # also check R
if (not type(R) is ndarray): if type(R) != ndarray:
raise TypeError(mdperr["R_type"]) raise TypeError(mdperr["R_type"])
# NumPy has an array type of 'object', which is roughly equivalent to # NumPy has an array type of 'object', which is roughly equivalent to
# the MATLAB cell array. These are most useful for storing sparse # the MATLAB cell array. These are most useful for storing sparse
...@@ -137,25 +137,25 @@ def check(P, R): ...@@ -137,25 +137,25 @@ def check(P, R):
# otherwise fail with a message expalining why. # otherwise fail with a message expalining why.
# If it is a normal array then the number of dimensions must be exactly # If it is a normal array then the number of dimensions must be exactly
# three, otherwise fail with a message explaining why. # three, otherwise fail with a message explaining why.
if (P.dtype == object): if P.dtype == object:
if (P.ndim > 1): if P.ndim > 1:
raise ValueError(mdperr["obj_shape"]) raise ValueError(mdperr["obj_shape"])
else: else:
P_is_object = True P_is_object = True
else: else:
if (P.ndim != 3): if P.ndim != 3:
raise ValueError(mdperr["P_shape"]) raise ValueError(mdperr["P_shape"])
else: else:
P_is_object = False P_is_object = False
# As above but for the reward array. A difference is that the reward # As above but for the reward array. A difference is that the reward
# array can have either two or 3 dimensions. # array can have either two or 3 dimensions.
if (R.dtype == object): if R.dtype == object:
if (R.ndim > 1): if R.ndim > 1:
raise ValueError(mdperr["obj_shape"]) raise ValueError(mdperr["obj_shape"])
else: else:
R_is_object = True R_is_object = True
else: else:
if (not R.ndim in (2, 3)): if R.ndim not in (2, 3):
raise ValueError(mdperr["R_shape"]) raise ValueError(mdperr["R_shape"])
else: else:
R_is_object = False R_is_object = False
...@@ -188,7 +188,7 @@ def check(P, R): ...@@ -188,7 +188,7 @@ def check(P, R):
# what was found in the first element, then we need to fail # what was found in the first element, then we need to fail
# telling the user what needs to be fixed. # telling the user what needs to be fixed.
sP0aa, sP1aa = P[aa].shape sP0aa, sP1aa = P[aa].shape
if ((sP0aa != sP0) or (sP1aa != sP1)): if (sP0aa != sP0) or (sP1aa != sP1):
raise ValueError(mdperr["obj_square"]) raise ValueError(mdperr["obj_square"])
else: else:
# if we are using a normal array for this, then the first # if we are using a normal array for this, then the first
...@@ -201,7 +201,7 @@ def check(P, R): ...@@ -201,7 +201,7 @@ def check(P, R):
# probability. Also, if the number of actions is less than one, or the # probability. Also, if the number of actions is less than one, or the
# number of states is less than one, then it also is not a valid # number of states is less than one, then it also is not a valid
# transition probability. # transition probability.
if ((sP0 < 1) or (aP < 1) or (sP0 != sP1)): if (sP0 < 1) or (aP < 1) or (sP0 != sP1):
raise ValueError(mdperr["P_shape"]) raise ValueError(mdperr["P_shape"])
# now we check that each transition matrix is square-stochastic. For # now we check that each transition matrix is square-stochastic. For
# object arrays this is the matrix held in each element, but for # object arrays this is the matrix held in each element, but for
...@@ -222,7 +222,7 @@ def check(P, R): ...@@ -222,7 +222,7 @@ def check(P, R):
sR0aa, sR1aa = R[aa].shape sR0aa, sR1aa = R[aa].shape
if ((sR0aa != sR0) or (sR1aa != sR1)): if ((sR0aa != sR0) or (sR1aa != sR1)):
raise ValueError(mdperr["obj_square"]) raise ValueError(mdperr["obj_square"])
elif (R.ndim == 3): elif R.ndim == 3:
# This indicates that the reward matrices are constructed per # This indicates that the reward matrices are constructed per
# transition, so that the first dimension is the actions and # transition, so that the first dimension is the actions and
# the second two dimensions are the states. # the second two dimensions are the states.
...@@ -236,7 +236,7 @@ def check(P, R): ...@@ -236,7 +236,7 @@ def check(P, R):
sR1 = sR0 sR1 = sR0
# the number of actions must be more than zero, the number of states # the number of actions must be more than zero, the number of states
# must also be more than 0, and the states must agree # must also be more than 0, and the states must agree
if ((sR0 < 1) or (aR < 1) or (sR0 != sR1)): if (sR0 < 1) or (aR < 1) or (sR0 != sR1):
raise ValueError(mdperr["R_shape"]) raise ValueError(mdperr["R_shape"])
# now we check to see that what the transition array is reporting and # now we check to see that what the transition array is reporting and
# what the reward arrar is reporting agree as to the number of actions # what the reward arrar is reporting agree as to the number of actions
...@@ -261,13 +261,13 @@ def checkSquareStochastic(Z): ...@@ -261,13 +261,13 @@ def checkSquareStochastic(Z):
""" """
s1, s2 = Z.shape s1, s2 = Z.shape
if (s1 != s2): if s1 != s2:
raise ValueError(mdperr["mat_square"]) raise ValueError(mdperr["mat_square"])
elif (absolute(Z.sum(axis=1) - ones(s2))).max() > 10**(-12): elif (absolute(Z.sum(axis=1) - ones(s2))).max() > 10e-12:
raise ValueError(mdperr["mat_stoch"]) raise ValueError(mdperr["mat_stoch"])
elif ((type(Z) is ndarray) or (type(Z) is matrix)) and (Z < 0).any(): elif ((type(Z) == ndarray) or (type(Z) == matrix)) and (Z < 0).any():
raise ValueError(mdperr["mat_nonneg"]) raise ValueError(mdperr["mat_nonneg"])
elif (type(Z) is sparse) and (Z.data < 0).any(): elif (type(Z) == sparse) and (Z.data < 0).any():
raise ValueError(mdperr["mat_nonneg"]) raise ValueError(mdperr["mat_nonneg"])
else: else:
return(None) return(None)
...@@ -308,11 +308,11 @@ def exampleForest(S=3, r1=4, r2=2, p=0.1): ...@@ -308,11 +308,11 @@ def exampleForest(S=3, r1=4, r2=2, p=0.1):
[ 4., 2.]]) [ 4., 2.]])
""" """
if (S <= 1): if S <= 1:
raise ValueError(mdperr["S_gt_1"]) raise ValueError(mdperr["S_gt_1"])
if (r1 <= 0) or (r2 <= 0): if (r1 <= 0) or (r2 <= 0):
raise ValueError(mdperr["R_gt_0"]) raise ValueError(mdperr["R_gt_0"])
if (p < 0 or p > 1): if (p < 0) or (p > 1):
raise ValueError(mdperr["prob_in01"]) raise ValueError(mdperr["prob_in01"])
# Definition of Transition matrix P(:,:,1) associated to action Wait # Definition of Transition matrix P(:,:,1) associated to action Wait
# (action 1) and P(:,:,2) associated to action Cut (action 2) # (action 1) and P(:,:,2) associated to action Cut (action 2)
...@@ -368,16 +368,16 @@ def exampleRand(S, A, is_sparse=False, mask=None): ...@@ -368,16 +368,16 @@ def exampleRand(S, A, is_sparse=False, mask=None):
""" """
# making sure the states and actions are more than one # making sure the states and actions are more than one
if (S < 1 or A < 1): if (S < 1) or (A < 1):
raise ValueError(mdperr["SA_gt_1"]) raise ValueError(mdperr["SA_gt_1"])
# the mask needs to be SxS # the mask needs to be SxS
try: try:
if (mask != None) and (mask.shape != (S, S)): if (mask is not None) and (mask.shape != (S, S)):
raise ValueError(mdperr["mask_SbyS"]) raise ValueError(mdperr["mask_SbyS"])
except AttributeError: except AttributeError:
raise TypeError(mdperr["mask_numpy"]) raise TypeError(mdperr["mask_numpy"])
# if the user hasn't specified a mask, then we will make a random one now # if the user hasn't specified a mask, then we will make a random one now
if mask == None: if mask is None:
mask = rand(A, S, S) mask = rand(A, S, S)
for a in range(A): for a in range(A):
r = random() r = random()
...@@ -392,7 +392,7 @@ def exampleRand(S, A, is_sparse=False, mask=None): ...@@ -392,7 +392,7 @@ def exampleRand(S, A, is_sparse=False, mask=None):
for a in range(A): for a in range(A):
PP = mask[a] * rand(S, S) PP = mask[a] * rand(S, S)
for s in range(S): for s in range(S):
if (mask[a, s, :].sum() == 0): if mask[a, s, :].sum() == 0:
PP[s, randint(0, S - 1)] = 1 PP[s, randint(0, S - 1)] = 1
PP[s, :] = PP[s, :] / PP[s, :].sum() PP[s, :] = PP[s, :] / PP[s, :].sum()
P[a] = sparse(PP) P[a] = sparse(PP)
...@@ -405,7 +405,7 @@ def exampleRand(S, A, is_sparse=False, mask=None): ...@@ -405,7 +405,7 @@ def exampleRand(S, A, is_sparse=False, mask=None):
for a in range(A): for a in range(A):
P[a, :, :] = mask[a] * rand(S, S) P[a, :, :] = mask[a] * rand(S, S)
for s in range(S): for s in range(S):
if (mask[a, s, :].sum() == 0): if mask[a, s, :].sum() == 0:
P[a, s, randint(0, S - 1)] = 1 P[a, s, randint(0, S - 1)] = 1
P[a, s, :] = P[a, s, :] / P[a, s, :].sum() P[a, s, :] = P[a, s, :] / P[a, s, :].sum()
R[a, :, :] = mask[a] * (2 * rand(S, S) - ones((S, S), dtype=int)) R[a, :, :] = mask[a] * (2 * rand(S, S) - ones((S, S), dtype=int))
...@@ -439,7 +439,7 @@ class MDP(object): ...@@ -439,7 +439,7 @@ class MDP(object):
"convergence. With no discount, convergence is not " "convergence. With no discount, convergence is not "
"always assumed.") "always assumed.")
self.discount = discount self.discount = discount
elif not discount is None: elif discount is not None:
raise ValueError("PyMDPtoolbox: the discount must be a positive " raise ValueError("PyMDPtoolbox: the discount must be a positive "
"real number less than or equal to one.") "real number less than or equal to one.")
...@@ -450,7 +450,7 @@ class MDP(object): ...@@ -450,7 +450,7 @@ class MDP(object):
raise ValueError(mdperr["maxi_min"]) raise ValueError(mdperr["maxi_min"])
else: else:
self.max_iter = max_iter self.max_iter = max_iter
elif not max_iter is None: elif max_iter is not None:
raise ValueError("PyMDPtoolbox: max_iter must be a positive real " raise ValueError("PyMDPtoolbox: max_iter must be a positive real "
"number greater than zero.") "number greater than zero.")
...@@ -458,7 +458,7 @@ class MDP(object): ...@@ -458,7 +458,7 @@ class MDP(object):
if epsilon <= 0: if epsilon <= 0:
raise ValueError("PyMDPtoolbox: epsilon must be greater than " raise ValueError("PyMDPtoolbox: epsilon must be greater than "
"0.") "0.")
elif not epsilon is None: elif epsilon is not None:
raise ValueError("PyMDPtoolbox: epsilon must be a positive real " raise ValueError("PyMDPtoolbox: epsilon must be a positive real "
"number greater than zero.") "number greater than zero.")
...@@ -489,20 +489,22 @@ class MDP(object): ...@@ -489,20 +489,22 @@ class MDP(object):
(policy, value) : tuple of new policy and its value (policy, value) : tuple of new policy and its value
""" """
# this V should be a reference to the data rather than a copy if V is None:
if V == None: # this V should be a reference to the data rather than a copy
V = self.V V = self.V
else: else:
if not ((type(V) in (ndarray, matrix)) and try:
(V.shape == (self.S, 1))): if V.shape != (self.S, 1):
raise ValueError("V in bellmanOperator needs to be correct.") raise ValueError("V in bellmanOperator needs to be "
"correct.")
except AttributeError:
raise TypeError("bellman: V must be a numpy array or matrix.")
Q = matrix(zeros((self.S, self.A))) Q = matrix(zeros((self.S, self.A)))
for aa in range(self.A): for aa in range(self.A):
Q[:, aa] = self.R[:, aa] + (self.discount * self.P[aa] * V) Q[:, aa] = self.R[:, aa] + (self.discount * self.P[aa] * V)
# Which way is better? if choose the first way, then the classes that # Which way is better?
# call this function must be changed
# 1. Return, (policy, value) # 1. Return, (policy, value)
return (Q.argmax(axis=1), Q.max(axis=1)) return (Q.argmax(axis=1), Q.max(axis=1))
# 2. update self.policy and self.V directly # 2. update self.policy and self.V directly
...@@ -527,11 +529,12 @@ class MDP(object): ...@@ -527,11 +529,12 @@ class MDP(object):
PR(SxA) = reward matrix PR(SxA) = reward matrix
""" """
# we assume that P and R define a MDP i,e. assumption is that # We assume that P and R define a MDP i,e. assumption is that
# check(P, R) has already been run and doesn't fail. # check(P, R) has already been run and doesn't fail.
#
# make P be an object array with (S, S) shaped array elements # Make P be an object array with (S, S) shaped array elements. Save it
if (P.dtype is object): # as a matrix.
if P.dtype == object:
self.P = P self.P = P
self.A = self.P.shape[0] self.A = self.P.shape[0]
self.S = self.P[0].shape[0] self.S = self.P[0].shape[0]
...@@ -540,31 +543,25 @@ class MDP(object): ...@@ -540,31 +543,25 @@ class MDP(object):
self.S = P.shape[1] self.S = P.shape[1]
self.P = zeros(self.A, dtype=object) self.P = zeros(self.A, dtype=object)
for aa in range(self.A): for aa in range(self.A):
self.P[aa] = P[aa, :, :] self.P[aa] = matrix(P[aa, :, :])
# Make R have the shape (S, A) and save it as a matrix
# make R have the shape (S, A) if R.dtype == object:
if R.dtype is object:
# R is object shaped (A,) with each element shaped (S, S) # R is object shaped (A,) with each element shaped (S, S)
self.R = zeros((self.S, self.A)) self.R = matrix(zeros((self.S, self.A)))
for aa in range(self.A): for aa in range(self.A):
self.R[:, aa] = multiply(P[aa], R[aa]).sum(1) self.R[:, aa] = (
multiply(P[aa], R[aa]).sum(1).reshape(self.S, 1))
else: else:
if R.ndim == 2: if R.ndim == 2:
# R already has shape (S, A) # R already has shape (S, A)
self.R = R self.R = matrix(R)
else: else:
# R has shape (A, S, S) # R has shape (A, S, S)
self.R = zeros((self.S, self.A)) self.R = matrix(zeros((self.S, self.A)))
for aa in range(self.A): for aa in range(self.A):
self.R[:, aa] = multiply(P[aa], R[aa, :, :]).sum(1) self.R[:, aa] = (
multiply(P[aa], R[aa, :, :]).sum(1).reshape(self.S, 1))
# convert the arrays to numpy matrices
for aa in range(self.A):
if (type(self.P[aa]) is ndarray):
self.P[aa] = matrix(self.P[aa])
if (type(self.R) is ndarray):
self.R = matrix(self.R)
def iterate(self): def iterate(self):
"""Raise error because child classes should implement this function.""" """Raise error because child classes should implement this function."""
raise NotImplementedError("You should create an iterate() method.") raise NotImplementedError("You should create an iterate() method.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment