-
-
Save gurusura/a839266a245971cda9d71f1f08061b6d to your computer and use it in GitHub Desktop.
Revisions
-
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 16 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -108,7 +108,7 @@ def show_q(): # Core algorithm gamma = 0.8 alpha = 1. n_episodes = 1E3 n_states = 6 n_actions = 6 epsilon = 0.05 @@ -125,22 +125,25 @@ def show_q(): #show_q() while not goal: # epsilon greedy valid_moves = r[current_state] >= 0 if random_state.rand() < epsilon: actions = np.array(list(range(n_actions))) actions = actions[valid_moves == True] if type(actions) is int: actions = [actions] random_state.shuffle(actions) action = actions[0] next_state = action else: if np.sum(q[current_state]) > 0: action = np.argmax(q[current_state]) else: # Don't allow invalid moves at the start # Just take a random move actions = np.array(list(range(n_actions))) actions = actions[valid_moves == True] random_state.shuffle(actions) action = actions[0] next_state = action reward = update_q(current_state, next_state, action, alpha=alpha, gamma=gamma) -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -72,7 +72,7 @@ def show_q(): for start, stop in zip(start_idx, end_idx)] values = np.array(q[q > 0]) # bump up values for viz values = values lc = LineCollection(segments, zorder=0, cmap=plt.cm.hot_r) lc.set_array(values) @@ -121,8 +121,8 @@ def show_q(): if e % int(n_episodes / 10.) == 0 and e > 0: pass # uncomment this to see plots each monitoring #show_traverse() #show_q() while not goal: # epsilon greedy if random_state.rand() < epsilon: -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 7 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -25,6 +25,9 @@ def update_q(state, next_state, action, alpha, gamma): qsa = q[state, action] new_q = qsa + alpha * (rsa + gamma * max(q[next_state, :]) - qsa) q[state, action] = new_q # renormalize row to be between 0 and 1 rn = q[state][q[state] > 0] / np.sum(q[state][q[state] > 0]) q[state][q[state] > 0] = rn return r[state, action] @@ -105,7 +108,7 @@ def show_q(): # Core algorithm gamma = 0.8 alpha = 1. n_episodes = 20 n_states = 6 n_actions = 6 epsilon = 0.05 @@ -118,8 +121,8 @@ def show_q(): if e % int(n_episodes / 10.) == 0 and e > 0: pass # uncomment this to see plots each monitoring show_traverse() show_q() while not goal: # epsilon greedy if random_state.rand() < epsilon: @@ -146,5 +149,6 @@ def show_q(): goal = True current_state = next_state print(q) show_traverse() show_q() -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 15 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,7 @@ # Author: Kyle Kastner # License: BSD 3-Clause # Implementing http://mnemstudio.org/path-finding-q-learning-tutorial.htm # Q-learning formula from http://sarvagyavaish.github.io/FlappyBirdRL/ # Visualization based on code from Gael Varoquaux [email protected] # http://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html @@ -8,6 +10,7 @@ from matplotlib.collections import LineCollection # defines the reward/connection graph r = np.array([[-1, -1, -1, -1, 0, -1], [-1, -1, -1, 0, -1, 100], [-1, -1, -1, 0, -1, -1], @@ -17,8 +20,10 @@ q = np.zeros_like(r) def update_q(state, next_state, action, alpha, gamma): rsa = r[state, action] qsa = q[state, action] new_q = qsa + alpha * (rsa + gamma * max(q[next_state, :]) - qsa) q[state, action] = new_q return r[state, action] @@ -90,15 +95,16 @@ def show_q(): x = x + .05 plt.text(x, y, name, size=10, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, bbox=dict(facecolor='w', edgecolor=plt.cm.spectral(float(len(coords))), alpha=.6)) plt.show() # Core algorithm gamma = 0.8 alpha = 1. n_episodes = 50 n_states = 6 n_actions = 6 @@ -133,7 +139,8 @@ def show_q(): else: action = np.argmax(q[current_state, :]) next_state = action reward = update_q(current_state, next_state, action, alpha=alpha, gamma=gamma) # Goal state has reward 100 if reward > 1: goal = True -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 87 additions and 71 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -22,8 +22,84 @@ def update_q(state, next_state, action, gamma): q[state, action] = new_q return r[state, action] def show_traverse(): # show all the greedy traversals for i in range(len(q)): current_state = i traverse = "%i -> " % current_state n_steps = 0 while current_state != 5 and n_steps < 20: next_state = np.argmax(q[current_state]) current_state = next_state traverse += "%i -> " % current_state n_steps = n_steps + 1 # cut off final arrow traverse = traverse[:-4] print("Greedy traversal for starting state %i" % i) print(traverse) print("") def show_q(): # show all the valid/used transitions coords = np.array([[2, 2], [4, 2], [5, 3], [4, 4], [2, 4], [5, 2]]) # invert y axis for display coords[:, 1] = max(coords[:, 1]) - coords[:, 1] plt.figure(1, facecolor='w', figsize=(10, 8)) plt.clf() ax = plt.axes([0., 0., 1., 1.]) plt.axis('off') plt.scatter(coords[:, 0], coords[:, 1], c='r') start_idx, end_idx = np.where(q > 0) segments = [[coords[start], coords[stop]] for start, stop in zip(start_idx, end_idx)] values = np.array(q[q > 0]) # bump up values for viz values = values / 50 lc = LineCollection(segments, zorder=0, cmap=plt.cm.hot_r) lc.set_array(values) ax.add_collection(lc) verticalalignment = 'top' horizontalalignment = 'left' for i in range(len(coords)): x = coords[i][0] y = coords[i][1] name = str(i) if i == 1: y = y - .05 x = x + .05 elif i == 3: y = y - .05 x = x + .05 elif i == 4: y = y - .05 x = x + .05 else: y = y + .05 x = x + .05 plt.text(x, y, name, size=10, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, bbox=dict(facecolor='w', edgecolor=plt.cm.spectral(float(len(coords))), alpha=.6)) plt.show() # Core algorithm gamma = 0.8 n_episodes = 50 n_states = 6 n_actions = 6 epsilon = 0.05 @@ -33,6 +109,11 @@ def update_q(state, next_state, action, gamma): random_state.shuffle(states) current_state = states[0] goal = False if e % int(n_episodes / 10.) == 0 and e > 0: pass # uncomment this to see plots each monitoring #show_traverse() #show_q() while not goal: # epsilon greedy if random_state.rand() < epsilon: @@ -42,6 +123,8 @@ def update_q(state, next_state, action, gamma): # shuffle to match actions valid_moves = valid_moves[actions] for i in range(len(valid_moves)): # choose the first action which is valid # out of the shuffled list if valid_moves[i] == True: action = actions[i] break @@ -51,77 +134,10 @@ def update_q(state, next_state, action, gamma): action = np.argmax(q[current_state, :]) next_state = action reward = update_q(current_state, next_state, action, gamma=gamma) # Goal state has reward 100 if reward > 1: goal = True current_state = next_state show_traverse() show_q() -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 2 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -55,7 +55,7 @@ def update_q(state, next_state, action, gamma): goal = True current_state = next_state # show all the greedy traversals for i in range(n_states): current_state = i traverse = "%i -> " % current_state @@ -70,6 +70,7 @@ def update_q(state, next_state, action, gamma): print(traverse) print("") # show all the valid/used transitions coords = np.array([[2, 2], [4, 2], [5, 3], -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 57 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,8 +1,12 @@ # Author: Kyle Kastner # License: BSD 3-Clause # Visualization based on code from Gael Varoquaux [email protected] # http://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html import numpy as np import matplotlib.pyplot as plt from matplotlib.collections import LineCollection r = np.array([[-1, -1, -1, -1, 0, -1], [-1, -1, -1, 0, -1, 100], @@ -66,5 +70,57 @@ def update_q(state, next_state, action, gamma): print(traverse) print("") coords = np.array([[2, 2], [4, 2], [5, 3], [4, 4], [2, 4], [5, 2]]) # invert y axis for display coords[:, 1] = max(coords[:, 1]) - coords[:, 1] plt.figure(1, facecolor='w', figsize=(10, 8)) plt.clf() ax = plt.axes([0., 0., 1., 1.]) plt.axis('off') plt.scatter(coords[:, 0], coords[:, 1], c='r') start_idx, end_idx = np.where(q > 0) segments = [[coords[start], coords[stop]] for start, stop in zip(start_idx, end_idx)] values = np.array(q[q > 0]) # bump up values for viz values = 10 + values / values.max() lc = LineCollection(segments, zorder=0, cmap=plt.cm.hot_r) lc.set_array(values) lc.set_linewidths(.3 * values) ax.add_collection(lc) verticalalignment = 'top' horizontalalignment = 'left' for i in range(len(coords)): x = coords[i][0] y = coords[i][1] name = str(i) if i == 1: y = y - .05 x = x + .05 elif i == 3: y = y - .05 x = x + .05 elif i == 4: y = y - .05 x = x + .05 else: y = y + .05 x = x + .05 plt.text(x, y, name, size=10, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, bbox=dict(facecolor='w', edgecolor=plt.cm.spectral(float(len(coords))), alpha=.6)) plt.show() -
kastnerkyle revised this gist
May 29, 2016 . 1 changed file with 69 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,2 +1,70 @@ # Author: Kyle Kastner # License: BSD 3-Clause import numpy as np import matplotlib.pyplot as plt r = np.array([[-1, -1, -1, -1, 0, -1], [-1, -1, -1, 0, -1, 100], [-1, -1, -1, 0, -1, -1], [-1, 0, 0, -1, 0, -1], [ 0, -1, -1, 0, -1, 100], [-1, 0, -1, -1, 0, 100]]).astype("float32") q = np.zeros_like(r) def update_q(state, next_state, action, gamma): new_q = r[state, action] + gamma * max(q[next_state, :]) q[state, action] = new_q return r[state, action] gamma = 0.8 n_episodes = 1E4 n_states = 6 n_actions = 6 epsilon = 0.05 random_state = np.random.RandomState(1999) for e in range(int(n_episodes)): states = list(range(n_states)) random_state.shuffle(states) current_state = states[0] goal = False while not goal: # epsilon greedy if random_state.rand() < epsilon: valid_moves = r[current_state] >= 0 actions = list(range(n_actions)) random_state.shuffle(actions) # shuffle to match actions valid_moves = valid_moves[actions] for i in range(len(valid_moves)): if valid_moves[i] == True: action = actions[i] break # action is the move to next state next_state = action else: action = np.argmax(q[current_state, :]) next_state = action reward = update_q(current_state, next_state, action, gamma=gamma) if reward > 1: goal = True current_state = next_state for i in range(n_states): current_state = i traverse = "%i -> " % current_state n_steps = 0 while current_state != 5: next_state = np.argmax(q[current_state]) current_state = next_state traverse += "%i -> " % current_state # cut off final arrow traverse = traverse[:-4] print("Greedy traversal for starting state %i" % i) print(traverse) print("") plt.matshow(q) plt.show() -
kastnerkyle created this gist
Mar 6, 2016 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,2 @@ # Author: Kyle Kastner # License: BSD 3-Clause