Commit bfb2d6cb by Steven Cordwell

### allow reward to be specified as a 1-d vector

parent e25e076b
 ... @@ -231,15 +231,22 @@ class MDP(object): ... @@ -231,15 +231,22 @@ class MDP(object): # Set self.R as a tuple of length A, with each element storing an 1×S # Set self.R as a tuple of length A, with each element storing an 1×S # vector. # vector. try: try: if R.ndim == 2: if R.ndim == 1: r = array(R).reshape(self.S) self.R = tuple([r for aa in range(self.A)]) elif R.ndim == 2: self.R = tuple([array(R[:, aa]).reshape(self.S) self.R = tuple([array(R[:, aa]).reshape(self.S) for aa in range(self.A)]) for aa in range(self.A)]) else: else: self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S) self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S) for aa in range(self.A)]) for aa in range(self.A)]) except AttributeError: except AttributeError: self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S) if len(R) == self.A: for aa in range(self.A)]) self.R = tuple([multiply(P[aa], R[aa]).sum(1).reshape(self.S) for aa in range(self.A)]) else: r = array(R).reshape(self.S) self.R = tuple([r for aa in range(self.A)]) def run(self): def run(self): # Raise error because child classes should implement this function. # Raise error because child classes should implement this function. ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment