Skip to content

Instantly share code, notes, and snippets.

@gurusura
Forked from kastnerkyle/painless_q.py
Created October 26, 2022 06:06
Show Gist options
  • Select an option

  • Save gurusura/a839266a245971cda9d71f1f08061b6d to your computer and use it in GitHub Desktop.

Select an option

Save gurusura/a839266a245971cda9d71f1f08061b6d to your computer and use it in GitHub Desktop.

Revisions

  1. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 16 additions and 13 deletions.
    29 changes: 16 additions & 13 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -108,7 +108,7 @@ def show_q():
    # Core algorithm
    gamma = 0.8
    alpha = 1.
    n_episodes = 20
    n_episodes = 1E3
    n_states = 6
    n_actions = 6
    epsilon = 0.05
    @@ -125,22 +125,25 @@ def show_q():
    #show_q()
    while not goal:
    # epsilon greedy
    valid_moves = r[current_state] >= 0
    if random_state.rand() < epsilon:
    valid_moves = r[current_state] >= 0
    actions = list(range(n_actions))
    actions = np.array(list(range(n_actions)))
    actions = actions[valid_moves == True]
    if type(actions) is int:
    actions = [actions]
    random_state.shuffle(actions)
    # shuffle to match actions
    valid_moves = valid_moves[actions]
    for i in range(len(valid_moves)):
    # choose the first action which is valid
    # out of the shuffled list
    if valid_moves[i] == True:
    action = actions[i]
    break
    # action is the move to next state
    action = actions[0]
    next_state = action
    else:
    action = np.argmax(q[current_state, :])
    if np.sum(q[current_state]) > 0:
    action = np.argmax(q[current_state])
    else:
    # Don't allow invalid moves at the start
    # Just take a random move
    actions = np.array(list(range(n_actions)))
    actions = actions[valid_moves == True]
    random_state.shuffle(actions)
    action = actions[0]
    next_state = action
    reward = update_q(current_state, next_state, action,
    alpha=alpha, gamma=gamma)
  2. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -72,7 +72,7 @@ def show_q():
    for start, stop in zip(start_idx, end_idx)]
    values = np.array(q[q > 0])
    # bump up values for viz
    values = values / 50
    values = values
    lc = LineCollection(segments,
    zorder=0, cmap=plt.cm.hot_r)
    lc.set_array(values)
    @@ -121,8 +121,8 @@ def show_q():
    if e % int(n_episodes / 10.) == 0 and e > 0:
    pass
    # uncomment this to see plots each monitoring
    show_traverse()
    show_q()
    #show_traverse()
    #show_q()
    while not goal:
    # epsilon greedy
    if random_state.rand() < epsilon:
  3. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 7 additions and 3 deletions.
    10 changes: 7 additions & 3 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -25,6 +25,9 @@ def update_q(state, next_state, action, alpha, gamma):
    qsa = q[state, action]
    new_q = qsa + alpha * (rsa + gamma * max(q[next_state, :]) - qsa)
    q[state, action] = new_q
    # renormalize row to be between 0 and 1
    rn = q[state][q[state] > 0] / np.sum(q[state][q[state] > 0])
    q[state][q[state] > 0] = rn
    return r[state, action]


    @@ -105,7 +108,7 @@ def show_q():
    # Core algorithm
    gamma = 0.8
    alpha = 1.
    n_episodes = 50
    n_episodes = 20
    n_states = 6
    n_actions = 6
    epsilon = 0.05
    @@ -118,8 +121,8 @@ def show_q():
    if e % int(n_episodes / 10.) == 0 and e > 0:
    pass
    # uncomment this to see plots each monitoring
    #show_traverse()
    #show_q()
    show_traverse()
    show_q()
    while not goal:
    # epsilon greedy
    if random_state.rand() < epsilon:
    @@ -146,5 +149,6 @@ def show_q():
    goal = True
    current_state = next_state

    print(q)
    show_traverse()
    show_q()
  4. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 15 additions and 8 deletions.
    23 changes: 15 additions & 8 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,7 @@
    # Author: Kyle Kastner
    # License: BSD 3-Clause
    # Implementing http://mnemstudio.org/path-finding-q-learning-tutorial.htm
    # Q-learning formula from http://sarvagyavaish.github.io/FlappyBirdRL/
    # Visualization based on code from Gael Varoquaux [email protected]
    # http://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html

    @@ -8,6 +10,7 @@
    from matplotlib.collections import LineCollection


    # defines the reward/connection graph
    r = np.array([[-1, -1, -1, -1, 0, -1],
    [-1, -1, -1, 0, -1, 100],
    [-1, -1, -1, 0, -1, -1],
    @@ -17,8 +20,10 @@
    q = np.zeros_like(r)


    def update_q(state, next_state, action, gamma):
    new_q = r[state, action] + gamma * max(q[next_state, :])
    def update_q(state, next_state, action, alpha, gamma):
    rsa = r[state, action]
    qsa = q[state, action]
    new_q = qsa + alpha * (rsa + gamma * max(q[next_state, :]) - qsa)
    q[state, action] = new_q
    return r[state, action]

    @@ -90,15 +95,16 @@ def show_q():
    x = x + .05

    plt.text(x, y, name, size=10,
    horizontalalignment=horizontalalignment,
    verticalalignment=verticalalignment,
    bbox=dict(facecolor='w',
    edgecolor=plt.cm.spectral(float(len(coords))),
    alpha=.6))
    horizontalalignment=horizontalalignment,
    verticalalignment=verticalalignment,
    bbox=dict(facecolor='w',
    edgecolor=plt.cm.spectral(float(len(coords))),
    alpha=.6))
    plt.show()

    # Core algorithm
    gamma = 0.8
    alpha = 1.
    n_episodes = 50
    n_states = 6
    n_actions = 6
    @@ -133,7 +139,8 @@ def show_q():
    else:
    action = np.argmax(q[current_state, :])
    next_state = action
    reward = update_q(current_state, next_state, action, gamma=gamma)
    reward = update_q(current_state, next_state, action,
    alpha=alpha, gamma=gamma)
    # Goal state has reward 100
    if reward > 1:
    goal = True
  5. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 87 additions and 71 deletions.
    158 changes: 87 additions & 71 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -22,8 +22,84 @@ def update_q(state, next_state, action, gamma):
    q[state, action] = new_q
    return r[state, action]


    def show_traverse():
    # show all the greedy traversals
    for i in range(len(q)):
    current_state = i
    traverse = "%i -> " % current_state
    n_steps = 0
    while current_state != 5 and n_steps < 20:
    next_state = np.argmax(q[current_state])
    current_state = next_state
    traverse += "%i -> " % current_state
    n_steps = n_steps + 1
    # cut off final arrow
    traverse = traverse[:-4]
    print("Greedy traversal for starting state %i" % i)
    print(traverse)
    print("")


    def show_q():
    # show all the valid/used transitions
    coords = np.array([[2, 2],
    [4, 2],
    [5, 3],
    [4, 4],
    [2, 4],
    [5, 2]])
    # invert y axis for display
    coords[:, 1] = max(coords[:, 1]) - coords[:, 1]

    plt.figure(1, facecolor='w', figsize=(10, 8))
    plt.clf()
    ax = plt.axes([0., 0., 1., 1.])
    plt.axis('off')

    plt.scatter(coords[:, 0], coords[:, 1], c='r')

    start_idx, end_idx = np.where(q > 0)
    segments = [[coords[start], coords[stop]]
    for start, stop in zip(start_idx, end_idx)]
    values = np.array(q[q > 0])
    # bump up values for viz
    values = values / 50
    lc = LineCollection(segments,
    zorder=0, cmap=plt.cm.hot_r)
    lc.set_array(values)
    ax.add_collection(lc)

    verticalalignment = 'top'
    horizontalalignment = 'left'
    for i in range(len(coords)):
    x = coords[i][0]
    y = coords[i][1]
    name = str(i)
    if i == 1:
    y = y - .05
    x = x + .05
    elif i == 3:
    y = y - .05
    x = x + .05
    elif i == 4:
    y = y - .05
    x = x + .05
    else:
    y = y + .05
    x = x + .05

    plt.text(x, y, name, size=10,
    horizontalalignment=horizontalalignment,
    verticalalignment=verticalalignment,
    bbox=dict(facecolor='w',
    edgecolor=plt.cm.spectral(float(len(coords))),
    alpha=.6))
    plt.show()

    # Core algorithm
    gamma = 0.8
    n_episodes = 1E4
    n_episodes = 50
    n_states = 6
    n_actions = 6
    epsilon = 0.05
    @@ -33,6 +109,11 @@ def update_q(state, next_state, action, gamma):
    random_state.shuffle(states)
    current_state = states[0]
    goal = False
    if e % int(n_episodes / 10.) == 0 and e > 0:
    pass
    # uncomment this to see plots each monitoring
    #show_traverse()
    #show_q()
    while not goal:
    # epsilon greedy
    if random_state.rand() < epsilon:
    @@ -42,6 +123,8 @@ def update_q(state, next_state, action, gamma):
    # shuffle to match actions
    valid_moves = valid_moves[actions]
    for i in range(len(valid_moves)):
    # choose the first action which is valid
    # out of the shuffled list
    if valid_moves[i] == True:
    action = actions[i]
    break
    @@ -51,77 +134,10 @@ def update_q(state, next_state, action, gamma):
    action = np.argmax(q[current_state, :])
    next_state = action
    reward = update_q(current_state, next_state, action, gamma=gamma)
    # Goal state has reward 100
    if reward > 1:
    goal = True
    current_state = next_state

    # show all the greedy traversals
    for i in range(n_states):
    current_state = i
    traverse = "%i -> " % current_state
    n_steps = 0
    while current_state != 5:
    next_state = np.argmax(q[current_state])
    current_state = next_state
    traverse += "%i -> " % current_state
    # cut off final arrow
    traverse = traverse[:-4]
    print("Greedy traversal for starting state %i" % i)
    print(traverse)
    print("")

    # show all the valid/used transitions
    coords = np.array([[2, 2],
    [4, 2],
    [5, 3],
    [4, 4],
    [2, 4],
    [5, 2]])
    # invert y axis for display
    coords[:, 1] = max(coords[:, 1]) - coords[:, 1]

    plt.figure(1, facecolor='w', figsize=(10, 8))
    plt.clf()
    ax = plt.axes([0., 0., 1., 1.])
    plt.axis('off')

    plt.scatter(coords[:, 0], coords[:, 1], c='r')

    start_idx, end_idx = np.where(q > 0)
    segments = [[coords[start], coords[stop]]
    for start, stop in zip(start_idx, end_idx)]
    values = np.array(q[q > 0])
    # bump up values for viz
    values = 10 + values / values.max()
    lc = LineCollection(segments,
    zorder=0, cmap=plt.cm.hot_r)
    lc.set_array(values)
    lc.set_linewidths(.3 * values)
    ax.add_collection(lc)

    verticalalignment = 'top'
    horizontalalignment = 'left'
    for i in range(len(coords)):
    x = coords[i][0]
    y = coords[i][1]
    name = str(i)
    if i == 1:
    y = y - .05
    x = x + .05
    elif i == 3:
    y = y - .05
    x = x + .05
    elif i == 4:
    y = y - .05
    x = x + .05
    else:
    y = y + .05
    x = x + .05

    plt.text(x, y, name, size=10,
    horizontalalignment=horizontalalignment,
    verticalalignment=verticalalignment,
    bbox=dict(facecolor='w',
    edgecolor=plt.cm.spectral(float(len(coords))),
    alpha=.6))
    plt.show()
    show_traverse()
    show_q()
  6. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion painless_q.py
    Original file line number Diff line number Diff line change
    @@ -55,7 +55,7 @@ def update_q(state, next_state, action, gamma):
    goal = True
    current_state = next_state


    # show all the greedy traversals
    for i in range(n_states):
    current_state = i
    traverse = "%i -> " % current_state
    @@ -70,6 +70,7 @@ def update_q(state, next_state, action, gamma):
    print(traverse)
    print("")

    # show all the valid/used transitions
    coords = np.array([[2, 2],
    [4, 2],
    [5, 3],
  7. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 57 additions and 1 deletion.
    58 changes: 57 additions & 1 deletion painless_q.py
    Original file line number Diff line number Diff line change
    @@ -1,8 +1,12 @@
    # Author: Kyle Kastner
    # License: BSD 3-Clause
    # Visualization based on code from Gael Varoquaux [email protected]
    # http://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection


    r = np.array([[-1, -1, -1, -1, 0, -1],
    [-1, -1, -1, 0, -1, 100],
    @@ -66,5 +70,57 @@ def update_q(state, next_state, action, gamma):
    print(traverse)
    print("")

    plt.matshow(q)
    coords = np.array([[2, 2],
    [4, 2],
    [5, 3],
    [4, 4],
    [2, 4],
    [5, 2]])
    # invert y axis for display
    coords[:, 1] = max(coords[:, 1]) - coords[:, 1]

    plt.figure(1, facecolor='w', figsize=(10, 8))
    plt.clf()
    ax = plt.axes([0., 0., 1., 1.])
    plt.axis('off')

    plt.scatter(coords[:, 0], coords[:, 1], c='r')

    start_idx, end_idx = np.where(q > 0)
    segments = [[coords[start], coords[stop]]
    for start, stop in zip(start_idx, end_idx)]
    values = np.array(q[q > 0])
    # bump up values for viz
    values = 10 + values / values.max()
    lc = LineCollection(segments,
    zorder=0, cmap=plt.cm.hot_r)
    lc.set_array(values)
    lc.set_linewidths(.3 * values)
    ax.add_collection(lc)

    verticalalignment = 'top'
    horizontalalignment = 'left'
    for i in range(len(coords)):
    x = coords[i][0]
    y = coords[i][1]
    name = str(i)
    if i == 1:
    y = y - .05
    x = x + .05
    elif i == 3:
    y = y - .05
    x = x + .05
    elif i == 4:
    y = y - .05
    x = x + .05
    else:
    y = y + .05
    x = x + .05

    plt.text(x, y, name, size=10,
    horizontalalignment=horizontalalignment,
    verticalalignment=verticalalignment,
    bbox=dict(facecolor='w',
    edgecolor=plt.cm.spectral(float(len(coords))),
    alpha=.6))
    plt.show()
  8. @kastnerkyle kastnerkyle revised this gist May 29, 2016. 1 changed file with 69 additions and 1 deletion.
    70 changes: 69 additions & 1 deletion painless_q.py
    Original file line number Diff line number Diff line change
    @@ -1,2 +1,70 @@
    # Author: Kyle Kastner
    # License: BSD 3-Clause
    # License: BSD 3-Clause

    import numpy as np
    import matplotlib.pyplot as plt

    r = np.array([[-1, -1, -1, -1, 0, -1],
    [-1, -1, -1, 0, -1, 100],
    [-1, -1, -1, 0, -1, -1],
    [-1, 0, 0, -1, 0, -1],
    [ 0, -1, -1, 0, -1, 100],
    [-1, 0, -1, -1, 0, 100]]).astype("float32")
    q = np.zeros_like(r)


    def update_q(state, next_state, action, gamma):
    new_q = r[state, action] + gamma * max(q[next_state, :])
    q[state, action] = new_q
    return r[state, action]

    gamma = 0.8
    n_episodes = 1E4
    n_states = 6
    n_actions = 6
    epsilon = 0.05
    random_state = np.random.RandomState(1999)
    for e in range(int(n_episodes)):
    states = list(range(n_states))
    random_state.shuffle(states)
    current_state = states[0]
    goal = False
    while not goal:
    # epsilon greedy
    if random_state.rand() < epsilon:
    valid_moves = r[current_state] >= 0
    actions = list(range(n_actions))
    random_state.shuffle(actions)
    # shuffle to match actions
    valid_moves = valid_moves[actions]
    for i in range(len(valid_moves)):
    if valid_moves[i] == True:
    action = actions[i]
    break
    # action is the move to next state
    next_state = action
    else:
    action = np.argmax(q[current_state, :])
    next_state = action
    reward = update_q(current_state, next_state, action, gamma=gamma)
    if reward > 1:
    goal = True
    current_state = next_state


    for i in range(n_states):
    current_state = i
    traverse = "%i -> " % current_state
    n_steps = 0
    while current_state != 5:
    next_state = np.argmax(q[current_state])
    current_state = next_state
    traverse += "%i -> " % current_state
    # cut off final arrow
    traverse = traverse[:-4]
    print("Greedy traversal for starting state %i" % i)
    print(traverse)
    print("")

    plt.matshow(q)
    plt.show()
  9. @kastnerkyle kastnerkyle created this gist Mar 6, 2016.
    2 changes: 2 additions & 0 deletions painless_q.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,2 @@
    # Author: Kyle Kastner
    # License: BSD 3-Clause