Commit 5b2b87c9 authored by Steven Cordwell's avatar Steven Cordwell
Browse files

Merge branch 'tictactoe'

Conflicts:
	src/examples/firemdp.py
parents 76bacbab 9583a8ba
...@@ -386,4 +386,3 @@ if __name__ == "__main__": ...@@ -386,4 +386,3 @@ if __name__ == "__main__":
else: else:
sdp = solveMDP() sdp = solveMDP()
printPolicy(sdp.policy[:, 0]) printPolicy(sdp.policy[:, 0])
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
#import mdp import numpy as np
from scipy.sparse import dok_matrix
def str_base(num, base, numerals = '0123456789abcdefghijklmnopqrstuvwxyz'): from mdptoolbox import mdp
if base < 2 or base > len(numerals):
raise ValueError("str_base: base must be between 2 and %i" % ACTIONS = 9
len(numerals)) STATES = 3**ACTIONS
PLAYER = 1
OPPONENT = 2
WINS = ([1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 1, 0, 1, 0, 1, 0, 0])
# The valid number of cells belonging to either the player or the opponent:
# (player, opponent)
OWNED_CELLS = ((0, 0),
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(0, 1),
(1, 2),
(2, 3),
(3, 4))
def convertIndexToTuple(state):
""""""
return(tuple(int(x) for x in np.base_repr(state, 3, 9)[-9::]))
def convertTupleToIndex(state):
""""""
return(int("".join(str(x) for x in state), 3))
def getLegalActions(state):
""""""
return(tuple(x for x in range(ACTIONS) if state[x] == 0))
def getTransitionAndRewardArrays():
""""""
P = [dok_matrix((STATES, STATES)) for a in range(ACTIONS)]
#R = spdok((STATES, ACTIONS))
R = np.zeros((STATES, ACTIONS))
# Naive approach, iterate through all possible combinations
for a in range(ACTIONS):
for s in range(STATES):
state = convertIndexToTuple(s)
if not isValid(state):
# There are no defined moves from an invalid state, so
# transition probabilities cannot be calculated. However,
# P must be a square stochastic matrix, so assign a
# probability of one to the invalid state transitioning
# back to itself.
P[a][s, s] = 1
# Reward is 0
else:
s1, p, r = getTransitionProbabilities(state, a)
P[a][s, s1] = p
R[s, a] = r
P[a] = P[a].tocsr()
#R = R.tolil()
return(P, R)
def getTransitionProbabilities(state, action):
"""
Parameters
----------
state : tuple
The state
action : int
The action
if num == 0: Returns
return '0' -------
s1, p, r : tuple of two lists and an int
s1 are the next states, p are the probabilities, and r is the reward
if num < 0: """
sign = '-' #assert isValid(state)
num = -num assert 0 <= action < ACTIONS
if not isLegal(state, action):
# If the action is illegal, then transition back to the same state but
# incur a high negative reward
s1 = [convertTupleToIndex(state)]
return(s1, [1], -10)
# Update the state with the action
state = list(state)
state[action] = PLAYER
if isWon(state, PLAYER):
# If the player's action is a winning move then transition to the
# winning state and receive a reward of 1.
s1 = [convertTupleToIndex(state)]
return(s1, [1], 1)
elif isDraw(state):
s1 = [convertTupleToIndex(state)]
return(s1, [1], 0)
# Now we search through the opponents moves, and calculate transition
# probabilities based on maximising the opponents chance of winning..
s1 = []
p = []
legal_a = getLegalActions(state)
for a in legal_a:
state[a] = OPPONENT
# If the opponent is going to win, we assume that the winning move will
# be chosen:
if isWon(state, OPPONENT):
s1 = [convertTupleToIndex(state)]
return(s1, [1], -1)
elif isDraw(state):
s1 = [convertTupleToIndex(state)]
return(s1, [1], 0)
# Otherwise we assume the opponent will select a move with uniform
# probability across potential moves:
s1.append(convertTupleToIndex(state))
p.append(1.0 / len(legal_a))
state[a] = 0
# During non-terminal play states the reward is 0.
return(s1, p, 0)
def getReward(state, action):
""""""
if not isLegal(state, action):
return -100
state = list(state)
state[action] = PLAYER
if isWon(state, PLAYER):
return 1
elif isWon(state, OPPONENT):
return -1
else: else:
sign = '' return 0
result = ''
while num:
result = numerals[num % (base)] + result
num //= base
return sign + result
def isDraw(state):
""""""
try:
state.index(0)
return False
except ValueError:
return True
class TicTacToeMDP(object): def isLegal(state, action):
"""""" """"""
if state[action] == 0:
def __init__(self): return True
"""""" else:
self.P = [None] * 9
for a in xrange(9):
self.P[a] = {}
self.R = {}
# some board states are equal, just rotations of other states
self.rotorder = []
#self.rotorder.append([0, 1, 2, 3, 4, 5, 6, 7, 8])
self.rotorder.append([6, 3, 0, 7, 4, 1, 8, 5, 2])
self.rotorder.append([8, 7, 6, 5, 4, 3, 2, 1, 0])
self.rotorder.append([2, 5, 8, 1, 4, 7, 0, 3, 6])
# The valid number of cells belonging to either the player or the
# opponent: (player, opponent)
self.nXO = ((0, 0),
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(0, 1),
(1, 2),
(2, 3),
(3, 4))
# The winning positions
self.wins = ([1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 1, 0, 1, 0, 1, 0, 0])
def rotate(self, state):
#rotations = []
identity = []
#rotations.append(state)
identity.append(int("".join(str(x) for x in state), 3))
for k in range(3):
#rotations.append(tuple(state[self.rotorder[k][kk]]
# for kk in xrange(9)))
# Convert the state from base 3 number to integer.
#identity.append(int("".join(str(x) for x in rotations[k + 1]), 3))
identity.append(int("".join(str(state[self.rotorder[k][kk]])
for kk in xrange(9)), 3))
# return the rotation with the smallest identity number
#idx = identity.index(min(identity))
#return (identity[idx], rotations[idx])
return min(identity)
def unrotate(self, move, rotation):
rotation -= 1
# return the move
return self.rotorder[rotation][move]
def isLegal(self, state, action):
""""""
if state[action] == 0:
return True
else:
return False
def isWon(self, state, who):
""""""
# Check to see if there are any wins
for w in self.wins:
S = sum(1 if (w[k] == 1 and state[k] == who) else 0
for k in xrange(9))
if S == 3:
# We have a win
return True
# There were no wins so return False
return False return False
def isWon(state, who):
"""Test if a tic-tac-toe game has been won.
def isDraw(self, state): Assumes that the board is in a legal state.
"""""" Will test if the value 1 is in any winning combination.
try:
state.index(0)
return False
except ValueError:
return True
except:
raise
def isValid(self, state): """
"""""" for w in WINS:
# S1 is the sum of the player's cells S = sum(1 if (w[k] == state[k] == who) else 0
S1 = sum(1 if x == 1 else 0 for x in state) for k in range(ACTIONS))
# S2 is the sum of the opponent's cells if S == 3:
S2 = sum(1 if x == 2 else 0 for x in state) # We have a win
if (S1, S2) in self.nXO:
return True return True
else: # There were no wins so return False
return False return False
def getReward(self, s): def isValid(state):
if self.isWon(s, 1): """"""
return 1 # S1 is the sum of the player's cells
elif self.isWon(s, 2): S1 = sum(1 if x == PLAYER else 0 for x in state)
return -1 # S2 is the sum of the opponent's cells
else: S2 = sum(1 if x == OPPONENT else 0 for x in state)
return 0 if (S1, S2) in OWNED_CELLS:
return True
def run(self): else:
"""""" return False
l = (0,1,2)
# Iterate through a generator of all the combinations
for s in ((a0,a1,a2,a3,a4,a5,a6,a7,a8) for a0 in l for a1 in l
for a2 in l for a3 in l for a4 in l for a5 in l
for a6 in l for a7 in l for a8 in l):
if self.isValid(s):
s_idn = self.rotate(s)
if not self.R.has_key(s_idn):
self.R[s_idn] = self.getReward(s)
self.transition(s)
# Convert P and R to ijv lists
# Iterate through up to the theorectically maxmimum value of s
for s in xrange(int('222211110',3)):
print s
# return (P, R)
def toTuple(self, state):
""""""
state = str_base(state, 3)
state = ''.join('0' for x in range(9 - len(state))) + state
return tuple(int(x) for x in state)
def transition(self, state):
""""""
#TODO: the state needs to be rotated before anything else is done!!!
idn_s = int("".join(str(x) for x in state), 3)
legal_a = [x for x in xrange(9) if state[x] == 0]
for a in legal_a:
s = [x for x in state]
s[a] = 1
is_won = self.isWon(s, 1)
legal_m = [x for x in xrange(9) if s[x] == 0]
for m in legal_m:
s_new = [x for x in s]
s_new[m] = 2
idn_s_new = self.rotate(s_new)
if not self.P[a].has_key((idn_s, idn_s_new)):
self.P[a][(idn_s, idn_s_new)] = len(legal_m)
if __name__ == "__main__": if __name__ == "__main__":
P, R = TicTacToeMDP().run() P, R = getTransitionAndRewardArrays()
#ttt = mdp.ValueIteration(P, R, 1) ttt = mdp.ValueIteration(P, R, 1)
print(ttt.policy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment