Commit 8d56b9ee authored by Steven Cordwell's avatar Steven Cordwell
Browse files

break out the storage of P, S, and A from MDP._computePR into MDP._computeP...

break out the storage of P, S, and A from MDP._computePR into MDP._computeP and use this in QLearning
parent ad4c8886
......@@ -230,6 +230,20 @@ class MDP(object):
# self.V = Q.max(axis=1)
# self.policy = Q.argmax(axis=1)
def _computeP(self, P):
# Set self.P as a tuple of length A, with each element storing an S×S
# matrix.
self.A = len(P)
try:
if P.ndim == 3:
self.S = P.shape[1]
else:
self.S = P[0].shape[0]
except AttributeError:
self.S = P[0].shape[0]
# convert P to a tuple of numpy arrays
self.P = tuple([P[aa] for aa in range(self.A)])
def _computePR(self, P, R):
# Compute the reward for the system in one state chosing an action.
# Arguments
......@@ -249,18 +263,8 @@ class MDP(object):
# We assume that P and R define a MDP i,e. assumption is that
# check(P, R) has already been run and doesn't fail.
#
# Set self.P as a tuple of length A, with each element storing an S×S
# matrix.
self.A = len(P)
try:
if P.ndim == 3:
self.S = P.shape[1]
else:
self.S = P[0].shape[0]
except AttributeError:
self.S = P[0].shape[0]
# convert P to a tuple of numpy arrays
self.P = tuple([P[aa] for aa in range(self.A)])
# First compute store P, S, and A
self._computeP(P)
# Set self.R as a tuple of length A, with each element storing an 1×S
# vector.
try:
......@@ -963,16 +967,8 @@ class QLearning(MDP):
# run on it, so check that it defines an MDP
check(transitions, reward)
if (transitions.dtype is object):
self.P = transitions
self.A = self.P.shape[0]
self.S = self.P[0].shape[0]
else: # convert to an object array
self.A = transitions.shape[0]
self.S = transitions.shape[1]
self.P = zeros(self.A, dtype=object)
for aa in range(self.A):
self.P[aa] = transitions[aa, :, :]
# Store P, S, and A
self._computeP(transitions)
self.R = reward
......@@ -1017,11 +1013,9 @@ class QLearning(MDP):
s_new = s_new + 1
p = p + self.P[a][s, s_new]
if (self.R.dtype == object):
try:
r = self.R[a][s, s_new]
elif (self.R.ndim == 3):
r = self.R[a, s, s_new]
else:
except IndexError:
r = self.R[s, a]
# Updating the value of Q
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment