Merge pull request #24812 from houtoms:google-cudnn-rnn-add-time-major

PiperOrigin-RevId: 237520914
This commit is contained in:
TensorFlower Gardener 2019-03-08 14:44:25 -08:00
commit d3b9ce5b4b
15 changed files with 539 additions and 268 deletions

View File

@ -69,6 +69,8 @@ def RunLSTM(sess,
time,
num_layers=1,
variable_seq_lengths=False,
time_major=True,
dynamic_shape_input=False,
is_training=True,
dropout=0.,
num_dirs=True,
@ -84,11 +86,14 @@ def RunLSTM(sess,
random_seed.set_random_seed(0)
np.random.seed(0)
inputs = variable_scope.get_variable(
"inputs",
initializer=np.random.rand(time, batch_size,
input_size).astype(dtype.as_numpy_dtype),
dtype=dtype)
shape = ([time, batch_size, input_size]
if time_major else [batch_size, time, input_size])
inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
inputs_static = variable_scope.get_variable(
"inputs", initializer=inputs_np, dtype=dtype)
inputs_dynamic = array_ops.placeholder(
dtype, shape=[None, None, None], name="inputs")
inputs = inputs_dynamic if dynamic_shape_input else inputs_static
initial_h_op = variable_scope.get_variable(
"initial_h_op",
initializer=np.random.rand(batch_size,
@ -122,12 +127,12 @@ def RunLSTM(sess,
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
outputs_op, state_tuple_op = rnn.dynamic_rnn(
cell,
inputs,
inputs_static,
sequence_length=lengths,
initial_state=rnn_cell_impl.LSTMStateTuple(
h=initial_h_op, c=initial_c_op),
dtype=dtype,
time_major=True,
time_major=time_major,
scope=None)
# Convert to cudnn opaque param.
@ -135,35 +140,38 @@ def RunLSTM(sess,
num_layers, num_units, input_size)
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0)
cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0)
cu_initial_h_op = array_ops.expand_dims(
initial_h_op, axis=(0 if time_major else 1))
cu_initial_c_op = array_ops.expand_dims(
initial_c_op, axis=(0 if time_major else 1))
cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn(
inputs,
cu_initial_h_op,
cu_initial_c_op,
opaque_params,
sequence_lengths=lengths,
time_major=time_major,
dropout=dropout,
is_training=is_training,
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
# Remove the trivial 1st dimension.
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
c=array_ops.squeeze(cu_c_op, axis=0),
h=array_ops.squeeze(cu_h_op, axis=0))
c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
if is_training:
(inp_grad_op, hgrad_op,
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
outputs_op, [inputs, initial_h_op, initial_c_op, w, b])
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
(cu_inp_grad_op, cu_hgrad_op,
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
cu_outputs_op,
[inputs, cu_initial_h_op, cu_initial_c_op, opaque_params])
# Remove the trivial 1st dimension
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0)
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1)
# Remove the trivial 1st dimension
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0)
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
opaque_grad_op)
@ -183,10 +191,12 @@ def RunLSTM(sess,
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
])
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
cu_bgrad) = sess.run([
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
])
cu_bgrad) = sess.run(
[
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
],
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
logging.vlog(1, "outputs: %s" % outputs)
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
@ -205,7 +215,10 @@ def RunLSTM(sess,
cu_bgrad)
else:
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op])
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
feed_dict=({
inputs: inputs_np
} if dynamic_shape_input else None))
logging.vlog(1, "outputs: %s" % outputs)
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
@ -336,6 +349,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
dtype,
variable_seq_lengths,
time_major,
dynamic_shape_input=False,
rtol=3e-6,
atol=3e-6):
with self.session(use_gpu=True) as sess:
@ -347,7 +362,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
batch_size,
time,
num_layers,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
for s, cu_s in zip(state_tuple, cu_state_tuple):
@ -359,13 +376,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_training(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths):
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
@ -375,16 +395,22 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
time,
num_layers,
dtypes.float32,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_training_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
@ -396,16 +422,21 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
dtypes.float16,
rtol=5e-3,
atol=5e-4,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths):
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
@ -417,7 +448,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
time,
num_layers,
is_training=False,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(outputs, cu_outputs)
# h
@ -426,13 +459,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(state_tuple.c, cu_state_tuple.c)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
@ -445,7 +482,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dtype=dtypes.float16,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
rtol, atol = 5e-3, 5e-4
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
@ -457,13 +496,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
"""Validates that dropout does not affect Cudnn Rnn inference."""
if not context.context().num_gpus():
self.skipTest("No GPUs found")
@ -480,7 +523,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dropout=0.,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
with ops.Graph().as_default() as g:
with self.session(use_gpu=True, graph=g) as sess:
@ -493,7 +538,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dropout=1.,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(cu_outputs, cu_outputs2)
# h
@ -510,6 +557,8 @@ def RunGRU(sess,
num_layers=1,
is_training=True,
variable_seq_lengths=False,
time_major=True,
dynamic_shape_input=False,
dropout=0.,
num_dirs=True,
dtype=dtypes.float32):
@ -524,11 +573,14 @@ def RunGRU(sess,
random_seed.set_random_seed(0)
np.random.seed(0)
inputs = variable_scope.get_variable(
"inputs",
initializer=np.random.rand(time, batch_size,
input_size).astype(dtype.as_numpy_dtype),
dtype=dtype)
shape = ([time, batch_size, input_size]
if time_major else [batch_size, time, input_size])
inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
inputs_static = variable_scope.get_variable(
"inputs", initializer=inputs_np, dtype=dtype)
inputs_dynamic = array_ops.placeholder(
dtype, shape=[None, None, None], name="inputs")
inputs = inputs_dynamic if dynamic_shape_input else inputs_static
initial_h_op = variable_scope.get_variable(
"initial_h_op",
initializer=np.random.rand(batch_size,
@ -573,11 +625,11 @@ def RunGRU(sess,
cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True)
outputs_op, h_op = rnn.dynamic_rnn(
cell,
inputs,
inputs_static,
sequence_length=lengths,
initial_state=initial_h_op,
dtype=dtype,
time_major=True,
time_major=time_major,
scope=None)
ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel]
@ -588,13 +640,15 @@ def RunGRU(sess,
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0)
cu_initial_h_op = array_ops.expand_dims(
initial_h_op, axis=(0 if time_major else 1))
cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn(
inputs,
cu_initial_h_op,
array_ops.zeros_like(cu_initial_h_op), # not used
opaque_params,
sequence_lengths=lengths,
time_major=time_major,
dropout=dropout,
is_training=is_training,
rnn_mode=cudnn_rnn_ops.CUDNN_GRU)
@ -602,12 +656,12 @@ def RunGRU(sess,
if is_training:
(inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op,
cib_grad_op, chb_grad_op) = gradients_impl.gradients(
outputs_op, [inputs, initial_h_op] + ws + bs)
outputs_op, [inputs_static, initial_h_op] + ws + bs)
(cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients(
cu_outputs_op, [inputs, cu_initial_h_op, opaque_params])
# Remove the trivial 1st dimension
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0)
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1)
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
opaque_grad_op)
@ -627,13 +681,15 @@ def RunGRU(sess,
(gk_grad_op, cik_grad_op, chk_grad_op),
(gb_grad_op, cib_grad_op, chb_grad_op)
])
(cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([
cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op,
(cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op),
(cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op)
])
(cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run(
[
cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op,
(cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op),
(cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op)
],
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
# Remove the trivial 1st dimension
cu_h = np.squeeze(cu_h, axis=0)
cu_h = np.squeeze(cu_h, axis=0 if time_major else 1)
logging.vlog(1, "outputs: %s" % outputs)
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
@ -651,9 +707,12 @@ def RunGRU(sess,
cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad)
else:
outputs, h = sess.run([outputs_op, h_op])
cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op])
cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op],
feed_dict=({
inputs: inputs_np
} if dynamic_shape_input else None))
# Remove the trivial 1st dimension.
cu_h = np.squeeze(cu_h, axis=0)
cu_h = np.squeeze(cu_h, axis=0 if time_major else 1)
logging.vlog(1, "outputs: %s" % outputs)
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
@ -672,6 +731,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
dtype,
variable_seq_lengths,
time_major,
dynamic_shape_input=False,
rtol=3e-6,
atol=3e-6):
with self.session(use_gpu=True) as sess:
@ -683,7 +744,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
batch_size,
time,
num_layers,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
@ -695,13 +758,16 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_training(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths):
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
@ -711,16 +777,22 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
time,
num_layers,
dtypes.float32,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_training_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
self._test_training_helper(
@ -732,16 +804,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
dtypes.float16,
rtol=5e-3,
atol=5e-4,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
variable_seq_lengths):
variable_seq_lengths, time_major, dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
@ -753,18 +830,24 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
time,
num_layers,
is_training=False,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(outputs, cu_outputs)
self.assertAllClose(h, cu_h)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference_fp16(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
if not context.context().num_gpus():
self.skipTest("No GPUs found")
with self.session(use_gpu=True) as sess:
@ -777,20 +860,26 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dtype=dtypes.float16,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
rtol, atol = 5e-3, 5e-4
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
@parameterized.named_parameters(
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
}))
ExpandNamedTestCases(
NAMED_RNN_TESTCASES, **{
"variable_seq_lengths": [True, False],
"time_major": [True, False],
"dynamic_shape_input": [True, False],
}))
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
num_layers, variable_seq_lengths):
num_layers, variable_seq_lengths, time_major,
dynamic_shape_input):
"""Validates that dropout does not affect Cudnn Rnn inference."""
# Hand-picked dropouts are used below (0. and 1.)
if not context.context().num_gpus():
@ -807,7 +896,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dropout=0.,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
with ops.Graph().as_default() as g:
with self.session(use_gpu=True, graph=g) as sess:
@ -820,7 +911,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
num_layers,
is_training=False,
dropout=1.,
variable_seq_lengths=variable_seq_lengths)
variable_seq_lengths=variable_seq_lengths,
time_major=time_major,
dynamic_shape_input=dynamic_shape_input)
self.assertAllClose(cu_outputs, cu_outputs2)
self.assertAllClose(cu_h[0], cu_h2[0])

