Commit 9e4acf3b authored by Steven Cordwell's avatar Steven Cordwell
Browse files

merge sqlite code

parents 6f70f434 fcce2142
# -*- coding: utf-8 -*-
import os
import sqlite3
from time import time
from numpy import arange
from numpy.random import permutation, random, randint
def exampleForest(S=3, r1=4, r2=2, p=0.1):
pass
def exampleRand(S, A):
"""WARNING: This will delete a database with the same name as 'db'."""
db = "MDP-%sx%s.db" % (S, A)
if os.path.exists(db):
os.remove(db)
conn = sqlite3.connect(db)
with conn:
c = conn.cursor()
cmd = '''
CREATE TABLE info (name TEXT, value INTEGER);
INSERT INTO info VALUES('states', %s);
INSERT INTO info VALUES('actions', %s);''' % (S, A)
c.executescript(cmd)
for a in range(1, A+1):
cmd = '''
CREATE TABLE transition%s (row INTEGER, col INTEGER, prob REAL);
CREATE TABLE reward%s (state INTEGER PRIMARY KEY ASC, val REAL);
''' % (a, a)
c.executescript(cmd)
cmd = "INSERT INTO reward%s(val) VALUES(?)" % a
c.executemany(cmd, zip(random(S).tolist()))
for s in xrange(1, S+1):
# to be usefully represented as a sparse matrix, the number of
# nonzero entries should be less than 1/3 of dimesion of the
# matrix, so S/3
n = randint(1, S//3)
# timeit [90894] * 20330
# ==> 10000 loops, best of 3: 141 us per loop
# timeit (90894*np.ones(20330, dtype=int)).tolist()
# ==> 1000 loops, best of 3: 548 us per loop
col = (permutation(arange(1,S+1))[0:n]).tolist()
val = random(n)
val = (val / val.sum()).tolist()
cmd = "INSERT INTO transition%s VALUES(?, ?, ?)" % a
c.executemany(cmd, zip([s] * n, col, val))
cmd = "CREATE UNIQUE INDEX Pidx%s ON transition%s (row, col);" % (a, a)
c.execute(cmd)
class MDP(object):
""""""
def __init__(self, db, discount, epsilon, max_iter, initial_V=0):
self.discount = discount
self.epsilon = epsilon
self.max_iter = max_iter
self.itr = 0
# The database stuff
self._conn = sqlite3.connect(db)
self._cur = self._conn.cursor()
self._cur.execute("SELECT value FROM info WHERE name='states'")
try:
self.S = self._cur.fetchone()[0]
except TypeError:
raise ValueError("Cannot determine number of states from "
"database. There is no name 'states' in table "
"'info'.")
self._cur.execute("SELECT value FROM info WHERE name='actions'")
try:
self.A = self._cur.fetchone()[0]
except TypeError:
raise ValueError("Cannot determine number of actions from "
"database. There is no name 'actions' in table "
"'info'.")
self._checkSquareStochastic()
self._initQ()
self._initResults(initial_V)
def _checkSquareStochastic(self):
# check that the columns of the transition matrices sum to one
for a in range(1, self.A + 1):
P = "transition%s" % a
cmd = "SELECT SUM(s) " \
" FROM (" \
" SELECT ABS(SUM(prob)-1)<10e-12 AS s" \
" FROM "+P+"" \
" GROUP BY row);"
self._cur.execute(cmd)
try:
if self._cur.fetchone()[0] != self.S:
raise ValueError("The transition matrix for action %s "
"is not stochastic." % a)
except TypeError:
raise StandardError("The check stochastic query for a=%s "
"failed." % a)
cmd = "SELECT MAX(row) FROM " + P
self._cur.execute(cmd)
row_max = self._cur.fetchone()[0]
if int(row_max) != self.S:
raise ValueError("The transition matrix for action %s is "
"not square: row_max = %s" % (a, row_max))
cmd = "SELECT MAX(col) FROM " + P
self._cur.execute(cmd)
col_max = self._cur.fetchone()[0]
if int(col_max) > row_max:
raise ValueError("The transition matrix for action %a id "
"not square: col_max = %s" % (a, col_max))
def _initQ(self):
self._delQ()
self._cur.execute("CREATE TABLE Q (state INTEGER, action INTEGER, "
"value REAL);")
for a in range(1, self.A + 1):
state = xrange(1, self.S + 1)
action = [a] * self.S
value = [None] * self.S
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
self._cur.executemany(cmd, zip(state, action, value))
self._cur.execute("CREATE UNIQUE INDEX Qidx ON Q (state, action);")
self._conn.commit()
def _delQ(self):
self._cur.executescript('''
DROP TABLE IF EXISTS Q;
DROP INDEX IF EXISTS Qidx;''')
def _initResults(self, initial_V):
self._delResults()
self._cur.executescript('''
CREATE TABLE policy (state INTEGER PRIMARY KEY ASC, action INTEGER);
CREATE TABLE V (state INTEGER PRIMARY KEY ASC, value REAL);
CREATE TABLE Vprev (state INTEGER PRIMARY KEY ASC, value REAL);''')
cmd1 = "INSERT INTO V(value) VALUES(?)"
cmd2 = "INSERT INTO policy(action) VALUES(?)"
cmd3 = "INSERT INTO Vprev(value) VALUES(?)"
values = zip([None] * self.S)
self._cur.executemany(cmd2, values)
self._cur.executemany(cmd3, values)
del values
if initial_V==0:
self._cur.executemany(cmd1, zip([0] * self.S))
else:
try:
self._cur.executemany(cmd1, zip(initial_V))
except:
raise ValueError("V is of unsupported type, use a list or "
"tuple.")
self._conn.commit()
def _delResults(self):
self._cur.executescript('''
DROP TABLE IF EXISTS policy;
DROP TABLE IF EXISTS V;
DROP TABLE IF EXISTS Vprev;''')
def __del__(self):
self._delQ()
self._cur.executescript('''
DROP TABLE IF EXISTS Vprev;
VACUUM;''')
self._cur.close()
self._conn.close()
def _bellmanOperator(self):
g = str(self.discount)
for a in range(1, self.A + 1):
P = "transition%s" % a
R = "reward%s" % a
cmd = "" \
"UPDATE Q " \
" SET value = (" \
" SELECT value "\
" FROM (" \
" SELECT R.state AS state, (R.val + B.val) AS value " \
" FROM "+R+" AS R, (" \
" SELECT P.row, "+g+"*SUM(P.prob * V.value) AS val" \
" FROM "+P+" AS P, V " \
" WHERE V.state = P.col " \
" GROUP BY P.row" \
" ) AS B " \
" WHERE R.state = B.row" \
" ) AS C "\
" WHERE Q.state = C.state) "\
" WHERE action = "+str(a)+";"
self._cur.execute(cmd)
self._conn.commit()
self._calculateValue()
def _calculatePolicy(self):
"""This implements argmax() over the actions of Q."""
cmd = '''
UPDATE policy
SET action = (
SELECT action
FROM (SELECT state, action, MAX(value)
FROM Q
GROUP BY state) AS A
WHERE policy.state = A.state
GROUP BY state);'''
self._cur.execute(cmd)
self._conn.commit()
def _calculateValue(self):
"""This is max() over the actions of Q."""
cmd = '''
UPDATE V
SET value = (
SELECT MAX(value)
FROM Q
WHERE V.state = Q.state
GROUP BY state);'''
self._cur.execute(cmd)
self._conn.commit()
def _getSpan(self):
cmd = '''
SELECT (MAX(A.value) - MIN(A.value))
FROM (
SELECT (V.value - Vprev.value) as value
FROM V, Vprev
WHERE V.state = Vprev.state) AS A;'''
self._cur.execute(cmd)
span = self._cur.fetchone()
if span is not None:
return span[0]
def getPolicyValue(self):
"""Get the policy and value vectors."""
self._cur.execute("SELECT action FROM policy")
r = self._cur.fetchall()
policy = [x[0] for x in r]
self._cur.execute("SELECT value FROM V")
r = self._cur.fetchall()
value = [x[0] for x in r]
return policy, value
def _randomQ(self):
for a in range(1,self.A+1):
state = xrange(1,self.S+1)
action = [a] * self.S
value = random(self.S).tolist()
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
self._cur.executemany(cmd, zip(state, action, value))
self._conn.commit()
class ValueIteration(MDP):
""""""
def __init__(self, db, discount, epsilon=0.01, max_iter=1000,
initial_value=0):
MDP.__init__(self, db, discount, epsilon, max_iter, initial_value)
if self.discount < 1:
self.thresh = epsilon * (1 - self.discount) / self.discount
else:
self.thresh = epsilon
#self._iterate()
def __del__(self):
MDP.__del__(self)
def _iterate(self):
self.time = time()
done = False
while not done:
self.itr += 1
self._copyPreviousValue()
self._bellmanOperator()
variation = self._getSpan()
if variation < self.thresh:
done = True
elif (self.itr == self.max_iter):
done = True
# get the optimal policy
self._calculatePolicy()
# calculate the time taken to finish
self.time = time() - self.time
def _copyPreviousValue(self):
cmd = '''
UPDATE Vprev
SET value = (
SELECT value
FROM V
WHERE Vprev.state = V.state);'''
self._cur.execute(cmd)
self._conn.commit()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment