RNN checkpoint migration tool

Change: 155158477
This commit is contained in:
Shanqing Cai 2017-05-04 19:11:23 -08:00 committed by TensorFlower Gardener
parent afd69fc26f
commit 44cf98028b
4 changed files with 365 additions and 1 deletions

View File

@ -304,6 +304,7 @@ filegroup(
exclude = [
"**/METADATA",
"**/OWNERS",
"tools/**",
],
),
visibility = ["//tensorflow:__subpackages__"],
@ -351,3 +352,27 @@ tf_kernel_library(
"//third_party/eigen3",
],
)
py_binary(
name = "checkpoint_convert",
srcs = ["python/tools/checkpoint_convert.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
py_test(
name = "checkpoint_convert_test",
size = "small",
srcs = ["python/tools/checkpoint_convert_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":checkpoint_convert",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,231 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints using RNNCells to new name convention.
Usage:
python checkpoint_convert [--write_v1_checkpoint] \
'/path/to/checkpoint' '/path/to/new_checkpoint'
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import re
import sys
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
############################################################################
# contrib/rnn/python/ops/core_rnn_cell_impl.py
# BasicRNNCell
('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
# GRUCell
('gru_cell/weights', 'gru_cell/kernel'),
('gru_cell/biases', 'gru_cell/bias'),
('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
('gru_cell/gates/biases', 'gru_cell/gates/bias'),
('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
# BasicLSTMCell
('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
# LSTMCell
('lstm_cell/weights', 'lstm_cell/kernel'),
('lstm_cell/biases', 'lstm_cell/bias'),
('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
# OutputProjectionWrapper
('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
# InputProjectionWrapper
('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/lstm_ops.py
# LSTMBlockFusedCell ??
('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/rnn_cell.py
# LayerNormBasicLSTMCell
('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
# UGRNNCell, not found in g3, but still need it?
('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
('ugrnn_cell/biases', 'ugrnn_cell/bias'),
# NASCell
('nas_rnn/weights', 'nas_rnn/kernel'),
('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
# IntersectionRNNCell
('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
('intersection_rnn_cell/in_projection/weights',
'intersection_rnn_cell/in_projection/kernel'),
('intersection_rnn_cell/in_projection/biases',
'intersection_rnn_cell/in_projection/bias'),
# PhasedLSTMCell
('phased_lstm_cell/mask_gates/weights',
'phased_lstm_cell/mask_gates/kernel'),
('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
('phased_lstm_cell/output_gate/weights',
'phased_lstm_cell/output_gate/kernel'),
('phased_lstm_cell/output_gate/biases',
'phased_lstm_cell/output_gate/bias'),
# AttentionCellWrapper
('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
('attention_cell_wrapper/attn_output_projection/weights',
'attention_cell_wrapper/attn_output_projection/kernel'),
('attention_cell_wrapper/attn_output_projection/biases',
'attention_cell_wrapper/attn_output_projection/bias'),
('attention_cell_wrapper/attention/weights',
'attention_cell_wrapper/attention/kernel'),
('attention_cell_wrapper/attention/biases',
'attention_cell_wrapper/attention/bias'),
])
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
('LSTMCell/W_', 'lstm_cell/weights/part_'),
('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
('GRUCell/W_', 'gru_cell/weights/part_'),
('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
])
def _rnn_name_replacement(var_name):
for pattern in _RNN_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
break
return var_name
def _rnn_name_replacement_sharded(var_name):
for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern,
_RNN_SHARDED_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
return var_name
def _split_sharded_vars(name_shape_map):
"""Split shareded variables.
Args:
name_shape_map: A dict from variable name to variable shape.
Returns:
not_sharded: Names of the non-sharded variables.
sharded: Names of the sharded varibales.
"""
sharded = []
not_sharded = []
for name in name_shape_map:
if re.match(name, '_[0-9]+$'):
if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
sharded.append(name)
else:
not_sharded.append(name)
else:
not_sharded.append(name)
return not_sharded, sharded
def convert_names(checkpoint_from_path,
checkpoint_to_path,
write_v1_checkpoint=False):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with ops.Graph().as_default():
logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
not_sharded, sharded = _split_sharded_vars(name_shape_map)
new_variable_map = {}
conversion_map = {}
for var_name in not_sharded:
new_var_name = _rnn_name_replacement(var_name)
tensor = reader.get_tensor(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
for var_name in sharded:
new_var_name = _rnn_name_replacement_sharded(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
write_version = (saver_pb2.SaverDef.V1
if write_v1_checkpoint else saver_pb2.SaverDef.V2)
saver = saver_lib.Saver(new_variable_map, write_version=write_version)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
saver.save(sess, checkpoint_to_path)
logging.info('Summary:')
logging.info(' Converted %d variable name(s).' % len(new_variable_map))
return new_variable_map, conversion_map
def main(_):
convert_names(
FLAGS.checkpoint_from_path,
FLAGS.checkpoint_to_path,
write_v1_checkpoint=FLAGS.write_v1_checkpoint)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument('checkpoint_from_path', type=str,
help='Path to source checkpoint to be read in.')
parser.add_argument('checkpoint_to_path', type=str,
help='Path to checkpoint to be written out.')
parser.add_argument('--write_v1_checkpoint', action='store_true',
help='Write v1 checkpoint')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,108 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for checkpoint converter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import tempfile
from tensorflow.contrib.rnn.python.tools import checkpoint_convert
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class CheckpointConvertTest(test.TestCase):
def setUp(self):
self._old_ckpt_path = tempfile.mktemp()
self._new_ckpt_path = tempfile.mktemp()
ops.reset_default_graph()
def tearDown(self):
for file_name in glob.glob(self._old_ckpt_path + "*"):
os.remove(file_name)
for file_name in glob.glob(self._new_ckpt_path + "*"):
os.remove(file_name)
def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
variables.Variable(20.0, name=old_name)
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
self.assertItemsEqual(
["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
new_var_map.keys())
self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertItemsEqual(["a"], new_var_map.keys())
self.assertFalse(conversion_map)
def testConversionToV1Succeeds(self):
variables.Variable(10.0, name="a")
variables.Variable(
20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
self.assertItemsEqual(
["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
new_var_map.keys())
self.assertEqual(
{list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
conversion_map)
if __name__ == "__main__":
test.main()

View File

@ -57,7 +57,7 @@ BLACKLIST = [
"//tensorflow/contrib/factorization/examples:mnist.py",
"//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
]