Commit b7bfb004 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

define the iteration methods as 'private' and call them from the __init__ function

parent 429a16b4
...@@ -818,9 +818,9 @@ class MDP(object): ...@@ -818,9 +818,9 @@ class MDP(object):
raise raise
self.R = tuple(self.R) self.R = tuple(self.R)
def iterate(self): def _iterate(self):
"""Raise error because child classes should implement this function.""" """Raise error because child classes should implement this function."""
raise NotImplementedError("You should create an iterate() method.") raise NotImplementedError("You should create an _iterate() method.")
def setSilent(self): def setSilent(self):
"""Set the MDP algorithm to silent mode.""" """Set the MDP algorithm to silent mode."""
...@@ -903,7 +903,10 @@ class FiniteHorizon(MDP): ...@@ -903,7 +903,10 @@ class FiniteHorizon(MDP):
if h is not None: if h is not None:
self.V[:, N] = h self.V[:, N] = h
def iterate(self): # Call the iteration method
self._iterate()
def _iterate(self):
"""Run the finite horizon algorithm.""" """Run the finite horizon algorithm."""
self.time = time() self.time = time()
...@@ -978,8 +981,11 @@ class LP(MDP): ...@@ -978,8 +981,11 @@ class LP(MDP):
# this doesn't do what I want it to do c.f. issue #3 # this doesn't do what I want it to do c.f. issue #3
if not self.verbose: if not self.verbose:
solvers.options['show_progress'] = False solvers.options['show_progress'] = False
# Call the iteration method
self._iterate()
def iterate(self): def _iterate(self):
"""Run the linear programming algorithm.""" """Run the linear programming algorithm."""
self.time = time() self.time = time()
# The objective is to resolve : min V / V >= PR + discount*P*V # The objective is to resolve : min V / V >= PR + discount*P*V
...@@ -1051,7 +1057,6 @@ class PolicyIteration(MDP): ...@@ -1051,7 +1057,6 @@ class PolicyIteration(MDP):
>>> import mdp >>> import mdp
>>> P, R = mdp.exampleRand(5, 3) >>> P, R = mdp.exampleRand(5, 3)
>>> pi = mdp.PolicyIteration(P, R, 0.9) >>> pi = mdp.PolicyIteration(P, R, 0.9)
>>> pi.iterate()
""" """
...@@ -1099,6 +1104,9 @@ class PolicyIteration(MDP): ...@@ -1099,6 +1104,9 @@ class PolicyIteration(MDP):
"evaluation or 1 for iterative evaluation. " "evaluation or 1 for iterative evaluation. "
"The strings 'matrix' and 'iterative' can also " "The strings 'matrix' and 'iterative' can also "
"be used.") "be used.")
# Call the iteration method
self._iterate()
def _computePpolicyPRpolicy(self): def _computePpolicyPRpolicy(self):
"""Compute the transition matrix and the reward matrix for a policy. """Compute the transition matrix and the reward matrix for a policy.
...@@ -1243,7 +1251,7 @@ class PolicyIteration(MDP): ...@@ -1243,7 +1251,7 @@ class PolicyIteration(MDP):
self.V = self._lin_eq( self.V = self._lin_eq(
(self._speye(self.S, self.S) - self.discount * Ppolicy), Rpolicy) (self._speye(self.S, self.S) - self.discount * Ppolicy), Rpolicy)
def iterate(self): def _iterate(self):
"""Run the policy iteration algorithm.""" """Run the policy iteration algorithm."""
if self.verbose: if self.verbose:
...@@ -1368,8 +1376,11 @@ class PolicyIterationModified(PolicyIteration): ...@@ -1368,8 +1376,11 @@ class PolicyIterationModified(PolicyIteration):
else: else:
# min(min()) is not right # min(min()) is not right
self.V = 1 / (1 - discount) * self.R.min() * ones((self.S, 1)) self.V = 1 / (1 - discount) * self.R.min() * ones((self.S, 1))
# Call the iteration method
self._iterate()
def iterate(self): def _iterate(self):
"""Run the modified policy iteration algorithm.""" """Run the modified policy iteration algorithm."""
if self.verbose: if self.verbose:
...@@ -1448,7 +1459,6 @@ class QLearning(MDP): ...@@ -1448,7 +1459,6 @@ class QLearning(MDP):
>>> random.seed(0) >>> random.seed(0)
>>> P, R = mdp.exampleForest() >>> P, R = mdp.exampleForest()
>>> ql = mdp.QLearning(P, R, 0.96) >>> ql = mdp.QLearning(P, R, 0.96)
>>> ql.iterate()
>>> ql.Q >>> ql.Q
array([[ 68.80977389, 46.62560314], array([[ 68.80977389, 46.62560314],
[ 72.58265749, 43.1170545 ], [ 72.58265749, 43.1170545 ],
...@@ -1465,7 +1475,6 @@ class QLearning(MDP): ...@@ -1465,7 +1475,6 @@ class QLearning(MDP):
>>> R = np.array([[5, 10], [-1, 2]]) >>> R = np.array([[5, 10], [-1, 2]])
>>> random.seed(0) >>> random.seed(0)
>>> ql = mdp.QLearning(P, R, 0.9) >>> ql = mdp.QLearning(P, R, 0.9)
>>> ql.iterate()
>>> ql.Q >>> ql.Q
array([[ 36.63245946, 42.24434307], array([[ 36.63245946, 42.24434307],
[ 35.96582807, 32.70456417]]) [ 35.96582807, 32.70456417]])
...@@ -1511,7 +1520,10 @@ class QLearning(MDP): ...@@ -1511,7 +1520,10 @@ class QLearning(MDP):
self.Q = zeros((self.S, self.A)) self.Q = zeros((self.S, self.A))
self.mean_discrepancy = [] self.mean_discrepancy = []
def iterate(self): # Call the iteration method
self._iterate()
def _iterate(self):
"""Run the Q-learning algoritm.""" """Run the Q-learning algoritm."""
discrepancy = [] discrepancy = []
...@@ -1613,7 +1625,6 @@ class RelativeValueIteration(MDP): ...@@ -1613,7 +1625,6 @@ class RelativeValueIteration(MDP):
>>> import mdp >>> import mdp
>>> P, R = exampleForest() >>> P, R = exampleForest()
>>> rvi = mdp.RelativeValueIteration(P, R) >>> rvi = mdp.RelativeValueIteration(P, R)
>>> rvi.iterate()
>>> rvi.average_reward >>> rvi.average_reward
2.4300000000000002 2.4300000000000002
>>> rvi.policy >>> rvi.policy
...@@ -1625,8 +1636,7 @@ class RelativeValueIteration(MDP): ...@@ -1625,8 +1636,7 @@ class RelativeValueIteration(MDP):
>>> import numpy as np >>> import numpy as np
>>> P = np.array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]]) >>> P = np.array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]])
>>> R = np.array([[5, 10], [-1, 2]]) >>> R = np.array([[5, 10], [-1, 2]])
>>> vi = mdp.RelativeValueIteration(P, R) >>> rvi = mdp.RelativeValueIteration(P, R)
>>> rvi.iterate()
>>> rvi.V >>> rvi.V
(10.0, 3.885235246411831) (10.0, 3.885235246411831)
>>> rvi.average_reward >>> rvi.average_reward
...@@ -1650,8 +1660,11 @@ class RelativeValueIteration(MDP): ...@@ -1650,8 +1660,11 @@ class RelativeValueIteration(MDP):
self.gain = 0 # self.U[self.S] self.gain = 0 # self.U[self.S]
self.average_reward = None self.average_reward = None
# Call the iteration method
self._iterate()
def iterate(self): def _iterate(self):
"""Run the relative value iteration algorithm.""" """Run the relative value iteration algorithm."""
done = False done = False
...@@ -1743,10 +1756,10 @@ class ValueIteration(MDP): ...@@ -1743,10 +1756,10 @@ class ValueIteration(MDP):
--------------- ---------------
V : value function V : value function
A vector which stores the optimal value function. Prior to calling the A vector which stores the optimal value function. Prior to calling the
iterate() method it has a value of None. Shape is (S, ). _iterate() method it has a value of None. Shape is (S, ).
policy : epsilon-optimal policy policy : epsilon-optimal policy
A vector which stores the optimal policy. Prior to calling the A vector which stores the optimal policy. Prior to calling the
iterate() method it has a value of None. Shape is (S, ). _iterate() method it has a value of None. Shape is (S, ).
iter : number of iterations taken to complete the computation iter : number of iterations taken to complete the computation
An integer An integer
time : used CPU time time : used CPU time
...@@ -1754,8 +1767,6 @@ class ValueIteration(MDP): ...@@ -1754,8 +1767,6 @@ class ValueIteration(MDP):
Methods Methods
------- -------
iterate()
Starts the loop for the algorithm to be completed.
setSilent() setSilent()
Sets the instance to silent mode. Sets the instance to silent mode.
setVerbose() setVerbose()
...@@ -1774,7 +1785,6 @@ class ValueIteration(MDP): ...@@ -1774,7 +1785,6 @@ class ValueIteration(MDP):
>>> vi = mdp.ValueIteration(P, R, 0.96) >>> vi = mdp.ValueIteration(P, R, 0.96)
>>> vi.verbose >>> vi.verbose
False False
>>> vi.iterate()
>>> vi.V >>> vi.V
(5.93215488, 9.38815488, 13.38815488) (5.93215488, 9.38815488, 13.38815488)
>>> vi.policy >>> vi.policy
...@@ -1789,7 +1799,6 @@ class ValueIteration(MDP): ...@@ -1789,7 +1799,6 @@ class ValueIteration(MDP):
>>> P = np.array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]]) >>> P = np.array([[[0.5, 0.5],[0.8, 0.2]],[[0, 1],[0.1, 0.9]]])
>>> R = np.array([[5, 10], [-1, 2]]) >>> R = np.array([[5, 10], [-1, 2]])
>>> vi = mdp.ValueIteration(P, R, 0.9) >>> vi = mdp.ValueIteration(P, R, 0.9)
>>> vi.iterate()
>>> vi.V >>> vi.V
(40.04862539271682, 33.65371175967546) (40.04862539271682, 33.65371175967546)
>>> vi.policy >>> vi.policy
...@@ -1807,7 +1816,6 @@ class ValueIteration(MDP): ...@@ -1807,7 +1816,6 @@ class ValueIteration(MDP):
>>> P[1] = sparse([[0, 1],[0.1, 0.9]]) >>> P[1] = sparse([[0, 1],[0.1, 0.9]])
>>> R = np.array([[5, 10], [-1, 2]]) >>> R = np.array([[5, 10], [-1, 2]])
>>> vi = mdp.ValueIteration(P, R, 0.9) >>> vi = mdp.ValueIteration(P, R, 0.9)
>>> vi.iterate()
>>> vi.V >>> vi.V
(40.04862539271682, 33.65371175967546) (40.04862539271682, 33.65371175967546)
>>> vi.policy >>> vi.policy
...@@ -1845,6 +1853,9 @@ class ValueIteration(MDP): ...@@ -1845,6 +1853,9 @@ class ValueIteration(MDP):
else: # discount == 1 else: # discount == 1
# threshold of variation for V for an epsilon-optimal policy # threshold of variation for V for an epsilon-optimal policy
self.thresh = epsilon self.thresh = epsilon
# Call the iteration method
self._iterate()
def _boundIter(self, epsilon): def _boundIter(self, epsilon):
"""Compute a bound for the number of iterations. """Compute a bound for the number of iterations.
...@@ -1895,7 +1906,7 @@ class ValueIteration(MDP): ...@@ -1895,7 +1906,7 @@ class ValueIteration(MDP):
self.max_iter = int(ceil(max_iter)) self.max_iter = int(ceil(max_iter))
def iterate(self): def _iterate(self):
"""Run the value iteration algorithm.""" """Run the value iteration algorithm."""
if self.verbose: if self.verbose:
...@@ -1982,8 +1993,10 @@ class ValueIterationGS(ValueIteration): ...@@ -1982,8 +1993,10 @@ class ValueIterationGS(ValueIteration):
ValueIteration.__init__(self, transitions, reward, discount, epsilon, ValueIteration.__init__(self, transitions, reward, discount, epsilon,
max_iter, initial_value) max_iter, initial_value)
# Call the iteration method
self._iterate()
def iterate(self): def _iterate(self):
"""Run the value iteration Gauss-Seidel algorithm.""" """Run the value iteration Gauss-Seidel algorithm."""
done = False done = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment