RNN checkpoint migration tool
Change: 155158477
This commit is contained in:
parent
afd69fc26f
commit
44cf98028b
@ -304,6 +304,7 @@ filegroup(
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
"tools/**",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
@ -351,3 +352,27 @@ tf_kernel_library(
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "checkpoint_convert",
|
||||
srcs = ["python/tools/checkpoint_convert.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "checkpoint_convert_test",
|
||||
size = "small",
|
||||
srcs = ["python/tools/checkpoint_convert_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":checkpoint_convert",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
231
tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
Normal file
231
tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Convert checkpoints using RNNCells to new name convention.
|
||||
|
||||
Usage:
|
||||
|
||||
python checkpoint_convert [--write_v1_checkpoint] \
|
||||
'/path/to/checkpoint' '/path/to/new_checkpoint'
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
|
||||
############################################################################
|
||||
# contrib/rnn/python/ops/core_rnn_cell_impl.py
|
||||
# BasicRNNCell
|
||||
('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
|
||||
('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
|
||||
# GRUCell
|
||||
('gru_cell/weights', 'gru_cell/kernel'),
|
||||
('gru_cell/biases', 'gru_cell/bias'),
|
||||
('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
|
||||
('gru_cell/gates/biases', 'gru_cell/gates/bias'),
|
||||
('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
|
||||
('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
|
||||
# BasicLSTMCell
|
||||
('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
|
||||
('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
|
||||
# LSTMCell
|
||||
('lstm_cell/weights', 'lstm_cell/kernel'),
|
||||
('lstm_cell/biases', 'lstm_cell/bias'),
|
||||
('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
|
||||
('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
|
||||
# OutputProjectionWrapper
|
||||
('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
|
||||
('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
|
||||
# InputProjectionWrapper
|
||||
('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
|
||||
('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
|
||||
############################################################################
|
||||
# contrib/rnn/python/ops/lstm_ops.py
|
||||
# LSTMBlockFusedCell ??
|
||||
('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
|
||||
('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
|
||||
############################################################################
|
||||
# contrib/rnn/python/ops/rnn_cell.py
|
||||
# LayerNormBasicLSTMCell
|
||||
('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
|
||||
('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
|
||||
# UGRNNCell, not found in g3, but still need it?
|
||||
('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
|
||||
('ugrnn_cell/biases', 'ugrnn_cell/bias'),
|
||||
# NASCell
|
||||
('nas_rnn/weights', 'nas_rnn/kernel'),
|
||||
('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
|
||||
# IntersectionRNNCell
|
||||
('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
|
||||
('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
|
||||
('intersection_rnn_cell/in_projection/weights',
|
||||
'intersection_rnn_cell/in_projection/kernel'),
|
||||
('intersection_rnn_cell/in_projection/biases',
|
||||
'intersection_rnn_cell/in_projection/bias'),
|
||||
# PhasedLSTMCell
|
||||
('phased_lstm_cell/mask_gates/weights',
|
||||
'phased_lstm_cell/mask_gates/kernel'),
|
||||
('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
|
||||
('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
|
||||
('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
|
||||
('phased_lstm_cell/output_gate/weights',
|
||||
'phased_lstm_cell/output_gate/kernel'),
|
||||
('phased_lstm_cell/output_gate/biases',
|
||||
'phased_lstm_cell/output_gate/bias'),
|
||||
# AttentionCellWrapper
|
||||
('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
|
||||
('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
|
||||
('attention_cell_wrapper/attn_output_projection/weights',
|
||||
'attention_cell_wrapper/attn_output_projection/kernel'),
|
||||
('attention_cell_wrapper/attn_output_projection/biases',
|
||||
'attention_cell_wrapper/attn_output_projection/bias'),
|
||||
('attention_cell_wrapper/attention/weights',
|
||||
'attention_cell_wrapper/attention/kernel'),
|
||||
('attention_cell_wrapper/attention/biases',
|
||||
'attention_cell_wrapper/attention/bias'),
|
||||
])
|
||||
|
||||
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
|
||||
('LSTMCell/W_', 'lstm_cell/weights/part_'),
|
||||
('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
|
||||
('GRUCell/W_', 'gru_cell/weights/part_'),
|
||||
('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
|
||||
])
|
||||
|
||||
|
||||
def _rnn_name_replacement(var_name):
|
||||
for pattern in _RNN_NAME_REPLACEMENTS:
|
||||
if pattern in var_name:
|
||||
old_var_name = var_name
|
||||
var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
|
||||
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
|
||||
break
|
||||
return var_name
|
||||
|
||||
|
||||
def _rnn_name_replacement_sharded(var_name):
|
||||
for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
|
||||
if pattern in var_name:
|
||||
old_var_name = var_name
|
||||
var_name = var_name.replace(pattern,
|
||||
_RNN_SHARDED_NAME_REPLACEMENTS[pattern])
|
||||
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
|
||||
return var_name
|
||||
|
||||
|
||||
def _split_sharded_vars(name_shape_map):
|
||||
"""Split shareded variables.
|
||||
|
||||
Args:
|
||||
name_shape_map: A dict from variable name to variable shape.
|
||||
|
||||
Returns:
|
||||
not_sharded: Names of the non-sharded variables.
|
||||
sharded: Names of the sharded varibales.
|
||||
"""
|
||||
sharded = []
|
||||
not_sharded = []
|
||||
for name in name_shape_map:
|
||||
if re.match(name, '_[0-9]+$'):
|
||||
if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
|
||||
sharded.append(name)
|
||||
else:
|
||||
not_sharded.append(name)
|
||||
else:
|
||||
not_sharded.append(name)
|
||||
return not_sharded, sharded
|
||||
|
||||
|
||||
def convert_names(checkpoint_from_path,
|
||||
checkpoint_to_path,
|
||||
write_v1_checkpoint=False):
|
||||
"""Migrates the names of variables within a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_from_path: Path to source checkpoint to be read in.
|
||||
checkpoint_to_path: Path to checkpoint to be written out.
|
||||
write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps the new variable names to the Variable objects.
|
||||
A dictionary that maps the old variable names to the new variable names.
|
||||
"""
|
||||
with ops.Graph().as_default():
|
||||
logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
|
||||
name_shape_map = reader.get_variable_to_shape_map()
|
||||
not_sharded, sharded = _split_sharded_vars(name_shape_map)
|
||||
new_variable_map = {}
|
||||
conversion_map = {}
|
||||
for var_name in not_sharded:
|
||||
new_var_name = _rnn_name_replacement(var_name)
|
||||
tensor = reader.get_tensor(var_name)
|
||||
var = variables.Variable(tensor, name=var_name)
|
||||
new_variable_map[new_var_name] = var
|
||||
if new_var_name != var_name:
|
||||
conversion_map[var_name] = new_var_name
|
||||
for var_name in sharded:
|
||||
new_var_name = _rnn_name_replacement_sharded(var_name)
|
||||
var = variables.Variable(tensor, name=var_name)
|
||||
new_variable_map[new_var_name] = var
|
||||
if new_var_name != var_name:
|
||||
conversion_map[var_name] = new_var_name
|
||||
|
||||
write_version = (saver_pb2.SaverDef.V1
|
||||
if write_v1_checkpoint else saver_pb2.SaverDef.V2)
|
||||
saver = saver_lib.Saver(new_variable_map, write_version=write_version)
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
|
||||
saver.save(sess, checkpoint_to_path)
|
||||
|
||||
logging.info('Summary:')
|
||||
logging.info(' Converted %d variable name(s).' % len(new_variable_map))
|
||||
return new_variable_map, conversion_map
|
||||
|
||||
|
||||
def main(_):
|
||||
convert_names(
|
||||
FLAGS.checkpoint_from_path,
|
||||
FLAGS.checkpoint_to_path,
|
||||
write_v1_checkpoint=FLAGS.write_v1_checkpoint)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument('checkpoint_from_path', type=str,
|
||||
help='Path to source checkpoint to be read in.')
|
||||
parser.add_argument('checkpoint_to_path', type=str,
|
||||
help='Path to checkpoint to be written out.')
|
||||
parser.add_argument('--write_v1_checkpoint', action='store_true',
|
||||
help='Write v1 checkpoint')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
108
tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
Normal file
108
tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit tests for checkpoint converter."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.rnn.python.tools import checkpoint_convert
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class CheckpointConvertTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._old_ckpt_path = tempfile.mktemp()
|
||||
self._new_ckpt_path = tempfile.mktemp()
|
||||
ops.reset_default_graph()
|
||||
|
||||
def tearDown(self):
|
||||
for file_name in glob.glob(self._old_ckpt_path + "*"):
|
||||
os.remove(file_name)
|
||||
for file_name in glob.glob(self._new_ckpt_path + "*"):
|
||||
os.remove(file_name)
|
||||
|
||||
def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
|
||||
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
|
||||
new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
|
||||
self.assertTrue(old_name)
|
||||
self.assertTrue(new_name)
|
||||
self.assertNotEqual(old_name, new_name)
|
||||
for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
|
||||
new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
|
||||
self.assertTrue(old_name)
|
||||
self.assertTrue(new_name)
|
||||
self.assertNotEqual(old_name, new_name)
|
||||
|
||||
def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
|
||||
variables.Variable(10.0, name="a")
|
||||
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
|
||||
variables.Variable(20.0, name=old_name)
|
||||
with session.Session() as sess:
|
||||
saver = saver_lib.Saver()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver.save(sess, self._old_ckpt_path)
|
||||
|
||||
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||
self._old_ckpt_path, self._new_ckpt_path)
|
||||
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
|
||||
self.assertItemsEqual(
|
||||
["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
|
||||
new_var_map.keys())
|
||||
self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
|
||||
|
||||
def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
|
||||
variables.Variable(10.0, name="a")
|
||||
with session.Session() as sess:
|
||||
saver = saver_lib.Saver()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver.save(sess, self._old_ckpt_path)
|
||||
|
||||
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||
self._old_ckpt_path, self._new_ckpt_path)
|
||||
self.assertItemsEqual(["a"], new_var_map.keys())
|
||||
self.assertFalse(conversion_map)
|
||||
|
||||
def testConversionToV1Succeeds(self):
|
||||
variables.Variable(10.0, name="a")
|
||||
variables.Variable(
|
||||
20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
|
||||
|
||||
with session.Session() as sess:
|
||||
saver = saver_lib.Saver()
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver.save(sess, self._old_ckpt_path)
|
||||
|
||||
new_var_map, conversion_map = checkpoint_convert.convert_names(
|
||||
self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
|
||||
self.assertItemsEqual(
|
||||
["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
|
||||
new_var_map.keys())
|
||||
self.assertEqual(
|
||||
{list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
|
||||
list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
|
||||
conversion_map)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -57,7 +57,7 @@ BLACKLIST = [
|
||||
"//tensorflow/contrib/factorization/examples:mnist.py",
|
||||
"//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long
|
||||
"//tensorflow/contrib/bayesflow:reinforce_simple_example",
|
||||
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long
|
||||
"//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user