Last active
August 16, 2019 06:39
-
-
Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.
Revisions
-
kastnerkyle revised this gist
Dec 29, 2017 . 1 changed file with 372 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1 +1,372 @@ # Based on tutorial from https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/ # Author: Kyle Kastner # License: BSD 3-Clause from __future__ import print_function import random import copy import numpy as np import time import argparse import sys global_random = np.random.RandomState(1989) class Board(object): def __init__(self): self.player_symbols = ["X", "O"] def start(self): board = [" "] * 9 return tuple(board) def current_player(self, board): player = None player_counts = [0, 0] for n in range(len(self.player_symbols)): all_syms = [b for b in board if b == self.player_symbols[n]] player_counts[n] = len(all_syms) if player_counts[0] == player_counts[1]: return 1 else: return 2 def is_available(self, board, move): if board[move] != " ": return False else: return True def legal_moves(self, board): move_opts = [0, 1, 2, 3, 4, 5, 6, 7, 8] return [mo for mo in move_opts if self.is_available(board, mo)] def next_state(self, board, move): new_board = copy.copy(list(board)) player = self.current_player(board) new_board[move] = self.player_symbols[player - 1] return tuple(new_board) def is_complete(self, board_history): # -1 tie # 0 continue # 1 player1 win # 2 player2 win board = board_history[-1] board_not_full = len([b for b in board if b != " "]) != len(board) game_won = False winner = "" # simple check for wins # horizontal if board[0] == board[1] == board[2] and board[0] != " ": game_won = True winner = board[0] elif board[3] == board[4] == board[5] and board[3] != " ": game_won = True winner = board[3] elif board[6] == board[7] == board[8] and board[6] != " ": game_won = True winner = board[6] # vertical elif board[0] == board[3] == board[6] and board[0] != " ": game_won = True winner = board[0] elif board[1] == board[4] == board[7] and board[1] != " ": game_won = True winner = board[1] elif board[2] == board[5] == board[8] and board[2] != " ": game_won = True winner = board[2] # diagonal elif board[0] == board[4] == board[8] and board[0] != " ": game_won = True winner = board[0] elif board[2] == board[4] == board[6] and board[2] != " ": game_won = True winner = board[2] if board_not_full: if game_won: if winner == self.player_symbols[0]: return 1 else: return 2 else: return 0 elif not board_not_full: if game_won: if winner == self.player_symbols[0]: return 1 else: return 2 else: return -1 def draw(self, board): tmp_board = copy.copy(list(board)) for i in range(len(tmp_board)): if tmp_board[i] == " ": tmp_board[i] = "({})".format(i) else: tmp_board[i] = " {} ".format(tmp_board[i]) print(' | |') print(' ' + tmp_board[6] + ' | ' + tmp_board[7] + ' | ' + tmp_board[8]) print(' | |') print('----------------') print(' | |') print(' ' + tmp_board[3] + ' | ' + tmp_board[4] + ' | ' + tmp_board[5]) print(' | |') print('----------------') print(' | |') print(' ' + tmp_board[0] + ' | ' + tmp_board[1] + ' | ' + tmp_board[2]) print(' | |') print("") class MCTS(object): def __init__(self, board, state_history=None, runtime_s=10, horizon=100, ucb_weight=1.4, verbose=False): # policy can be ucb1, random self.board = board self.runtime_s = runtime_s self.horizon = horizon self.ucb_weight = ucb_weight self.plays = {} self.rewards = {} self.verbose = verbose if state_history is None: self.state_history = [board.start()] else: self.state_history = state_history def update(self, state): self.state_history.append(state) def estimate(self, policy="uct"): if policy == "uct": pass elif policy == "random": pass else: raise ValueError("Unknown value policy={}".format(policy)) self.max_depth = 0 state = self.state_history[-1] player = self.board.current_player(self.state_history[-1]) moves = self.board.legal_moves(self.state_history[-1]) if len(moves) == 0: return elif len(moves) == 1: return moves[0] games = 0 start_time = time.time() last_t = start_time print("Player {}({})'s turn".format(player, self.board.player_symbols[player - 1])) while time.time() - start_time < self.runtime_s: this_t = time.time() if this_t - last_t > 1: last_t = this_t print("Calculating...") if policy == "uct": self.rollout_uct() elif policy == "random": self.rollout_random() else: raise ValueError("Unknown value policy={}".format(policy)) games += 1 print("") end_time = time.time() - start_time moves_states = [(m, self.board.next_state(state, m)) for m in moves] if self.verbose: print("Number of sim games {}, total time {}".format(games, end_time)) percent, move = max( (self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), p) for p, S in moves_states) if self.verbose: for x in sorted( ((100 * self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), self.plays.get((player, S), 1), self.rewards.get((player, S), 0), p) for p, S in moves_states), reverse=True): print("{3}: {0:.2f}% ({2} / {1})".format(*x)) return move def rollout_uct(self): plays, rewards = self.plays, self.rewards exploration = self.ucb_weight visited_states = {} states_copy = copy.copy(self.state_history) state = states_copy[-1] player = self.board.current_player(states_copy[-1]) expand = True for t in range(self.horizon + 1): moves = self.board.legal_moves(states_copy[-1]) moves_states = [(p, self.board.next_state(state, p)) for p in moves] if all(plays.get((player, S)) for p, S in moves_states): log_total = np.log(sum(plays[(player, S)] for p, S in moves_states)) basic_move_triples = [((rewards[(player, S)] / float(plays[(player, S)])) + exploration * np.sqrt(log_total / float(plays[(player, S)])), p, S) for p, S in moves_states] value, move, state = max(basic_move_triples) else: # random move global_random.shuffle(moves) move = moves[0] state = self.board.next_state(states_copy[-1], move) states_copy.append(state) if expand and (player, state) not in self.plays: expand = False self.plays[(player, state)] = 0 self.rewards[(player, state)] = 0 if t > self.max_depth: self.max_depth = t visited_states[(player, state)] = None player = self.board.current_player(states_copy[-1]) complete = self.board.is_complete(states_copy) if complete != 0: break for player, state in visited_states.keys(): if (player, state) not in self.plays: continue self.plays[(player, state)] += 1 if player == complete: self.rewards[(player, state)] += 1 # ties if complete == -1: self.rewards[(player, state)] += 1 def rollout_random(self): visited_states = {} states_copy = copy.copy(self.state_history) state = states_copy[-1] player = self.board.current_player(states_copy[-1]) expand = True for t in range(self.horizon + 1): moves = self.board.legal_moves(states_copy[-1]) global_random.shuffle(moves) move = moves[0] state = self.board.next_state(states_copy[-1], move) states_copy.append(state) if expand and (player, state) not in self.plays: expand = False self.plays[(player, state)] = 0 self.rewards[(player, state)] = 0 if t > self.max_depth: self.max_depth = t visited_states[(player, state)] = None player = self.board.current_player(states_copy[-1]) complete = self.board.is_complete(states_copy) if complete != 0: break for player, state in visited_states.keys(): if (player, state) not in self.plays: continue self.plays[(player, state)] += 1 if player == complete: self.rewards[(player, state)] += 1 # ties if complete == -1: self.rewards[(player, state)] += 1 if __name__ == "__main__": parser = argparse.ArgumentParser(description="Demo of player/MCTS for TIC TAC TOE. Use -i to play against the machine, or -a to make the machine play itself.") parser.add_argument("-i", "--interactive", action="store_true", default=False, help="Play against the computer yourself. For a beatable computer, try -r .1 or lower") parser.add_argument("-a", "--automatic", action="store_true", default=False, help="Play the computer against itself.") parser.add_argument("-r", "--roundtime", type=float, default=3) parser.add_argument("-n", "--no_verbose", action="store_false", default=True) args = parser.parse_args() automatic = args.automatic interactive = args.interactive if not automatic and not interactive: parser.print_help() sys.exit(1) if automatic and interactive: print("Must choose either -i or -a, not both!") sys.exit(1) while True: board = Board() roundtime = args.roundtime verbose = args.no_verbose mcts1 = MCTS(board, runtime_s=roundtime, verbose=verbose) mcts2 = MCTS(board, runtime_s=roundtime, verbose=verbose) board_history = [board.start()] if not args.interactive: # randomly switch player order (still called 1, 2 but symbols swap) if global_random.randint(18888) % 2: board.player_symbols = ["O", "X"] print("Player {} ({}), Player {}, ({})".format(1, board.player_symbols[0], 2, board.player_symbols[1])) # inner game loop while True: complete = board.is_complete(board_history) board.draw(board_history[-1]) if complete == 0: player = board.current_player(board_history[-1]) if player == 1: if args.interactive: move = "100" get_move = True while get_move: print("Human player {}, next move? (0-8)".format(1)) move = raw_input() move_opts = ["0", "1", "2", "3", "4", "5", "6", "7", "8"] if board.is_available(board_history[-1], int(move)) and str(move) in move_opts: get_move = False break move = int(move) else: move = mcts1.estimate(policy="uct") else: move = mcts2.estimate(policy="uct") """ # True random, naive moves moves = board.legal_moves(board_history[-1]) global_random.shuffle(moves) move = moves[0] """ next_board = board.next_state(board_history[-1], move) board_history.append(next_board) mcts1.update(next_board) mcts2.update(next_board) else: # evaluate print("Game over!") if complete == 1: print("Player 1 wins.") elif complete == 2: print("Player 2 wins.") elif complete == -1: print("Tie game.") break print("Play again (y/n)?") choice = raw_input() if str(choice) != "y": print("Thanks for playing!") break -
kastnerkyle created this gist
Dec 29, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1 @@ None