Commit fb1bc8ea authored by Steven Cordwell's avatar Steven Cordwell
Browse files

working value iteration algorithm

parent eb4cb1fa
......@@ -2,22 +2,28 @@
import sqlite3
from time import time
class MDPSQLite(object):
""""""
def __init__(self, db, discount, initial_V=0):
def __init__(self, db, discount, epsilon, max_iter, initial_V=0):
self.discount = discount
self.conn = sqlite3.connect(db)
self.cur = self.conn.cursor()
self.cur.execute("SELECT value FROM info WHERE name='states'")
self.epsilon = epsilon
self.max_iter = max_iter
self.itr = 0
# The database stuff
self._conn = sqlite3.connect(db)
self._cur = self._conn.cursor()
self._cur.execute("SELECT value FROM info WHERE name='states'")
try:
self.S = self.cur.fetchone()[0]
self.S = self._cur.fetchone()[0]
except TypeError:
raise ValueError("Cannot determine number of states from database. "
"There is no name 'states' in table 'info'.")
self.cur.execute("SELECT value FROM info WHERE name='actions'")
self._cur.execute("SELECT value FROM info WHERE name='actions'")
try:
self.A = self.cur.fetchone()[0]
self.A = self._cur.fetchone()[0]
except TypeError:
raise ValueError("Cannot determine number of actions from database. "
"There is no name 'actions' in table 'info'.")
......@@ -25,7 +31,7 @@ class MDPSQLite(object):
self._initResults(initial_V)
def _initQ(self):
self.cur.executescript('''
self._cur.executescript('''
DROP TABLE IF EXISTS Q;
CREATE TABLE Q (state INTEGER, action INTEGER, value REAL);''')
for a in range(self.A):
......@@ -33,39 +39,46 @@ class MDPSQLite(object):
action = [a] * self.S
value = [None] * self.S
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
self.cur.executemany(cmd, zip(state, action, value))
self.conn.commit()
self._cur.executemany(cmd, zip(state, action, value))
self._conn.commit()
def _initResults(self, initial_V):
self.cur.executescript('''
self._cur.executescript('''
DROP TABLE IF EXISTS policy;
DROP TABLE IF EXISTS V;
DROP TABLE IF EXISTS Vprev;
CREATE TABLE policy (state INTEGER, action INTEGER);
CREATE TABLE V (state INTEGER, value REAL);''')
CREATE TABLE V (state INTEGER, value REAL);
CREATE TABLE Vprev (state INTEGER, value REAL);''')
cmd1 = "INSERT INTO V(state, value) VALUES(?, ?)"
cmd2 = "INSERT INTO policy(state, action) VALUES(?, ?)"
cmd3 = "INSERT INTO Vprev(state, value) VALUES(?, ?)"
state = range(self.S)
action = [None] * self.S
self.cur.executemany(cmd2, zip(state, action))
nones = [None] * self.S
values = zip(state, nones)
del nones
self._cur.executemany(cmd2, values)
self._cur.executemany(cmd3, values)
del values
if initial_V==0:
V = [0] * self.S
self.cur.executemany(cmd1, zip(state, V))
self._cur.executemany(cmd1, zip(state, V))
else:
try:
self.cur.executemany(cmd1, zip(state, V))
self._cur.executemany(cmd1, zip(state, V))
except:
raise ValueError("V is of unsupported type, use a list or tuple.")
self.conn.commit()
self._conn.commit()
def __del__(self):
self.cur.executescript('''
DROP TABLE IF EXISTS Q;
DROP TABLE IF EXISTS V;
DROP TABLE IF EXISTS policy;''')
self.cur.close()
self.conn.close()
#self._cur.executescript('''
# DROP TABLE IF EXISTS Q;
# DROP TABLE IF EXISTS V;
# DROP TABLE IF EXISTS policy;''')
self._cur.close()
self._conn.close()
def bellmanOperator(self):
def _bellmanOperator(self):
g = str(self.discount)
for a in range(self.A):
P = "transition%s" % a
......@@ -86,21 +99,25 @@ class MDPSQLite(object):
" ) AS C "\
" WHERE Q.state = C.state) "\
" WHERE action = "+str(a)+";"
self.cur.execute(cmd)
self.conn.commit()
self.calculateValue()
self._cur.execute(cmd)
self._conn.commit()
self._calculateValue()
def calculatePolicy(self):
def _calculatePolicy(self):
"""This implements argmax() over the actions of Q."""
cmd = '''SELECT state, action
FROM (SELECT state, action, MAX(value)
FROM Q
GROUP BY state)
GROUP BY state;'''
self.cur.execute(cmd)
self.conn.commit()
cmd = '''
UPDATE policy
SET action = (
SELECT action
FROM (SELECT state, action, MAX(value)
FROM Q
GROUP BY state) AS A
WHERE policy.state = A.state
GROUP BY state);'''
self._cur.execute(cmd)
self._conn.commit()
def calculateValue(self):
def _calculateValue(self):
"""This is max() over the actions of Q."""
cmd = '''
UPDATE V
......@@ -109,28 +126,78 @@ class MDPSQLite(object):
FROM Q
WHERE V.state = Q.state
GROUP BY state);'''
self.cur.execute(cmd)
self.conn.commit()
self._cur.execute(cmd)
self._conn.commit()
def _getSpan(self):
cmd = '''
SELECT (MAX(A.value) - MIN(A.value))
FROM (
SELECT (V.value - Vprev.value) as value
FROM V, Vprev
WHERE V.state = Vprev.state) AS A;'''
self._cur.execute(cmd)
span = self._cur.fetchone()
if span is not None:
return span[0]
def getPolicyValue(self):
"""Get the policy and value vectors."""
self.cur.execute("SELECT action FROM policy")
r = self.cur.fetchall()
self._cur.execute("SELECT action FROM policy")
r = self._cur.fetchall()
policy = [x[0] for x in r]
self.cur.execute("SELECT value FROM V")
r = self.cur.fetchall()
self._cur.execute("SELECT value FROM V")
r = self._cur.fetchall()
value = [x[0] for x in r]
return policy, value
def randomQ(self):
def _randomQ(self):
from numpy.random import random
for a in range(self.A):
state = range(self.S)
action = [a] * self.S
value = random(self.S).tolist()
cmd = "INSERT INTO Q VALUES(?, ?, ?)"
self.cur.executemany(cmd, zip(state, action, value))
self.conn.commit()
self._cur.executemany(cmd, zip(state, action, value))
self._conn.commit()
class ValueIterationSQLite(DatabaseManager):
pass
class ValueIterationSQLite(MDPSQLite):
""""""
def __init__(self, db, discount, epsilon=0.01, max_iter=1000,
initial_value=0):
MDPSQLite.__init__(self, db, discount, epsilon, max_iter, initial_value)
if self.discount < 1:
self.thresh = epsilon * (1 - self.discount) / self.discount
else:
self.thresh = epsilon
self._iterate()
def _iterate(self):
self.time = time()
done = False
while not done:
self.itr += 1
self._copyPreviousValue()
self._bellmanOperator()
variation = self._getSpan()
if variation < self.thresh:
done = True
elif (self.itr == self.max_iter):
done = True
self._calculatePolicy()
def _copyPreviousValue(self):
cmd = '''
UPDATE Vprev
SET value = (
SELECT value
FROM V
WHERE Vprev.state = V.state);'''
self._cur.execute(cmd)
self._conn.commit()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment