# -*- coding: utf-8 -*- import numpy as np from scipy.sparse import dok_matrix from mdptoolbox import mdp ACTIONS = 9 STATES = 3**ACTIONS PLAYER = 1 OPPONENT = 2 WINS = ([1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 1], [1, 0, 0, 1, 0, 0, 1, 0, 0], [0, 1, 0, 0, 1, 0, 0, 1, 0], [0, 0, 1, 0, 0, 1, 0, 0, 1], [1, 0, 0, 0, 1, 0, 0, 0, 1], [0, 0, 1, 0, 1, 0, 1, 0, 0]) # The valid number of cells belonging to either the player or the opponent: # (player, opponent) OWNED_CELLS = ((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (0, 1), (1, 2), (2, 3), (3, 4)) def convertIndexToTuple(state): """""" return(tuple(int(x) for x in np.base_repr(state, 3, 9)[-9::])) def convertTupleToIndex(state): """""" return(int("".join(str(x) for x in state), 3)) def getLegalActions(state): """""" return(tuple(x for x in range(ACTIONS) if state[x] == 0)) def getTransitionAndRewardArrays(): """""" P = [dok_matrix((STATES, STATES)) for a in range(ACTIONS)] #R = spdok((STATES, ACTIONS)) R = np.zeros((STATES, ACTIONS)) # Naive approach, iterate through all possible combinations for a in range(ACTIONS): for s in range(STATES): state = convertIndexToTuple(s) if not isValid(state): # There are no defined moves from an invalid state, so # transition probabilities cannot be calculated. However, # P must be a square stochastic matrix, so assign a # probability of one to the invalid state transitioning # back to itself. P[a][s, s] = 1 # Reward is 0 else: s1, p, r = getTransitionProbabilities(state, a) P[a][s, s1] = p R[s, a] = r P[a] = P[a].tocsr() #R = R.tolil() return(P, R) def getTransitionProbabilities(state, action): """ Parameters ---------- state : tuple The state action : int The action Returns ------- s1, p, r : tuple of two lists and an int s1 are the next states, p are the probabilities, and r is the reward """ #assert isValid(state) assert 0 <= action < ACTIONS if not isLegal(state, action): # If the action is illegal, then transition back to the same state but # incur a high negative reward s1 = [convertTupleToIndex(state)] return(s1, [1], -10) # Update the state with the action state = list(state) state[action] = PLAYER if isWon(state, PLAYER): # If the player's action is a winning move then transition to the # winning state and receive a reward of 1. s1 = [convertTupleToIndex(state)] return(s1, [1], 1) elif isDraw(state): s1 = [convertTupleToIndex(state)] return(s1, [1], 0) # Now we search through the opponents moves, and calculate transition # probabilities based on maximising the opponents chance of winning.. s1 = [] p = [] legal_a = getLegalActions(state) for a in legal_a: state[a] = OPPONENT # If the opponent is going to win, we assume that the winning move will # be chosen: if isWon(state, OPPONENT): s1 = [convertTupleToIndex(state)] return(s1, [1], -1) elif isDraw(state): s1 = [convertTupleToIndex(state)] return(s1, [1], 0) # Otherwise we assume the opponent will select a move with uniform # probability across potential moves: s1.append(convertTupleToIndex(state)) p.append(1.0 / len(legal_a)) state[a] = 0 # During non-terminal play states the reward is 0. return(s1, p, 0) def getReward(state, action): """""" if not isLegal(state, action): return -100 state = list(state) state[action] = PLAYER if isWon(state, PLAYER): return 1 elif isWon(state, OPPONENT): return -1 else: return 0 def isDraw(state): """""" try: state.index(0) return False except ValueError: return True def isLegal(state, action): """""" if state[action] == 0: return True else: return False def isWon(state, who): """Test if a tic-tac-toe game has been won. Assumes that the board is in a legal state. Will test if the value 1 is in any winning combination. """ for w in WINS: S = sum(1 if (w[k] == state[k] == who) else 0 for k in range(ACTIONS)) if S == 3: # We have a win return True # There were no wins so return False return False def isValid(state): """""" # S1 is the sum of the player's cells S1 = sum(1 if x == PLAYER else 0 for x in state) # S2 is the sum of the opponent's cells S2 = sum(1 if x == OPPONENT else 0 for x in state) if (S1, S2) in OWNED_CELLS: return True else: return False if __name__ == "__main__": P, R = getTransitionAndRewardArrays() ttt = mdp.ValueIteration(P, R, 1) print(ttt.policy)