From a936b239cf91a9727d15ab94fbe1b08e68c685b3 Mon Sep 17 00:00:00 2001 From: James Qin Date: Tue, 20 Jun 2017 16:54:29 -0700 Subject: [PATCH] Support reuse cuDNNLSTM-trained checkpoints by multi-layer LSTM(Block)Cell * Add tensor stitching/partition in RNNParamSaveable for saving/restoring properly shaped and formatted weights/biases to share w/ LSTM(Block)Cell * Add remapped names for canonical tensors during saving. * Unittests PiperOrigin-RevId: 159634913 --- tensorflow/contrib/cudnn_rnn/BUILD | 2 + .../python/kernel_tests/cudnn_rnn_ops_test.py | 241 +++++++++++++++++- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 229 ++++++++++++++++- tensorflow/contrib/rnn/python/ops/lstm_ops.py | 12 +- tensorflow/python/ops/rnn_cell_impl.py | 5 +- 5 files changed, 465 insertions(+), 24 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index b1caac476a2..fc473d3380d 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -87,6 +87,8 @@ cuda_py_test( additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 08ec3076e49..0e51ab99353 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -20,8 +20,13 @@ from __future__ import print_function import os import unittest +import numpy as np + from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase @@ -29,10 +34,14 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib @@ -69,7 +78,8 @@ class CudnnRNNTest(TensorFlowTestCase): model: a CudnnRNN model. """ params_saveable = cudnn_rnn_ops.RNNParamsSaveable( - model.params_to_canonical, model.canonical_to_params, [params]) + model, model.params_to_canonical, model.canonical_to_params, [params], + "rnn") ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) def _testSaveRestoreVariable(self, rnn_mode): @@ -93,6 +103,218 @@ class CudnnRNNTest(TensorFlowTestCase): params_v_restored = sess.run(params) self.assertAllEqual(params_v, params_v_restored) + def _create_equivalent_canonical_rnn(self, + cudnn_model, + inputs, + use_block_cell, + scope="rnn"): + if cudnn_model.rnn_mode is not "lstm": + raise ValueError("%s is not supported!" % cudnn_model.rnn_mode) + + num_units = cudnn_model.num_units + num_layers = cudnn_model.num_layers + + # To reuse cuDNN-trained models, must set + # forget_bias, clip_cell = 0, False + # In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned + # bias, whereas cuDNN does not apply the additional bias. + if use_block_cell: + # pylint: disable=g-long-lambda + single_cell = lambda: lstm_ops.LSTMBlockCell(num_units, forget_bias=0, + clip_cell=False) + # pylint: enable=g-long-lambda + else: + single_cell = lambda: rnn_cell_impl.LSTMCell(num_units, forget_bias=0) + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + + def _build_forward_cudnn_model(self, + rnn_mode, + num_layers, + num_units, + input_data, + is_training=False): + input_data_shape = input_data.get_shape().with_rank(3) + batch_size = input_data_shape[1].value + input_size = input_data_shape[2].value + model = self._CreateModel(rnn_mode, num_layers, num_units, input_size) + + # Set zero init input states + input_h = constant_op.constant( + np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) + has_input_c = (rnn_mode == "lstm") + if has_input_c: + input_c = constant_op.constant( + np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) + + # Set rnn params + params_size_t = model.params_size() + params = variables.Variable( + random_ops.random_uniform([params_size_t]), validate_shape=False) + args = { + "input_data": input_data, + "input_h": input_h, + "params": params, + "is_training": is_training + } + if has_input_c: + args["input_c"] = input_c + # Build cell + output_tuple = model(**args) + + # Create savable objects for params + self._create_params_savable(params, model) + + return output_tuple, model, params + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCheckpointReusableByCanonicalLSTMCells(self): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + "rnn_mode": "lstm" + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + "rnn_mode": "lstm" + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + "rnn_mode": "lstm" + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + "rnn_mode": "lstm" + }, + ] + for cfg in configs: + self._testCheckpointReusableByCanonicalLSTMCells( + cfg["num_layers"], + cfg["seq_length"], + cfg["num_units"], + cfg["input_size"], + cfg["batch_size"], + cfg["rnn_mode"], + use_block_cell=False) + self._testCheckpointReusableByCanonicalLSTMCells( + cfg["num_layers"], + cfg["seq_length"], + cfg["num_units"], + cfg["input_size"], + cfg["batch_size"], + cfg["rnn_mode"], + use_block_cell=True) + + def _testCheckpointReusableByCanonicalLSTMCells( + self, num_layers, seq_length, num_units, input_size, batch_size, rnn_mode, + use_block_cell): + np.random.seed(0) + # Train graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + input_data = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model( + rnn_mode, num_layers, num_units, input_data, is_training=True) + target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None) + total_sum = sum(map(math_ops.reduce_sum, output_tuple)) + + loss_op = losses.log_loss(labels=target_output, predictions=total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + # Train Cudnn model + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + targets = np.random.rand() + sess.run( + train_op, feed_dict={input_data: inputs, + target_output: targets}) + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + cudnn_params_v = sess.run(cudnn_params) + + # cuDNN inference graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + cudnn_inputs = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + (cudnn_output_tuple, cudnn_model, + cudnn_params) = self._build_forward_cudnn_model( + rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False) + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + restored_cudnn_params_v = sess.run(cudnn_params) + self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v) + + # Cudnn inference + (cudnn_output, cudnn_output_h, cudnn_output_c) = sess.run( + cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input}) + + # LSTMBlockCell inference graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + cell_inputs = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + (output, states) = self._create_equivalent_canonical_rnn( + cudnn_model, cell_inputs, use_block_cell) + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + output_v, states_v = sess.run( + [output, states], feed_dict={cell_inputs: inference_input}) + + # output across timestamps are packed into one tensor. + self.assertAllClose(cudnn_output, output_v, atol=1e-6, rtol=1e-6) + + for i in range(num_layers): + # output_h + self.assertAllClose( + cudnn_output_h[i, :], states_v[i].h, atol=1e-6, rtol=1e-6) + # output_c + self.assertAllClose( + cudnn_output_c[i, :], states_v[i].c, atol=1e-6, rtol=1e-6) + def _testSaveRestoreOutput(self, rnn_mode): num_layers = 2 num_units = 7 @@ -187,9 +409,13 @@ class CudnnRNNTest(TensorFlowTestCase): batch_size, seq_length, dir_count, dropout, expected, tolerance): random_seed.set_random_seed(5678) - model = self._CreateModel(rnn_mode, num_layers, num_units, input_size, - input_mode="auto_select", - dropout=dropout) + model = self._CreateModel( + rnn_mode, + num_layers, + num_units, + input_size, + input_mode="auto_select", + dropout=dropout) has_input_c = (rnn_mode == "lstm") params_size_t = model.params_size() input_data = array_ops.ones([seq_length, batch_size, input_size]) @@ -216,7 +442,7 @@ class CudnnRNNTest(TensorFlowTestCase): if has_input_c: output_c_sum = math_ops.reduce_sum(output_c) total_sum += output_c_sum - with self.test_session(use_gpu=True) as sess: + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: sess.run(variables.global_variables_initializer()) total_sum_v = sess.run([total_sum]) @@ -310,8 +536,8 @@ class CudnnRNNTest(TensorFlowTestCase): os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) has_input_c = (rnn_mode == "lstm") random_seed.set_random_seed(1234) - model = self._CreateModel(rnn_mode, num_layers, num_units, input_size, - dropout=dropout) + model = self._CreateModel( + rnn_mode, num_layers, num_units, input_size, dropout=dropout) params_size_t = model.params_size() input_data = variables.Variable( random_ops.random_uniform([seq_length, batch_size, input_size])) @@ -417,6 +643,7 @@ class CudnnRNNTest(TensorFlowTestCase): }, }, ] + ops.reset_default_graph() with ops.Graph().as_default(): for config in test_configs: rnn_mode = config["rnn_mode"] diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index cc0c7b08296..0437467f3fb 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -16,7 +16,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops from tensorflow.contrib.util import loader @@ -46,9 +45,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): """SaveableObject implementation that handles the RNN params variable.""" def __init__(self, + cudnn_rnn, params_to_canonical, canonical_to_params, param_variables, + base_variable_scope=None, name="params_canonical"): """Creates a RNNParamsSaveable object. @@ -75,6 +76,7 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. Args: + cudnn_rnn: cudnn RNN class instance. params_to_canonical: a function to convert params from a specific format for cuDNN or other RNN ops to the canonical format. _CudnnRNN.params_to_canonical() should be provided here. @@ -87,25 +89,42 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): For cuDNN RNN ops, this is a single merged variable for both weights and biases; for other RNN ops, this might be multiple unmerged or partially merged variables respectively for weights and biases. + base_variable_scope: a string, name of outer variable scope, used as + part of prefix of names of saved variables. name: the name of the RNNParamsSaveable object. """ # There is only a single merged parameter variable for cuDNN when saving. + self._cudnn_rnn = cudnn_rnn weights, biases = params_to_canonical(param_variables[0]) + weights, biases, = self._transform_canonical(weights, biases) + weight_names, biase_names = self._transformed_canonical_names( + weights, biases) self._canonical_to_params = canonical_to_params self._variables = param_variables # We currently don't use slice_spec. It might be useful in a distributed # setting where each parameter server node stores a slice of variable, # instead of having the master pull all slices and then save them. slice_spec = "" + params = weights + biases + param_names = weight_names + biase_names + if base_variable_scope: + param_names = ["%s/%s" % (base_variable_scope, pn) for pn in param_names] specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param.name) - for param in itertools.chain(weights, biases) + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, param_names) ] super(RNNParamsSaveable, self).__init__(None, specs, name) def restore(self, restored_tensors, restored_shapes): - weights = restored_tensors[:len(restored_tensors) // 2] - biases = restored_tensors[len(restored_tensors) // 2:] + if (self._cudnn_rnn.direction == "unidirectional" and + self._cudnn_rnn.rnn_mode == "lstm"): + assert len(restored_tensors) % 4 == 0 + weights = restored_tensors[:len(restored_tensors) // 4] + biases = restored_tensors[len(restored_tensors) // 4:] + else: + weights = restored_tensors[:len(restored_tensors) // 2] + biases = restored_tensors[len(restored_tensors) // 2:] + weights, biases = self._untransform_canonical(weights, biases) params = self._canonical_to_params(weights, biases) if not isinstance(params, tuple): params = (params,) @@ -115,6 +134,159 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): ] return control_flow_ops.group(*assign_ops) + def _switch_inner(self, array, base_idx): + array[base_idx + 1], array[base_idx + 2] = (array[base_idx + 2], + array[base_idx + 1]) + + def _transform_canonical(self, weights, biases): + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + return weights, biases + return self._transform_lstm_canonical(weights, biases) + + def _transformed_canonical_names(self, weights, biases): + """Return canonical names for fused weight and bias tensors.""" + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + assert len(weights) == len(biases) + return ([w.name for w in weights], [b.name for b in biases]) + else: + w_names, b_names = [], [] + assert len(weights) * 3 == len(biases) + num_layers = self._cudnn_rnn.num_layers + # TODO(jamesqin): get rid of multi_rnn_cell when num_layers is 1 + for i in range(num_layers): + # One fused weight tensor each layer. + w_names.append("multi_rnn_cell/cell_%d/lstm_cell/kernel" % i) + # Three fused bias tensors each layer: + # the 1st is for LSTMBlockCell restore; the latter two sum up to the + # 1st, and are used for cuDNN restore. + b_names.append("multi_rnn_cell/cell_%d/lstm_cell/bias" % i) + b_names.extend([ + "multi_rnn_cell/cell_%d/lstm_cell/bias_cudnn_%d" % (i, j) + for j in range(2) + ]) + return w_names, b_names + + def _transform_lstm_canonical(self, weights, biases): + """Create fused lstm canonical params. + + Produce properly-shaped monolithic weight and bias tensors to share between + cuDNN and non-platform specific LSTM cells (w/o peephole). + Args: + weights: a list of Tensors recovered from cuDNN params_to_canonical. + biases: a list of Tensors recovered from cuDNN params_to_canonical. + Returns: + Two lists of tensors, one for weight and bias each. + The weight list contains num_layers tensors and bias one contains 3 * + num_layers tensors. Both original and combined biases since cuDNN biases + are not restorable from the fused version. + """ + transformed_weights, transformed_biases = [], [] + for i in range(self._cudnn_rnn.num_layers): + base_idx = i * 8 + num_units = self._cudnn_rnn.num_units + input_size = self._cudnn_rnn.input_size if i == 0 else num_units + # cuDNN tensor shapes per time_step: + # input.shape: [batch_size, input_size], + # input_weights.shape: [num_units, input_size] (first layer) + # [num_units, num_units] (other layers) + # state_weights.shape: [num_units, num_units] + # biases.shape: [num_units] + # + # General LSTM cells compute gate functions using: + # [x, h_prev] * weights + biases + # Therefore for each layer, they expect + # weight.shape: [input_size + num_units, 4 * num_units] (first_layer) + # [num_units + num_units, 4 * num_units] (other layers) + # bias.shape: [4 * num_units] + + # Stitch weights together in this layer. + stitched_w = [] + for j in range(4): + stitched_w.append( + array_ops.concat( + [ + array_ops.reshape(weights[base_idx + j], + [num_units, input_size]), + array_ops.reshape(weights[base_idx + j + 4], + [num_units, num_units]) + ], + axis=1)) + # cuDNN weights are in ifco order, convert to icfo order. + self._switch_inner(stitched_w, 0) + transformed_weights.append( + array_ops.transpose(array_ops.concat(stitched_w, axis=0))) + + # Stitch biases together in this layer. + # Convert to icfo order. + self._switch_inner(biases, base_idx) + self._switch_inner(biases, base_idx + 4) + # The bias for layer input. + b_in = array_ops.concat(biases[base_idx:base_idx + 4], axis=0) + # The bias for recurrent input. + b_rec = array_ops.concat(biases[base_idx + 4:base_idx + 8], axis=0) + + transformed_biases.extend([b_in + b_rec, b_in, b_rec]) + return transformed_weights, transformed_biases + + def _untransform_canonical(self, transformed_weights, transformed_biases): + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + return transformed_weights, transformed_biases + return self._untransform_lstm_canonical(transformed_weights, + transformed_biases) + + def _untransform_lstm_canonical(self, transformed_weights, + transformed_biases): + """The reverse procedure of _transform_lstm_canonical(). + + Args: + transformed_weights: a list of tensors, one for each layer. + transformed_biases: a list of tensors , 3 for each layer: the 2nd for + layer input, the 3rd for recurrent input, the 1st is the sum of the + latter two. + Returns: + Two lists of tensors for weights and biases respectively. + There are 8 tensors per weight and per bias for each layer: + tensor 0-3 are applied to the input from the previous layer; + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + """ + weights, biases = [], [] + assert 3 * len(transformed_weights) == len(transformed_biases) + for i in range(len(transformed_weights)): + num_units = self._cudnn_rnn.num_units + input_size = self._cudnn_rnn.input_size if i == 0 else num_units + # weights applied on layer inputs. + wi = array_ops.slice(transformed_weights[i], [0, 0], + [input_size, 4 * num_units]) + # weights applied on recurrent inputs. + wr = array_ops.slice(transformed_weights[i], [input_size, 0], + [num_units, 4 * num_units]) + wi_list = array_ops.split(wi, 4, axis=1) + wr_list = array_ops.split(wr, 4, axis=1) + + for j in range(len(wi_list)): + wi_list[j] = array_ops.reshape(array_ops.transpose(wi_list[j]), [-1]) + wr_list[j] = array_ops.reshape(array_ops.transpose(wr_list[j]), [-1]) + # canonical weights are in icfo order, convert to ifco order for cuDNN. + self._switch_inner(wi_list, 0) + self._switch_inner(wr_list, 0) + weights.extend(wi_list) + weights.extend(wr_list) + + base_idx = 3 * i + bi_list = array_ops.split(transformed_biases[base_idx + 1], 4, axis=0) + br_list = array_ops.split(transformed_biases[base_idx + 2], 4, axis=0) + # canonical weights are in icfo order, convert to ifco order for cuDNN. + self._switch_inner(bi_list, 0) + self._switch_inner(br_list, 0) + biases.extend(bi_list) + biases.extend(br_list) + return weights, biases + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and @@ -199,6 +371,26 @@ class _CudnnRNN(object): if self._seed is None and self._seed2 is None: self._seed, self._seed2 = 0, 0 + @property + def input_size(self): + return self._input_size + + @property + def num_units(self): + return self._num_units + + @property + def num_layers(self): + return self._num_layers + + @property + def rnn_mode(self): + return self._rnn_mode + + @property + def direction(self): + return self._direction + def params_size(self): """Calculates the size of the opaque parameter buffer needed for this model. @@ -222,9 +414,12 @@ class _CudnnRNN(object): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. - input_h: the initial hidden state for h. + input_data: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. + A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. @@ -308,7 +503,7 @@ class CudnnLSTM(_CudnnRNN): num_layers, num_units, input_size, - input_mode="auto_select", + input_mode="linear_input", direction="unidirectional", dropout=0., seed=0): @@ -344,9 +539,12 @@ class CudnnLSTM(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. - input_h: the initial hidden state for h. - input_c: the initial hidden state for c. + input_data: the input sequence to the LSTM model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + input_c: the initial hidden state for c. A Tensor of the same shape as + input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. @@ -368,7 +566,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): num_layers, num_units, input_size, - input_mode="auto_select", + input_mode="linear_input", direction="unidirectional", dropout=0., seed=0): @@ -390,6 +588,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. """ + super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, num_layers, @@ -404,8 +603,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. - input_h: the initial hidden state for h. + input_data: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index c41b5793fc9..97b9dcc905d 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -58,7 +58,7 @@ def _lstm_block_cell(x, ```python xh = [x, h_prev] - [i, f, ci, o] = xh * w + b + [i, ci, f, o] = xh * w + b f = f + forget_bias if not use_peephole: @@ -93,7 +93,7 @@ def _lstm_block_cell(x, The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. cell_clip: An optional `float`. Defaults to `3`. - Value to clip the 'cs' value to. + Value to clip the 'cs' value to. Disable by setting to negative value. use_peephole: An optional `bool`. Defaults to `False`. Whether to use peephole weights. name: A name for the operation (optional). @@ -341,17 +341,24 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): def __init__(self, num_units, forget_bias=1.0, + clip_cell=True, use_peephole=False): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). + clip_cell: boolean, whether to apply cell clipping. See + `_lstm_block_cell()` for details. use_peephole: Whether to use peephole connections or not. + + When restoring from CudnnLSTM-trained checkpoints, must set the following: + forget_bias, clip_cell, use_peephole = 0, False, False """ self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole + self._clip_cell = clip_cell self._names = { "W": "kernel", "b": "bias", @@ -400,6 +407,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): wco=wco, wcf=wcf, forget_bias=self._forget_bias, + cell_clip=None if self._clip_cell else -1, use_peephole=self._use_peephole) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 49a4aba4735..ca69cddae25 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -345,6 +345,8 @@ class BasicLSTMCell(RNNCell): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. @@ -444,7 +446,8 @@ class LSTMCell(RNNCell): Use a variable_scope partitioner instead. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of - the training. + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated.