Commit ad4c8886 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

Merge branch 'assertstmts'

parents d4a32136 6630e624
...@@ -79,8 +79,8 @@ def forest(S=3, r1=4, r2=2, p=0.1, is_sparse=False): ...@@ -79,8 +79,8 @@ def forest(S=3, r1=4, r2=2, p=0.1, is_sparse=False):
Examples Examples
-------- --------
>>> import mdp >>> import mdptoolbox.example
>>> P, R = mdp.exampleForest() >>> P, R = mdptoolbox.example.forest()
>>> P >>> P
array([[[ 0.1, 0.9, 0. ], array([[[ 0.1, 0.9, 0. ],
[ 0.1, 0. , 0.9], [ 0.1, 0. , 0.9],
...@@ -164,8 +164,8 @@ def rand(S, A, is_sparse=False, mask=None): ...@@ -164,8 +164,8 @@ def rand(S, A, is_sparse=False, mask=None):
Examples Examples
-------- --------
>>> import mdp >>> import mdptoolbox.example
>>> P, R = mdp.exampleRand(5, 3) >>> P, R = mdptoolbox.example.rand(5, 3)
""" """
# making sure the states and actions are more than one # making sure the states and actions are more than one
......
...@@ -156,37 +156,22 @@ class MDP(object): ...@@ -156,37 +156,22 @@ class MDP(object):
# if the discount is None then the algorithm is assumed to not use it # if the discount is None then the algorithm is assumed to not use it
# in its computations # in its computations
if type(discount) in (int, float): if discount is not None:
if (discount <= 0) or (discount > 1): self.discount = float(discount)
raise ValueError("Discount rate must be in ]0; 1]") assert 0.0 < self.discount <= 1.0, "Discount rate must be in ]0; 1]"
else: if self.discount == 1:
if discount == 1: print("PyMDPtoolbox WARNING: check conditions of convergence. "
print("PyMDPtoolbox WARNING: check conditions of " "With no discount, convergence is not always assumed.")
"convergence. With no discount, convergence is not "
"always assumed.")
self.discount = discount
elif discount is not None:
raise ValueError("PyMDPtoolbox: the discount must be a positive "
"real number less than or equal to one.")
# if the max_iter is None then the algorithm is assumed to not use it # if the max_iter is None then the algorithm is assumed to not use it
# in its computations # in its computations
if type(max_iter) in (int, float): if max_iter is not None:
if max_iter <= 0: self.max_iter = int(max_iter)
raise ValueError("The maximum number of iterations must be " assert self.max_iter > 0, "The maximum number of iterations " \
"greater than 0") "must be greater than 0."
else:
self.max_iter = max_iter
elif max_iter is not None:
raise ValueError("PyMDPtoolbox: max_iter must be a positive real "
"number greater than zero.")
# check that epsilon is something sane # check that epsilon is something sane
if type(epsilon) in (int, float): if epsilon is not None:
if epsilon <= 0: self.epsilon = float(epsilon)
raise ValueError("PyMDPtoolbox: epsilon must be greater than " assert self.epsilon > 0, "Epsilon must be greater than 0."
"0.")
elif epsilon is not None:
raise ValueError("PyMDPtoolbox: epsilon must be a positive real "
"number greater than zero.")
# we run a check on P and R to make sure they are describing an MDP. If # we run a check on P and R to make sure they are describing an MDP. If
# an exception isn't raised then they are assumed to be correct. # an exception isn't raised then they are assumed to be correct.
check(transitions, reward) check(transitions, reward)
...@@ -226,10 +211,10 @@ class MDP(object): ...@@ -226,10 +211,10 @@ class MDP(object):
else: else:
# make sure the user supplied V is of the right shape # make sure the user supplied V is of the right shape
try: try:
if V.shape not in ((self.S,), (1, self.S)): assert V.shape in ((self.S,), (1, self.S)), "V is not the " \
raise ValueError("bellman: V is not the right shape.") "right shape (Bellman operator)."
except AttributeError: except AttributeError:
raise TypeError("bellman: V must be a numpy array or matrix.") raise TypeError("V must be a numpy array or matrix.")
# Looping through each action the the Q-value matrix is calculated. # Looping through each action the the Q-value matrix is calculated.
# P and V can be any object that supports indexing, so it is important # P and V can be any object that supports indexing, so it is important
# that you know they define a valid MDP before calling the # that you know they define a valid MDP before calling the
...@@ -274,34 +259,20 @@ class MDP(object): ...@@ -274,34 +259,20 @@ class MDP(object):
self.S = P[0].shape[0] self.S = P[0].shape[0]
except AttributeError: except AttributeError:
self.S = P[0].shape[0] self.S = P[0].shape[0]
except: # convert P to a tuple of numpy arrays
raise self.P = tuple([P[aa] for aa in range(self.A)])
# convert Ps to matrices
self.P = []
for aa in xrange(self.A):
self.P.append(P[aa])
self.P = tuple(self.P)
# Set self.R as a tuple of length A, with each element storing an 1×S # Set self.R as a tuple of length A, with each element storing an 1×S
# vector. # vector.
try: try:
if R.ndim == 2: if R.ndim == 2:
self.R = [] self.R = tuple([array(R[:, aa]).reshape(self.S)
for aa in xrange(self.A): for aa in range(self.A)])
self.R.append(array(R[:, aa]).reshape(self.S))
else: else:
raise AttributeError self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S)
except AttributeError: for aa in xrange(self.A)])
self.R = []
for aa in xrange(self.A):
try:
self.R.append(P[aa].multiply(R[aa]).sum(1).reshape(self.S))
except AttributeError: except AttributeError:
self.R.append(multiply(P[aa],R[aa]).sum(1).reshape(self.S)) self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S)
except: for aa in xrange(self.A)])
raise
except:
raise
self.R = tuple(self.R)
def _iterate(self): def _iterate(self):
# Raise error because child classes should implement this function. # Raise error because child classes should implement this function.
...@@ -371,10 +342,8 @@ class FiniteHorizon(MDP): ...@@ -371,10 +342,8 @@ class FiniteHorizon(MDP):
def __init__(self, transitions, reward, discount, N, h=None): def __init__(self, transitions, reward, discount, N, h=None):
# Initialise a finite horizon MDP. # Initialise a finite horizon MDP.
if N < 1: self.N = int(N)
raise ValueError('PyMDPtoolbox: N must be greater than 0') assert self.N > 0, 'PyMDPtoolbox: N must be greater than 0.'
else:
self.N = N
# Initialise the base class # Initialise the base class
MDP.__init__(self, transitions, reward, discount, None, None) MDP.__init__(self, transitions, reward, discount, None, None)
# remove the iteration counter, it is not meaningful for backwards # remove the iteration counter, it is not meaningful for backwards
......
...@@ -7,6 +7,8 @@ Created on Sun Aug 18 14:30:09 2013 ...@@ -7,6 +7,8 @@ Created on Sun Aug 18 14:30:09 2013
from numpy import absolute, ones from numpy import absolute, ones
SMALLNUM = 10e-12
# These need to be fixed so that we use classes derived from Error. # These need to be fixed so that we use classes derived from Error.
mdperr = { mdperr = {
"mat_nonneg" : "mat_nonneg" :
...@@ -250,7 +252,7 @@ def checkSquareStochastic(Z): ...@@ -250,7 +252,7 @@ def checkSquareStochastic(Z):
# check that the matrix is square, and that each row sums to one # check that the matrix is square, and that each row sums to one
if s1 != s2: if s1 != s2:
raise InvalidMDPError(mdperr["mat_square"]) raise InvalidMDPError(mdperr["mat_square"])
elif (absolute(Z.sum(axis=1) - ones(s2))).max() > 10e-12: elif (absolute(Z.sum(axis=1) - ones(s2))).max() > SMALLNUM:
raise InvalidMDPError(mdperr["mat_stoch"]) raise InvalidMDPError(mdperr["mat_stoch"])
# make sure that there are no values less than zero # make sure that there are no values less than zero
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment