Commit 74c48a4b authored by Steven Cordwell's avatar Steven Cordwell
Browse files

add checkSquareStochastic function to sql mdp

parent 0c0fde1d
......@@ -9,6 +9,7 @@ from numpy import arange
from numpy.random import permutation, random, randint
def exampleRand(S, A):
"""WARNING: This will delete a database with the same name as 'db'."""
db = "MDP-%sx%s.db" % (S, A)
if os.path.exists(db):
os.remove(db)
......@@ -20,14 +21,14 @@ def exampleRand(S, A):
INSERT INTO info VALUES('states', %s);
INSERT INTO info VALUES('actions', %s);''' % (S, A)
c.executescript(cmd)
for a in range(A):
for a in range(1, A+1):
cmd = '''
CREATE TABLE transition%s (row INTEGER, col INTEGER, prob REAL);
CREATE TABLE reward%s (state INTEGER PRIMARY KEY ASC, val REAL);''' % (a, a)
c.executescript(cmd)
cmd = "INSERT INTO reward%s(val) VALUES(?)" % a
c.executemany(cmd, zip(random(S).tolist()))
for s in xrange(S):
for s in xrange(1, S+1):
# to be usefully represented as a sparse matrix, the number of
# nonzero entries should be less than 1/3 of dimesion of the
# matrix, so S/3
......@@ -36,7 +37,7 @@ def exampleRand(S, A):
# ==> 10000 loops, best of 3: 141 us per loop
# timeit (90894*np.ones(20330, dtype=int)).tolist()
# ==> 1000 loops, best of 3: 548 us per loop
col = (permutation(arange(S))[0:n]).tolist()
col = (permutation(arange(1,S+1))[0:n]).tolist()
val = random(n)
val = (val / val.sum()).tolist()
cmd = "INSERT INTO transition%s VALUES(?, ?, ?)" % a
......@@ -68,14 +69,45 @@ class MDP(object):
except TypeError:
raise ValueError("Cannot determine number of actions from database. "
"There is no name 'actions' in table 'info'.")
self._checkSquareStochastic()
self._initQ()
self._initResults(initial_V)
def _checkSquareStochastic(self):
# check that the columns of the transition matrices sum to one
for a in range(1, self.A + 1):
P = "transition%s" % a
cmd = "SELECT SUM(s) " \
" FROM (" \
" SELECT ABS(SUM(prob)-1)<10e-12 AS s" \
" FROM "+P+"" \
" GROUP BY row);"
self._cur.execute(cmd)
try:
if self._cur.fetchone()[0] != self.S:
raise ValueError("The transition matrix for action %s " \
"is not stochastic." % a)
except TypeError:
raise StandardError("The check stochastic query for a=%s " \
"failed." % a)
cmd = "SELECT MAX(row) FROM " + P
self._cur.execute(cmd)
row_max = self._cur.fetchone()[0]
if int(row_max) != self.S:
raise ValueError("The transition matrix for action %s is " \
"not square: row_max = %s" % (a, row_max))
cmd = "SELECT MAX(col) FROM " + P
self._cur.execute(cmd)
col_max = self._cur.fetchone()[0]
if int(col_max) > row_max:
raise ValueError("The transition matrix for action %a id " \
"not square: col_max = %s" % (a, col_max))
def _initQ(self):
self._delQ()
self._cur.execute("CREATE TABLE Q (state INTEGER, action INTEGER, value REAL);")
for a in range(self.A):
state = range(self.S)
for a in range(1, self.A + 1):
state = xrange(1, self.S + 1)
action = [a] * self.S
value = [None] * self.S
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
......@@ -198,8 +230,8 @@ class MDP(object):
return policy, value
def _randomQ(self):
for a in range(self.A):
state = range(self.S)
for a in range(1,self.A+1):
state = xrange(1,self.S+1)
action = [a] * self.S
value = random(self.S).tolist()
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
......@@ -218,10 +250,10 @@ class ValueIteration(MDP):
else:
self.thresh = epsilon
self._iterate()
#self._iterate()
def __del__(self):
MDP.__del__()
MDP.__del__(self)
def _iterate(self):
self.time = time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment