Skip to content

Instantly share code, notes, and snippets.

@iwyoo
Created December 10, 2017 05:17
Show Gist options
  • Select an option

  • Save iwyoo/de45d3f933d45b667267653d044a4cd0 to your computer and use it in GitHub Desktop.

Select an option

Save iwyoo/de45d3f933d45b667267653d044a4cd0 to your computer and use it in GitHub Desktop.
Gathering values using sparse binary form indices
import numpy as np
import tensorflow as tf
val = np.array(
[[[1,2,3], [4,5,6]],
[[1,2,3], [4,5,6]]])
pos = np.array(
[[[1,0,0,1,1], [0,1,1,0,1]],
[[0,1,0,1,1], [1,1,1,0,0]]])
print val.shape # (2, 2, 3)
print pos.shape # (2, 2, 5)
val_input = tf.placeholder(tf.int32, val.shape)
pos_input = tf.placeholder(tf.int32, pos.shape)
### model start
# Flattening
val_input_r = tf.reshape(val_input, [-1, val_input.shape[-1]])
pos_input_r = tf.reshape(pos_input, [-1, pos_input.shape[-1]])
# Getting index (ex. [1,0,0,1,1] => [0, 0, 0, 1, 2])
triu_one = np.triu(np.ones([pos.shape[-1], pos.shape[-1]]))
triu_one = tf.constant(triu_one, dtype=tf.int32)
index = tf.matmul(pos_input_r, triu_one) - 1
index = tf.maximum(0, index)
index = tf.expand_dims(index, 2)
# + grid
N = tf.range(pos_input_r.shape[0])
K = tf.range(pos_input_r.shape[1])
N, _ = tf.meshgrid(N, K, indexing='ij')
N = tf.expand_dims(N, 2)
index = tf.concat([N, index], axis=2)
# Gathering
val_output_r = tf.gather_nd(val_input_r, index)
val_output = tf.reshape(val_output_r, pos.shape) * pos_input
with tf.Session() as sess:
print sess.run(val_output,
{val_input:val, pos_input:pos})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment