#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Tue Feb 27 02:46:51 2018 @author: memo very quick & simple dictionary / json based graph builder for tensorflow ( inspired by https://github.com/dribnet/discgen/blob/master/discgen/vae.py#L43-L163 ) """ from __future__ import print_function from __future__ import division import tensorflow as tf import numpy as np from pprint import pprint import msa.tf.ops def example(): tf.reset_default_graph() # dict of dicts { { : kwargs }, ... } default_op_args = { 'conv2d' : { 'padding':'same', 'kernel_size':(3,3), 'strides':(1,1) }, 'conv2d_transpose' : { 'kernel_size':(2,2), 'strides':(2,2) }, } # list of dicts [ {'op':, kwargs }, ... ] encoder_ops_info = [ { 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'identity', 'name':'pre_z_conv' }, { 'op':'flatten' }, { 'op':'dense', 'units':1024 }, { 'op':'batch_norm' }, { 'op':'relu' }, ] decoder_ops_info = [ { 'op':'dense', 'units':128, 'name':'z' }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'dense', 'units':0, 'name':'post_z_flat' }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'tf.reshape', 'name':'post_z_conv' }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d_transpose','filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d_transpose', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d_transpose', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, { 'op':'conv2d', 'kernel_size':(1,1), 'filters':3 }, { 'op':'tanh', 'name':'output'} ] x = tf.placeholder(tf.float32, [None, 64, 64, 3]) # build encoder with tf.variable_scope('encoder'): encoder_ops, errors = build_graph(x, encoder_ops_info, default_op_args) # TODO: THIS BIT IS UGLY. is there a better way of automating all of this? # need to get the conv shape before flattening. search encoder tensors by name pre_z_conv = get_tensors_by_name(encoder_ops, 'pre_z_conv')[0] # write to decoder_ops_info # flattened shape is multiplication of all dims except for batch size get_ops_by_name(decoder_ops_info, 'post_z_flat')[0]['units'] = np.prod(pre_z_conv.shape[1:]) # first conv op after flat layer needs write shape. get_ops_by_name(decoder_ops_info, 'post_z_conv')[0]['shape'] = tf.shape(pre_z_conv) # now build decoder with tf.variable_scope('decoder'): decoder_ops, errors = build_graph(encoder_ops[-1], decoder_ops_info, default_op_args) return encoder_ops, decoder_ops ''' Output: -------------------------------------------------------------------------------- > msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 64} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_2/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_2/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_1:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_2:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 128} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_4/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_3:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_5/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 256} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_6/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_5:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.identity {'name': 'pre_z_conv'} --> Tensor("encoder/pre_z_conv:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.flatten {} --> Tensor("encoder/Flatten/flatten/Reshape:0", shape=(?, 262144), dtype=float32) > msa.tf.ops.dense {'units': 1024} --> Tensor("encoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_7/batchnorm/add_1:0", shape=(?, 1024), dtype=float32) > msa.tf.ops.relu {} --> Tensor("encoder/Relu_6:0", shape=(?, 1024), dtype=float32) -------------------------------------------------------------------------------- 23 ops added -------------------------------------------------------------------------------- > msa.tf.ops.dense {'units': 128, 'name': 'z'} --> Tensor("decoder/z/BiasAdd:0", shape=(?, 128), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization/batchnorm/add_1:0", shape=(?, 128), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu:0", shape=(?, 128), dtype=float32) > msa.tf.ops.dense {'units': 1024} --> Tensor("decoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_2/batchnorm/add_1:0", shape=(?, 1024), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_1:0", shape=(?, 1024), dtype=float32) > tf.reshape {'shape': , 'name': 'post_z_conv'} --> Tensor("decoder/post_z_conv:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_2:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_3:0", shape=(?, 32, 32, 256), dtype=float32) > msa.tf.ops.conv2d_transpose {'filters': 256} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32) > msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_2/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_5:0", shape=(?, 64, 64, 128), dtype=float32) > msa.tf.ops.conv2d_transpose {'filters': 128} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_2/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_7/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_6:0", shape=(?, 128, 128, 128), dtype=float32) > msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_8/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_7:0", shape=(?, 128, 128, 64), dtype=float32) > msa.tf.ops.conv2d_transpose {'filters': 64} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_3/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_9/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.relu {} --> Tensor("decoder/Relu_8:0", shape=(?, 256, 256, 64), dtype=float32) > msa.tf.ops.conv2d {'kernel_size': (1, 1), 'filters': 3} + defaults {'padding': 'same', 'strides': (1, 1)} --> Tensor("decoder/conv2d_4/BiasAdd:0", shape=(?, 256, 256, 3), dtype=float32) > msa.tf.ops.tanh {'name': 'output'} --> Tensor("decoder/output:0", shape=(?, 256, 256, 3), dtype=float32) -------------------------------------------------------------------------------- 29 ops added ''' #%% namespaces=[ '', 'msa.tf.ops', 'tf', 'tf.layers', 'tf.nn', 'tf.contrib.layers' ] def get_tensors_by_name(tensors, name): '''given a list of tensors, return any tensor which has matching name''' return filter(lambda x: name in x.name, tensors) def get_ops_by_name(ops_info, name): '''given a list of op info dicts, return any op dict which has matching name''' return filter(lambda x: 'name' in x and name in x['name'], ops_info) def build_graph(input_T, ops_info, default_op_args=None, verbose=True): print('-'*80) errors = [] def handle_error(msg, op_dict): print('\n** ERROR', msg, op_dict,'\n') errors.append( {msg : op_dict} ) t = input_T ops = [] for op_dict in ops_info: if type(op_dict) == dict: if 'op' in op_dict: op_str = op_dict['op'] # get dict for this layer op_fn = None fn_path = None for namespace in namespaces: try: fn_path = '.'.join([namespace, op_str]) if namespace else op_str op_fn = eval(fn_path) break except: pass if op_fn: # get op args excluding op name args = { k:v for k,v in op_dict.items() if k!='op' } if verbose: print('>', fn_path, args, end=' ') extra_args = None if default_op_args and op_str in default_op_args: # check for defaults op_defaults = default_op_args[op_str] # defaults dict for this op type extra_args = { k:v for k,v in op_defaults.items() if k not in args } if extra_args: if verbose: print('+ defaults', extra_args, end=' ') args.update(extra_args) try: t = op_fn(t, **args) print('-->', t) ops.append(t) except Exception as e: handle_error(fn_path + ' : ' + str(e), op_dict) else: # if op_fn: handle_error('function not found', op_dict) else: # if 'op' in op_dict: handle_error('missing op key', op_dict) else: # type(op_dict) == dict: handle_error('unknown entry type', op_dict) print('-'*80) print('{} ops added'.format(len(ops))) if len(errors) > 0: print('{} errors found:'.format(len(errors))) pprint(errors) return ops, errors #%% if __name__ == "__main__": encoder_ops, decoder_ops = example()