#!/usr/bin/env python # Max Jaderberg 21/5/14 # this replaces weights from a matlab file from proto import caffe_pb2 import numpy as np import scipy.io import os # path to the net net_path = './caffe_imagenet_train_iter_755000' new_net_path = './caffe_imagenet_train_iter_755000_approxconv2_2x' replace_name = 'conv2' replace_weights = 'conv2_scheme12x.mat' # load net net = caffe_pb2.NetParameter() fid = open(net_path, 'rb') net.ParseFromString(fid.read()) fid.close() # load weights w = scipy.io.loadmat(replace_weights)['w'] net_layers = [conn.layer.name for conn in net.layers] #print net_layers for layer_name in net_layers: if layer_name != replace_name: continue layer_idx = net_layers.index(layer_name) # list of blobs origblobs = list(net.layers[layer_idx].layer.blobs) origblob = net.layers[layer_idx].layer.blobs[0] del origblob.data[:] origblob.data.extend(w.astype(float).flat) # Write the new address book back to disk. f = open(new_net_path, "wb") f.write(net.SerializeToString()) f.close()