Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active August 16, 2019 06:39
Show Gist options
  • Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.
Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.

Revisions

  1. kastnerkyle revised this gist Dec 29, 2017. 1 changed file with 372 additions and 1 deletion.
    373 changes: 372 additions & 1 deletion mcts_tictactoe.py
    Original file line number Diff line number Diff line change
    @@ -1 +1,372 @@
    None
    # Based on tutorial from https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/
    # Author: Kyle Kastner
    # License: BSD 3-Clause
    from __future__ import print_function
    import random
    import copy
    import numpy as np
    import time
    import argparse
    import sys

    global_random = np.random.RandomState(1989)

    class Board(object):
    def __init__(self):
    self.player_symbols = ["X", "O"]

    def start(self):
    board = [" "] * 9
    return tuple(board)

    def current_player(self, board):
    player = None
    player_counts = [0, 0]
    for n in range(len(self.player_symbols)):
    all_syms = [b for b in board if b == self.player_symbols[n]]
    player_counts[n] = len(all_syms)

    if player_counts[0] == player_counts[1]:
    return 1
    else:
    return 2

    def is_available(self, board, move):
    if board[move] != " ":
    return False
    else:
    return True

    def legal_moves(self, board):
    move_opts = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    return [mo for mo in move_opts if self.is_available(board, mo)]

    def next_state(self, board, move):
    new_board = copy.copy(list(board))
    player = self.current_player(board)
    new_board[move] = self.player_symbols[player - 1]
    return tuple(new_board)

    def is_complete(self, board_history):
    # -1 tie
    # 0 continue
    # 1 player1 win
    # 2 player2 win
    board = board_history[-1]
    board_not_full = len([b for b in board if b != " "]) != len(board)
    game_won = False
    winner = ""
    # simple check for wins
    # horizontal
    if board[0] == board[1] == board[2] and board[0] != " ":
    game_won = True
    winner = board[0]
    elif board[3] == board[4] == board[5] and board[3] != " ":
    game_won = True
    winner = board[3]
    elif board[6] == board[7] == board[8] and board[6] != " ":
    game_won = True
    winner = board[6]
    # vertical
    elif board[0] == board[3] == board[6] and board[0] != " ":
    game_won = True
    winner = board[0]
    elif board[1] == board[4] == board[7] and board[1] != " ":
    game_won = True
    winner = board[1]
    elif board[2] == board[5] == board[8] and board[2] != " ":
    game_won = True
    winner = board[2]
    # diagonal
    elif board[0] == board[4] == board[8] and board[0] != " ":
    game_won = True
    winner = board[0]
    elif board[2] == board[4] == board[6] and board[2] != " ":
    game_won = True
    winner = board[2]

    if board_not_full:
    if game_won:
    if winner == self.player_symbols[0]:
    return 1
    else:
    return 2
    else:
    return 0
    elif not board_not_full:
    if game_won:
    if winner == self.player_symbols[0]:
    return 1
    else:
    return 2
    else:
    return -1

    def draw(self, board):
    tmp_board = copy.copy(list(board))
    for i in range(len(tmp_board)):
    if tmp_board[i] == " ":
    tmp_board[i] = "({})".format(i)
    else:
    tmp_board[i] = " {} ".format(tmp_board[i])
    print(' | |')
    print(' ' + tmp_board[6] + ' | ' + tmp_board[7] + ' | ' + tmp_board[8])
    print(' | |')
    print('----------------')
    print(' | |')
    print(' ' + tmp_board[3] + ' | ' + tmp_board[4] + ' | ' + tmp_board[5])
    print(' | |')
    print('----------------')
    print(' | |')
    print(' ' + tmp_board[0] + ' | ' + tmp_board[1] + ' | ' + tmp_board[2])
    print(' | |')
    print("")

    class MCTS(object):
    def __init__(self, board, state_history=None, runtime_s=10, horizon=100,
    ucb_weight=1.4,
    verbose=False):
    # policy can be ucb1, random
    self.board = board
    self.runtime_s = runtime_s
    self.horizon = horizon
    self.ucb_weight = ucb_weight
    self.plays = {}
    self.rewards = {}
    self.verbose = verbose
    if state_history is None:
    self.state_history = [board.start()]
    else:
    self.state_history = state_history

    def update(self, state):
    self.state_history.append(state)

    def estimate(self, policy="uct"):
    if policy == "uct":
    pass
    elif policy == "random":
    pass
    else:
    raise ValueError("Unknown value policy={}".format(policy))

    self.max_depth = 0
    state = self.state_history[-1]
    player = self.board.current_player(self.state_history[-1])
    moves = self.board.legal_moves(self.state_history[-1])

    if len(moves) == 0:
    return
    elif len(moves) == 1:
    return moves[0]

    games = 0
    start_time = time.time()
    last_t = start_time
    print("Player {}({})'s turn".format(player, self.board.player_symbols[player - 1]))
    while time.time() - start_time < self.runtime_s:
    this_t = time.time()
    if this_t - last_t > 1:
    last_t = this_t
    print("Calculating...")
    if policy == "uct":
    self.rollout_uct()
    elif policy == "random":
    self.rollout_random()
    else:
    raise ValueError("Unknown value policy={}".format(policy))
    games += 1
    print("")
    end_time = time.time() - start_time

    moves_states = [(m, self.board.next_state(state, m)) for m in moves]
    if self.verbose:
    print("Number of sim games {}, total time {}".format(games, end_time))

    percent, move = max(
    (self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), p) for p, S in moves_states)

    if self.verbose:
    for x in sorted(
    ((100 * self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)),
    self.plays.get((player, S), 1),
    self.rewards.get((player, S), 0),
    p) for p, S in moves_states), reverse=True):

    print("{3}: {0:.2f}% ({2} / {1})".format(*x))

    return move

    def rollout_uct(self):
    plays, rewards = self.plays, self.rewards
    exploration = self.ucb_weight

    visited_states = {}
    states_copy = copy.copy(self.state_history)
    state = states_copy[-1]
    player = self.board.current_player(states_copy[-1])

    expand = True
    for t in range(self.horizon + 1):
    moves = self.board.legal_moves(states_copy[-1])
    moves_states = [(p, self.board.next_state(state, p)) for p in moves]

    if all(plays.get((player, S)) for p, S in moves_states):
    log_total = np.log(sum(plays[(player, S)] for p, S in moves_states))
    basic_move_triples = [((rewards[(player, S)] / float(plays[(player, S)])) + exploration * np.sqrt(log_total / float(plays[(player, S)])), p, S) for p, S in moves_states]
    value, move, state = max(basic_move_triples)
    else:
    # random move
    global_random.shuffle(moves)
    move = moves[0]
    state = self.board.next_state(states_copy[-1], move)

    states_copy.append(state)

    if expand and (player, state) not in self.plays:
    expand = False
    self.plays[(player, state)] = 0
    self.rewards[(player, state)] = 0
    if t > self.max_depth:
    self.max_depth = t

    visited_states[(player, state)] = None

    player = self.board.current_player(states_copy[-1])
    complete = self.board.is_complete(states_copy)
    if complete != 0:
    break

    for player, state in visited_states.keys():
    if (player, state) not in self.plays:
    continue
    self.plays[(player, state)] += 1
    if player == complete:
    self.rewards[(player, state)] += 1
    # ties
    if complete == -1:
    self.rewards[(player, state)] += 1

    def rollout_random(self):
    visited_states = {}
    states_copy = copy.copy(self.state_history)
    state = states_copy[-1]
    player = self.board.current_player(states_copy[-1])

    expand = True
    for t in range(self.horizon + 1):
    moves = self.board.legal_moves(states_copy[-1])
    global_random.shuffle(moves)
    move = moves[0]
    state = self.board.next_state(states_copy[-1], move)
    states_copy.append(state)
    if expand and (player, state) not in self.plays:
    expand = False
    self.plays[(player, state)] = 0
    self.rewards[(player, state)] = 0
    if t > self.max_depth:
    self.max_depth = t

    visited_states[(player, state)] = None

    player = self.board.current_player(states_copy[-1])
    complete = self.board.is_complete(states_copy)
    if complete != 0:
    break

    for player, state in visited_states.keys():
    if (player, state) not in self.plays:
    continue
    self.plays[(player, state)] += 1
    if player == complete:
    self.rewards[(player, state)] += 1
    # ties
    if complete == -1:
    self.rewards[(player, state)] += 1

    if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Demo of player/MCTS for TIC TAC TOE. Use -i to play against the machine, or -a to make the machine play itself.")
    parser.add_argument("-i", "--interactive", action="store_true",
    default=False,
    help="Play against the computer yourself. For a beatable computer, try -r .1 or lower")
    parser.add_argument("-a", "--automatic", action="store_true",
    default=False,
    help="Play the computer against itself.")
    parser.add_argument("-r", "--roundtime", type=float, default=3)
    parser.add_argument("-n", "--no_verbose", action="store_false", default=True)

    args = parser.parse_args()

    automatic = args.automatic
    interactive = args.interactive
    if not automatic and not interactive:
    parser.print_help()
    sys.exit(1)

    if automatic and interactive:
    print("Must choose either -i or -a, not both!")
    sys.exit(1)

    while True:
    board = Board()
    roundtime = args.roundtime
    verbose = args.no_verbose
    mcts1 = MCTS(board, runtime_s=roundtime, verbose=verbose)
    mcts2 = MCTS(board, runtime_s=roundtime, verbose=verbose)

    board_history = [board.start()]

    if not args.interactive:
    # randomly switch player order (still called 1, 2 but symbols swap)
    if global_random.randint(18888) % 2:
    board.player_symbols = ["O", "X"]
    print("Player {} ({}), Player {}, ({})".format(1, board.player_symbols[0], 2, board.player_symbols[1]))

    # inner game loop
    while True:
    complete = board.is_complete(board_history)
    board.draw(board_history[-1])
    if complete == 0:
    player = board.current_player(board_history[-1])
    if player == 1:
    if args.interactive:
    move = "100"
    get_move = True
    while get_move:
    print("Human player {}, next move? (0-8)".format(1))
    move = raw_input()
    move_opts = ["0", "1", "2", "3", "4", "5", "6", "7", "8"]
    if board.is_available(board_history[-1], int(move)) and str(move) in move_opts:
    get_move = False
    break
    move = int(move)
    else:
    move = mcts1.estimate(policy="uct")
    else:
    move = mcts2.estimate(policy="uct")
    """
    # True random, naive moves
    moves = board.legal_moves(board_history[-1])
    global_random.shuffle(moves)
    move = moves[0]
    """
    next_board = board.next_state(board_history[-1], move)
    board_history.append(next_board)
    mcts1.update(next_board)
    mcts2.update(next_board)
    else:
    # evaluate
    print("Game over!")
    if complete == 1:
    print("Player 1 wins.")
    elif complete == 2:
    print("Player 2 wins.")
    elif complete == -1:
    print("Tie game.")
    break

    print("Play again (y/n)?")
    choice = raw_input()
    if str(choice) != "y":
    print("Thanks for playing!")
    break
  2. kastnerkyle created this gist Dec 29, 2017.
    1 change: 1 addition & 0 deletions mcts_tictactoe.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1 @@
    None