Support reuse cuDNNLSTM-trained checkpoints by multi-layer LSTM(Block)Cell
* Add tensor stitching/partition in RNNParamSaveable for saving/restoring properly shaped and formatted weights/biases to share w/ LSTM(Block)Cell * Add remapped names for canonical tensors during saving. * Unittests PiperOrigin-RevId: 159634913
This commit is contained in:
parent
4be287671a
commit
a936b239cf
@ -87,6 +87,8 @@ cuda_py_test(
|
||||
additional_deps = [
|
||||
":cudnn_rnn_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
"//tensorflow/python/ops/losses:losses",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
|
@ -20,8 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
@ -29,10 +34,14 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
@ -69,7 +78,8 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
model: a CudnnRNN model.
|
||||
"""
|
||||
params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
|
||||
model.params_to_canonical, model.canonical_to_params, [params])
|
||||
model, model.params_to_canonical, model.canonical_to_params, [params],
|
||||
"rnn")
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
|
||||
|
||||
def _testSaveRestoreVariable(self, rnn_mode):
|
||||
@ -93,6 +103,218 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
params_v_restored = sess.run(params)
|
||||
self.assertAllEqual(params_v, params_v_restored)
|
||||
|
||||
def _create_equivalent_canonical_rnn(self,
|
||||
cudnn_model,
|
||||
inputs,
|
||||
use_block_cell,
|
||||
scope="rnn"):
|
||||
if cudnn_model.rnn_mode is not "lstm":
|
||||
raise ValueError("%s is not supported!" % cudnn_model.rnn_mode)
|
||||
|
||||
num_units = cudnn_model.num_units
|
||||
num_layers = cudnn_model.num_layers
|
||||
|
||||
# To reuse cuDNN-trained models, must set
|
||||
# forget_bias, clip_cell = 0, False
|
||||
# In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned
|
||||
# bias, whereas cuDNN does not apply the additional bias.
|
||||
if use_block_cell:
|
||||
# pylint: disable=g-long-lambda
|
||||
single_cell = lambda: lstm_ops.LSTMBlockCell(num_units, forget_bias=0,
|
||||
clip_cell=False)
|
||||
# pylint: enable=g-long-lambda
|
||||
else:
|
||||
single_cell = lambda: rnn_cell_impl.LSTMCell(num_units, forget_bias=0)
|
||||
cell = rnn_cell_impl.MultiRNNCell(
|
||||
[single_cell() for _ in range(num_layers)])
|
||||
return rnn.dynamic_rnn(
|
||||
cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope)
|
||||
|
||||
def _build_forward_cudnn_model(self,
|
||||
rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_data,
|
||||
is_training=False):
|
||||
input_data_shape = input_data.get_shape().with_rank(3)
|
||||
batch_size = input_data_shape[1].value
|
||||
input_size = input_data_shape[2].value
|
||||
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size)
|
||||
|
||||
# Set zero init input states
|
||||
input_h = constant_op.constant(
|
||||
np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
|
||||
has_input_c = (rnn_mode == "lstm")
|
||||
if has_input_c:
|
||||
input_c = constant_op.constant(
|
||||
np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
|
||||
|
||||
# Set rnn params
|
||||
params_size_t = model.params_size()
|
||||
params = variables.Variable(
|
||||
random_ops.random_uniform([params_size_t]), validate_shape=False)
|
||||
args = {
|
||||
"input_data": input_data,
|
||||
"input_h": input_h,
|
||||
"params": params,
|
||||
"is_training": is_training
|
||||
}
|
||||
if has_input_c:
|
||||
args["input_c"] = input_c
|
||||
# Build cell
|
||||
output_tuple = model(**args)
|
||||
|
||||
# Create savable objects for params
|
||||
self._create_params_savable(params, model)
|
||||
|
||||
return output_tuple, model, params
|
||||
|
||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||
"Test only applicable when running on GPUs")
|
||||
def testCheckpointReusableByCanonicalLSTMCells(self):
|
||||
configs = [
|
||||
{
|
||||
"num_layers": 1,
|
||||
"seq_length": 3,
|
||||
"num_units": 4,
|
||||
"input_size": 5,
|
||||
"batch_size": 6,
|
||||
"rnn_mode": "lstm"
|
||||
},
|
||||
{
|
||||
"num_layers": 2,
|
||||
"seq_length": 8,
|
||||
"num_units": 4,
|
||||
"input_size": 8,
|
||||
"batch_size": 16,
|
||||
"rnn_mode": "lstm"
|
||||
},
|
||||
{
|
||||
"num_layers": 2,
|
||||
"seq_length": 3,
|
||||
"num_units": 4,
|
||||
"input_size": 5,
|
||||
"batch_size": 6,
|
||||
"rnn_mode": "lstm"
|
||||
},
|
||||
{
|
||||
"num_layers": 1,
|
||||
"seq_length": 2,
|
||||
"num_units": 2,
|
||||
"input_size": 4,
|
||||
"batch_size": 1,
|
||||
"rnn_mode": "lstm"
|
||||
},
|
||||
]
|
||||
for cfg in configs:
|
||||
self._testCheckpointReusableByCanonicalLSTMCells(
|
||||
cfg["num_layers"],
|
||||
cfg["seq_length"],
|
||||
cfg["num_units"],
|
||||
cfg["input_size"],
|
||||
cfg["batch_size"],
|
||||
cfg["rnn_mode"],
|
||||
use_block_cell=False)
|
||||
self._testCheckpointReusableByCanonicalLSTMCells(
|
||||
cfg["num_layers"],
|
||||
cfg["seq_length"],
|
||||
cfg["num_units"],
|
||||
cfg["input_size"],
|
||||
cfg["batch_size"],
|
||||
cfg["rnn_mode"],
|
||||
use_block_cell=True)
|
||||
|
||||
def _testCheckpointReusableByCanonicalLSTMCells(
|
||||
self, num_layers, seq_length, num_units, input_size, batch_size, rnn_mode,
|
||||
use_block_cell):
|
||||
np.random.seed(0)
|
||||
# Train graph
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(299)
|
||||
input_data = array_ops.placeholder(
|
||||
dtypes.float32, shape=[seq_length, batch_size, input_size])
|
||||
output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model(
|
||||
rnn_mode, num_layers, num_units, input_data, is_training=True)
|
||||
target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None)
|
||||
total_sum = sum(map(math_ops.reduce_sum, output_tuple))
|
||||
|
||||
loss_op = losses.log_loss(labels=target_output, predictions=total_sum)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2)
|
||||
train_op = optimizer.minimize(loss_op)
|
||||
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
|
||||
|
||||
# Train Cudnn model
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Train 128 steps
|
||||
num_steps = 128
|
||||
for _ in range(num_steps):
|
||||
inputs = np.random.rand(seq_length, batch_size,
|
||||
input_size).astype(np.float32)
|
||||
targets = np.random.rand()
|
||||
sess.run(
|
||||
train_op, feed_dict={input_data: inputs,
|
||||
target_output: targets})
|
||||
|
||||
save_path = os.path.join(self.get_temp_dir(),
|
||||
("cudnn-rnn-%s-test" % rnn_mode))
|
||||
save_v = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, save_v)
|
||||
cudnn_params_v = sess.run(cudnn_params)
|
||||
|
||||
# cuDNN inference graph
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(299)
|
||||
cudnn_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=[seq_length, batch_size, input_size])
|
||||
(cudnn_output_tuple, cudnn_model,
|
||||
cudnn_params) = self._build_forward_cudnn_model(
|
||||
rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False)
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
|
||||
|
||||
inference_input = np.random.rand(seq_length, batch_size,
|
||||
input_size).astype(np.float32)
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver.restore(sess, save_path)
|
||||
restored_cudnn_params_v = sess.run(cudnn_params)
|
||||
self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v)
|
||||
|
||||
# Cudnn inference
|
||||
(cudnn_output, cudnn_output_h, cudnn_output_c) = sess.run(
|
||||
cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input})
|
||||
|
||||
# LSTMBlockCell inference graph
|
||||
with ops.Graph().as_default():
|
||||
random_seed.set_random_seed(299)
|
||||
cell_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=[seq_length, batch_size, input_size])
|
||||
(output, states) = self._create_equivalent_canonical_rnn(
|
||||
cudnn_model, cell_inputs, use_block_cell)
|
||||
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
|
||||
|
||||
with self.test_session(
|
||||
use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
saver.restore(sess, save_path)
|
||||
|
||||
# BlockCell inference
|
||||
output_v, states_v = sess.run(
|
||||
[output, states], feed_dict={cell_inputs: inference_input})
|
||||
|
||||
# output across timestamps are packed into one tensor.
|
||||
self.assertAllClose(cudnn_output, output_v, atol=1e-6, rtol=1e-6)
|
||||
|
||||
for i in range(num_layers):
|
||||
# output_h
|
||||
self.assertAllClose(
|
||||
cudnn_output_h[i, :], states_v[i].h, atol=1e-6, rtol=1e-6)
|
||||
# output_c
|
||||
self.assertAllClose(
|
||||
cudnn_output_c[i, :], states_v[i].c, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def _testSaveRestoreOutput(self, rnn_mode):
|
||||
num_layers = 2
|
||||
num_units = 7
|
||||
@ -187,9 +409,13 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
batch_size, seq_length, dir_count, dropout,
|
||||
expected, tolerance):
|
||||
random_seed.set_random_seed(5678)
|
||||
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
|
||||
input_mode="auto_select",
|
||||
dropout=dropout)
|
||||
model = self._CreateModel(
|
||||
rnn_mode,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
input_mode="auto_select",
|
||||
dropout=dropout)
|
||||
has_input_c = (rnn_mode == "lstm")
|
||||
params_size_t = model.params_size()
|
||||
input_data = array_ops.ones([seq_length, batch_size, input_size])
|
||||
@ -216,7 +442,7 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
if has_input_c:
|
||||
output_c_sum = math_ops.reduce_sum(output_c)
|
||||
total_sum += output_c_sum
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
total_sum_v = sess.run([total_sum])
|
||||
|
||||
@ -310,8 +536,8 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
|
||||
has_input_c = (rnn_mode == "lstm")
|
||||
random_seed.set_random_seed(1234)
|
||||
model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
|
||||
dropout=dropout)
|
||||
model = self._CreateModel(
|
||||
rnn_mode, num_layers, num_units, input_size, dropout=dropout)
|
||||
params_size_t = model.params_size()
|
||||
input_data = variables.Variable(
|
||||
random_ops.random_uniform([seq_length, batch_size, input_size]))
|
||||
@ -417,6 +643,7 @@ class CudnnRNNTest(TensorFlowTestCase):
|
||||
},
|
||||
},
|
||||
]
|
||||
ops.reset_default_graph()
|
||||
with ops.Graph().as_default():
|
||||
for config in test_configs:
|
||||
rnn_mode = config["rnn_mode"]
|
||||
|
@ -16,7 +16,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import itertools
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
@ -46,9 +45,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject implementation that handles the RNN params variable."""
|
||||
|
||||
def __init__(self,
|
||||
cudnn_rnn,
|
||||
params_to_canonical,
|
||||
canonical_to_params,
|
||||
param_variables,
|
||||
base_variable_scope=None,
|
||||
name="params_canonical"):
|
||||
"""Creates a RNNParamsSaveable object.
|
||||
|
||||
@ -75,6 +76,7 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate.
|
||||
|
||||
Args:
|
||||
cudnn_rnn: cudnn RNN class instance.
|
||||
params_to_canonical: a function to convert params from a specific format
|
||||
for cuDNN or other RNN ops to the canonical format.
|
||||
_CudnnRNN.params_to_canonical() should be provided here.
|
||||
@ -87,25 +89,42 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
For cuDNN RNN ops, this is a single merged variable for both weights
|
||||
and biases; for other RNN ops, this might be multiple unmerged or
|
||||
partially merged variables respectively for weights and biases.
|
||||
base_variable_scope: a string, name of outer variable scope, used as
|
||||
part of prefix of names of saved variables.
|
||||
name: the name of the RNNParamsSaveable object.
|
||||
"""
|
||||
# There is only a single merged parameter variable for cuDNN when saving.
|
||||
self._cudnn_rnn = cudnn_rnn
|
||||
weights, biases = params_to_canonical(param_variables[0])
|
||||
weights, biases, = self._transform_canonical(weights, biases)
|
||||
weight_names, biase_names = self._transformed_canonical_names(
|
||||
weights, biases)
|
||||
self._canonical_to_params = canonical_to_params
|
||||
self._variables = param_variables
|
||||
# We currently don't use slice_spec. It might be useful in a distributed
|
||||
# setting where each parameter server node stores a slice of variable,
|
||||
# instead of having the master pull all slices and then save them.
|
||||
slice_spec = ""
|
||||
params = weights + biases
|
||||
param_names = weight_names + biase_names
|
||||
if base_variable_scope:
|
||||
param_names = ["%s/%s" % (base_variable_scope, pn) for pn in param_names]
|
||||
specs = [
|
||||
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param.name)
|
||||
for param in itertools.chain(weights, biases)
|
||||
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name)
|
||||
for param, param_name in zip(params, param_names)
|
||||
]
|
||||
super(RNNParamsSaveable, self).__init__(None, specs, name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
weights = restored_tensors[:len(restored_tensors) // 2]
|
||||
biases = restored_tensors[len(restored_tensors) // 2:]
|
||||
if (self._cudnn_rnn.direction == "unidirectional" and
|
||||
self._cudnn_rnn.rnn_mode == "lstm"):
|
||||
assert len(restored_tensors) % 4 == 0
|
||||
weights = restored_tensors[:len(restored_tensors) // 4]
|
||||
biases = restored_tensors[len(restored_tensors) // 4:]
|
||||
else:
|
||||
weights = restored_tensors[:len(restored_tensors) // 2]
|
||||
biases = restored_tensors[len(restored_tensors) // 2:]
|
||||
weights, biases = self._untransform_canonical(weights, biases)
|
||||
params = self._canonical_to_params(weights, biases)
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
@ -115,6 +134,159 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
]
|
||||
return control_flow_ops.group(*assign_ops)
|
||||
|
||||
def _switch_inner(self, array, base_idx):
|
||||
array[base_idx + 1], array[base_idx + 2] = (array[base_idx + 2],
|
||||
array[base_idx + 1])
|
||||
|
||||
def _transform_canonical(self, weights, biases):
|
||||
if (self._cudnn_rnn.direction != "unidirectional" or
|
||||
self._cudnn_rnn.rnn_mode != "lstm"):
|
||||
return weights, biases
|
||||
return self._transform_lstm_canonical(weights, biases)
|
||||
|
||||
def _transformed_canonical_names(self, weights, biases):
|
||||
"""Return canonical names for fused weight and bias tensors."""
|
||||
if (self._cudnn_rnn.direction != "unidirectional" or
|
||||
self._cudnn_rnn.rnn_mode != "lstm"):
|
||||
assert len(weights) == len(biases)
|
||||
return ([w.name for w in weights], [b.name for b in biases])
|
||||
else:
|
||||
w_names, b_names = [], []
|
||||
assert len(weights) * 3 == len(biases)
|
||||
num_layers = self._cudnn_rnn.num_layers
|
||||
# TODO(jamesqin): get rid of multi_rnn_cell when num_layers is 1
|
||||
for i in range(num_layers):
|
||||
# One fused weight tensor each layer.
|
||||
w_names.append("multi_rnn_cell/cell_%d/lstm_cell/kernel" % i)
|
||||
# Three fused bias tensors each layer:
|
||||
# the 1st is for LSTMBlockCell restore; the latter two sum up to the
|
||||
# 1st, and are used for cuDNN restore.
|
||||
b_names.append("multi_rnn_cell/cell_%d/lstm_cell/bias" % i)
|
||||
b_names.extend([
|
||||
"multi_rnn_cell/cell_%d/lstm_cell/bias_cudnn_%d" % (i, j)
|
||||
for j in range(2)
|
||||
])
|
||||
return w_names, b_names
|
||||
|
||||
def _transform_lstm_canonical(self, weights, biases):
|
||||
"""Create fused lstm canonical params.
|
||||
|
||||
Produce properly-shaped monolithic weight and bias tensors to share between
|
||||
cuDNN and non-platform specific LSTM cells (w/o peephole).
|
||||
Args:
|
||||
weights: a list of Tensors recovered from cuDNN params_to_canonical.
|
||||
biases: a list of Tensors recovered from cuDNN params_to_canonical.
|
||||
Returns:
|
||||
Two lists of tensors, one for weight and bias each.
|
||||
The weight list contains num_layers tensors and bias one contains 3 *
|
||||
num_layers tensors. Both original and combined biases since cuDNN biases
|
||||
are not restorable from the fused version.
|
||||
"""
|
||||
transformed_weights, transformed_biases = [], []
|
||||
for i in range(self._cudnn_rnn.num_layers):
|
||||
base_idx = i * 8
|
||||
num_units = self._cudnn_rnn.num_units
|
||||
input_size = self._cudnn_rnn.input_size if i == 0 else num_units
|
||||
# cuDNN tensor shapes per time_step:
|
||||
# input.shape: [batch_size, input_size],
|
||||
# input_weights.shape: [num_units, input_size] (first layer)
|
||||
# [num_units, num_units] (other layers)
|
||||
# state_weights.shape: [num_units, num_units]
|
||||
# biases.shape: [num_units]
|
||||
#
|
||||
# General LSTM cells compute gate functions using:
|
||||
# [x, h_prev] * weights + biases
|
||||
# Therefore for each layer, they expect
|
||||
# weight.shape: [input_size + num_units, 4 * num_units] (first_layer)
|
||||
# [num_units + num_units, 4 * num_units] (other layers)
|
||||
# bias.shape: [4 * num_units]
|
||||
|
||||
# Stitch weights together in this layer.
|
||||
stitched_w = []
|
||||
for j in range(4):
|
||||
stitched_w.append(
|
||||
array_ops.concat(
|
||||
[
|
||||
array_ops.reshape(weights[base_idx + j],
|
||||
[num_units, input_size]),
|
||||
array_ops.reshape(weights[base_idx + j + 4],
|
||||
[num_units, num_units])
|
||||
],
|
||||
axis=1))
|
||||
# cuDNN weights are in ifco order, convert to icfo order.
|
||||
self._switch_inner(stitched_w, 0)
|
||||
transformed_weights.append(
|
||||
array_ops.transpose(array_ops.concat(stitched_w, axis=0)))
|
||||
|
||||
# Stitch biases together in this layer.
|
||||
# Convert to icfo order.
|
||||
self._switch_inner(biases, base_idx)
|
||||
self._switch_inner(biases, base_idx + 4)
|
||||
# The bias for layer input.
|
||||
b_in = array_ops.concat(biases[base_idx:base_idx + 4], axis=0)
|
||||
# The bias for recurrent input.
|
||||
b_rec = array_ops.concat(biases[base_idx + 4:base_idx + 8], axis=0)
|
||||
|
||||
transformed_biases.extend([b_in + b_rec, b_in, b_rec])
|
||||
return transformed_weights, transformed_biases
|
||||
|
||||
def _untransform_canonical(self, transformed_weights, transformed_biases):
|
||||
if (self._cudnn_rnn.direction != "unidirectional" or
|
||||
self._cudnn_rnn.rnn_mode != "lstm"):
|
||||
return transformed_weights, transformed_biases
|
||||
return self._untransform_lstm_canonical(transformed_weights,
|
||||
transformed_biases)
|
||||
|
||||
def _untransform_lstm_canonical(self, transformed_weights,
|
||||
transformed_biases):
|
||||
"""The reverse procedure of _transform_lstm_canonical().
|
||||
|
||||
Args:
|
||||
transformed_weights: a list of tensors, one for each layer.
|
||||
transformed_biases: a list of tensors , 3 for each layer: the 2nd for
|
||||
layer input, the 3rd for recurrent input, the 1st is the sum of the
|
||||
latter two.
|
||||
Returns:
|
||||
Two lists of tensors for weights and biases respectively.
|
||||
There are 8 tensors per weight and per bias for each layer:
|
||||
tensor 0-3 are applied to the input from the previous layer;
|
||||
tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate;
|
||||
tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate;
|
||||
tensor 3 and 7 the output gate.
|
||||
"""
|
||||
weights, biases = [], []
|
||||
assert 3 * len(transformed_weights) == len(transformed_biases)
|
||||
for i in range(len(transformed_weights)):
|
||||
num_units = self._cudnn_rnn.num_units
|
||||
input_size = self._cudnn_rnn.input_size if i == 0 else num_units
|
||||
# weights applied on layer inputs.
|
||||
wi = array_ops.slice(transformed_weights[i], [0, 0],
|
||||
[input_size, 4 * num_units])
|
||||
# weights applied on recurrent inputs.
|
||||
wr = array_ops.slice(transformed_weights[i], [input_size, 0],
|
||||
[num_units, 4 * num_units])
|
||||
wi_list = array_ops.split(wi, 4, axis=1)
|
||||
wr_list = array_ops.split(wr, 4, axis=1)
|
||||
|
||||
for j in range(len(wi_list)):
|
||||
wi_list[j] = array_ops.reshape(array_ops.transpose(wi_list[j]), [-1])
|
||||
wr_list[j] = array_ops.reshape(array_ops.transpose(wr_list[j]), [-1])
|
||||
# canonical weights are in icfo order, convert to ifco order for cuDNN.
|
||||
self._switch_inner(wi_list, 0)
|
||||
self._switch_inner(wr_list, 0)
|
||||
weights.extend(wi_list)
|
||||
weights.extend(wr_list)
|
||||
|
||||
base_idx = 3 * i
|
||||
bi_list = array_ops.split(transformed_biases[base_idx + 1], 4, axis=0)
|
||||
br_list = array_ops.split(transformed_biases[base_idx + 2], 4, axis=0)
|
||||
# canonical weights are in icfo order, convert to ifco order for cuDNN.
|
||||
self._switch_inner(bi_list, 0)
|
||||
self._switch_inner(br_list, 0)
|
||||
biases.extend(bi_list)
|
||||
biases.extend(br_list)
|
||||
return weights, biases
|
||||
|
||||
|
||||
_cudnn_rnn_common_doc_string = """
|
||||
Cudnn RNN has an opaque parameter buffer that can be used for inference and
|
||||
@ -199,6 +371,26 @@ class _CudnnRNN(object):
|
||||
if self._seed is None and self._seed2 is None:
|
||||
self._seed, self._seed2 = 0, 0
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return self._input_size
|
||||
|
||||
@property
|
||||
def num_units(self):
|
||||
return self._num_units
|
||||
|
||||
@property
|
||||
def num_layers(self):
|
||||
return self._num_layers
|
||||
|
||||
@property
|
||||
def rnn_mode(self):
|
||||
return self._rnn_mode
|
||||
|
||||
@property
|
||||
def direction(self):
|
||||
return self._direction
|
||||
|
||||
def params_size(self):
|
||||
"""Calculates the size of the opaque parameter buffer needed for this model.
|
||||
|
||||
@ -222,9 +414,12 @@ class _CudnnRNN(object):
|
||||
"""Runs the forward step for the RNN model.
|
||||
|
||||
Args:
|
||||
input_data: the input sequence to the RNN model.
|
||||
input_h: the initial hidden state for h.
|
||||
input_data: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||
A Tensor of the same shape as input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference.
|
||||
|
||||
@ -308,7 +503,7 @@ class CudnnLSTM(_CudnnRNN):
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
input_mode="auto_select",
|
||||
input_mode="linear_input",
|
||||
direction="unidirectional",
|
||||
dropout=0.,
|
||||
seed=0):
|
||||
@ -344,9 +539,12 @@ class CudnnLSTM(_CudnnRNN):
|
||||
"""Runs the forward step for the Cudnn LSTM model.
|
||||
|
||||
Args:
|
||||
input_data: the input sequence to the LSTM model.
|
||||
input_h: the initial hidden state for h.
|
||||
input_c: the initial hidden state for c.
|
||||
input_data: the input sequence to the LSTM model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
input_c: the initial hidden state for c. A Tensor of the same shape as
|
||||
input_h.
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference.
|
||||
|
||||
@ -368,7 +566,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
input_mode="auto_select",
|
||||
input_mode="linear_input",
|
||||
direction="unidirectional",
|
||||
dropout=0.,
|
||||
seed=0):
|
||||
@ -390,6 +588,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the seed used for initializing dropout.
|
||||
"""
|
||||
|
||||
super(_CudnnRNNNoInputC, self).__init__(
|
||||
self._rnn_mode,
|
||||
num_layers,
|
||||
@ -404,8 +603,10 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
||||
"""Runs the forward step for the Cudnn LSTM model.
|
||||
|
||||
Args:
|
||||
input_data: the input sequence to the LSTM model.
|
||||
input_h: the initial hidden state for h.
|
||||
input_data: the input sequence to the RNN model. A Tensor of shape [?,
|
||||
batch_size, input_size].
|
||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
||||
batch_size, num_units].
|
||||
params: the parameter buffer created for this model.
|
||||
is_training: whether this operation will be used in training or inference.
|
||||
|
||||
|
@ -58,7 +58,7 @@ def _lstm_block_cell(x,
|
||||
|
||||
```python
|
||||
xh = [x, h_prev]
|
||||
[i, f, ci, o] = xh * w + b
|
||||
[i, ci, f, o] = xh * w + b
|
||||
f = f + forget_bias
|
||||
|
||||
if not use_peephole:
|
||||
@ -93,7 +93,7 @@ def _lstm_block_cell(x,
|
||||
The weight matrix for output gate peephole connection.
|
||||
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
|
||||
cell_clip: An optional `float`. Defaults to `3`.
|
||||
Value to clip the 'cs' value to.
|
||||
Value to clip the 'cs' value to. Disable by setting to negative value.
|
||||
use_peephole: An optional `bool`. Defaults to `False`.
|
||||
Whether to use peephole weights.
|
||||
name: A name for the operation (optional).
|
||||
@ -341,17 +341,24 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
forget_bias=1.0,
|
||||
clip_cell=True,
|
||||
use_peephole=False):
|
||||
"""Initialize the basic LSTM cell.
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
clip_cell: boolean, whether to apply cell clipping. See
|
||||
`_lstm_block_cell()` for details.
|
||||
use_peephole: Whether to use peephole connections or not.
|
||||
|
||||
When restoring from CudnnLSTM-trained checkpoints, must set the following:
|
||||
forget_bias, clip_cell, use_peephole = 0, False, False
|
||||
"""
|
||||
self._num_units = num_units
|
||||
self._forget_bias = forget_bias
|
||||
self._use_peephole = use_peephole
|
||||
self._clip_cell = clip_cell
|
||||
self._names = {
|
||||
"W": "kernel",
|
||||
"b": "bias",
|
||||
@ -400,6 +407,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||
wco=wco,
|
||||
wcf=wcf,
|
||||
forget_bias=self._forget_bias,
|
||||
cell_clip=None if self._clip_cell else -1,
|
||||
use_peephole=self._use_peephole)
|
||||
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
||||
|
@ -345,6 +345,8 @@ class BasicLSTMCell(RNNCell):
|
||||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
Must set to `0.0` manually when restoring from CudnnLSTM-trained
|
||||
checkpoints.
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. The latter behavior will soon be deprecated.
|
||||
@ -444,7 +446,8 @@ class LSTMCell(RNNCell):
|
||||
Use a variable_scope partitioner instead.
|
||||
forget_bias: Biases of the forget gate are initialized by default to 1
|
||||
in order to reduce the scale of forgetting at the beginning of
|
||||
the training.
|
||||
the training. Must set it manually to `0.0` when restoring from
|
||||
CudnnLSTM trained checkpoints.
|
||||
state_is_tuple: If True, accepted and returned states are 2-tuples of
|
||||
the `c_state` and `m_state`. If False, they are concatenated
|
||||
along the column axis. This latter behavior will soon be deprecated.
|
||||
|
Loading…
Reference in New Issue
Block a user