Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Zahra Rajabi
pymdptoolbox
Commits
fb1bc8ea
Commit
fb1bc8ea
authored
May 21, 2013
by
Steven Cordwell
Browse files
working value iteration algorithm
parent
eb4cb1fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
util.py
View file @
fb1bc8ea
...
...
@@ -2,22 +2,28 @@
import
sqlite3
from
time
import
time
class
MDPSQLite
(
object
):
""""""
def
__init__
(
self
,
db
,
discount
,
initial_V
=
0
):
def
__init__
(
self
,
db
,
discount
,
epsilon
,
max_iter
,
initial_V
=
0
):
self
.
discount
=
discount
self
.
conn
=
sqlite3
.
connect
(
db
)
self
.
cur
=
self
.
conn
.
cursor
()
self
.
cur
.
execute
(
"SELECT value FROM info WHERE name='states'"
)
self
.
epsilon
=
epsilon
self
.
max_iter
=
max_iter
self
.
itr
=
0
# The database stuff
self
.
_conn
=
sqlite3
.
connect
(
db
)
self
.
_cur
=
self
.
_conn
.
cursor
()
self
.
_cur
.
execute
(
"SELECT value FROM info WHERE name='states'"
)
try
:
self
.
S
=
self
.
cur
.
fetchone
()[
0
]
self
.
S
=
self
.
_
cur
.
fetchone
()[
0
]
except
TypeError
:
raise
ValueError
(
"Cannot determine number of states from database. "
"There is no name 'states' in table 'info'."
)
self
.
cur
.
execute
(
"SELECT value FROM info WHERE name='actions'"
)
self
.
_
cur
.
execute
(
"SELECT value FROM info WHERE name='actions'"
)
try
:
self
.
A
=
self
.
cur
.
fetchone
()[
0
]
self
.
A
=
self
.
_
cur
.
fetchone
()[
0
]
except
TypeError
:
raise
ValueError
(
"Cannot determine number of actions from database. "
"There is no name 'actions' in table 'info'."
)
...
...
@@ -25,7 +31,7 @@ class MDPSQLite(object):
self
.
_initResults
(
initial_V
)
def
_initQ
(
self
):
self
.
cur
.
executescript
(
'''
self
.
_
cur
.
executescript
(
'''
DROP TABLE IF EXISTS Q;
CREATE TABLE Q (state INTEGER, action INTEGER, value REAL);'''
)
for
a
in
range
(
self
.
A
):
...
...
@@ -33,39 +39,46 @@ class MDPSQLite(object):
action
=
[
a
]
*
self
.
S
value
=
[
None
]
*
self
.
S
cmd
=
"INSERT INTO Q VALUES(?, ?, ?)"
self
.
cur
.
executemany
(
cmd
,
zip
(
state
,
action
,
value
))
self
.
conn
.
commit
()
self
.
_
cur
.
executemany
(
cmd
,
zip
(
state
,
action
,
value
))
self
.
_
conn
.
commit
()
def
_initResults
(
self
,
initial_V
):
self
.
cur
.
executescript
(
'''
self
.
_
cur
.
executescript
(
'''
DROP TABLE IF EXISTS policy;
DROP TABLE IF EXISTS V;
DROP TABLE IF EXISTS Vprev;
CREATE TABLE policy (state INTEGER, action INTEGER);
CREATE TABLE V (state INTEGER, value REAL);'''
)
CREATE TABLE V (state INTEGER, value REAL);
CREATE TABLE Vprev (state INTEGER, value REAL);'''
)
cmd1
=
"INSERT INTO V(state, value) VALUES(?, ?)"
cmd2
=
"INSERT INTO policy(state, action) VALUES(?, ?)"
cmd3
=
"INSERT INTO Vprev(state, value) VALUES(?, ?)"
state
=
range
(
self
.
S
)
action
=
[
None
]
*
self
.
S
self
.
cur
.
executemany
(
cmd2
,
zip
(
state
,
action
))
nones
=
[
None
]
*
self
.
S
values
=
zip
(
state
,
nones
)
del
nones
self
.
_cur
.
executemany
(
cmd2
,
values
)
self
.
_cur
.
executemany
(
cmd3
,
values
)
del
values
if
initial_V
==
0
:
V
=
[
0
]
*
self
.
S
self
.
cur
.
executemany
(
cmd1
,
zip
(
state
,
V
))
self
.
_
cur
.
executemany
(
cmd1
,
zip
(
state
,
V
))
else
:
try
:
self
.
cur
.
executemany
(
cmd1
,
zip
(
state
,
V
))
self
.
_
cur
.
executemany
(
cmd1
,
zip
(
state
,
V
))
except
:
raise
ValueError
(
"V is of unsupported type, use a list or tuple."
)
self
.
conn
.
commit
()
self
.
_
conn
.
commit
()
def
__del__
(
self
):
self
.
cur
.
executescript
(
'''
DROP TABLE IF EXISTS Q;
DROP TABLE IF EXISTS V;
DROP TABLE IF EXISTS policy;'''
)
self
.
cur
.
close
()
self
.
conn
.
close
()
#
self.
_
cur.executescript('''
#
DROP TABLE IF EXISTS Q;
#
DROP TABLE IF EXISTS V;
#
DROP TABLE IF EXISTS policy;''')
self
.
_
cur
.
close
()
self
.
_
conn
.
close
()
def
bellmanOperator
(
self
):
def
_
bellmanOperator
(
self
):
g
=
str
(
self
.
discount
)
for
a
in
range
(
self
.
A
):
P
=
"transition%s"
%
a
...
...
@@ -86,21 +99,25 @@ class MDPSQLite(object):
" ) AS C "
\
" WHERE Q.state = C.state) "
\
" WHERE action = "
+
str
(
a
)
+
";"
self
.
cur
.
execute
(
cmd
)
self
.
conn
.
commit
()
self
.
calculateValue
()
self
.
_
cur
.
execute
(
cmd
)
self
.
_
conn
.
commit
()
self
.
_
calculateValue
()
def
calculatePolicy
(
self
):
def
_
calculatePolicy
(
self
):
"""This implements argmax() over the actions of Q."""
cmd
=
'''SELECT state, action
FROM (SELECT state, action, MAX(value)
FROM Q
GROUP BY state)
GROUP BY state;'''
self
.
cur
.
execute
(
cmd
)
self
.
conn
.
commit
()
cmd
=
'''
UPDATE policy
SET action = (
SELECT action
FROM (SELECT state, action, MAX(value)
FROM Q
GROUP BY state) AS A
WHERE policy.state = A.state
GROUP BY state);'''
self
.
_cur
.
execute
(
cmd
)
self
.
_conn
.
commit
()
def
calculateValue
(
self
):
def
_
calculateValue
(
self
):
"""This is max() over the actions of Q."""
cmd
=
'''
UPDATE V
...
...
@@ -109,28 +126,78 @@ class MDPSQLite(object):
FROM Q
WHERE V.state = Q.state
GROUP BY state);'''
self
.
cur
.
execute
(
cmd
)
self
.
conn
.
commit
()
self
.
_cur
.
execute
(
cmd
)
self
.
_conn
.
commit
()
def
_getSpan
(
self
):
cmd
=
'''
SELECT (MAX(A.value) - MIN(A.value))
FROM (
SELECT (V.value - Vprev.value) as value
FROM V, Vprev
WHERE V.state = Vprev.state) AS A;'''
self
.
_cur
.
execute
(
cmd
)
span
=
self
.
_cur
.
fetchone
()
if
span
is
not
None
:
return
span
[
0
]
def
getPolicyValue
(
self
):
"""Get the policy and value vectors."""
self
.
cur
.
execute
(
"SELECT action FROM policy"
)
r
=
self
.
cur
.
fetchall
()
self
.
_
cur
.
execute
(
"SELECT action FROM policy"
)
r
=
self
.
_
cur
.
fetchall
()
policy
=
[
x
[
0
]
for
x
in
r
]
self
.
cur
.
execute
(
"SELECT value FROM V"
)
r
=
self
.
cur
.
fetchall
()
self
.
_
cur
.
execute
(
"SELECT value FROM V"
)
r
=
self
.
_
cur
.
fetchall
()
value
=
[
x
[
0
]
for
x
in
r
]
return
policy
,
value
def
randomQ
(
self
):
def
_
randomQ
(
self
):
from
numpy.random
import
random
for
a
in
range
(
self
.
A
):
state
=
range
(
self
.
S
)
action
=
[
a
]
*
self
.
S
value
=
random
(
self
.
S
).
tolist
()
cmd
=
"INSERT INTO Q VALUES(?, ?, ?)"
self
.
cur
.
executemany
(
cmd
,
zip
(
state
,
action
,
value
))
self
.
conn
.
commit
()
self
.
_
cur
.
executemany
(
cmd
,
zip
(
state
,
action
,
value
))
self
.
_
conn
.
commit
()
class
ValueIterationSQLite
(
DatabaseManager
):
pass
class
ValueIterationSQLite
(
MDPSQLite
):
""""""
def
__init__
(
self
,
db
,
discount
,
epsilon
=
0.01
,
max_iter
=
1000
,
initial_value
=
0
):
MDPSQLite
.
__init__
(
self
,
db
,
discount
,
epsilon
,
max_iter
,
initial_value
)
if
self
.
discount
<
1
:
self
.
thresh
=
epsilon
*
(
1
-
self
.
discount
)
/
self
.
discount
else
:
self
.
thresh
=
epsilon
self
.
_iterate
()
def
_iterate
(
self
):
self
.
time
=
time
()
done
=
False
while
not
done
:
self
.
itr
+=
1
self
.
_copyPreviousValue
()
self
.
_bellmanOperator
()
variation
=
self
.
_getSpan
()
if
variation
<
self
.
thresh
:
done
=
True
elif
(
self
.
itr
==
self
.
max_iter
):
done
=
True
self
.
_calculatePolicy
()
def
_copyPreviousValue
(
self
):
cmd
=
'''
UPDATE Vprev
SET value = (
SELECT value
FROM V
WHERE Vprev.state = V.state);'''
self
.
_cur
.
execute
(
cmd
)
self
.
_conn
.
commit
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment