Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Zahra Rajabi
pymdptoolbox
Commits
5b2b87c9
Commit
5b2b87c9
authored
Mar 13, 2014
by
Steven Cordwell
Browse files
Merge branch 'tictactoe'
Conflicts: src/examples/firemdp.py
parents
76bacbab
9583a8ba
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/examples/firemdp.py
View file @
5b2b87c9
...
...
@@ -386,4 +386,3 @@ if __name__ == "__main__":
else
:
sdp
=
solveMDP
()
printPolicy
(
sdp
.
policy
[:,
0
])
src/examples/tictactoe.py
View file @
5b2b87c9
# -*- coding: utf-8 -*-
#import mdp
import
numpy
as
np
from
scipy.sparse
import
dok_matrix
def
str_base
(
num
,
base
,
numerals
=
'0123456789abcdefghijklmnopqrstuvwxyz'
):
if
base
<
2
or
base
>
len
(
numerals
):
raise
ValueError
(
"str_base: base must be between 2 and %i"
%
len
(
numerals
))
from
mdptoolbox
import
mdp
ACTIONS
=
9
STATES
=
3
**
ACTIONS
PLAYER
=
1
OPPONENT
=
2
WINS
=
([
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
],
[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
])
# The valid number of cells belonging to either the player or the opponent:
# (player, opponent)
OWNED_CELLS
=
((
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
),
(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
3
,
4
))
def
convertIndexToTuple
(
state
):
""""""
return
(
tuple
(
int
(
x
)
for
x
in
np
.
base_repr
(
state
,
3
,
9
)[
-
9
::]))
def
convertTupleToIndex
(
state
):
""""""
return
(
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
))
def
getLegalActions
(
state
):
""""""
return
(
tuple
(
x
for
x
in
range
(
ACTIONS
)
if
state
[
x
]
==
0
))
def
getTransitionAndRewardArrays
():
""""""
P
=
[
dok_matrix
((
STATES
,
STATES
))
for
a
in
range
(
ACTIONS
)]
#R = spdok((STATES, ACTIONS))
R
=
np
.
zeros
((
STATES
,
ACTIONS
))
# Naive approach, iterate through all possible combinations
for
a
in
range
(
ACTIONS
):
for
s
in
range
(
STATES
):
state
=
convertIndexToTuple
(
s
)
if
not
isValid
(
state
):
# There are no defined moves from an invalid state, so
# transition probabilities cannot be calculated. However,
# P must be a square stochastic matrix, so assign a
# probability of one to the invalid state transitioning
# back to itself.
P
[
a
][
s
,
s
]
=
1
# Reward is 0
else
:
s1
,
p
,
r
=
getTransitionProbabilities
(
state
,
a
)
P
[
a
][
s
,
s1
]
=
p
R
[
s
,
a
]
=
r
P
[
a
]
=
P
[
a
].
tocsr
()
#R = R.tolil()
return
(
P
,
R
)
def
getTransitionProbabilities
(
state
,
action
):
"""
Parameters
----------
state : tuple
The state
action : int
The action
if
num
==
0
:
return
'0'
Returns
-------
s1, p, r : tuple of two lists and an int
s1 are the next states, p are the probabilities, and r is the reward
if
num
<
0
:
sign
=
'-'
num
=
-
num
"""
#assert isValid(state)
assert
0
<=
action
<
ACTIONS
if
not
isLegal
(
state
,
action
):
# If the action is illegal, then transition back to the same state but
# incur a high negative reward
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
-
10
)
# Update the state with the action
state
=
list
(
state
)
state
[
action
]
=
PLAYER
if
isWon
(
state
,
PLAYER
):
# If the player's action is a winning move then transition to the
# winning state and receive a reward of 1.
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
1
)
elif
isDraw
(
state
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
0
)
# Now we search through the opponents moves, and calculate transition
# probabilities based on maximising the opponents chance of winning..
s1
=
[]
p
=
[]
legal_a
=
getLegalActions
(
state
)
for
a
in
legal_a
:
state
[
a
]
=
OPPONENT
# If the opponent is going to win, we assume that the winning move will
# be chosen:
if
isWon
(
state
,
OPPONENT
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
-
1
)
elif
isDraw
(
state
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
0
)
# Otherwise we assume the opponent will select a move with uniform
# probability across potential moves:
s1
.
append
(
convertTupleToIndex
(
state
))
p
.
append
(
1.0
/
len
(
legal_a
))
state
[
a
]
=
0
# During non-terminal play states the reward is 0.
return
(
s1
,
p
,
0
)
def
getReward
(
state
,
action
):
""""""
if
not
isLegal
(
state
,
action
):
return
-
100
state
=
list
(
state
)
state
[
action
]
=
PLAYER
if
isWon
(
state
,
PLAYER
):
return
1
elif
isWon
(
state
,
OPPONENT
):
return
-
1
else
:
sign
=
''
result
=
''
while
num
:
result
=
numerals
[
num
%
(
base
)]
+
result
num
//=
base
return
sign
+
result
return
0
def
isDraw
(
state
):
""""""
try
:
state
.
index
(
0
)
return
False
except
ValueError
:
return
True
class
TicTacToeMDP
(
object
):
def
isLegal
(
state
,
action
):
""""""
def
__init__
(
self
):
""""""
self
.
P
=
[
None
]
*
9
for
a
in
xrange
(
9
):
self
.
P
[
a
]
=
{}
self
.
R
=
{}
# some board states are equal, just rotations of other states
self
.
rotorder
=
[]
#self.rotorder.append([0, 1, 2, 3, 4, 5, 6, 7, 8])
self
.
rotorder
.
append
([
6
,
3
,
0
,
7
,
4
,
1
,
8
,
5
,
2
])
self
.
rotorder
.
append
([
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
])
self
.
rotorder
.
append
([
2
,
5
,
8
,
1
,
4
,
7
,
0
,
3
,
6
])
# The valid number of cells belonging to either the player or the
# opponent: (player, opponent)
self
.
nXO
=
((
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
),
(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
3
,
4
))
# The winning positions
self
.
wins
=
([
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
],
[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
])
def
rotate
(
self
,
state
):
#rotations = []
identity
=
[]
#rotations.append(state)
identity
.
append
(
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
))
for
k
in
range
(
3
):
#rotations.append(tuple(state[self.rotorder[k][kk]]
# for kk in xrange(9)))
# Convert the state from base 3 number to integer.
#identity.append(int("".join(str(x) for x in rotations[k + 1]), 3))
identity
.
append
(
int
(
""
.
join
(
str
(
state
[
self
.
rotorder
[
k
][
kk
]])
for
kk
in
xrange
(
9
)),
3
))
# return the rotation with the smallest identity number
#idx = identity.index(min(identity))
#return (identity[idx], rotations[idx])
return
min
(
identity
)
def
unrotate
(
self
,
move
,
rotation
):
rotation
-=
1
# return the move
return
self
.
rotorder
[
rotation
][
move
]
def
isLegal
(
self
,
state
,
action
):
""""""
if
state
[
action
]
==
0
:
return
True
else
:
return
False
def
isWon
(
self
,
state
,
who
):
""""""
# Check to see if there are any wins
for
w
in
self
.
wins
:
S
=
sum
(
1
if
(
w
[
k
]
==
1
and
state
[
k
]
==
who
)
else
0
for
k
in
xrange
(
9
))
if
S
==
3
:
# We have a win
return
True
# There were no wins so return False
if
state
[
action
]
==
0
:
return
True
else
:
return
False
def
isWon
(
state
,
who
):
"""Test if a tic-tac-toe game has been won.
def
isDraw
(
self
,
state
):
""""""
try
:
state
.
index
(
0
)
return
False
except
ValueError
:
return
True
except
:
raise
Assumes that the board is in a legal state.
Will test if the value 1 is in any winning combination.
def
isValid
(
self
,
state
):
""""""
# S1 is the sum of the player's cells
S1
=
sum
(
1
if
x
==
1
else
0
for
x
in
state
)
# S2 is the sum of the opponent's cells
S2
=
sum
(
1
if
x
==
2
else
0
for
x
in
state
)
if
(
S1
,
S2
)
in
self
.
nXO
:
"""
for
w
in
WINS
:
S
=
sum
(
1
if
(
w
[
k
]
==
state
[
k
]
==
who
)
else
0
for
k
in
range
(
ACTIONS
))
if
S
==
3
:
# We have a win
return
True
else
:
return
False
def
getReward
(
self
,
s
):
if
self
.
isWon
(
s
,
1
):
return
1
elif
self
.
isWon
(
s
,
2
):
return
-
1
else
:
return
0
def
run
(
self
):
""""""
l
=
(
0
,
1
,
2
)
# Iterate through a generator of all the combinations
for
s
in
((
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
)
for
a0
in
l
for
a1
in
l
for
a2
in
l
for
a3
in
l
for
a4
in
l
for
a5
in
l
for
a6
in
l
for
a7
in
l
for
a8
in
l
):
if
self
.
isValid
(
s
):
s_idn
=
self
.
rotate
(
s
)
if
not
self
.
R
.
has_key
(
s_idn
):
self
.
R
[
s_idn
]
=
self
.
getReward
(
s
)
self
.
transition
(
s
)
# Convert P and R to ijv lists
# Iterate through up to the theorectically maxmimum value of s
for
s
in
xrange
(
int
(
'222211110'
,
3
)):
print
s
# return (P, R)
def
toTuple
(
self
,
state
):
""""""
state
=
str_base
(
state
,
3
)
state
=
''
.
join
(
'0'
for
x
in
range
(
9
-
len
(
state
)))
+
state
return
tuple
(
int
(
x
)
for
x
in
state
)
def
transition
(
self
,
state
):
""""""
#TODO: the state needs to be rotated before anything else is done!!!
idn_s
=
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
)
legal_a
=
[
x
for
x
in
xrange
(
9
)
if
state
[
x
]
==
0
]
for
a
in
legal_a
:
s
=
[
x
for
x
in
state
]
s
[
a
]
=
1
is_won
=
self
.
isWon
(
s
,
1
)
legal_m
=
[
x
for
x
in
xrange
(
9
)
if
s
[
x
]
==
0
]
for
m
in
legal_m
:
s_new
=
[
x
for
x
in
s
]
s_new
[
m
]
=
2
idn_s_new
=
self
.
rotate
(
s_new
)
if
not
self
.
P
[
a
].
has_key
((
idn_s
,
idn_s_new
)):
self
.
P
[
a
][(
idn_s
,
idn_s_new
)]
=
len
(
legal_m
)
# There were no wins so return False
return
False
def
isValid
(
state
):
""""""
# S1 is the sum of the player's cells
S1
=
sum
(
1
if
x
==
PLAYER
else
0
for
x
in
state
)
# S2 is the sum of the opponent's cells
S2
=
sum
(
1
if
x
==
OPPONENT
else
0
for
x
in
state
)
if
(
S1
,
S2
)
in
OWNED_CELLS
:
return
True
else
:
return
False
if
__name__
==
"__main__"
:
P
,
R
=
TicTacToeMDP
().
run
()
#ttt = mdp.ValueIteration(P, R, 1)
P
,
R
=
getTransitionAndRewardArrays
()
ttt
=
mdp
.
ValueIteration
(
P
,
R
,
1
)
print
(
ttt
.
policy
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment