Commit a1cf5c8d authored by Steven Cordwell's avatar Steven Cordwell
Browse files

fixes to make sure that the dot product method of numpy arrays are called

parent 65167871
......@@ -1650,7 +1650,7 @@ class RelativeValueIteration(MDP):
self.epsilon = epsilon
self.discount = 1
self.V = zeros((self.S, 1))
self.V = zeros(self.S)
self.gain = 0 # self.U[self.S]
self.average_reward = None
......@@ -2003,10 +2003,9 @@ class ValueIterationGS(ValueIteration):
Vprev = self.V.copy()
for s in range(self.S):
Q = []
for a in range(self.A):
Q.append(float(self.R[a][s] +
self.discount * self.P[a][s, :] * self.V))
Q = [float(self.R[a][s]+
self.discount * self.P[a][s, :].dot(self.V))
for a in range(self.A)]
self.V[s] = max(Q)
......@@ -2031,7 +2030,7 @@ class ValueIterationGS(ValueIteration):
for s in range(self.S):
Q = zeros(self.A)
for a in range(self.A):
Q[a] = self.R[a][s] + self.P[a][s,:] * self.discount * self.V
Q[a] = self.R[a][s] + self.discount * self.P[a][s,:].dot(self.V)
self.V[s] = Q.max()
self.policy.append(int(Q.argmax()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment