Support reuse cuDNNLSTM-trained checkpoints by multi-layer LSTM(Block)Cell

* Add tensor stitching/partition in RNNParamSaveable for saving/restoring properly shaped and formatted weights/biases to share w/ LSTM(Block)Cell
* Add remapped names for canonical tensors during saving.
* Unittests

PiperOrigin-RevId: 159634913
This commit is contained in:
James Qin 2017-06-20 16:54:29 -07:00 committed by TensorFlower Gardener
parent 4be287671a
commit a936b239cf
5 changed files with 465 additions and 24 deletions

View File

@ -87,6 +87,8 @@ cuda_py_test(
additional_deps = [
":cudnn_rnn_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/python/ops/losses:losses",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",

View File

@ -20,8 +20,13 @@ from __future__ import print_function
import os
import unittest
import numpy as np
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework.test_util import TensorFlowTestCase
@ -29,10 +34,14 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import saver as saver_lib
@ -69,7 +78,8 @@ class CudnnRNNTest(TensorFlowTestCase):
model: a CudnnRNN model.
"""
params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
model.params_to_canonical, model.canonical_to_params, [params])
model, model.params_to_canonical, model.canonical_to_params, [params],
"rnn")
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
def _testSaveRestoreVariable(self, rnn_mode):
@ -93,6 +103,218 @@ class CudnnRNNTest(TensorFlowTestCase):
params_v_restored = sess.run(params)
self.assertAllEqual(params_v, params_v_restored)
def _create_equivalent_canonical_rnn(self,
cudnn_model,
inputs,
use_block_cell,
scope="rnn"):
if cudnn_model.rnn_mode is not "lstm":
raise ValueError("%s is not supported!" % cudnn_model.rnn_mode)
num_units = cudnn_model.num_units
num_layers = cudnn_model.num_layers
# To reuse cuDNN-trained models, must set
# forget_bias, clip_cell = 0, False
# In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned
# bias, whereas cuDNN does not apply the additional bias.
if use_block_cell:
# pylint: disable=g-long-lambda
single_cell = lambda: lstm_ops.LSTMBlockCell(num_units, forget_bias=0,
clip_cell=False)
# pylint: enable=g-long-lambda
else:
single_cell = lambda: rnn_cell_impl.LSTMCell(num_units, forget_bias=0)
cell = rnn_cell_impl.MultiRNNCell(
[single_cell() for _ in range(num_layers)])
return rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope)
def _build_forward_cudnn_model(self,
rnn_mode,
num_layers,
num_units,
input_data,
is_training=False):
input_data_shape = input_data.get_shape().with_rank(3)
batch_size = input_data_shape[1].value
input_size = input_data_shape[2].value
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size)
# Set zero init input states
input_h = constant_op.constant(
np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
has_input_c = (rnn_mode == "lstm")
if has_input_c:
input_c = constant_op.constant(
np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
# Set rnn params
params_size_t = model.params_size()
params = variables.Variable(
random_ops.random_uniform([params_size_t]), validate_shape=False)
args = {
"input_data": input_data,
"input_h": input_h,
"params": params,
"is_training": is_training
}
if has_input_c:
args["input_c"] = input_c
# Build cell
output_tuple = model(**args)
# Create savable objects for params
self._create_params_savable(params, model)
return output_tuple, model, params
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def testCheckpointReusableByCanonicalLSTMCells(self):
configs = [
{
"num_layers": 1,
"seq_length": 3,
"num_units": 4,
"input_size": 5,
"batch_size": 6,
"rnn_mode": "lstm"
},
{
"num_layers": 2,
"seq_length": 8,
"num_units": 4,
"input_size": 8,
"batch_size": 16,
"rnn_mode": "lstm"
},
{
"num_layers": 2,
"seq_length": 3,
"num_units": 4,
"input_size": 5,
"batch_size": 6,
"rnn_mode": "lstm"
},
{
"num_layers": 1,
"seq_length": 2,
"num_units": 2,
"input_size": 4,
"batch_size": 1,
"rnn_mode": "lstm"
},
]
for cfg in configs:
self._testCheckpointReusableByCanonicalLSTMCells(
cfg["num_layers"],
cfg["seq_length"],
cfg["num_units"],
cfg["input_size"],
cfg["batch_size"],
cfg["rnn_mode"],
use_block_cell=False)
self._testCheckpointReusableByCanonicalLSTMCells(
cfg["num_layers"],
cfg["seq_length"],
cfg["num_units"],
cfg["input_size"],
cfg["batch_size"],
cfg["rnn_mode"],
use_block_cell=True)
def _testCheckpointReusableByCanonicalLSTMCells(
self, num_layers, seq_length, num_units, input_size, batch_size, rnn_mode,
use_block_cell):
np.random.seed(0)
# Train graph
with ops.Graph().as_default():
random_seed.set_random_seed(299)
input_data = array_ops.placeholder(
dtypes.float32, shape=[seq_length, batch_size, input_size])
output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model(
rnn_mode, num_layers, num_units, input_data, is_training=True)
target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None)
total_sum = sum(map(math_ops.reduce_sum, output_tuple))
loss_op = losses.log_loss(labels=target_output, predictions=total_sum)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2)
train_op = optimizer.minimize(loss_op)
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
# Train Cudnn model
with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer())
# Train 128 steps
num_steps = 128
for _ in range(num_steps):
inputs = np.random.rand(seq_length, batch_size,
input_size).astype(np.float32)
targets = np.random.rand()
sess.run(
train_op, feed_dict={input_data: inputs,
target_output: targets})
save_path = os.path.join(self.get_temp_dir(),
("cudnn-rnn-%s-test" % rnn_mode))
save_v = saver.save(sess, save_path)
self.assertEqual(save_path, save_v)
cudnn_params_v = sess.run(cudnn_params)
# cuDNN inference graph
with ops.Graph().as_default():
random_seed.set_random_seed(299)
cudnn_inputs = array_ops.placeholder(
dtypes.float32, shape=[seq_length, batch_size, input_size])
(cudnn_output_tuple, cudnn_model,
cudnn_params) = self._build_forward_cudnn_model(
rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False)
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
inference_input = np.random.rand(seq_length, batch_size,
input_size).astype(np.float32)
with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer())
saver.restore(sess, save_path)
restored_cudnn_params_v = sess.run(cudnn_params)
self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v)
# Cudnn inference
(cudnn_output, cudnn_output_h, cudnn_output_c) = sess.run(
cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input})
# LSTMBlockCell inference graph
with ops.Graph().as_default():
random_seed.set_random_seed(299)
cell_inputs = array_ops.placeholder(
dtypes.float32, shape=[seq_length, batch_size, input_size])
(output, states) = self._create_equivalent_canonical_rnn(
cudnn_model, cell_inputs, use_block_cell)
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
with self.test_session(
use_gpu=True, graph=ops.get_default_graph()) as sess:
saver.restore(sess, save_path)
# BlockCell inference
output_v, states_v = sess.run(
[output, states], feed_dict={cell_inputs: inference_input})
# output across timestamps are packed into one tensor.
self.assertAllClose(cudnn_output, output_v, atol=1e-6, rtol=1e-6)
for i in range(num_layers):
# output_h
self.assertAllClose(
cudnn_output_h[i, :], states_v[i].h, atol=1e-6, rtol=1e-6)
# output_c
self.assertAllClose(
cudnn_output_c[i, :], states_v[i].c, atol=1e-6, rtol=1e-6)
def _testSaveRestoreOutput(self, rnn_mode):
num_layers = 2
num_units = 7
@ -187,9 +409,13 @@ class CudnnRNNTest(TensorFlowTestCase):
batch_size, seq_length, dir_count, dropout,
expected, tolerance):
random_seed.set_random_seed(5678)
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
input_mode="auto_select",
dropout=dropout)
model = self._CreateModel(
rnn_mode,
num_layers,
num_units,
input_size,
input_mode="auto_select",
dropout=dropout)
has_input_c = (rnn_mode == "lstm")
params_size_t = model.params_size()
input_data = array_ops.ones([seq_length, batch_size, input_size])
@ -216,7 +442,7 @@ class CudnnRNNTest(TensorFlowTestCase):
if has_input_c:
output_c_sum = math_ops.reduce_sum(output_c)
total_sum += output_c_sum
with self.test_session(use_gpu=True) as sess:
with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer())
total_sum_v = sess.run([total_sum])
@ -310,8 +536,8 @@ class CudnnRNNTest(TensorFlowTestCase):
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
has_input_c = (rnn_mode == "lstm")
random_seed.set_random_seed(1234)
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
dropout=dropout)
model = self._CreateModel(
rnn_mode, num_layers, num_units, input_size, dropout=dropout)
params_size_t = model.params_size()
input_data = variables.Variable(
random_ops.random_uniform([seq_length, batch_size, input_size]))
@ -417,6 +643,7 @@ class CudnnRNNTest(TensorFlowTestCase):
},
},
]
ops.reset_default_graph()
with ops.Graph().as_default():
for config in test_configs:
rnn_mode = config["rnn_mode"]

View File

@ -16,7 +16,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
from tensorflow.contrib.util import loader
@ -46,9 +45,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation that handles the RNN params variable."""
def __init__(self,
cudnn_rnn,
params_to_canonical,
canonical_to_params,
param_variables,
base_variable_scope=None,
name="params_canonical"):
"""Creates a RNNParamsSaveable object.
@ -75,6 +76,7 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate.
Args:
cudnn_rnn: cudnn RNN class instance.
params_to_canonical: a function to convert params from a specific format
for cuDNN or other RNN ops to the canonical format.
_CudnnRNN.params_to_canonical() should be provided here.
@ -87,25 +89,42 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
For cuDNN RNN ops, this is a single merged variable for both weights
and biases; for other RNN ops, this might be multiple unmerged or
partially merged variables respectively for weights and biases.
base_variable_scope: a string, name of outer variable scope, used as
part of prefix of names of saved variables.
name: the name of the RNNParamsSaveable object.
"""
# There is only a single merged parameter variable for cuDNN when saving.
self._cudnn_rnn = cudnn_rnn
weights, biases = params_to_canonical(param_variables[0])
weights, biases, = self._transform_canonical(weights, biases)
weight_names, biase_names = self._transformed_canonical_names(
weights, biases)
self._canonical_to_params = canonical_to_params
self._variables = param_variables
# We currently don't use slice_spec. It might be useful in a distributed
# setting where each parameter server node stores a slice of variable,
# instead of having the master pull all slices and then save them.
slice_spec = ""
params = weights + biases
param_names = weight_names + biase_names
if base_variable_scope:
param_names = ["%s/%s" % (base_variable_scope, pn) for pn in param_names]
specs = [
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param.name)
for param in itertools.chain(weights, biases)
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name)
for param, param_name in zip(params, param_names)
]
super(RNNParamsSaveable, self).__init__(None, specs, name)
def restore(self, restored_tensors, restored_shapes):
weights = restored_tensors[:len(restored_tensors) // 2]
biases = restored_tensors[len(restored_tensors) // 2:]
if (self._cudnn_rnn.direction == "unidirectional" and
self._cudnn_rnn.rnn_mode == "lstm"):
assert len(restored_tensors) % 4 == 0
weights = restored_tensors[:len(restored_tensors) // 4]
biases = restored_tensors[len(restored_tensors) // 4:]
else:
weights = restored_tensors[:len(restored_tensors) // 2]
biases = restored_tensors[len(restored_tensors) // 2:]
weights, biases = self._untransform_canonical(weights, biases)
params = self._canonical_to_params(weights, biases)
if not isinstance(params, tuple):
params = (params,)
@ -115,6 +134,159 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
]
return control_flow_ops.group(*assign_ops)
def _switch_inner(self, array, base_idx):
array[base_idx + 1], array[base_idx + 2] = (array[base_idx + 2],
array[base_idx + 1])
def _transform_canonical(self, weights, biases):
if (self._cudnn_rnn.direction != "unidirectional" or
self._cudnn_rnn.rnn_mode != "lstm"):
return weights, biases
return self._transform_lstm_canonical(weights, biases)
def _transformed_canonical_names(self, weights, biases):
"""Return canonical names for fused weight and bias tensors."""
if (self._cudnn_rnn.direction != "unidirectional" or
self._cudnn_rnn.rnn_mode != "lstm"):
assert len(weights) == len(biases)
return ([w.name for w in weights], [b.name for b in biases])
else:
w_names, b_names = [], []
assert len(weights) * 3 == len(biases)
num_layers = self._cudnn_rnn.num_layers
# TODO(jamesqin): get rid of multi_rnn_cell when num_layers is 1
for i in range(num_layers):
# One fused weight tensor each layer.
w_names.append("multi_rnn_cell/cell_%d/lstm_cell/kernel" % i)
# Three fused bias tensors each layer:
# the 1st is for LSTMBlockCell restore; the latter two sum up to the
# 1st, and are used for cuDNN restore.
b_names.append("multi_rnn_cell/cell_%d/lstm_cell/bias" % i)
b_names.extend([
"multi_rnn_cell/cell_%d/lstm_cell/bias_cudnn_%d" % (i, j)
for j in range(2)
])
return w_names, b_names
def _transform_lstm_canonical(self, weights, biases):
"""Create fused lstm canonical params.
Produce properly-shaped monolithic weight and bias tensors to share between
cuDNN and non-platform specific LSTM cells (w/o peephole).
Args:
weights: a list of Tensors recovered from cuDNN params_to_canonical.
biases: a list of Tensors recovered from cuDNN params_to_canonical.
Returns:
Two lists of tensors, one for weight and bias each.
The weight list contains num_layers tensors and bias one contains 3 *
num_layers tensors. Both original and combined biases since cuDNN biases
are not restorable from the fused version.
"""
transformed_weights, transformed_biases = [], []
for i in range(self._cudnn_rnn.num_layers):
base_idx = i * 8
num_units = self._cudnn_rnn.num_units
input_size = self._cudnn_rnn.input_size if i == 0 else num_units
# cuDNN tensor shapes per time_step:
# input.shape: [batch_size, input_size],
# input_weights.shape: [num_units, input_size] (first layer)
# [num_units, num_units] (other layers)
# state_weights.shape: [num_units, num_units]
# biases.shape: [num_units]
#
# General LSTM cells compute gate functions using:
# [x, h_prev] * weights + biases
# Therefore for each layer, they expect
# weight.shape: [input_size + num_units, 4 * num_units] (first_layer)
# [num_units + num_units, 4 * num_units] (other layers)
# bias.shape: [4 * num_units]
# Stitch weights together in this layer.
stitched_w = []
for j in range(4):
stitched_w.append(
array_ops.concat(
[
array_ops.reshape(weights[base_idx + j],
[num_units, input_size]),
array_ops.reshape(weights[base_idx + j + 4],
[num_units, num_units])
],
axis=1))
# cuDNN weights are in ifco order, convert to icfo order.
self._switch_inner(stitched_w, 0)
transformed_weights.append(
array_ops.transpose(array_ops.concat(stitched_w, axis=0)))
# Stitch biases together in this layer.
# Convert to icfo order.
self._switch_inner(biases, base_idx)
self._switch_inner(biases, base_idx + 4)
# The bias for layer input.
b_in = array_ops.concat(biases[base_idx:base_idx + 4], axis=0)
# The bias for recurrent input.
b_rec = array_ops.concat(biases[base_idx + 4:base_idx + 8], axis=0)
transformed_biases.extend([b_in + b_rec, b_in, b_rec])
return transformed_weights, transformed_biases
def _untransform_canonical(self, transformed_weights, transformed_biases):
if (self._cudnn_rnn.direction != "unidirectional" or
self._cudnn_rnn.rnn_mode != "lstm"):
return transformed_weights, transformed_biases
return self._untransform_lstm_canonical(transformed_weights,
transformed_biases)
def _untransform_lstm_canonical(self, transformed_weights,
transformed_biases):
"""The reverse procedure of _transform_lstm_canonical().
Args:
transformed_weights: a list of tensors, one for each layer.
transformed_biases: a list of tensors , 3 for each layer: the 2nd for
layer input, the 3rd for recurrent input, the 1st is the sum of the
latter two.
Returns:
Two lists of tensors for weights and biases respectively.
There are 8 tensors per weight and per bias for each layer:
tensor 0-3 are applied to the input from the previous layer;
tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate;
tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate;
tensor 3 and 7 the output gate.
"""
weights, biases = [], []
assert 3 * len(transformed_weights) == len(transformed_biases)
for i in range(len(transformed_weights)):
num_units = self._cudnn_rnn.num_units
input_size = self._cudnn_rnn.input_size if i == 0 else num_units
# weights applied on layer inputs.
wi = array_ops.slice(transformed_weights[i], [0, 0],
[input_size, 4 * num_units])
# weights applied on recurrent inputs.
wr = array_ops.slice(transformed_weights[i], [input_size, 0],
[num_units, 4 * num_units])
wi_list = array_ops.split(wi, 4, axis=1)
wr_list = array_ops.split(wr, 4, axis=1)
for j in range(len(wi_list)):
wi_list[j] = array_ops.reshape(array_ops.transpose(wi_list[j]), [-1])
wr_list[j] = array_ops.reshape(array_ops.transpose(wr_list[j]), [-1])
# canonical weights are in icfo order, convert to ifco order for cuDNN.
self._switch_inner(wi_list, 0)
self._switch_inner(wr_list, 0)
weights.extend(wi_list)
weights.extend(wr_list)
base_idx = 3 * i
bi_list = array_ops.split(transformed_biases[base_idx + 1], 4, axis=0)
br_list = array_ops.split(transformed_biases[base_idx + 2], 4, axis=0)
# canonical weights are in icfo order, convert to ifco order for cuDNN.
self._switch_inner(bi_list, 0)
self._switch_inner(br_list, 0)
biases.extend(bi_list)
biases.extend(br_list)
return weights, biases
_cudnn_rnn_common_doc_string = """
Cudnn RNN has an opaque parameter buffer that can be used for inference and
@ -199,6 +371,26 @@ class _CudnnRNN(object):
if self._seed is None and self._seed2 is None:
self._seed, self._seed2 = 0, 0
@property
def input_size(self):
return self._input_size
@property
def num_units(self):
return self._num_units
@property
def num_layers(self):
return self._num_layers
@property
def rnn_mode(self):
return self._rnn_mode
@property
def direction(self):
return self._direction
def params_size(self):
"""Calculates the size of the opaque parameter buffer needed for this model.
@ -222,9 +414,12 @@ class _CudnnRNN(object):
"""Runs the forward step for the RNN model.
Args:
input_data: the input sequence to the RNN model.
input_h: the initial hidden state for h.
input_data: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
@ -308,7 +503,7 @@ class CudnnLSTM(_CudnnRNN):
num_layers,
num_units,
input_size,
input_mode="auto_select",
input_mode="linear_input",
direction="unidirectional",
dropout=0.,
seed=0):
@ -344,9 +539,12 @@ class CudnnLSTM(_CudnnRNN):
"""Runs the forward step for the Cudnn LSTM model.
Args:
input_data: the input sequence to the LSTM model.
input_h: the initial hidden state for h.
input_c: the initial hidden state for c.
input_data: the input sequence to the LSTM model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. A Tensor of the same shape as
input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
@ -368,7 +566,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
num_layers,
num_units,
input_size,
input_mode="auto_select",
input_mode="linear_input",
direction="unidirectional",
dropout=0.,
seed=0):
@ -390,6 +588,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the seed used for initializing dropout.
"""
super(_CudnnRNNNoInputC, self).__init__(
self._rnn_mode,
num_layers,
@ -404,8 +603,10 @@ class _CudnnRNNNoInputC(_CudnnRNN):
"""Runs the forward step for the Cudnn LSTM model.
Args:
input_data: the input sequence to the LSTM model.
input_h: the initial hidden state for h.
input_data: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.

View File

@ -58,7 +58,7 @@ def _lstm_block_cell(x,
```python
xh = [x, h_prev]
[i, f, ci, o] = xh * w + b
[i, ci, f, o] = xh * w + b
f = f + forget_bias
if not use_peephole:
@ -93,7 +93,7 @@ def _lstm_block_cell(x,
The weight matrix for output gate peephole connection.
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
cell_clip: An optional `float`. Defaults to `3`.
Value to clip the 'cs' value to.
Value to clip the 'cs' value to. Disable by setting to negative value.
use_peephole: An optional `bool`. Defaults to `False`.
Whether to use peephole weights.
name: A name for the operation (optional).
@ -341,17 +341,24 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
def __init__(self,
num_units,
forget_bias=1.0,
clip_cell=True,
use_peephole=False):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
clip_cell: boolean, whether to apply cell clipping. See
`_lstm_block_cell()` for details.
use_peephole: Whether to use peephole connections or not.
When restoring from CudnnLSTM-trained checkpoints, must set the following:
forget_bias, clip_cell, use_peephole = 0, False, False
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._use_peephole = use_peephole
self._clip_cell = clip_cell
self._names = {
"W": "kernel",
"b": "bias",
@ -400,6 +407,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
wco=wco,
wcf=wcf,
forget_bias=self._forget_bias,
cell_clip=None if self._clip_cell else -1,
use_peephole=self._use_peephole)
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)

View File

@ -345,6 +345,8 @@ class BasicLSTMCell(RNNCell):
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained
checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
@ -444,7 +446,8 @@ class LSTMCell(RNNCell):
Use a variable_scope partitioner instead.
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training.
the training. Must set it manually to `0.0` when restoring from
CudnnLSTM trained checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.