Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save LeiWangR/43d69532d646a69d3188e34160aee773 to your computer and use it in GitHub Desktop.
Save LeiWangR/43d69532d646a69d3188e34160aee773 to your computer and use it in GitHub Desktop.
Extracting detection features from tensorflow object detection API.
"""This file extracts faster-rcnn features and bounding box coordinates"""
import pdb
import argparse
import numpy as np
import tensorflow as tf
import PIL.Image as PILI
def session(sess, feat_conv, feat_avg, boxes, classes, scores, image_tensor, image):
feat_conv_out, feat_avg_out, boxes_out, classes_out, scores_out = sess.run([
feat_conv, feat_avg, boxes, classes, scores], feed_dict={image_tensor: image})
feat_conv_out = feat_conv_out.squeeze()
feat_avg_out = feat_avg_out.squeeze()
boxes_out = boxes_out.squeeze()
classes_out = classes_out.squeeze().astype(np.int32)
scores_out = scores_out.squeeze()
return feat_conv_out, feat_avg_out, boxes_out, classes_out, scores_out
def load_graph(graph, ckpt_path):
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(ckpt_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img', metavar='', type=str, default=None, help='Image path.')
parser.add_argument('--model', metavar='', type=str, default='frcnn_res101', help='frcnn_incresv2 or frcnn_res101.')
args, unparsed = parser.parse_known_args()
if len(unparsed) != 0: raise SystemExit('Unknown argument: {}'.format(unparsed))
graph = tf.Graph()
if args.model == 'frcnn_incresv2':
ckpt_path = './faster_rcnn_inception_resnet_v2_atrous_coco_2017_11_08/frozen_inference_graph.pb'
load_graph(graph, ckpt_path)
# (1, ?, ?, 3)
image_tensor = graph.get_tensor_by_name('image_tensor:0')
# (100, 8, 8, 1536)
feat_conv = graph.get_tensor_by_name('SecondStageFeatureExtractor/InceptionResnetV2/Conv2d_7b_1x1/Relu:0')
# (100, 1, 1, 1536)
feat_avg = graph.get_tensor_by_name('SecondStageBoxPredictor/AvgPool:0')
elif args.model == 'frcnn_res101':
ckpt_path = './faster_rcnn_resnet101_coco_2017_11_08/frozen_inference_graph.pb'
load_graph(graph, ckpt_path)
# (1, ?, ?, 3)
image_tensor = graph.get_tensor_by_name('image_tensor:0')
# (100, 7, 7, 2048)
feat_conv = graph.get_tensor_by_name('SecondStageFeatureExtractor/resnet_v1_101/block4/unit_3/bottleneck_v1/Relu:0')
# (100, 1, 1, 2048)
feat_avg = graph.get_tensor_by_name('SecondStageBoxPredictor/AvgPool:0')
else:
raise SystemExit('Unknown model: {}'.format(args.model))
boxes = graph.get_tensor_by_name('detection_boxes:0')
scores = graph.get_tensor_by_name('detection_scores:0')
classes = graph.get_tensor_by_name('detection_classes:0')
print ('model: {}'.format(args.model))
# Load tf model into memory
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config, graph=graph)
print ('Detect a single image')
# Load image
image = PILI.open(args.img)
image = np.asarray(image)
# Run session
feat_conv, feat_avg, boxes, classes, scores = session(
sess, feat_conv, feat_avg, boxes, classes, scores, image_tensor, np.expand_dims(image, 0))
print ('Done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment