You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
2.4 KiB
52 lines
2.4 KiB
#!/usr/bin/env python
|
|
"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
|
|
Usage
|
|
python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file
|
|
|
|
Saves each variable to a {variable_name}.npy binary file.
|
|
|
|
Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
|
|
{model_name}.data-{step}-of-{max_step}
|
|
instead of:
|
|
{model_name}.ckpt
|
|
When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
|
|
when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
|
|
|
|
Also note that this script relies on the parameters to be extracted being in the
|
|
'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
|
|
specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
|
|
collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
|
|
|
|
Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
|
|
"""
|
|
import argparse
|
|
import numpy as np
|
|
import os
|
|
import tensorflow as tf
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
|
|
parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
|
|
file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
|
|
model name with ".ckpt" extension')
|
|
parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
|
|
args = parser.parse_args()
|
|
|
|
# Load Tensorflow Net
|
|
saver = tf.train.import_meta_graph(args.netFile)
|
|
with tf.Session() as sess:
|
|
# Restore session
|
|
saver.restore(sess, args.modelFile)
|
|
print('Model restored.')
|
|
# Save trainable variables to numpy arrays
|
|
for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
|
|
varname = t.name
|
|
if os.path.sep in t.name:
|
|
varname = varname.replace(os.path.sep, '_')
|
|
print("Renaming variable {0} to {1}".format(t.name, varname))
|
|
print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
|
|
# Dump as binary
|
|
np.save(varname, sess.run(t))
|