View File

@ -378,20 +378,33 @@ class _CudnnRNN(base_layer.Layer):
inputs,
initial_state=None,
sequence_lengths=None,
time_major=True,
training=True):
"""Runs the forward step for the RNN model.
Args:
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
inputs: `3-D` tensor. If `time_major` is True (default), the Tensor shape
is [time_len, batch_size, input_size]. If `time_major` is False, the
shape is [batch_size, time_len, input_size].
initial_state: a tuple of tensor(s) of shape
`[num_layers * num_dirs, batch_size, num_units]`. If not provided, use
`[num_layers * num_dirs, batch_size, num_units]` if
`time_major` is True (default) or `[batch_size, num_layers * num_dirs,
num_units]` if `time_major` is False. If not provided, use
zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs.
sequence_lengths: an int32 array representing the variable sequence
lengths in a batch. The size of the array has to equal the
batch_size. If not provided, the same sequence length will be assumed.
time_major: The shape format of the `inputs` and `outputs` Tensors. If
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
If false, these Tensors must be shaped ['batch_size', 'max_time',
'depth']. By default this function accepts input and emits output in
time-major form. This param is only effective when 'sequence_lengths'
is used.
training: whether this operation will be used in training or inference.
Returns:
output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`.
output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`
if `time_major` is True (default) or `[batch_size, time_len,
num_dirs * num_units]` if `time_major` is False.
It is a `concat([fwd_output, bak_output], axis=2)`.
output_states: a tuple of tensor(s) of the same shape and structure as
`initial_state`.
@ -417,8 +430,8 @@ class _CudnnRNN(base_layer.Layer):
else:
# For model that doesn't take input_c, replace with a dummy tensor.
c = array_ops.constant([], dtype=dtype)
outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel,
sequence_lengths, training)
outputs, (output_h, output_c) = self._forward(
inputs, h, c, self.kernel, sequence_lengths, time_major, training)
if self._rnn_mode == CUDNN_LSTM:
return outputs, (output_h, output_c)
else:
@ -482,7 +495,8 @@ class _CudnnRNN(base_layer.Layer):
dropout=self._dropout,
direction=self._direction)
def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training):
def _forward(self, inputs, h, c, opaque_params, sequence_lengths, time_major,
training):
output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access
inputs,
h,
@ -491,6 +505,7 @@ class _CudnnRNN(base_layer.Layer):
training,
self._rnn_mode,
sequence_lengths=sequence_lengths,
time_major=time_major,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,

View File

@ -956,6 +956,7 @@ def _cudnn_rnn(inputs,
is_training,
rnn_mode,
sequence_lengths=None,
time_major=True,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
@ -964,10 +965,12 @@ def _cudnn_rnn(inputs,
"""Cudnn RNN.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
@ -977,6 +980,11 @@ def _cudnn_rnn(inputs,
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
@ -1017,6 +1025,14 @@ def _cudnn_rnn(inputs,
}
if sequence_lengths is not None:
args["sequence_lengths"] = sequence_lengths
args["time_major"] = time_major
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
elif time_major is False:
batch_size = array_ops.shape(inputs)[0]
max_time = array_ops.shape(inputs)[1]
sequence_lengths = array_ops.fill([batch_size], max_time)
args["sequence_lengths"] = sequence_lengths
args["time_major"] = time_major
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
elif use_cudnn_v2 != "1":
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
@ -1031,6 +1047,7 @@ def cudnn_lstm(inputs,
params,
is_training,
sequence_lengths=None,
time_major=True,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
@ -1039,15 +1056,26 @@ def cudnn_lstm(inputs,
"""Cudnn LSTM.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
sequence_lengths: an int32 array representing the variable sequence lengths
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
'linear_input' (default) always applies a linear projection of input
@ -1060,17 +1088,13 @@ def cudnn_lstm(inputs,
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
for behavior.
sequence_lengths: an int32 array representing the variable sequence lengths
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
name: name of the operation.
Returns:
outputs, output_h, output_c
"""
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
sequence_lengths, input_mode, direction, dropout, seed,
name)
sequence_lengths, time_major, input_mode, direction,
dropout, seed, name)
def _cudnn_rnn_no_input_c(inputs,
@ -1079,6 +1103,7 @@ def _cudnn_rnn_no_input_c(inputs,
is_training,
rnn_mode,
sequence_lengths=None,
time_major=True,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
@ -1087,10 +1112,12 @@ def _cudnn_rnn_no_input_c(inputs,
"""Cudnn RNN w/o input_c.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
@ -1098,6 +1125,11 @@ def _cudnn_rnn_no_input_c(inputs,
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
input_mode: indicate whether there is a linear projection between the
input and the actual computation before the first layer. It could be
'linear_input', 'skip_input' or 'auto_select'.
@ -1116,9 +1148,9 @@ def _cudnn_rnn_no_input_c(inputs,
outputs, output_h
"""
input_c = array_ops.constant([], dtype=input_h.dtype)
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
is_training, rnn_mode, sequence_lengths,
input_mode, direction, dropout, seed, name)
outputs, output_h, _ = _cudnn_rnn(
inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths,
time_major, input_mode, direction, dropout, seed, name)
return outputs, output_h
@ -1127,6 +1159,7 @@ def cudnn_gru(inputs,
params,
is_training,
sequence_lengths=None,
time_major=True,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
@ -1135,10 +1168,12 @@ def cudnn_gru(inputs,
"""Cudnn GRU.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
@ -1153,6 +1188,11 @@ def cudnn_gru(inputs,
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
@ -1163,8 +1203,8 @@ def cudnn_gru(inputs,
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
sequence_lengths, input_mode, direction, dropout,
seed, name)
sequence_lengths, time_major, input_mode,
direction, dropout, seed, name)
def cudnn_rnn_relu(inputs,
@ -1176,14 +1216,17 @@ def cudnn_rnn_relu(inputs,
dropout=0.,
seed=0,
sequence_lengths=None,
time_major=True,
name=None):
"""Cudnn RNN Relu.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
@ -1201,14 +1244,19 @@ def cudnn_rnn_relu(inputs,
sequence_lengths: an int32 array representing the variable sequence lengths
in a batch. The size of the array has to equal the batch_size. If not
provided, the same sequence length will be assumed.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
name: name of the operation.
Returns:
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_RELU, sequence_lengths, input_mode,
direction, dropout, seed, name)
CUDNN_RNN_RELU, sequence_lengths, time_major,
input_mode, direction, dropout, seed, name)
def cudnn_rnn_tanh(inputs,
@ -1216,6 +1264,7 @@ def cudnn_rnn_tanh(inputs,
params,
is_training,
sequence_lengths=None,
time_major=True,
input_mode=CUDNN_INPUT_LINEAR_MODE,
direction=CUDNN_RNN_UNIDIRECTION,
dropout=0.,
@ -1224,10 +1273,12 @@ def cudnn_rnn_tanh(inputs,
"""Cudnn RNN Tanh.
Args:
inputs: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
inputs: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference
input_mode: indicate whether there is a linear projection between the
@ -1242,6 +1293,11 @@ def cudnn_rnn_tanh(inputs,
in a batch. The size of the array has to equal the batch_size. Default to
None, in which case sequences in the batch are assumed to have the same
length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
By default this function accepts input and emits output in time-major
form. This param is only effective when 'sequence_lengths' is used.
direction: the direction model that the model operates. Could be either
'unidirectional' or 'bidirectional'
dropout: whether to enable dropout. With it is 0, dropout is disabled.
@ -1252,8 +1308,8 @@ def cudnn_rnn_tanh(inputs,
outputs, output_h
"""
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
CUDNN_RNN_TANH, sequence_lengths, input_mode,
direction, dropout, seed, name)
CUDNN_RNN_TANH, sequence_lengths, time_major,
input_mode, direction, dropout, seed, name)
def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
@ -1537,22 +1593,32 @@ class _CudnnRNN(object):
input_c,
params,
is_training=True,
sequence_lengths=None):
sequence_lengths=None,
time_major=True):
"""Runs the forward step for the RNN model.
Args:
input_data: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
A Tensor of the same shape as input_h.
input_data: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM. A
Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
sequence_lengths: an int32 array representing the variable sequence
lengths in a batch. The size of the array has to equal the batch_size.
Default to None, in which case sequences in the batch are assumed to
have the same length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
If false, these Tensors must be shaped ['batch_size', 'max_time',
'depth']. By default this function accepts input and emits output in
time-major form. This param is only effective when 'sequence_lengths' is
used.
Returns:
output: the output sequence.
output_h: the final state for h.
@ -1566,6 +1632,7 @@ class _CudnnRNN(object):
is_training,
self._rnn_mode,
sequence_lengths=sequence_lengths,
time_major=time_major,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,
@ -1666,14 +1733,17 @@ class CudnnLSTM(_CudnnRNN):
input_c,
params,
sequence_lengths=None,
time_major=True,
is_training=True):
"""Runs the forward step for the Cudnn LSTM model.
Args:
input_data: the input sequence to the LSTM model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_data: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
input_c: the initial hidden state for c. A Tensor of the same shape as
input_h.
params: the parameter buffer created for this model.
@ -1681,6 +1751,12 @@ class CudnnLSTM(_CudnnRNN):
lengths in a batch. The size of the array has to equal the batch_size.
Default to None, in which case sequences in the batch are assumed to
have the same length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
If false, these Tensors must be shaped ['batch_size', 'max_time',
'depth']. By default this function accepts input and emits output in
time-major form. This param is only effective when 'sequence_lengths'
is used.
is_training: whether this operation will be used in training or inference.
Returns:
output: the output sequence.
@ -1693,6 +1769,7 @@ class CudnnLSTM(_CudnnRNN):
input_c,
params,
sequence_lengths=sequence_lengths,
time_major=time_major,
is_training=is_training)
return (output, output_h, output_c)
@ -1752,19 +1829,28 @@ class _CudnnRNNNoInputC(_CudnnRNN):
input_h,
params,
sequence_lengths=None,
time_major=True,
is_training=True):
"""Runs the forward step for the Cudnn LSTM model.
Args:
input_data: the input sequence to the RNN model. A Tensor of shape [?,
batch_size, input_size].
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
batch_size, num_units].
input_data: the input sequence to the RNN model. If `time_major` is True
(default), the Tensor shape is [max_time, batch_size, input_size]. If
`time_major` is False, the shape is [batch_size, max_time, input_size].
input_h: the initial hidden state for h. If `time_major` is True
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
`time_major` is False, the shape is [batch_size, num_layers, num_units].
params: the parameter buffer created for this model.
sequence_lengths: an int32 array representing the variable sequence
lengths in a batch. The size of the array has to equal the batch_size.
Default to None, in which case sequences in the batch are assumed to
have the same length, which is inferred from inputs.
time_major: The shape format of the `inputs` and `outputs` Tensors. If
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
If false, these Tensors must be shaped ['batch_size', 'max_time',
'depth']. By default this function accepts input and emits output in
time-major form. This param is only effective when 'sequence_lengths'
is used.
is_training: whether this operation will be used in training or inference.
Returns:
output: the output sequence.
@ -1777,6 +1863,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
is_training,
self._rnn_mode,
sequence_lengths=sequence_lengths,
time_major=time_major,
input_mode=self._input_mode,
direction=self._direction,
dropout=self._dropout,

View File

@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be
dropout: Dropout probability. When set to 0., dropout is disabled.
seed: The 1st part of a seed to initialize dropout.
seed2: The 2nd part of a seed to initialize dropout.
input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input: If time_major is true, this is a 3-D tensor with the shape of
[seq_length, batch_size, input_size]. If time_major is false, the shape is
[batch_size, seq_length, input_size].
input_h: If time_major is true, this is a 3-D tensor with the shape of
[num_layer * dir, batch_size, num_units]. If time_major is false, the shape
is [batch_size, num_layer * dir, num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
params: A 1-D tensor that contains the weights and biases in an opaque layout.
@ -26,8 +29,9 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout.
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
sequence_lengths: a vector of lengths of each input sequence.
output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
output: If time_major is true, this is a 3-D tensor with the shape of
[seq_length, batch_size, dir * num_units]. If time_major is false, the
shape is [batch_size, seq_length, dir * num_units].
output_h: The same shape has input_h.
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
output_backprop: A 3-D tensor with the same shape as output in the forward pass.
@ -35,6 +39,8 @@ output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
pass.
output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
pass.
time_major: Indicates whether the input/output format is time major or batch
major.
reserve_space: The same reserve_space produced in the forward operation.
input_backprop: The backprop to input in the forward pass. Has the same shape
as input.

View File

@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be
dropout: Dropout probability. When set to 0., dropout is disabled.
seed: The 1st part of a seed to initialize dropout.
seed2: The 2nd part of a seed to initialize dropout.
input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input: If time_major is true, this is a 3-D tensor with the shape of
[seq_length, batch_size, input_size]. If time_major is false, the shape is
[batch_size, seq_length, input_size].
input_h: If time_major is true, this is a 3-D tensor with the shape of
[num_layer * dir, batch_size, num_units]. If time_major is false, the shape
is [batch_size, num_layer * dir, num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
params: A 1-D tensor that contains the weights and biases in an opaque layout.
@ -26,12 +29,15 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout.
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
sequence_lengths: a vector of lengths of each input sequence.
output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
output: If time_major is true, this is a 3-D tensor with the shape of
[seq_length, batch_size, dir * num_units]. If time_major is false, the
shape is [batch_size, seq_length, dir * num_units].
output_h: The same shape has input_h.
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
is_training: Indicates whether this operation is used for inferenece or
training.
time_major: Indicates whether the input/output format is time major or batch
major.
reserve_space: An opaque tensor that can be used in backprop calculation. It
is only produced if is_training is true.
END

View File

@ -559,7 +559,7 @@ struct RnnScratchSpace {
// Extract and checks the forward input tensors, parameters, and shapes from the
// OpKernelContext.
Status ExtractForwardInput(OpKernelContext* context,
const CudnnModelTypes& model_types,
const CudnnModelTypes& model_types, bool time_major,
const Tensor** input, const Tensor** input_h,
const Tensor** input_c, const Tensor** params,
CudnnRnnModelShapes* model_shapes) {
@ -573,8 +573,13 @@ Status ExtractForwardInput(OpKernelContext* context,
if ((*input)->dims() != 3) {
return errors::InvalidArgument("RNN input must be a 3-D vector.");
}
model_shapes->max_seq_length = (*input)->dim_size(0);
model_shapes->batch_size = (*input)->dim_size(1);
if (time_major) {
model_shapes->max_seq_length = (*input)->dim_size(0);
model_shapes->batch_size = (*input)->dim_size(1);
} else {
model_shapes->max_seq_length = (*input)->dim_size(1);
model_shapes->batch_size = (*input)->dim_size(0);
}
model_shapes->input_size = (*input)->dim_size(2);
model_shapes->input_shape = (*input)->shape();
model_shapes->dir_count =
@ -585,12 +590,25 @@ Status ExtractForwardInput(OpKernelContext* context,
if ((*input_h)->dims() != 3) {
return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
}
model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
if (time_major) {
model_shapes->num_layers =
(*input_h)->dim_size(0) / model_shapes->dir_count;
} else {
model_shapes->num_layers =
(*input_h)->dim_size(1) / model_shapes->dir_count;
}
model_shapes->num_units = (*input_h)->dim_size(2);
model_shapes->hidden_state_shape =
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
model_shapes->batch_size, model_shapes->num_units});
if (time_major) {
model_shapes->hidden_state_shape =
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
model_shapes->batch_size, model_shapes->num_units});
} else {
model_shapes->hidden_state_shape =
TensorShape({model_shapes->batch_size,
model_shapes->dir_count * model_shapes->num_layers,
model_shapes->num_units});
}
if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
return errors::InvalidArgument(
"Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
@ -604,23 +622,28 @@ Status ExtractForwardInput(OpKernelContext* context,
(*input_c)->shape().DebugString());
}
}
model_shapes->output_shape =
TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
model_shapes->dir_count * model_shapes->num_units});
if (time_major) {
model_shapes->output_shape =
TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
model_shapes->dir_count * model_shapes->num_units});
} else {
model_shapes->output_shape =
TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
model_shapes->dir_count * model_shapes->num_units});
}
return Status::OK();
}
// Extract and checks the sequence_lengths, forward input tensors,
// parameters, and shapes from the OpKernelContext.
// Overloaded function to process the sequence_lengths
Status ExtractForwardInput(OpKernelContext* context,
const CudnnModelTypes& model_types,
const CudnnModelTypes& model_types, bool time_major,
const Tensor** input, const Tensor** input_h,
const Tensor** input_c, const Tensor** params,
CudnnRnnModelShapes* model_shapes,
const Tensor** sequence_lengths) {
const Tensor** sequence_lengths,
CudnnRnnModelShapes* model_shapes) {
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
return ExtractForwardInput(context, model_types, input, input_h, input_c,
params, model_shapes);
return ExtractForwardInput(context, model_types, time_major, input, input_h,
input_c, params, model_shapes);
}
template <typename T>
@ -629,7 +652,7 @@ Status CreateForwardAndBackwardIODescriptors(
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
const absl::Span<const int>& seq_lengths) {
const absl::Span<const int>& seq_lengths, bool time_major) {
StreamExecutor* executor = context->op_device_context()->stream()->parent();
se::dnn::DataType data_type = ToDataType<T>::value;
@ -639,11 +662,19 @@ Status CreateForwardAndBackwardIODescriptors(
DCHECK_EQ(input_shape.dims(), 3);
if (seq_lengths.data() != nullptr) {
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(0), input_shape.dim_size(1),
input_shape.dim_size(2), seq_lengths, data_type);
TF_RETURN_IF_ERROR(input_desc_s.status());
*input_desc = input_desc_s.ConsumeValueOrDie();
if (time_major) {
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(0), input_shape.dim_size(1),
input_shape.dim_size(2), seq_lengths, time_major, data_type);
TF_RETURN_IF_ERROR(input_desc_s.status());
*input_desc = input_desc_s.ConsumeValueOrDie();
} else {
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(1), input_shape.dim_size(0),
input_shape.dim_size(2), seq_lengths, time_major, data_type);
TF_RETURN_IF_ERROR(input_desc_s.status());
*input_desc = input_desc_s.ConsumeValueOrDie();
}
} else {
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
input_shape.dim_size(0), input_shape.dim_size(1),
@ -653,19 +684,35 @@ Status CreateForwardAndBackwardIODescriptors(
}
DCHECK_EQ(hidden_state_shape.dims(), 3);
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
hidden_state_shape.dim_size(2), data_type);
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
if (time_major) {
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
hidden_state_shape.dim_size(2), data_type);
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
} else {
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
hidden_state_shape.dim_size(2), data_type);
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
}
DCHECK_EQ(output_shape.dims(), 3);
if (seq_lengths.data() != nullptr) {
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
output_shape.dim_size(0), output_shape.dim_size(1),
output_shape.dim_size(2), seq_lengths, data_type);
TF_RETURN_IF_ERROR(output_desc_s.status());
*output_desc = output_desc_s.ConsumeValueOrDie();
if (time_major) {
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
output_shape.dim_size(0), output_shape.dim_size(1),
output_shape.dim_size(2), seq_lengths, time_major, data_type);
TF_RETURN_IF_ERROR(output_desc_s.status());
*output_desc = output_desc_s.ConsumeValueOrDie();
} else {
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
output_shape.dim_size(1), output_shape.dim_size(0),
output_shape.dim_size(2), seq_lengths, time_major, data_type);
TF_RETURN_IF_ERROR(output_desc_s.status());
*output_desc = output_desc_s.ConsumeValueOrDie();
}
} else {
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
output_shape.dim_size(0), output_shape.dim_size(1),
@ -687,7 +734,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
const bool is_training,
/* forward outputs, outputs of the function */
Tensor* output, Tensor* output_h, Tensor* output_c,
const Tensor* sequence_lengths,
const Tensor* sequence_lengths, bool time_major,
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
ProfileResult* output_profile_result) {
@ -702,7 +749,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
}
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
context, model_shapes, &input_desc, &state_desc, &output_desc,
seq_lengths));
seq_lengths, time_major));
auto input_data = AsDeviceMemory<T>(input);
auto input_h_data = AsDeviceMemory<T>(input_h);
@ -750,7 +797,7 @@ Status DoBackward(
const Tensor* output_c_backprop, const Tensor* reserve_space,
/* backprop outputs, output of the function */
Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
Tensor* params_backprop, const Tensor* sequence_lengths,
Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
ScratchAllocator* workspace_allocator,
ProfileResult* output_profile_result) {
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
@ -764,7 +811,7 @@ Status DoBackward(
}
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
context, model_shapes, &input_desc, &state_desc, &output_desc,
seq_lengths));
seq_lengths, time_major));
auto input_data = AsDeviceMemory<T>(input);
auto input_h_data = AsDeviceMemory<T>(input_h);
@ -1216,13 +1263,15 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
void Compute(OpKernelContext* context) override {
AlgorithmConfig algo_config;
ComputeAndReturnAlgorithm(context, &algo_config, false);
ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
/*time_major=*/true);
}
protected:
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
AlgorithmConfig* output_algo_config,
bool var_seq_lengths) {
bool var_seq_lengths,
bool time_major) {
CHECK_NE(output_algo_config, nullptr);
const Tensor* input = nullptr;
@ -1232,14 +1281,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
const Tensor* sequence_lengths = nullptr;
CudnnRnnModelShapes model_shapes;
if (var_seq_lengths) {
OP_REQUIRES_OK(
context, ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes,
&sequence_lengths));
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), time_major,
&input, &input_h, &input_c, &params,
&sequence_lengths, &model_shapes));
} else {
OP_REQUIRES_OK(
context, ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
OP_REQUIRES_OK(context, ExtractForwardInput(
context, model_types(), time_major, &input,
&input_h, &input_c, &params, &model_shapes));
}
RnnInputMode input_mode;
OP_REQUIRES_OK(context,
@ -1278,19 +1327,11 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
*output_algo_config,
&rnn_state_cache_, &rnn_desc_ptr));
if (var_seq_lengths) {
launch_status = DoForward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, is_training_, output, output_h, output_c,
sequence_lengths, &reserve_space_allocator, &workspace_allocator,
/*output_profile_result=*/nullptr);
} else {
launch_status = DoForward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, is_training_, output, output_h, output_c, nullptr,
&reserve_space_allocator, &workspace_allocator,
/*output_profile_result=*/nullptr);
}
launch_status = DoForward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, is_training_, output, output_h, output_c,
sequence_lengths, time_major, &reserve_space_allocator,
&workspace_allocator, /*output_profile_result=*/nullptr);
}
OP_REQUIRES_OK(context, launch_status);
}
@ -1372,7 +1413,8 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
void Compute(OpKernelContext* context) override {
AlgorithmConfig best_algo_config;
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
context, &best_algo_config, false);
context, &best_algo_config, /*var_seq_lengths=*/false,
/*time_major=*/true);
if (!context->status().ok()) {
return;
}
@ -1490,10 +1532,11 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
// Again use temp scratch allocator during profiling.
CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
status = DoForward<T>(
context, *rnn_desc, model_types(), model_shapes, input, input_h,
input_c, params, is_training(), output, output_h, output_c, nullptr,
&reserve_space_allocator, &workspace_allocator, &fwd_profile_result);
status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
input, input_h, input_c, params, is_training(),
output, output_h, output_c, nullptr, true,
&reserve_space_allocator, &workspace_allocator,
&fwd_profile_result);
if (!status.ok()) {
continue;
}
@ -1506,7 +1549,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
input_c, params, output, output_h, output_c, &output_backprop,
&output_h_backprop, &output_c_backprop, &reserve_space,
&input_backprop, &input_h_backprop, &input_c_backprop,
&params_backprop, nullptr, &workspace_allocator,
&params_backprop, nullptr, true, &workspace_allocator,
&bak_profile_result);
if (!status.ok()) {
continue;
@ -1561,15 +1604,22 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
using CudnnRNNKernelCommon::dropout;
using CudnnRNNKernelCommon::HasInputC;
using CudnnRNNKernelCommon::model_types;
bool time_major_;
protected:
bool time_major() { return time_major_; }
public:
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
: CudnnRNNForwardOp<GPUDevice, T>(context) {}
: CudnnRNNForwardOp<GPUDevice, T>(context) {
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
}
void Compute(OpKernelContext* context) override {
AlgorithmConfig best_algo_config;
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
context, &best_algo_config, true);
context, &best_algo_config, /*var_seq_lengths=*/true,
/*time_major=*/time_major());
if (!context->status().ok()) {
return;
}
@ -1604,11 +1654,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
: CudnnRNNKernelCommon(context) {}
void Compute(OpKernelContext* context) override {
ComputeImpl(context, false);
ComputeImpl(context, false, true);
}
protected:
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths) {
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
bool time_major) {
const Tensor* input = nullptr;
const Tensor* input_h = nullptr;
const Tensor* input_c = nullptr;
@ -1616,14 +1667,14 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
const Tensor* sequence_lengths = nullptr;
CudnnRnnModelShapes model_shapes;
if (var_seq_lengths) {
OP_REQUIRES_OK(
context, ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes,
&sequence_lengths));
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), time_major,
&input, &input_h, &input_c, &params,
&sequence_lengths, &model_shapes));
} else {
OP_REQUIRES_OK(
context, ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
OP_REQUIRES_OK(context, ExtractForwardInput(
context, model_types(), time_major, &input,
&input_h, &input_c, &params, &model_shapes));
}
RnnInputMode input_mode;
OP_REQUIRES_OK(context,
@ -1665,22 +1716,13 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
algo_config, &rnn_state_cache_,
&rnn_desc_ptr));
if (var_seq_lengths) {
launch_status = DoBackward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, output, output_h, output_c, output_backprop,
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
input_h_backprop, input_c_backprop, params_backprop,
sequence_lengths, &workspace_allocator,
/*output_profile_result=*/nullptr);
} else {
launch_status = DoBackward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, output, output_h, output_c, output_backprop,
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
input_h_backprop, input_c_backprop, params_backprop, nullptr,
&workspace_allocator, /*output_profile_result=*/nullptr);
}
launch_status = DoBackward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, output, output_h, output_c, output_backprop,
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
time_major, &workspace_allocator,
/*output_profile_result=*/nullptr);
}
OP_REQUIRES_OK(context, launch_status);
}
@ -1827,12 +1869,20 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNBackwardOpV3<GPUDevice, T>
: public CudnnRNNBackwardOp<GPUDevice, T> {
private:
bool time_major_;
protected:
bool time_major() { return time_major_; }
public:
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
: CudnnRNNBackwardOp<GPUDevice, T>(context) {}
: CudnnRNNBackwardOp<GPUDevice, T>(context) {
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
}
void Compute(OpKernelContext* context) override {
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true);
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
}
};

View File

@ -167,6 +167,7 @@ REGISTER_OP("CudnnRNNV3")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("is_training: bool = true")
.Attr("time_major: bool = true")
.SetShapeFn([](InferenceContext* c) {
auto input_shape = c->input(0);
auto input_h_shape = c->input(1);
@ -292,6 +293,7 @@ REGISTER_OP("CudnnRNNBackpropV3")
.Attr("dropout: float = 0.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("time_major: bool = true")
.SetShapeFn([](InferenceContext* c) {
auto input_shape = c->input(0);
auto input_h_shape = c->input(1);

View File

@ -97,6 +97,7 @@ def _cudnn_rnn_backwardv3(op, *grads):
dropout=op.get_attr("dropout"),
seed=op.get_attr("seed"),
seed2=op.get_attr("seed2"),
time_major=op.get_attr("time_major"),
rnn_mode=op.get_attr("rnn_mode"),
input_mode=op.get_attr("input_mode"),
direction=op.get_attr("direction")) + (None,)

View File

@ -1300,7 +1300,8 @@ class CudnnRnnSequenceTensorDescriptor
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
const absl::Span<const int>& seq_lengths, cudnnDataType_t data_type) {
const absl::Span<const int>& seq_lengths, bool time_major,
cudnnDataType_t data_type) {
#if CUDNN_VERSION >= 7201
CHECK_GT(max_seq_length, 0);
int dims[] = {batch_size, data_size, 1};
@ -1313,9 +1314,15 @@ class CudnnRnnSequenceTensorDescriptor
const int* seq_lengths_array = seq_lengths.data();
RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
float padding_fill = 0.0f;
cudnnRNNDataLayout_t layout;
if (time_major) {
layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
} else {
layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
}
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
/*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
/*layout=*/CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
/*layout=*/layout,
/*maxSeqLength=*/max_seq_length,
/*batchSize=*/batch_size, /*vectorSize=*/data_size,
/*seqLengthArray=*/seq_lengths_array,
@ -1849,11 +1856,12 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
CudnnSupport::createRnnSequenceTensorDescriptor(
int max_seq_length, int batch_size, int data_size,
const absl::Span<const int>& seq_lengths, dnn::DataType data_type) {
const absl::Span<const int>& seq_lengths, bool time_major,
dnn::DataType data_type) {
SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
CudnnRnnSequenceTensorDescriptor::Create(
parent_, max_seq_length, batch_size, data_size,
seq_lengths, ToCudnnDataType(data_type)));
seq_lengths, time_major, ToCudnnDataType(data_type)));
return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
}

View File

@ -63,6 +63,7 @@ class CudnnSupport : public dnn::DnnSupport {
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
int data_size,
const absl::Span<const int>& seq_lengths,
bool time_major,
dnn::DataType data_type) override;
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>

View File

@ -2070,7 +2070,7 @@ class DnnSupport {
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
int data_size,
const absl::Span<const int>& seq_lengths,
dnn::DataType data_type) {
bool time_major, dnn::DataType data_type) {
return port::Status(port::error::UNIMPLEMENTED,
"createRnnSequenceTensorDescriptor is unimplemented");
}

View File

@ -411,14 +411,16 @@ StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
StreamExecutor::createRnnSequenceTensorDescriptor(
int max_seq_length, int batch_size, int data_size,
const absl::Span<const int> &seq_lengths, dnn::DataType data_type) {
const absl::Span<const int> &seq_lengths, bool time_major,
dnn::DataType data_type) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return port::Status(port::error::UNKNOWN,
"Fail to find the dnn implementation.");
}
return dnn_support->createRnnSequenceTensorDescriptor(
max_seq_length, batch_size, data_size, seq_lengths, data_type);
max_seq_length, batch_size, data_size, seq_lengths, time_major,
data_type);
}
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>

View File

@ -421,7 +421,7 @@ class StreamExecutor {
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
int data_size,
const absl::Span<const int> &seq_lengths,
dnn::DataType data_type);
bool time_major, dnn::DataType data_type);
// Create an RNN state descriptor that specifies the input or hidden state.
// The caller retains the ownership of the returned descriptor.

View File

@ -738,7 +738,7 @@ tf_module {
}
member_method {
name: "CudnnRNNBackpropV3"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "CudnnRNNCanonicalToParams"
@ -758,7 +758,7 @@ tf_module {
}
member_method {
name: "CudnnRNNV3"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "Cumprod"

View File

@ -738,7 +738,7 @@ tf_module {
}
member_method {
name: "CudnnRNNBackpropV3"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "CudnnRNNCanonicalToParams"
@ -758,7 +758,7 @@ tf_module {
}
member_method {
name: "CudnnRNNV3"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "Cumprod"