{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import cm as cm, pyplot as plt, gridspec as gridspec\n", "from matplotlib.patches import ConnectionPatch\n", "%matplotlib inline\n", "\n", "class MatplotlibGridDisplay:\n", " \n", " def __init__(self, rows, cols):\n", " self.rows, self.cols = rows, cols\n", " self.axes_order = {}\n", " \n", " @staticmethod\n", " def _prepare_axis(ax):\n", " \n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " for sp in ax.spines.values():\n", " sp.set_visible(False)\n", " if ax.is_first_row():\n", " ax.spines['top'].set_visible(True)\n", " if ax.is_last_row():\n", " ax.spines['bottom'].set_visible(True)\n", " if ax.is_first_col():\n", " ax.spines['left'].set_visible(True)\n", " if ax.is_last_col():\n", " ax.spines['right'].set_visible(True)\n", " \n", " return ax\n", " \n", " def _xy_to_rowcol(self, x, y):\n", " \"\"\"Converts (x, y) to (row, col).\n", "\n", " \"\"\"\n", " return self.rows - y, x - 1\n", "\n", " def _rowcol_to_xy(self, row, col):\n", " \"\"\"Converts (row, col) to (x, y).\n", "\n", " \"\"\"\n", " return col + 1, self.rows - row\n", " \n", " def connect_axes(self, fig, ax1, ax2, order=\"forward\"):\n", " \n", " axis_center = (0., 0.)\n", " if order == \"forward\":\n", " con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n", " coordsA=\"data\", coordsB=\"data\",\n", " axesA=ax1, axesB=ax2, color=\"red\", \n", " mutation_scale=40, arrowstyle=\"->\", \n", " shrinkB=5)\n", " ax1.add_artist(con)\n", " else:\n", " con = ConnectionPatch(xyA=axis_center, xyB=axis_center, \n", " coordsA=\"data\", coordsB=\"data\",\n", " axesA=ax2, axesB=ax1, color=\"red\", \n", " mutation_scale=40, arrowstyle=\"<-\",\n", " shrinkB=5)\n", " ax2.add_artist(con)\n", " \n", " ax1.plot(*axis_center,'ro',markersize=10)\n", " ax2.plot(*axis_center,'ro',markersize=10)\n", " \n", " def add_trajectory(self, fig, axes_grid, traj):\n", " \n", " x_list, y_list = tuple(zip(*traj)) # [(x, y), ..] -> [x, ...], [y, ...]\n", " for idx in range(len(x_list)-1):\n", " ax1 = axes_grid[(x_list[idx], y_list[idx])]\n", " ax2 = axes_grid[(x_list[idx+1], y_list[idx+1])]\n", " ax1.set_zorder(-2*idx+1)\n", " ax2.set_zorder(-2*idx)\n", " self.connect_axes(fig, ax1, ax2)\n", " \n", " def add_trajectories(self, fig, axes_grid, traj_lst):\n", " \n", " if traj_lst is not None:\n", " for traj in traj_lst:\n", " self.add_trajectory(fig, axes_grid, traj)\n", " \n", " def render(self, data, phi_shape, traj_lst=None,\n", " interpolation=\"None\", cmap=cm.viridis, vmin=None, vmax=None):\n", " \n", " H, W, D = data.shape\n", " # Setup axes grid\n", " fig = plt.figure(figsize=(W*2, H*2))\n", " gs = gridspec.GridSpec(H, W)\n", " gs.update(wspace=0., hspace=0., left = 0., right = 1., bottom = 0., top = 1.)\n", " axes_grid = {}\n", " \n", " for row in range(H):\n", " for col in range(W):\n", " ax = plt.Subplot(fig, gs[row, col])\n", " ax.imshow(data[row, col].reshape(*phi_shape), vmin=vmin, vmax=vmax)\n", " fig.add_subplot(self._prepare_axis(ax))\n", " axes_grid[self._rowcol_to_xy(row, col)] = ax\n", " \n", " self.add_trajectories(fig, axes_grid, traj_lst)\n", " return fig, ax" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(
,\n", " )" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAATUAAAE1CAYAAACGH3cEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAC81JREFUeJzt3X+M1/V9wPHXV2947iZMDgGtlIJybrRyGgmCS2ytirRbFmpxXbJ1zhhnoouelCV0poJNtqaLVrbWTrKRVdPOrbOItjXcubX/LI1xI4GKVg/kl9Xy69DjvEVx8Nkf7uRQ5L7f7/H9wesej7/0m/eXvAIfn/f6fr7fL5aKogiALE5r9AAAJ5OoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpBKSyWHx5XOKFqjrVazkNxg29vRMnFio8fgFHXolV/uL4rinJHOVRS11miLy0tXVz8VY9rWv5rf6BE4he3oWraznHNefgKpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBpUYNLAQERRNHoMTkDUoAJf+5d/ix/evyqu2fy8uDUpUYMK/NnNfxrfXHhN3PVUt7g1qZZGDwCnkuK006JnzsXx9Cc+Htdufj7ueqo77lzfE3+7aGH8+8dnR5RKjR5xzBO1Ueos9sbtsTE+Fgffe2xHjI8H45LYVJrcwMmopZMVt/m9W2Ll2nXRsXvPe4/1Tp0SK69fHM90zKrV+KmVigpW5/GlicXlpatrOM6p5Y+KF+LGeCEiIoZfwkO/ow/H7PheaXbd52pWWx+Y3+gRaqZ05Ehcu/n5uHP903HktFJZcfvz9T2xdH3Pu88f9vjQ9fONRQvjW4sW1m7oU8yOrmUbiqKYO9I599Sq1FnsjRvjhSjFsRdk/P+/lyLixnghOou99R+Ouhva3H5vWVdZ99zm926Jpet7Tnj9LF3fE/N7t9Rh+lxErUq3x8ayzt1W5jlyKDduK9euK+vXW1HmOY5yT61KH4uDH/gJ+36liJgRB2NW8Xo9Rmp6ra/8stEj1NVrZ58dy//whrh868vx5Sd+GMuf/FGs+eSV8ejvLIiO3XvKun4uGnavjfKIWh10xYZGj9AU3v7XFxs9QsMUEfHR/X3xlXVPxmPzRrwtxCiIWh3cXrqm0SM0ha3L8r5R8GEu3b4j7ux+OiYODsaKJdfHY/Pmxjst/rOrJb+7VdoR40d8CVpExPYYX6+RaCJDMbtwz5548Npr4pZ5Nx0Ts96pU0Z8CVpExEtTp9R81my8UVClB+OSss59u8xz5HDp9h3xnYf+Ib75yHeje87FcdXdy+PRK+Z/YDtbef3isn69e8s8x1E2tSptKk2Oh4vZI35OzQdwx4aRNrP3e6ZjVnxj0cIRP6fmA7iVE7VR+F5pdmwuJsVtsTFmDPtGwfYYH9/2jYIxodKYDfetRQvjv2fOiBVr1x3zLudLU6fEvb5RUDXfKKBuMn2j4P0xG+0bAE/e90D85RduiM3Tzj+JU+ZS7jcKbGpQgTm7dsXSp7qr2syoD38aUIEbnvmv6J5zsZg1MX8qUIGv/MHnGz0CI/CRDiAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSETUgFVEDUhE1IBVRA1IRNSAVUQNSaank8P+c8Xa88YUFtZrllPbIjx6Iv16wJF5sn9boUZrWmt9f3egRmtb0f9wf91y5NgYubm30KE3r6q6YXs65ija1M35jYnXTAIzeznIOefkJpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAqogakImpQB7++7VCU/rdo2PPHElGDOpixqi/Oe7S/que2vH44LluyK1oOHjnJU+UkalAHu24+O6Y/dCBKb1cepo+ueT32Ljor3pl4eg0my0fUoA4GOlvjzd86I877/sGKntfy+uH4yD+/ETtvm1ijyfIRNaiT7Xe0V7ytDW1pb53/azWcLBdRgzqpdFuzpVVH1KCOKtnWbGnVETWoo3K3NVta9UQN6qycbc2WVj1RgzobaVuzpY2OqEEDnGhbs6WNjqhBA3zYtmZLGz1RgwY53rZmSxs9UYMGef+2dvpBW9rJIGpNbsJbb8bpRw43egxqZGhbiyNFTH1iwJZ2EohakzpvoC/u/tn347F1X4/zB/oaPQ41MrStjdt/OCZ3v2lLOwlErckMxezhH6+KvjPPiiWLl8fOCZMbPRY1tP2O9mjdezgOLDjTlnYStDR6AN513kBf3PTcf8Sndj0XP7joiliyeHn0t7Y1eizqYKCzNfZf1RbbuiY1epQURG2ULvvVllj27ONxQf+eeOTHqyIi4uUJU+K+eZ+LDefOGvH5Yja2/ebPBqPj3n3RtuVQtP90MCIiBmeNi94V58QbV7gOqiFqo3Dzpp64dVN3RESUhj1+Qf+e+PunH4rVndfFms6Fx32umDH97/pi5qp375cOv37athyKS//41djW1R4772hvzHCnMFGr0mW/2hK3buo+5mIcMvTYrZu6Y+PkGcdsbGJGxLsb2sxVfSe8fmau6ov+ua02tgp5o6BKy559vKxzX3p2XUQc/w2Ahy79jKCNUR337ivv3FfLO8dRNrUqXdC/57g/ZYcrRcSF/bvjb376TzF399b4yfQ58eVPfjHeHHdmnDt4IM4dPFCPUZvGWc+91egRmkbblkNlXT9tvYfqMU4qolYHV72yOfadOT4u6ns1Og681uhxGuYjO20d1J6o1cGffLYrbvl5T3QceDUe/sSn44lZl8eh08fe55G+vnJ1o0doGlfN7G30CGm5p1allydMiZH+17JFRGydMDV+MWlaLP30zfEXn7opFrz2Uqx9/Gtxw4v/GeMOv1OPUWlCg7PGlXX9DHaMq8c4qYhale6b97myzt0/b/F7/yxuDOldcU555+4p7xxHiVqVNpw7K1Z3XhdFxAd+4g49trrzuuN+AFfceOOKttjW1X7C62dbV7uPc1TBPbVRWNO5MDZOnhFfenZdXNi/+73Ht06YGvfPWzziNwqG4vbb+1+JW37eEzdu/smYvuc21uy8oz3657ZGx1f3HfMu52DHuOi9xzcKqlUqipFe2R/VNmlaMft376rhOGPbUNyyvqHgjQJG4+qZvRuKopg70jkvP5vIh70s9fepQflErQkNj9sle7f7+9SgAu6pNbFfTJoWd1/5xUaPAacUmxqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqYgakIqoAamIGpCKqAGpiBqQiqgBqZSKoij/cKm0LyJ21m4ckpserh+qN70oinNGOlRR1ACanZefQCqiBqQiakAqogakImpAKqIGpCJqQCqiBqQiakAq/wd9PCIrVuNL8gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "m = MatplotlibGridDisplay(2, 2)\n", "m.render(\n", " data = np.array([[0., 0.5], [0.3, 0.9]])[:,:,np.newaxis],\n", " phi_shape=(1,1), traj_lst=[[(1,1), (1,2), (2,2), (2,1), (1,1)]],\n", " vmin=0., vmax=1.)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:irl]", "language": "python", "name": "conda-env-irl-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }