# -*- coding: utf-8 -*-
import numpy as np
from scipy.sparse import dok_matrix as spdok
from mdptoolbox import mdp
ACTIONS = 9
STATES = 3**ACTIONS
PLAYER = 1
OPPONENT = 2
WINS = ([1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 1, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 0, 1, 0, 1, 0, 1, 0, 0])
# The valid number of cells belonging to either the player or the opponent:
# (player, opponent)
OWNED_CELLS = ((0, 0),
(1, 1),
(2, 2),
(3, 3),
(4, 4),
(0, 1),
(1, 2),
(2, 3),
(3, 4))
def convertIndexToTuple(state):
""""""
return(tuple(int(x) for x in np.base_repr(state, 3, 9)[-9::]))
def convertTupleToIndex(state):
""""""
return(int("".join(str(x) for x in state), 3))
def getLegalActions(state):
""""""
return(tuple(x for x in range(ACTIONS) if state[x] == 0))
def getTransitionAndRewardArrays():
""""""
P = [spdok((STATES, STATES)) for a in range(ACTIONS)]
R = spdok((STATES, ACTIONS))
# Naive approach, iterate through all possible combinations
for a in range(ACTIONS):
for s in range(STATES):
state = convertIndexToTuple(s)
if not isValid(state):
# There are no defined moves from an invalid state, so
# transition probabilities cannot be calculated. However,
# P must be a square stochastic matrix, so assign a
# probability of one to the invalid state transitioning
# back to itself.
P[a][s, s] = 1
# Reward is 0
else:
s1, p, r = getTransitionProbabilities(state, a)
P[a][s, s1] = p
R[s, a] = r
P[a] = P[a].tocsr()
R = R.tocsc()
return(P, R)
def getTransitionProbabilities(state, action):
"""
Parameters
----------
state : tuple
The state
action : int
The action
Returns
-------
s1, p, r : tuple of two lists and an int
s1 are the next states, p are the probabilities, and r is the reward
"""
#assert isValid(state)
assert 0 <= action < ACTIONS
if not isLegal(state, action):
# If the action is illegal, then transition back to the same state but
# incur a high negative reward
s1 = [convertTupleToIndex(state)]
return(s1, [1], -10)
# Update the state with the action
state = list(state)
state[action] = PLAYER
if isWon(state, PLAYER):
# If the player's action is a winning move then transition to the
# winning state and receive a reward of 1.
s1 = [convertTupleToIndex(state)]
return(s1, [1], 1)
elif isDraw(state):
s1 = [convertTupleToIndex(state)]
return(s1, [1], 0)
# Now we search through the opponents moves, and calculate transition
# probabilities based on maximising the opponents chance of winning..
s1 = []
p = []
legal_a = getLegalActions(state)
for a in legal_a:
state[a] = OPPONENT
# If the opponent is going to win, we assume that the winning move will
# be chosen:
if isWon(state, OPPONENT):
s1 = [convertTupleToIndex(state)]
return(s1, [1], -1)
elif isDraw(state):
s1 = [convertTupleToIndex(state)]
return(s1, [1], 0)
# Otherwise we assume the opponent will select a move with uniform
# probability across potential moves:
s1.append(convertTupleToIndex(state))
p.append(1.0 / len(legal_a))
state[a] = 0
# During non-terminal play states the reward is 0.
return(s1, p, 0)
def getReward(state, action):
""""""
if not isLegal(state, action):
return -100
state = list(state)
state[action] = PLAYER
if isWon(state, PLAYER):
return 1
elif isWon(state, OPPONENT):
return -1
else:
return 0
def isDraw(state):
""""""
try:
state.index(0)
return False
except ValueError:
return True
def isLegal(state, action):
""""""
if state[action] == 0:
return True
else:
return False
def isWon(state, who):
"""Test if a tic-tac-toe game has been won.
Assumes that the board is in a legal state.
Will test if the value 1 is in any winning combination.
"""
for w in WINS:
S = sum(1 if (w[k] == state[k] == who) else 0
for k in range(ACTIONS))
if S == 3:
# We have a win
return True
# There were no wins so return False
return False
def isValid(state):
""""""
# S1 is the sum of the player's cells
S1 = sum(1 if x == PLAYER else 0 for x in state)
# S2 is the sum of the opponent's cells
S2 = sum(1 if x == OPPONENT else 0 for x in state)
if (S1, S2) in OWNED_CELLS:
return True
else:
return False
if __name__ == "__main__":
P, R = getTransitionAndRewardArrays()
ttt = mdp.ValueIteration(P, R, 1)
print(ttt.policy)