mdpsql.py 12.1 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import os
4
import sqlite3
5

6
7
from time import time

8
from numpy import arange
9
from numpy.random import permutation, random, randint
10

Steven Cordwell's avatar
Steven Cordwell committed
11
def exampleForest(S=3, r1=4, r2=2, p=0.1):
Steven Cordwell's avatar
Steven Cordwell committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    db = "MDP-forest-%s.db" % S
    if os.path.exists(db):
        os.remove(db)
    conn = sqlite3.connect(db)
    with conn:
        c = conn.cursor()
        cmd = '''
            CREATE TABLE info (name TEXT, value INTEGER);
            INSERT INTO info VALUES('states', %s);
            INSERT INTO info VALUES('actions', 2);''' % S
        c.executescript(cmd)
        cmd = '''
            CREATE TABLE transition1 (row INTEGER, col INTEGER, prob REAL);
            CREATE TABLE reward1 (state INTEGER PRIMARY KEY ASC, val REAL);
            CREATE TABLE transition2 (row INTEGER, col INTEGER, prob REAL);
            CREATE TABLE reward2 (state INTEGER PRIMARY KEY ASC, val REAL);'''
        c.executescript(cmd)
        rows = range(1, S + 1) * 2
        cols = [1] * S + range(2, S + 1) + [S]
        vals = [p] * S + [1-p] * S
        cmd = "INSERT INTO transition1 VALUES(?, ?, ?)"
        c.executemany(cmd, zip(rows, cols, vals))
        rows = range(1, S + 1)
        cols = [1] * S
        vals = [1] * S
        cmd = "INSERT INTO transition2 VALUES(?, ?, ?)"
        c.executemany(cmd, zip(rows, cols, vals))
        cmd = "INSERT INTO reward1(val) VALUES(?)"
        c.executemany(cmd, zip([0] * (S - 1) + [r1]))
        cmd = "INSERT INTO reward2(val) VALUES(?)"
        c.executemany(cmd, zip([0] + [1] * (S - 2) + [r2]))
        cmd = '''
            CREATE INDEX Pidx1 ON transition1 (row, col);
            CREATE INDEX Pidx2 ON transition2 (row, col);'''
        c.executescript(cmd)
    # return the databases name
    return db
Steven Cordwell's avatar
Steven Cordwell committed
49

50
def exampleRand(S, A):
51
    """WARNING: This will delete a database with the same name as 'db'."""
52
53
54
55
56
57
58
59
60
61
62
    db = "MDP-%sx%s.db" % (S, A)
    if os.path.exists(db):
        os.remove(db)
    conn = sqlite3.connect(db)
    with conn:
        c = conn.cursor()
        cmd = '''
            CREATE TABLE info (name TEXT, value INTEGER);
            INSERT INTO info VALUES('states', %s);
            INSERT INTO info VALUES('actions', %s);''' % (S, A)
        c.executescript(cmd)
63
        for a in range(1, A+1):
64
65
            cmd = '''
                CREATE TABLE transition%s (row INTEGER, col INTEGER, prob REAL);
66
67
                CREATE TABLE reward%s (state INTEGER PRIMARY KEY ASC, val REAL);
                ''' % (a, a)
68
69
            c.executescript(cmd)
            cmd = "INSERT INTO reward%s(val) VALUES(?)" % a
70
            c.executemany(cmd, zip(random(S).tolist()))
71
            for s in xrange(1, S+1):
72
73
74
                # to be usefully represented as a sparse matrix, the number of
                # nonzero entries should be less than 1/3 of dimesion of the
                # matrix, so S/3
75
76
77
78
79
                n = randint(1, S//3)
                # timeit [90894] * 20330
                # ==> 10000 loops, best of 3: 141 us per loop
                # timeit (90894*np.ones(20330, dtype=int)).tolist()
                # ==> 1000 loops, best of 3: 548 us per loop
80
                col = (permutation(arange(1,S+1))[0:n]).tolist()
81
                val = random(n)
82
83
84
85
86
                val = (val / val.sum()).tolist()
                cmd = "INSERT INTO transition%s VALUES(?, ?, ?)" % a
                c.executemany(cmd, zip([s] * n, col, val))
            cmd = "CREATE UNIQUE INDEX Pidx%s ON transition%s (row, col);" % (a, a)
            c.execute(cmd)
Steven Cordwell's avatar
Steven Cordwell committed
87
88
    # return the name of teh database
    return db
89

90
91

class MDP(object):
92
93
    """"""
    
94
    def __init__(self, db, discount, epsilon, max_iter, initial_V=0):
Steven Cordwell's avatar
Steven Cordwell committed
95
        self.discount = discount
96
97
98
99
100
101
102
        self.epsilon = epsilon
        self.max_iter = max_iter
        self.itr = 0
        # The database stuff
        self._conn = sqlite3.connect(db)
        self._cur = self._conn.cursor()
        self._cur.execute("SELECT value FROM info WHERE name='states'")
103
        try:
104
            self.S = self._cur.fetchone()[0]
105
        except TypeError:
106
107
108
            raise ValueError("Cannot determine number of states from "
                             "database. There is no name 'states' in table "
                             "'info'.")
109
        self._cur.execute("SELECT value FROM info WHERE name='actions'")
110
        try:
111
            self.A = self._cur.fetchone()[0]
112
        except TypeError:
113
114
115
            raise ValueError("Cannot determine number of actions from "
                             "database. There is no name 'actions' in table "
                             "'info'.")
116
        self._checkSquareStochastic()
117
        self._initQ()
Steven Cordwell's avatar
Steven Cordwell committed
118
        self._initResults(initial_V)
119
    
120
121
122
123
124
125
126
127
128
129
130
131
    def _checkSquareStochastic(self):
        # check that the columns of the transition matrices sum to one
        for a in range(1, self.A + 1):
            P = "transition%s" % a
            cmd = "SELECT SUM(s) " \
                  "  FROM (" \
                  "       SELECT ABS(SUM(prob)-1)<10e-12 AS s" \
                  "         FROM "+P+"" \
                  "        GROUP BY row);"
            self._cur.execute(cmd)
            try:
                if self._cur.fetchone()[0] != self.S:
132
                    raise ValueError("The transition matrix for action %s "
133
134
                                     "is not stochastic." % a)
            except TypeError:
135
                raise StandardError("The check stochastic query for a=%s "
136
137
138
139
140
                                    "failed." % a)
            cmd = "SELECT MAX(row) FROM " + P
            self._cur.execute(cmd)
            row_max = self._cur.fetchone()[0]
            if int(row_max) != self.S:
141
                raise ValueError("The transition matrix for action %s is "
142
143
144
145
146
                                 "not square: row_max = %s" % (a, row_max))
            cmd = "SELECT MAX(col) FROM " + P
            self._cur.execute(cmd)
            col_max = self._cur.fetchone()[0]
            if int(col_max) > row_max:
147
                raise ValueError("The transition matrix for action %a id "
148
149
                                 "not square: col_max = %s" % (a, col_max))
    
150
    def _initQ(self):
151
        self._delQ()
152
153
        self._cur.execute("CREATE TABLE Q (state INTEGER, action INTEGER, "
                          "value REAL);")
154
155
        for a in range(1, self.A + 1):
            state = xrange(1, self.S + 1)
Steven Cordwell's avatar
Steven Cordwell committed
156
157
158
            action = [a] * self.S
            value = [None] * self.S
            cmd = "INSERT INTO Q VALUES(?, ?, ?)"
159
            self._cur.executemany(cmd, zip(state, action, value))
160
        self._cur.execute("CREATE UNIQUE INDEX Qidx ON Q (state, action);")
161
        self._conn.commit()
162
    
163
164
165
166
167
    def _delQ(self):
        self._cur.executescript('''
            DROP TABLE IF EXISTS Q;
            DROP INDEX IF EXISTS Qidx;''')
    
Steven Cordwell's avatar
Steven Cordwell committed
168
    def _initResults(self, initial_V):
169
        self._delResults()
170
        self._cur.executescript('''
171
172
173
            CREATE TABLE policy (state INTEGER PRIMARY KEY ASC, action INTEGER);
            CREATE TABLE V (state INTEGER PRIMARY KEY ASC, value REAL);
            CREATE TABLE Vprev (state INTEGER PRIMARY KEY ASC, value REAL);''')
174
175
176
177
        cmd1 = "INSERT INTO V(value) VALUES(?)"
        cmd2 = "INSERT INTO policy(action) VALUES(?)"
        cmd3 = "INSERT INTO Vprev(value) VALUES(?)"
        values = zip([None] * self.S)
178
179
180
        self._cur.executemany(cmd2, values)
        self._cur.executemany(cmd3, values)
        del values
Steven Cordwell's avatar
Steven Cordwell committed
181
        if initial_V==0:
182
            self._cur.executemany(cmd1, zip([0] * self.S))
183
184
        else:
            try:
185
                self._cur.executemany(cmd1, zip(initial_V))
186
            except:
187
188
                raise ValueError("V is of unsupported type, use a list or "
                                 "tuple.")
189
        self._conn.commit()
190
    
191
192
193
194
195
196
    def _delResults(self):
        self._cur.executescript('''
            DROP TABLE IF EXISTS policy;
            DROP TABLE IF EXISTS V;
            DROP TABLE IF EXISTS Vprev;''')
    
197
    def __del__(self):
198
199
200
201
        self._delQ()
        self._cur.executescript('''
            DROP TABLE IF EXISTS Vprev;
            VACUUM;''')
202
203
        self._cur.close()
        self._conn.close()
204
    
205
    def _bellmanOperator(self):
Steven Cordwell's avatar
Steven Cordwell committed
206
        g = str(self.discount)
207
        for a in range(1, self.A + 1):
Steven Cordwell's avatar
Steven Cordwell committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            P = "transition%s" % a
            R = "reward%s" % a
            cmd = "" \
"UPDATE Q " \
"   SET value = (" \
"       SELECT value "\
"              FROM (" \
"                   SELECT R.state AS state, (R.val + B.val) AS value " \
"                     FROM "+R+" AS R, (" \
"                          SELECT P.row, "+g+"*SUM(P.prob * V.value) AS val" \
"                            FROM "+P+" AS P, V " \
"                           WHERE V.state = P.col " \
"                           GROUP BY P.row" \
"                          ) AS B " \
"                    WHERE R.state = B.row" \
"                   ) AS C "\
"        WHERE Q.state = C.state) "\
" WHERE action = "+str(a)+";"
226
227
228
            self._cur.execute(cmd)
        self._conn.commit()
        self._calculateValue()
229
    
230
    def _calculatePolicy(self):
231
        """This implements argmax() over the actions of Q."""
232
233
234
235
236
237
238
239
240
241
242
        cmd = '''
              UPDATE policy
                 SET action = (
                     SELECT action
                       FROM (SELECT state, action, MAX(value)
                               FROM Q
                              GROUP BY state) AS A
                       WHERE policy.state = A.state
                       GROUP BY state);'''
        self._cur.execute(cmd)
        self._conn.commit()
243
    
244
    def _calculateValue(self):
245
        """This is max() over the actions of Q."""
Steven Cordwell's avatar
Steven Cordwell committed
246
247
248
249
250
251
252
        cmd = '''
              UPDATE V
                 SET value = (
                     SELECT MAX(value)
                       FROM Q
                      WHERE V.state = Q.state
                      GROUP BY state);'''
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        self._cur.execute(cmd)
        self._conn.commit()
    
    def _getSpan(self):
        cmd = '''
              SELECT (MAX(A.value) - MIN(A.value))
              FROM (
                   SELECT (V.value - Vprev.value) as value
                     FROM V, Vprev
                    WHERE V.state = Vprev.state) AS A;'''
        self._cur.execute(cmd)
        span = self._cur.fetchone()
        if span is not None:
            return span[0]
267
268
    
    def getPolicyValue(self):
Steven Cordwell's avatar
Steven Cordwell committed
269
        """Get the policy and value vectors."""
270
271
        self._cur.execute("SELECT action FROM policy")
        r = self._cur.fetchall()
272
        policy = [x[0] for x in r]
273
274
        self._cur.execute("SELECT value FROM V")
        r = self._cur.fetchall()
Steven Cordwell's avatar
Steven Cordwell committed
275
        value = [x[0] for x in r]
276
277
        return policy, value
    
278
    def _randomQ(self):
279
280
        for a in range(1,self.A+1):
            state = xrange(1,self.S+1)
281
282
283
            action = [a] * self.S
            value = random(self.S).tolist()
            cmd = "INSERT INTO Q VALUES(?, ?, ?)"
284
285
            self._cur.executemany(cmd, zip(state, action, value))
        self._conn.commit()
Steven Cordwell's avatar
Steven Cordwell committed
286

287
class ValueIteration(MDP):
288
289
290
291
    """"""
    
    def __init__(self, db, discount, epsilon=0.01, max_iter=1000,
                 initial_value=0):
292
        MDP.__init__(self, db, discount, epsilon, max_iter, initial_value)
293
294
295
296
297
298
        
        if self.discount < 1:
            self.thresh = epsilon * (1 - self.discount) / self.discount
        else:
            self.thresh = epsilon
        
Steven Cordwell's avatar
Steven Cordwell committed
299
        self._iterate()
300
    
301
    def __del__(self):
302
        MDP.__del__(self)
303
    
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    def _iterate(self):
        self.time = time()
        done = False
        while not done:
            self.itr += 1
            
            self._copyPreviousValue()
            self._bellmanOperator()
            variation = self._getSpan()
            
            if variation < self.thresh:
                done = True
            elif (self.itr == self.max_iter):
                done = True
318
        # get the optimal policy
319
        self._calculatePolicy()
320
        # calculate the time taken to finish
321
        self.time = time() - self.time
322
323
324
325
326
327
328
329
330
331
    
    def _copyPreviousValue(self):
        cmd = '''
              UPDATE Vprev
                 SET value = (
                     SELECT value
                       FROM V
                      WHERE Vprev.state = V.state);'''
        self._cur.execute(cmd)
        self._conn.commit()