From ffaef478fd3ad63b355b0d6a17ee0c7541a650b2 Mon Sep 17 00:00:00 2001 From: Yasser Gonzalez Date: Wed, 21 Jan 2015 13:04:36 -0500 Subject: [PATCH] Add skip_check argument in MDP and subclasses Allowing to skip checking the transition and reward matrices on creation of an MDP instance (and all of its subclasses). --- src/mdptoolbox/mdp.py | 93 +++++++++++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/src/mdptoolbox/mdp.py b/src/mdptoolbox/mdp.py index a90671d..ef7c26b 100644 --- a/src/mdptoolbox/mdp.py +++ b/src/mdptoolbox/mdp.py @@ -130,6 +130,10 @@ class MDP(object): this many iterations have elapsed. This must be greater than 0 if specified. Subclasses of ``MDP`` may pass ``None`` in the case where the algorithm does not use a maximum number of iterations. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Attributes ---------- @@ -164,7 +168,8 @@ class MDP(object): """ - def __init__(self, transitions, reward, discount, epsilon, max_iter): + def __init__(self, transitions, reward, discount, epsilon, max_iter, + skip_check=False): # Initialise a MDP based on the input parameters. # if the discount is None then the algorithm is assumed to not use it @@ -188,9 +193,11 @@ class MDP(object): self.epsilon = float(epsilon) assert self.epsilon > 0, "Epsilon must be greater than 0." - # we run a check on P and R to make sure they are describing an MDP. If - # an exception isn't raised then they are assumed to be correct. - _util.check(transitions, reward) + if not skip_check: + # We run a check on P and R to make sure they are describing an MDP. + # If an exception isn't raised then they are assumed to be correct. + _util.check(transitions, reward) + self.S, self.A = _computeDimensions(transitions) self.P = self._computeTransition(transitions) self.R = self._computeReward(reward, transitions) @@ -332,6 +339,10 @@ class FiniteHorizon(MDP): Number of periods. Must be greater than 0. h : array, optional Terminal reward. Default: a vector of zeros. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -366,12 +377,14 @@ class FiniteHorizon(MDP): """ - def __init__(self, transitions, reward, discount, N, h=None): + def __init__(self, transitions, reward, discount, N, h=None, + skip_check=False): # Initialise a finite horizon MDP. self.N = int(N) assert self.N > 0, "N must be greater than 0." # Initialise the base class - MDP.__init__(self, transitions, reward, discount, None, None) + MDP.__init__(self, transitions, reward, discount, None, None, + skip_check=skip_check) # remove the iteration counter, it is not meaningful for backwards # induction del self.iter @@ -423,6 +436,10 @@ class _LP(MDP): details. h : array, optional Terminal reward. Default: a vector of zeros. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -449,7 +466,7 @@ class _LP(MDP): """ - def __init__(self, transitions, reward, discount): + def __init__(self, transitions, reward, discount, skip_check=False): # Initialise a linear programming MDP. # import some functions from cvxopt and set them as object methods try: @@ -460,7 +477,8 @@ class _LP(MDP): raise ImportError("The python module cvxopt is required to use " "linear programming functionality.") # initialise the MDP. epsilon and max_iter are not needed - MDP.__init__(self, transitions, reward, discount, None, None) + MDP.__init__(self, transitions, reward, discount, None, None, + skip_check=skip_check) # Set the cvxopt solver to be quiet by default, but ... # this doesn't do what I want it to do c.f. issue #3 if not self.verbose: @@ -523,6 +541,10 @@ class PolicyIteration(MDP): Type of function used to evaluate policy. 0 or "matrix" to solve as a set of linear equations. 1 or "iterative" to solve iteratively. Default: 0. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -558,11 +580,12 @@ class PolicyIteration(MDP): """ def __init__(self, transitions, reward, discount, policy0=None, - max_iter=1000, eval_type=0): + max_iter=1000, eval_type=0, skip_check=False): # Initialise a policy iteration MDP. # # Set up the MDP, but don't need to worry about epsilon values - MDP.__init__(self, transitions, reward, discount, None, max_iter) + MDP.__init__(self, transitions, reward, discount, None, max_iter, + skip_check=skip_check) # Check if the user has supplied an initial policy. If not make one. if policy0 is None: # Initialise the policy to the one which maximises the expected @@ -803,6 +826,10 @@ class PolicyIterationModified(PolicyIteration): max_iter : int, optional Maximum number of iterations. See the documentation for the ``MDP`` class for details. Default is 10. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -830,7 +857,7 @@ class PolicyIterationModified(PolicyIteration): """ def __init__(self, transitions, reward, discount, epsilon=0.01, - max_iter=10): + max_iter=10, skip_check=False): # Initialise a (modified) policy iteration MDP. # Maybe its better not to subclass from PolicyIteration, because the @@ -839,7 +866,7 @@ class PolicyIterationModified(PolicyIteration): # is needed from the PolicyIteration class is the _evalPolicyIterative # function. Perhaps there is a better way to do it? PolicyIteration.__init__(self, transitions, reward, discount, None, - max_iter, 1) + max_iter, 1, skip_check=skip_check) # PolicyIteration doesn't pass epsilon to MDP.__init__() so we will # check it here @@ -916,6 +943,10 @@ class QLearning(MDP): n_iter : int, optional Number of iterations to execute. This is ignored unless it is an integer greater than the default value. Defaut: 10,000. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -967,7 +998,8 @@ class QLearning(MDP): """ - def __init__(self, transitions, reward, discount, n_iter=10000): + def __init__(self, transitions, reward, discount, n_iter=10000, + skip_check=False): # Initialise a Q-learning MDP. # The following check won't be done in MDP()'s initialisation, so let's @@ -975,9 +1007,10 @@ class QLearning(MDP): self.max_iter = int(n_iter) assert self.max_iter >= 10000, "'n_iter' should be greater than 10000." - # We don't want to send this to MDP because _computePR should not be - # run on it, so check that it defines an MDP - _util.check(transitions, reward) + if not skip_check: + # We don't want to send this to MDP because _computePR should not be + # run on it, so check that it defines an MDP + _util.check(transitions, reward) # Store P, S, and A self.S, self.A = _computeDimensions(transitions) @@ -1076,6 +1109,10 @@ class RelativeValueIteration(MDP): max_iter : int, optional Maximum number of iterations. See the documentation for the ``MDP`` class for details. Default: 1000. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -1123,10 +1160,12 @@ class RelativeValueIteration(MDP): """ - def __init__(self, transitions, reward, epsilon=0.01, max_iter=1000): + def __init__(self, transitions, reward, epsilon=0.01, max_iter=1000, + skip_check=False): # Initialise a relative value iteration MDP. - MDP.__init__(self, transitions, reward, None, epsilon, max_iter) + MDP.__init__(self, transitions, reward, None, epsilon, max_iter, + skip_check=skip_check) self.epsilon = epsilon self.discount = 1 @@ -1215,6 +1254,10 @@ class ValueIteration(MDP): documentation for the ``MDP`` class for further details. initial_value : array, optional The starting value function. Default: a vector of zeros. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attributes --------------- @@ -1291,10 +1334,11 @@ class ValueIteration(MDP): """ def __init__(self, transitions, reward, discount, epsilon=0.01, - max_iter=1000, initial_value=0): + max_iter=1000, initial_value=0, skip_check=False): # Initialise a value iteration MDP. - MDP.__init__(self, transitions, reward, discount, epsilon, max_iter) + MDP.__init__(self, transitions, reward, discount, epsilon, max_iter, + skip_check=skip_check) # initialization of optional arguments if initial_value == 0: @@ -1422,6 +1466,10 @@ class ValueIterationGS(ValueIteration): and ``ValueIteration`` classes for details. Default: computed. initial_value : array, optional The starting value function. Default: a vector of zeros. + skip_check : bool + By default we run a check on the ``transitions`` and ``rewards`` + arguments to make sure they describe a valid MDP. You can set this + argument to True in order to skip this check. Data Attribues -------------- @@ -1453,10 +1501,11 @@ class ValueIterationGS(ValueIteration): """ def __init__(self, transitions, reward, discount, epsilon=0.01, - max_iter=10, initial_value=0): + max_iter=10, initial_value=0, skip_check=False): # Initialise a value iteration Gauss-Seidel MDP. - MDP.__init__(self, transitions, reward, discount, epsilon, max_iter) + MDP.__init__(self, transitions, reward, discount, epsilon, max_iter, + skip_check=skip_check) # initialization of optional arguments if initial_value == 0: -- GitLab