Commit 303d218c authored by Steven Cordwell's avatar Steven Cordwell
Browse files

refactor MDP.computePR() code to do with rewards matrix to be more clear

parent 0918be59
...@@ -560,7 +560,7 @@ class MDP(object): ...@@ -560,7 +560,7 @@ class MDP(object):
# check(P, R) has already been run and doesn't fail. # check(P, R) has already been run and doesn't fail.
# make P be an object array with (S, S) shaped array elements # make P be an object array with (S, S) shaped array elements
if (P.dtype == object): if (P.dtype is object):
self.P = P self.P = P
self.A = self.P.shape[0] self.A = self.P.shape[0]
self.S = self.P[0].shape[0] self.S = self.P[0].shape[0]
...@@ -572,17 +572,18 @@ class MDP(object): ...@@ -572,17 +572,18 @@ class MDP(object):
self.P[aa] = P[aa, :, :] self.P[aa] = P[aa, :, :]
# make R have the shape (S, A) # make R have the shape (S, A)
if ((R.ndim == 2) and (not R.dtype is object)): if R.dtype is object:
# R already has shape (S, A) # R is object shaped (A,) with each element shaped (S, S)
self.R = R
else:
# R has shape (A, S, S) or object shaped (A,) with each element
# shaped (S, S)
self.R = zeros((self.S, self.A)) self.R = zeros((self.S, self.A))
if (R.dtype is object): for aa in range(self.A):
for aa in range(self.A): self.R[:, aa] = multiply(P[aa], R[aa]).sum(1)
self.R[:, aa] = multiply(P[aa], R[aa]).sum(1) else:
if R.ndim == 2:
# R already has shape (S, A)
self.R = R
else: else:
# R has shape (A, S, S)
self.R = zeros((self.S, self.A))
for aa in range(self.A): for aa in range(self.A):
self.R[:, aa] = multiply(P[aa], R[aa, :, :]).sum(1) self.R[:, aa] = multiply(P[aa], R[aa, :, :]).sum(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment