From 44cf98028b635ff3dd4145df263b0706ba663924 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 4 May 2017 19:11:23 -0800 Subject: [PATCH] RNN checkpoint migration tool Change: 155158477 --- tensorflow/contrib/rnn/BUILD | 25 ++ .../rnn/python/tools/checkpoint_convert.py | 231 ++++++++++++++++++ .../python/tools/checkpoint_convert_test.py | 108 ++++++++ .../tools/pip_package/pip_smoke_test.py | 2 +- 4 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/rnn/python/tools/checkpoint_convert.py create mode 100644 tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index ab443eab6f6..9d67563eddd 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -304,6 +304,7 @@ filegroup( exclude = [ "**/METADATA", "**/OWNERS", + "tools/**", ], ), visibility = ["//tensorflow:__subpackages__"], @@ -351,3 +352,27 @@ tf_kernel_library( "//third_party/eigen3", ], ) + +py_binary( + name = "checkpoint_convert", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "checkpoint_convert_test", + size = "small", + srcs = ["python/tools/checkpoint_convert_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":checkpoint_convert", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py new file mode 100644 index 00000000000..1e29114b0cc --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py @@ -0,0 +1,231 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Convert checkpoints using RNNCells to new name convention. + +Usage: + + python checkpoint_convert [--write_v1_checkpoint] \ + '/path/to/checkpoint' '/path/to/new_checkpoint' +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import re +import sys + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + +_RNN_NAME_REPLACEMENTS = collections.OrderedDict([ + ############################################################################ + # contrib/rnn/python/ops/core_rnn_cell_impl.py + # BasicRNNCell + ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'), + ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'), + # GRUCell + ('gru_cell/weights', 'gru_cell/kernel'), + ('gru_cell/biases', 'gru_cell/bias'), + ('gru_cell/gates/weights', 'gru_cell/gates/kernel'), + ('gru_cell/gates/biases', 'gru_cell/gates/bias'), + ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'), + ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'), + # BasicLSTMCell + ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'), + ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'), + # LSTMCell + ('lstm_cell/weights', 'lstm_cell/kernel'), + ('lstm_cell/biases', 'lstm_cell/bias'), + ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'), + ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'), + # OutputProjectionWrapper + ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'), + ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'), + # InputProjectionWrapper + ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'), + ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/lstm_ops.py + # LSTMBlockFusedCell ?? + ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'), + ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/rnn_cell.py + # LayerNormBasicLSTMCell + ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'), + ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'), + # UGRNNCell, not found in g3, but still need it? + ('ugrnn_cell/weights', 'ugrnn_cell/kernel'), + ('ugrnn_cell/biases', 'ugrnn_cell/bias'), + # NASCell + ('nas_rnn/weights', 'nas_rnn/kernel'), + ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'), + # IntersectionRNNCell + ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'), + ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'), + ('intersection_rnn_cell/in_projection/weights', + 'intersection_rnn_cell/in_projection/kernel'), + ('intersection_rnn_cell/in_projection/biases', + 'intersection_rnn_cell/in_projection/bias'), + # PhasedLSTMCell + ('phased_lstm_cell/mask_gates/weights', + 'phased_lstm_cell/mask_gates/kernel'), + ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'), + ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'), + ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'), + ('phased_lstm_cell/output_gate/weights', + 'phased_lstm_cell/output_gate/kernel'), + ('phased_lstm_cell/output_gate/biases', + 'phased_lstm_cell/output_gate/bias'), + # AttentionCellWrapper + ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'), + ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'), + ('attention_cell_wrapper/attn_output_projection/weights', + 'attention_cell_wrapper/attn_output_projection/kernel'), + ('attention_cell_wrapper/attn_output_projection/biases', + 'attention_cell_wrapper/attn_output_projection/bias'), + ('attention_cell_wrapper/attention/weights', + 'attention_cell_wrapper/attention/kernel'), + ('attention_cell_wrapper/attention/biases', + 'attention_cell_wrapper/attention/bias'), +]) + +_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([ + ('LSTMCell/W_', 'lstm_cell/weights/part_'), + ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'), + ('GRUCell/W_', 'gru_cell/weights/part_'), + ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'), +]) + + +def _rnn_name_replacement(var_name): + for pattern in _RNN_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + break + return var_name + + +def _rnn_name_replacement_sharded(var_name): + for pattern in _RNN_SHARDED_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, + _RNN_SHARDED_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + return var_name + + +def _split_sharded_vars(name_shape_map): + """Split shareded variables. + + Args: + name_shape_map: A dict from variable name to variable shape. + + Returns: + not_sharded: Names of the non-sharded variables. + sharded: Names of the sharded varibales. + """ + sharded = [] + not_sharded = [] + for name in name_shape_map: + if re.match(name, '_[0-9]+$'): + if re.sub('_[0-9]+$', '_1', name) in name_shape_map: + sharded.append(name) + else: + not_sharded.append(name) + else: + not_sharded.append(name) + return not_sharded, sharded + + +def convert_names(checkpoint_from_path, + checkpoint_to_path, + write_v1_checkpoint=False): + """Migrates the names of variables within a checkpoint. + + Args: + checkpoint_from_path: Path to source checkpoint to be read in. + checkpoint_to_path: Path to checkpoint to be written out. + write_v1_checkpoint: Whether the output checkpoint will be in V1 format. + + Returns: + A dictionary that maps the new variable names to the Variable objects. + A dictionary that maps the old variable names to the new variable names. + """ + with ops.Graph().as_default(): + logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path) + reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path) + name_shape_map = reader.get_variable_to_shape_map() + not_sharded, sharded = _split_sharded_vars(name_shape_map) + new_variable_map = {} + conversion_map = {} + for var_name in not_sharded: + new_var_name = _rnn_name_replacement(var_name) + tensor = reader.get_tensor(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + for var_name in sharded: + new_var_name = _rnn_name_replacement_sharded(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + + write_version = (saver_pb2.SaverDef.V1 + if write_v1_checkpoint else saver_pb2.SaverDef.V2) + saver = saver_lib.Saver(new_variable_map, write_version=write_version) + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path) + saver.save(sess, checkpoint_to_path) + + logging.info('Summary:') + logging.info(' Converted %d variable name(s).' % len(new_variable_map)) + return new_variable_map, conversion_map + + +def main(_): + convert_names( + FLAGS.checkpoint_from_path, + FLAGS.checkpoint_to_path, + write_v1_checkpoint=FLAGS.write_v1_checkpoint) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument('checkpoint_from_path', type=str, + help='Path to source checkpoint to be read in.') + parser.add_argument('checkpoint_to_path', type=str, + help='Path to checkpoint to be written out.') + parser.add_argument('--write_v1_checkpoint', action='store_true', + help='Write v1 checkpoint') + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py new file mode 100644 index 00000000000..e2fc2fa80ea --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for checkpoint converter.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os +import tempfile + +from tensorflow.contrib.rnn.python.tools import checkpoint_convert +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class CheckpointConvertTest(test.TestCase): + + def setUp(self): + self._old_ckpt_path = tempfile.mktemp() + self._new_ckpt_path = tempfile.mktemp() + ops.reset_default_graph() + + def tearDown(self): + for file_name in glob.glob(self._old_ckpt_path + "*"): + os.remove(file_name) + for file_name in glob.glob(self._new_ckpt_path + "*"): + os.remove(file_name) + + def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self): + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + + def testConversionFromV2WithConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + variables.Variable(20.0, name=old_name) + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertTrue(glob.glob(self._new_ckpt_path + "*")) + self.assertItemsEqual( + ["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()), + new_var_map.keys()) + self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map) + + def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertItemsEqual(["a"], new_var_map.keys()) + self.assertFalse(conversion_map) + + def testConversionToV1Succeeds(self): + variables.Variable(10.0, name="a") + variables.Variable( + 20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]) + + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True) + self.assertItemsEqual( + ["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]], + new_var_map.keys()) + self.assertEqual( + {list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]: + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]}, + conversion_map) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 4bb5c1b73c6..459d6ee3284 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -57,7 +57,7 @@ BLACKLIST = [ "//tensorflow/contrib/factorization/examples:mnist.py", "//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long "//tensorflow/contrib/bayesflow:reinforce_simple_example", - "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long + "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long ]