Merge pull request #24812 from houtoms:google-cudnn-rnn-add-time-major
PiperOrigin-RevId: 237520914
This commit is contained in:
commit
d3b9ce5b4b
@ -69,6 +69,8 @@ def RunLSTM(sess,
|
|||||||
time,
|
time,
|
||||||
num_layers=1,
|
num_layers=1,
|
||||||
variable_seq_lengths=False,
|
variable_seq_lengths=False,
|
||||||
|
time_major=True,
|
||||||
|
dynamic_shape_input=False,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
num_dirs=True,
|
num_dirs=True,
|
||||||
@ -84,11 +86,14 @@ def RunLSTM(sess,
|
|||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
inputs = variable_scope.get_variable(
|
shape = ([time, batch_size, input_size]
|
||||||
"inputs",
|
if time_major else [batch_size, time, input_size])
|
||||||
initializer=np.random.rand(time, batch_size,
|
inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
|
||||||
input_size).astype(dtype.as_numpy_dtype),
|
inputs_static = variable_scope.get_variable(
|
||||||
dtype=dtype)
|
"inputs", initializer=inputs_np, dtype=dtype)
|
||||||
|
inputs_dynamic = array_ops.placeholder(
|
||||||
|
dtype, shape=[None, None, None], name="inputs")
|
||||||
|
inputs = inputs_dynamic if dynamic_shape_input else inputs_static
|
||||||
initial_h_op = variable_scope.get_variable(
|
initial_h_op = variable_scope.get_variable(
|
||||||
"initial_h_op",
|
"initial_h_op",
|
||||||
initializer=np.random.rand(batch_size,
|
initializer=np.random.rand(batch_size,
|
||||||
@ -122,12 +127,12 @@ def RunLSTM(sess,
|
|||||||
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
|
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
|
||||||
outputs_op, state_tuple_op = rnn.dynamic_rnn(
|
outputs_op, state_tuple_op = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs_static,
|
||||||
sequence_length=lengths,
|
sequence_length=lengths,
|
||||||
initial_state=rnn_cell_impl.LSTMStateTuple(
|
initial_state=rnn_cell_impl.LSTMStateTuple(
|
||||||
h=initial_h_op, c=initial_c_op),
|
h=initial_h_op, c=initial_c_op),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
time_major=True,
|
time_major=time_major,
|
||||||
scope=None)
|
scope=None)
|
||||||
|
|
||||||
# Convert to cudnn opaque param.
|
# Convert to cudnn opaque param.
|
||||||
@ -135,35 +140,38 @@ def RunLSTM(sess,
|
|||||||
num_layers, num_units, input_size)
|
num_layers, num_units, input_size)
|
||||||
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
||||||
|
|
||||||
cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0)
|
cu_initial_h_op = array_ops.expand_dims(
|
||||||
cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0)
|
initial_h_op, axis=(0 if time_major else 1))
|
||||||
|
cu_initial_c_op = array_ops.expand_dims(
|
||||||
|
initial_c_op, axis=(0 if time_major else 1))
|
||||||
cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn(
|
cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn(
|
||||||
inputs,
|
inputs,
|
||||||
cu_initial_h_op,
|
cu_initial_h_op,
|
||||||
cu_initial_c_op,
|
cu_initial_c_op,
|
||||||
opaque_params,
|
opaque_params,
|
||||||
sequence_lengths=lengths,
|
sequence_lengths=lengths,
|
||||||
|
time_major=time_major,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
|
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
|
||||||
# Remove the trivial 1st dimension.
|
# Remove the trivial 1st dimension.
|
||||||
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
|
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
|
||||||
c=array_ops.squeeze(cu_c_op, axis=0),
|
c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
|
||||||
h=array_ops.squeeze(cu_h_op, axis=0))
|
h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
(inp_grad_op, hgrad_op,
|
(inp_grad_op, hgrad_op,
|
||||||
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
||||||
outputs_op, [inputs, initial_h_op, initial_c_op, w, b])
|
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
|
||||||
|
|
||||||
(cu_inp_grad_op, cu_hgrad_op,
|
(cu_inp_grad_op, cu_hgrad_op,
|
||||||
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
||||||
cu_outputs_op,
|
cu_outputs_op,
|
||||||
[inputs, cu_initial_h_op, cu_initial_c_op, opaque_params])
|
[inputs, cu_initial_h_op, cu_initial_c_op, opaque_params])
|
||||||
# Remove the trivial 1st dimension
|
# Remove the trivial 1st dimension
|
||||||
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0)
|
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1)
|
||||||
# Remove the trivial 1st dimension
|
# Remove the trivial 1st dimension
|
||||||
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0)
|
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
|
||||||
|
|
||||||
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
||||||
opaque_grad_op)
|
opaque_grad_op)
|
||||||
@ -183,10 +191,12 @@ def RunLSTM(sess,
|
|||||||
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
|
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
|
||||||
])
|
])
|
||||||
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||||
cu_bgrad) = sess.run([
|
cu_bgrad) = sess.run(
|
||||||
|
[
|
||||||
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||||
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
||||||
])
|
],
|
||||||
|
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||||
|
|
||||||
logging.vlog(1, "outputs: %s" % outputs)
|
logging.vlog(1, "outputs: %s" % outputs)
|
||||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||||
@ -205,7 +215,10 @@ def RunLSTM(sess,
|
|||||||
cu_bgrad)
|
cu_bgrad)
|
||||||
else:
|
else:
|
||||||
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
|
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
|
||||||
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op])
|
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
|
||||||
|
feed_dict=({
|
||||||
|
inputs: inputs_np
|
||||||
|
} if dynamic_shape_input else None))
|
||||||
|
|
||||||
logging.vlog(1, "outputs: %s" % outputs)
|
logging.vlog(1, "outputs: %s" % outputs)
|
||||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||||
@ -336,6 +349,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
dtype,
|
dtype,
|
||||||
variable_seq_lengths,
|
variable_seq_lengths,
|
||||||
|
time_major,
|
||||||
|
dynamic_shape_input=False,
|
||||||
rtol=3e-6,
|
rtol=3e-6,
|
||||||
atol=3e-6):
|
atol=3e-6):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -347,7 +362,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
batch_size,
|
batch_size,
|
||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
for s, cu_s in zip(state_tuple, cu_state_tuple):
|
for s, cu_s in zip(state_tuple, cu_state_tuple):
|
||||||
@ -359,13 +376,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
|
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths):
|
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
@ -375,16 +395,22 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
dtypes.float32,
|
dtypes.float32,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
@ -396,16 +422,21 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dtypes.float16,
|
dtypes.float16,
|
||||||
rtol=5e-3,
|
rtol=5e-3,
|
||||||
atol=5e-4,
|
atol=5e-4,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths):
|
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -417,7 +448,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
self.assertAllClose(outputs, cu_outputs)
|
self.assertAllClose(outputs, cu_outputs)
|
||||||
# h
|
# h
|
||||||
@ -426,13 +459,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(state_tuple.c, cu_state_tuple.c)
|
self.assertAllClose(state_tuple.c, cu_state_tuple.c)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -445,7 +482,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dtype=dtypes.float16,
|
dtype=dtypes.float16,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
rtol, atol = 5e-3, 5e-4
|
rtol, atol = 5e-3, 5e-4
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
@ -457,13 +496,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
|
state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
@ -480,7 +523,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
with self.session(use_gpu=True, graph=g) as sess:
|
with self.session(use_gpu=True, graph=g) as sess:
|
||||||
@ -493,7 +538,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dropout=1.,
|
dropout=1.,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
self.assertAllClose(cu_outputs, cu_outputs2)
|
self.assertAllClose(cu_outputs, cu_outputs2)
|
||||||
# h
|
# h
|
||||||
@ -510,6 +557,8 @@ def RunGRU(sess,
|
|||||||
num_layers=1,
|
num_layers=1,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
variable_seq_lengths=False,
|
variable_seq_lengths=False,
|
||||||
|
time_major=True,
|
||||||
|
dynamic_shape_input=False,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
num_dirs=True,
|
num_dirs=True,
|
||||||
dtype=dtypes.float32):
|
dtype=dtypes.float32):
|
||||||
@ -524,11 +573,14 @@ def RunGRU(sess,
|
|||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
inputs = variable_scope.get_variable(
|
shape = ([time, batch_size, input_size]
|
||||||
"inputs",
|
if time_major else [batch_size, time, input_size])
|
||||||
initializer=np.random.rand(time, batch_size,
|
inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
|
||||||
input_size).astype(dtype.as_numpy_dtype),
|
inputs_static = variable_scope.get_variable(
|
||||||
dtype=dtype)
|
"inputs", initializer=inputs_np, dtype=dtype)
|
||||||
|
inputs_dynamic = array_ops.placeholder(
|
||||||
|
dtype, shape=[None, None, None], name="inputs")
|
||||||
|
inputs = inputs_dynamic if dynamic_shape_input else inputs_static
|
||||||
initial_h_op = variable_scope.get_variable(
|
initial_h_op = variable_scope.get_variable(
|
||||||
"initial_h_op",
|
"initial_h_op",
|
||||||
initializer=np.random.rand(batch_size,
|
initializer=np.random.rand(batch_size,
|
||||||
@ -573,11 +625,11 @@ def RunGRU(sess,
|
|||||||
cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True)
|
cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True)
|
||||||
outputs_op, h_op = rnn.dynamic_rnn(
|
outputs_op, h_op = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs,
|
inputs_static,
|
||||||
sequence_length=lengths,
|
sequence_length=lengths,
|
||||||
initial_state=initial_h_op,
|
initial_state=initial_h_op,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
time_major=True,
|
time_major=time_major,
|
||||||
scope=None)
|
scope=None)
|
||||||
|
|
||||||
ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel]
|
ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel]
|
||||||
@ -588,13 +640,15 @@ def RunGRU(sess,
|
|||||||
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
||||||
|
|
||||||
|
|
||||||
cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0)
|
cu_initial_h_op = array_ops.expand_dims(
|
||||||
|
initial_h_op, axis=(0 if time_major else 1))
|
||||||
cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn(
|
cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn(
|
||||||
inputs,
|
inputs,
|
||||||
cu_initial_h_op,
|
cu_initial_h_op,
|
||||||
array_ops.zeros_like(cu_initial_h_op), # not used
|
array_ops.zeros_like(cu_initial_h_op), # not used
|
||||||
opaque_params,
|
opaque_params,
|
||||||
sequence_lengths=lengths,
|
sequence_lengths=lengths,
|
||||||
|
time_major=time_major,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
rnn_mode=cudnn_rnn_ops.CUDNN_GRU)
|
rnn_mode=cudnn_rnn_ops.CUDNN_GRU)
|
||||||
@ -602,12 +656,12 @@ def RunGRU(sess,
|
|||||||
if is_training:
|
if is_training:
|
||||||
(inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op,
|
(inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op,
|
||||||
cib_grad_op, chb_grad_op) = gradients_impl.gradients(
|
cib_grad_op, chb_grad_op) = gradients_impl.gradients(
|
||||||
outputs_op, [inputs, initial_h_op] + ws + bs)
|
outputs_op, [inputs_static, initial_h_op] + ws + bs)
|
||||||
|
|
||||||
(cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
(cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
||||||
cu_outputs_op, [inputs, cu_initial_h_op, opaque_params])
|
cu_outputs_op, [inputs, cu_initial_h_op, opaque_params])
|
||||||
# Remove the trivial 1st dimension
|
# Remove the trivial 1st dimension
|
||||||
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0)
|
cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1)
|
||||||
|
|
||||||
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
||||||
opaque_grad_op)
|
opaque_grad_op)
|
||||||
@ -627,13 +681,15 @@ def RunGRU(sess,
|
|||||||
(gk_grad_op, cik_grad_op, chk_grad_op),
|
(gk_grad_op, cik_grad_op, chk_grad_op),
|
||||||
(gb_grad_op, cib_grad_op, chb_grad_op)
|
(gb_grad_op, cib_grad_op, chb_grad_op)
|
||||||
])
|
])
|
||||||
(cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([
|
(cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run(
|
||||||
|
[
|
||||||
cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op,
|
cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op,
|
||||||
(cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op),
|
(cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op),
|
||||||
(cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op)
|
(cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op)
|
||||||
])
|
],
|
||||||
|
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||||
# Remove the trivial 1st dimension
|
# Remove the trivial 1st dimension
|
||||||
cu_h = np.squeeze(cu_h, axis=0)
|
cu_h = np.squeeze(cu_h, axis=0 if time_major else 1)
|
||||||
|
|
||||||
logging.vlog(1, "outputs: %s" % outputs)
|
logging.vlog(1, "outputs: %s" % outputs)
|
||||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||||
@ -651,9 +707,12 @@ def RunGRU(sess,
|
|||||||
cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad)
|
cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad)
|
||||||
else:
|
else:
|
||||||
outputs, h = sess.run([outputs_op, h_op])
|
outputs, h = sess.run([outputs_op, h_op])
|
||||||
cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op])
|
cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op],
|
||||||
|
feed_dict=({
|
||||||
|
inputs: inputs_np
|
||||||
|
} if dynamic_shape_input else None))
|
||||||
# Remove the trivial 1st dimension.
|
# Remove the trivial 1st dimension.
|
||||||
cu_h = np.squeeze(cu_h, axis=0)
|
cu_h = np.squeeze(cu_h, axis=0 if time_major else 1)
|
||||||
|
|
||||||
logging.vlog(1, "outputs: %s" % outputs)
|
logging.vlog(1, "outputs: %s" % outputs)
|
||||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||||
@ -672,6 +731,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
dtype,
|
dtype,
|
||||||
variable_seq_lengths,
|
variable_seq_lengths,
|
||||||
|
time_major,
|
||||||
|
dynamic_shape_input=False,
|
||||||
rtol=3e-6,
|
rtol=3e-6,
|
||||||
atol=3e-6):
|
atol=3e-6):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -683,7 +744,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
batch_size,
|
batch_size,
|
||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
|
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
|
||||||
@ -695,13 +758,16 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol)
|
self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths):
|
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
@ -711,16 +777,22 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
dtypes.float32,
|
dtypes.float32,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
@ -732,16 +804,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dtypes.float16,
|
dtypes.float16,
|
||||||
rtol=5e-3,
|
rtol=5e-3,
|
||||||
atol=5e-4,
|
atol=5e-4,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths):
|
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -753,18 +830,24 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
time,
|
time,
|
||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
self.assertAllClose(outputs, cu_outputs)
|
self.assertAllClose(outputs, cu_outputs)
|
||||||
self.assertAllClose(h, cu_h)
|
self.assertAllClose(h, cu_h)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
self.skipTest("No GPUs found")
|
self.skipTest("No GPUs found")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
@ -777,20 +860,26 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dtype=dtypes.float16,
|
dtype=dtypes.float16,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
rtol, atol = 5e-3, 5e-4
|
rtol, atol = 5e-3, 5e-4
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
|
self.assertAllClose(h, cu_h, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{
|
ExpandNamedTestCases(
|
||||||
|
NAMED_RNN_TESTCASES, **{
|
||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
|
"time_major": [True, False],
|
||||||
|
"dynamic_shape_input": [True, False],
|
||||||
}))
|
}))
|
||||||
@unittest.skipUnless(test.is_built_with_cuda(),
|
@unittest.skipUnless(test.is_built_with_cuda(),
|
||||||
"Test only applicable when running on GPUs")
|
"Test only applicable when running on GPUs")
|
||||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths):
|
num_layers, variable_seq_lengths, time_major,
|
||||||
|
dynamic_shape_input):
|
||||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||||
# Hand-picked dropouts are used below (0. and 1.)
|
# Hand-picked dropouts are used below (0. and 1.)
|
||||||
if not context.context().num_gpus():
|
if not context.context().num_gpus():
|
||||||
@ -807,7 +896,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
with self.session(use_gpu=True, graph=g) as sess:
|
with self.session(use_gpu=True, graph=g) as sess:
|
||||||
@ -820,7 +911,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase):
|
|||||||
num_layers,
|
num_layers,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
dropout=1.,
|
dropout=1.,
|
||||||
variable_seq_lengths=variable_seq_lengths)
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
|
time_major=time_major,
|
||||||
|
dynamic_shape_input=dynamic_shape_input)
|
||||||
|
|
||||||
self.assertAllClose(cu_outputs, cu_outputs2)
|
self.assertAllClose(cu_outputs, cu_outputs2)
|
||||||
self.assertAllClose(cu_h[0], cu_h2[0])
|
self.assertAllClose(cu_h[0], cu_h2[0])
|
||||||
|
@ -378,20 +378,33 @@ class _CudnnRNN(base_layer.Layer):
|
|||||||
inputs,
|
inputs,
|
||||||
initial_state=None,
|
initial_state=None,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
training=True):
|
training=True):
|
||||||
"""Runs the forward step for the RNN model.
|
"""Runs the forward step for the RNN model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
|
inputs: `3-D` tensor. If `time_major` is True (default), the Tensor shape
|
||||||
|
is [time_len, batch_size, input_size]. If `time_major` is False, the
|
||||||
|
shape is [batch_size, time_len, input_size].
|
||||||
initial_state: a tuple of tensor(s) of shape
|
initial_state: a tuple of tensor(s) of shape
|
||||||
`[num_layers * num_dirs, batch_size, num_units]`. If not provided, use
|
`[num_layers * num_dirs, batch_size, num_units]` if
|
||||||
|
`time_major` is True (default) or `[batch_size, num_layers * num_dirs,
|
||||||
|
num_units]` if `time_major` is False. If not provided, use
|
||||||
zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs.
|
zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs.
|
||||||
sequence_lengths: an int32 array representing the variable sequence
|
sequence_lengths: an int32 array representing the variable sequence
|
||||||
lengths in a batch. The size of the array has to equal the
|
lengths in a batch. The size of the array has to equal the
|
||||||
batch_size. If not provided, the same sequence length will be assumed.
|
batch_size. If not provided, the same sequence length will be assumed.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If
|
||||||
|
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
|
||||||
|
If false, these Tensors must be shaped ['batch_size', 'max_time',
|
||||||
|
'depth']. By default this function accepts input and emits output in
|
||||||
|
time-major form. This param is only effective when 'sequence_lengths'
|
||||||
|
is used.
|
||||||
training: whether this operation will be used in training or inference.
|
training: whether this operation will be used in training or inference.
|
||||||
Returns:
|
Returns:
|
||||||
output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`.
|
output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`
|
||||||
|
if `time_major` is True (default) or `[batch_size, time_len,
|
||||||
|
num_dirs * num_units]` if `time_major` is False.
|
||||||
It is a `concat([fwd_output, bak_output], axis=2)`.
|
It is a `concat([fwd_output, bak_output], axis=2)`.
|
||||||
output_states: a tuple of tensor(s) of the same shape and structure as
|
output_states: a tuple of tensor(s) of the same shape and structure as
|
||||||
`initial_state`.
|
`initial_state`.
|
||||||
@ -417,8 +430,8 @@ class _CudnnRNN(base_layer.Layer):
|
|||||||
else:
|
else:
|
||||||
# For model that doesn't take input_c, replace with a dummy tensor.
|
# For model that doesn't take input_c, replace with a dummy tensor.
|
||||||
c = array_ops.constant([], dtype=dtype)
|
c = array_ops.constant([], dtype=dtype)
|
||||||
outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel,
|
outputs, (output_h, output_c) = self._forward(
|
||||||
sequence_lengths, training)
|
inputs, h, c, self.kernel, sequence_lengths, time_major, training)
|
||||||
if self._rnn_mode == CUDNN_LSTM:
|
if self._rnn_mode == CUDNN_LSTM:
|
||||||
return outputs, (output_h, output_c)
|
return outputs, (output_h, output_c)
|
||||||
else:
|
else:
|
||||||
@ -482,7 +495,8 @@ class _CudnnRNN(base_layer.Layer):
|
|||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
direction=self._direction)
|
direction=self._direction)
|
||||||
|
|
||||||
def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training):
|
def _forward(self, inputs, h, c, opaque_params, sequence_lengths, time_major,
|
||||||
|
training):
|
||||||
output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access
|
output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access
|
||||||
inputs,
|
inputs,
|
||||||
h,
|
h,
|
||||||
@ -491,6 +505,7 @@ class _CudnnRNN(base_layer.Layer):
|
|||||||
training,
|
training,
|
||||||
self._rnn_mode,
|
self._rnn_mode,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
time_major=time_major,
|
||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
|
@ -956,6 +956,7 @@ def _cudnn_rnn(inputs,
|
|||||||
is_training,
|
is_training,
|
||||||
rnn_mode,
|
rnn_mode,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -964,10 +965,12 @@ def _cudnn_rnn(inputs,
|
|||||||
"""Cudnn RNN.
|
"""Cudnn RNN.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||||
A Tensor of the same shape as input_h.
|
A Tensor of the same shape as input_h.
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
@ -977,6 +980,11 @@ def _cudnn_rnn(inputs,
|
|||||||
in a batch. The size of the array has to equal the batch_size. Default to
|
in a batch. The size of the array has to equal the batch_size. Default to
|
||||||
None, in which case sequences in the batch are assumed to have the same
|
None, in which case sequences in the batch are assumed to have the same
|
||||||
length, which is inferred from inputs.
|
length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
input and the actual computation before the first layer. It could be
|
input and the actual computation before the first layer. It could be
|
||||||
'linear_input', 'skip_input' or 'auto_select'.
|
'linear_input', 'skip_input' or 'auto_select'.
|
||||||
@ -1017,6 +1025,14 @@ def _cudnn_rnn(inputs,
|
|||||||
}
|
}
|
||||||
if sequence_lengths is not None:
|
if sequence_lengths is not None:
|
||||||
args["sequence_lengths"] = sequence_lengths
|
args["sequence_lengths"] = sequence_lengths
|
||||||
|
args["time_major"] = time_major
|
||||||
|
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||||
|
elif time_major is False:
|
||||||
|
batch_size = array_ops.shape(inputs)[0]
|
||||||
|
max_time = array_ops.shape(inputs)[1]
|
||||||
|
sequence_lengths = array_ops.fill([batch_size], max_time)
|
||||||
|
args["sequence_lengths"] = sequence_lengths
|
||||||
|
args["time_major"] = time_major
|
||||||
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||||
elif use_cudnn_v2 != "1":
|
elif use_cudnn_v2 != "1":
|
||||||
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
|
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
|
||||||
@ -1031,6 +1047,7 @@ def cudnn_lstm(inputs,
|
|||||||
params,
|
params,
|
||||||
is_training,
|
is_training,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -1039,14 +1056,25 @@ def cudnn_lstm(inputs,
|
|||||||
"""Cudnn LSTM.
|
"""Cudnn LSTM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
||||||
A Tensor of the same shape as input_h.
|
A Tensor of the same shape as input_h.
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference
|
is_training: whether this operation will be used in training or inference
|
||||||
|
sequence_lengths: an int32 array representing the variable sequence lengths
|
||||||
|
in a batch. The size of the array has to equal the batch_size. Default to
|
||||||
|
None, in which case sequences in the batch are assumed to have the same
|
||||||
|
length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
input and the actual computation before the first layer. It could be
|
input and the actual computation before the first layer. It could be
|
||||||
'linear_input', 'skip_input' or 'auto_select'.
|
'linear_input', 'skip_input' or 'auto_select'.
|
||||||
@ -1060,17 +1088,13 @@ def cudnn_lstm(inputs,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
sequence_lengths: an int32 array representing the variable sequence lengths
|
|
||||||
in a batch. The size of the array has to equal the batch_size. Default to
|
|
||||||
None, in which case sequences in the batch are assumed to have the same
|
|
||||||
length, which is inferred from inputs.
|
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
outputs, output_h, output_c
|
outputs, output_h, output_c
|
||||||
"""
|
"""
|
||||||
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
||||||
sequence_lengths, input_mode, direction, dropout, seed,
|
sequence_lengths, time_major, input_mode, direction,
|
||||||
name)
|
dropout, seed, name)
|
||||||
|
|
||||||
|
|
||||||
def _cudnn_rnn_no_input_c(inputs,
|
def _cudnn_rnn_no_input_c(inputs,
|
||||||
@ -1079,6 +1103,7 @@ def _cudnn_rnn_no_input_c(inputs,
|
|||||||
is_training,
|
is_training,
|
||||||
rnn_mode,
|
rnn_mode,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -1087,10 +1112,12 @@ def _cudnn_rnn_no_input_c(inputs,
|
|||||||
"""Cudnn RNN w/o input_c.
|
"""Cudnn RNN w/o input_c.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference
|
is_training: whether this operation will be used in training or inference
|
||||||
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh').
|
||||||
@ -1098,6 +1125,11 @@ def _cudnn_rnn_no_input_c(inputs,
|
|||||||
in a batch. The size of the array has to equal the batch_size. Default to
|
in a batch. The size of the array has to equal the batch_size. Default to
|
||||||
None, in which case sequences in the batch are assumed to have the same
|
None, in which case sequences in the batch are assumed to have the same
|
||||||
length, which is inferred from inputs.
|
length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
input and the actual computation before the first layer. It could be
|
input and the actual computation before the first layer. It could be
|
||||||
'linear_input', 'skip_input' or 'auto_select'.
|
'linear_input', 'skip_input' or 'auto_select'.
|
||||||
@ -1116,9 +1148,9 @@ def _cudnn_rnn_no_input_c(inputs,
|
|||||||
outputs, output_h
|
outputs, output_h
|
||||||
"""
|
"""
|
||||||
input_c = array_ops.constant([], dtype=input_h.dtype)
|
input_c = array_ops.constant([], dtype=input_h.dtype)
|
||||||
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
|
outputs, output_h, _ = _cudnn_rnn(
|
||||||
is_training, rnn_mode, sequence_lengths,
|
inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths,
|
||||||
input_mode, direction, dropout, seed, name)
|
time_major, input_mode, direction, dropout, seed, name)
|
||||||
return outputs, output_h
|
return outputs, output_h
|
||||||
|
|
||||||
|
|
||||||
@ -1127,6 +1159,7 @@ def cudnn_gru(inputs,
|
|||||||
params,
|
params,
|
||||||
is_training,
|
is_training,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -1135,10 +1168,12 @@ def cudnn_gru(inputs,
|
|||||||
"""Cudnn GRU.
|
"""Cudnn GRU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference
|
is_training: whether this operation will be used in training or inference
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
@ -1153,6 +1188,11 @@ def cudnn_gru(inputs,
|
|||||||
in a batch. The size of the array has to equal the batch_size. Default to
|
in a batch. The size of the array has to equal the batch_size. Default to
|
||||||
None, in which case sequences in the batch are assumed to have the same
|
None, in which case sequences in the batch are assumed to have the same
|
||||||
length, which is inferred from inputs.
|
length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
direction: the direction model that the model operates. Could be either
|
direction: the direction model that the model operates. Could be either
|
||||||
'unidirectional' or 'bidirectional'
|
'unidirectional' or 'bidirectional'
|
||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
@ -1163,8 +1203,8 @@ def cudnn_gru(inputs,
|
|||||||
outputs, output_h
|
outputs, output_h
|
||||||
"""
|
"""
|
||||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
|
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU,
|
||||||
sequence_lengths, input_mode, direction, dropout,
|
sequence_lengths, time_major, input_mode,
|
||||||
seed, name)
|
direction, dropout, seed, name)
|
||||||
|
|
||||||
|
|
||||||
def cudnn_rnn_relu(inputs,
|
def cudnn_rnn_relu(inputs,
|
||||||
@ -1176,14 +1216,17 @@ def cudnn_rnn_relu(inputs,
|
|||||||
dropout=0.,
|
dropout=0.,
|
||||||
seed=0,
|
seed=0,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
name=None):
|
name=None):
|
||||||
"""Cudnn RNN Relu.
|
"""Cudnn RNN Relu.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference
|
is_training: whether this operation will be used in training or inference
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
@ -1201,14 +1244,19 @@ def cudnn_rnn_relu(inputs,
|
|||||||
sequence_lengths: an int32 array representing the variable sequence lengths
|
sequence_lengths: an int32 array representing the variable sequence lengths
|
||||||
in a batch. The size of the array has to equal the batch_size. If not
|
in a batch. The size of the array has to equal the batch_size. If not
|
||||||
provided, the same sequence length will be assumed.
|
provided, the same sequence length will be assumed.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
outputs, output_h
|
outputs, output_h
|
||||||
"""
|
"""
|
||||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||||
CUDNN_RNN_RELU, sequence_lengths, input_mode,
|
CUDNN_RNN_RELU, sequence_lengths, time_major,
|
||||||
direction, dropout, seed, name)
|
input_mode, direction, dropout, seed, name)
|
||||||
|
|
||||||
|
|
||||||
def cudnn_rnn_tanh(inputs,
|
def cudnn_rnn_tanh(inputs,
|
||||||
@ -1216,6 +1264,7 @@ def cudnn_rnn_tanh(inputs,
|
|||||||
params,
|
params,
|
||||||
is_training,
|
is_training,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
@ -1224,10 +1273,12 @@ def cudnn_rnn_tanh(inputs,
|
|||||||
"""Cudnn RNN Tanh.
|
"""Cudnn RNN Tanh.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: the input sequence to the RNN model. A Tensor of shape [?,
|
inputs: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference
|
is_training: whether this operation will be used in training or inference
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
@ -1242,6 +1293,11 @@ def cudnn_rnn_tanh(inputs,
|
|||||||
in a batch. The size of the array has to equal the batch_size. Default to
|
in a batch. The size of the array has to equal the batch_size. Default to
|
||||||
None, in which case sequences in the batch are assumed to have the same
|
None, in which case sequences in the batch are assumed to have the same
|
||||||
length, which is inferred from inputs.
|
length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||||
|
these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If
|
||||||
|
false, these Tensors must be shaped ['batch_size', 'max_time', 'depth'].
|
||||||
|
By default this function accepts input and emits output in time-major
|
||||||
|
form. This param is only effective when 'sequence_lengths' is used.
|
||||||
direction: the direction model that the model operates. Could be either
|
direction: the direction model that the model operates. Could be either
|
||||||
'unidirectional' or 'bidirectional'
|
'unidirectional' or 'bidirectional'
|
||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
@ -1252,8 +1308,8 @@ def cudnn_rnn_tanh(inputs,
|
|||||||
outputs, output_h
|
outputs, output_h
|
||||||
"""
|
"""
|
||||||
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training,
|
||||||
CUDNN_RNN_TANH, sequence_lengths, input_mode,
|
CUDNN_RNN_TANH, sequence_lengths, time_major,
|
||||||
direction, dropout, seed, name)
|
input_mode, direction, dropout, seed, name)
|
||||||
|
|
||||||
|
|
||||||
def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
||||||
@ -1537,22 +1593,32 @@ class _CudnnRNN(object):
|
|||||||
input_c,
|
input_c,
|
||||||
params,
|
params,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
sequence_lengths=None):
|
sequence_lengths=None,
|
||||||
|
time_major=True):
|
||||||
"""Runs the forward step for the RNN model.
|
"""Runs the forward step for the RNN model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: the input sequence to the RNN model. A Tensor of shape [?,
|
input_data: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
input_c: the initial hidden state for c. This is only relevant for LSTM.
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
A Tensor of the same shape as input_h.
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
|
input_c: the initial hidden state for c. This is only relevant for LSTM. A
|
||||||
|
Tensor of the same shape as input_h.
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
is_training: whether this operation will be used in training or inference.
|
is_training: whether this operation will be used in training or inference.
|
||||||
sequence_lengths: an int32 array representing the variable sequence
|
sequence_lengths: an int32 array representing the variable sequence
|
||||||
lengths in a batch. The size of the array has to equal the batch_size.
|
lengths in a batch. The size of the array has to equal the batch_size.
|
||||||
Default to None, in which case sequences in the batch are assumed to
|
Default to None, in which case sequences in the batch are assumed to
|
||||||
have the same length, which is inferred from inputs.
|
have the same length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If
|
||||||
|
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
|
||||||
|
If false, these Tensors must be shaped ['batch_size', 'max_time',
|
||||||
|
'depth']. By default this function accepts input and emits output in
|
||||||
|
time-major form. This param is only effective when 'sequence_lengths' is
|
||||||
|
used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
output: the output sequence.
|
output: the output sequence.
|
||||||
output_h: the final state for h.
|
output_h: the final state for h.
|
||||||
@ -1566,6 +1632,7 @@ class _CudnnRNN(object):
|
|||||||
is_training,
|
is_training,
|
||||||
self._rnn_mode,
|
self._rnn_mode,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
time_major=time_major,
|
||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
@ -1666,14 +1733,17 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
input_c,
|
input_c,
|
||||||
params,
|
params,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
is_training=True):
|
is_training=True):
|
||||||
"""Runs the forward step for the Cudnn LSTM model.
|
"""Runs the forward step for the Cudnn LSTM model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: the input sequence to the LSTM model. A Tensor of shape [?,
|
input_data: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
input_c: the initial hidden state for c. A Tensor of the same shape as
|
input_c: the initial hidden state for c. A Tensor of the same shape as
|
||||||
input_h.
|
input_h.
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
@ -1681,6 +1751,12 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
lengths in a batch. The size of the array has to equal the batch_size.
|
lengths in a batch. The size of the array has to equal the batch_size.
|
||||||
Default to None, in which case sequences in the batch are assumed to
|
Default to None, in which case sequences in the batch are assumed to
|
||||||
have the same length, which is inferred from inputs.
|
have the same length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If
|
||||||
|
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
|
||||||
|
If false, these Tensors must be shaped ['batch_size', 'max_time',
|
||||||
|
'depth']. By default this function accepts input and emits output in
|
||||||
|
time-major form. This param is only effective when 'sequence_lengths'
|
||||||
|
is used.
|
||||||
is_training: whether this operation will be used in training or inference.
|
is_training: whether this operation will be used in training or inference.
|
||||||
Returns:
|
Returns:
|
||||||
output: the output sequence.
|
output: the output sequence.
|
||||||
@ -1693,6 +1769,7 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
input_c,
|
input_c,
|
||||||
params,
|
params,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
time_major=time_major,
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
return (output, output_h, output_c)
|
return (output, output_h, output_c)
|
||||||
|
|
||||||
@ -1752,19 +1829,28 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
|||||||
input_h,
|
input_h,
|
||||||
params,
|
params,
|
||||||
sequence_lengths=None,
|
sequence_lengths=None,
|
||||||
|
time_major=True,
|
||||||
is_training=True):
|
is_training=True):
|
||||||
"""Runs the forward step for the Cudnn LSTM model.
|
"""Runs the forward step for the Cudnn LSTM model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: the input sequence to the RNN model. A Tensor of shape [?,
|
input_data: the input sequence to the RNN model. If `time_major` is True
|
||||||
batch_size, input_size].
|
(default), the Tensor shape is [max_time, batch_size, input_size]. If
|
||||||
input_h: the initial hidden state for h. A Tensor of shape [num_layers,
|
`time_major` is False, the shape is [batch_size, max_time, input_size].
|
||||||
batch_size, num_units].
|
input_h: the initial hidden state for h. If `time_major` is True
|
||||||
|
(default), the Tensor shape is [num_layers, batch_size, num_units]. If
|
||||||
|
`time_major` is False, the shape is [batch_size, num_layers, num_units].
|
||||||
params: the parameter buffer created for this model.
|
params: the parameter buffer created for this model.
|
||||||
sequence_lengths: an int32 array representing the variable sequence
|
sequence_lengths: an int32 array representing the variable sequence
|
||||||
lengths in a batch. The size of the array has to equal the batch_size.
|
lengths in a batch. The size of the array has to equal the batch_size.
|
||||||
Default to None, in which case sequences in the batch are assumed to
|
Default to None, in which case sequences in the batch are assumed to
|
||||||
have the same length, which is inferred from inputs.
|
have the same length, which is inferred from inputs.
|
||||||
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If
|
||||||
|
true, these Tensors must be shaped ['max_time', 'batch_size', 'depth'].
|
||||||
|
If false, these Tensors must be shaped ['batch_size', 'max_time',
|
||||||
|
'depth']. By default this function accepts input and emits output in
|
||||||
|
time-major form. This param is only effective when 'sequence_lengths'
|
||||||
|
is used.
|
||||||
is_training: whether this operation will be used in training or inference.
|
is_training: whether this operation will be used in training or inference.
|
||||||
Returns:
|
Returns:
|
||||||
output: the output sequence.
|
output: the output sequence.
|
||||||
@ -1777,6 +1863,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
|
|||||||
is_training,
|
is_training,
|
||||||
self._rnn_mode,
|
self._rnn_mode,
|
||||||
sequence_lengths=sequence_lengths,
|
sequence_lengths=sequence_lengths,
|
||||||
|
time_major=time_major,
|
||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
|
@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be
|
|||||||
dropout: Dropout probability. When set to 0., dropout is disabled.
|
dropout: Dropout probability. When set to 0., dropout is disabled.
|
||||||
seed: The 1st part of a seed to initialize dropout.
|
seed: The 1st part of a seed to initialize dropout.
|
||||||
seed2: The 2nd part of a seed to initialize dropout.
|
seed2: The 2nd part of a seed to initialize dropout.
|
||||||
input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
|
input: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
|
[seq_length, batch_size, input_size]. If time_major is false, the shape is
|
||||||
num_units].
|
[batch_size, seq_length, input_size].
|
||||||
|
input_h: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
|
[num_layer * dir, batch_size, num_units]. If time_major is false, the shape
|
||||||
|
is [batch_size, num_layer * dir, num_units].
|
||||||
input_c: For LSTM, a 3-D tensor with the shape of
|
input_c: For LSTM, a 3-D tensor with the shape of
|
||||||
[num_layer * dir, batch, num_units]. For other models, it is ignored.
|
[num_layer * dir, batch, num_units]. For other models, it is ignored.
|
||||||
params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
||||||
@ -26,8 +29,9 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
|||||||
separately. Note that they might not be compatible across different
|
separately. Note that they might not be compatible across different
|
||||||
generations. So it is a good idea to save and restore
|
generations. So it is a good idea to save and restore
|
||||||
sequence_lengths: a vector of lengths of each input sequence.
|
sequence_lengths: a vector of lengths of each input sequence.
|
||||||
output: A 3-D tensor with the shape of [seq_length, batch_size,
|
output: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
dir * num_units].
|
[seq_length, batch_size, dir * num_units]. If time_major is false, the
|
||||||
|
shape is [batch_size, seq_length, dir * num_units].
|
||||||
output_h: The same shape has input_h.
|
output_h: The same shape has input_h.
|
||||||
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
|
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
|
||||||
output_backprop: A 3-D tensor with the same shape as output in the forward pass.
|
output_backprop: A 3-D tensor with the same shape as output in the forward pass.
|
||||||
@ -35,6 +39,8 @@ output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
|
|||||||
pass.
|
pass.
|
||||||
output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
|
output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
|
||||||
pass.
|
pass.
|
||||||
|
time_major: Indicates whether the input/output format is time major or batch
|
||||||
|
major.
|
||||||
reserve_space: The same reserve_space produced in the forward operation.
|
reserve_space: The same reserve_space produced in the forward operation.
|
||||||
input_backprop: The backprop to input in the forward pass. Has the same shape
|
input_backprop: The backprop to input in the forward pass. Has the same shape
|
||||||
as input.
|
as input.
|
||||||
|
@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be
|
|||||||
dropout: Dropout probability. When set to 0., dropout is disabled.
|
dropout: Dropout probability. When set to 0., dropout is disabled.
|
||||||
seed: The 1st part of a seed to initialize dropout.
|
seed: The 1st part of a seed to initialize dropout.
|
||||||
seed2: The 2nd part of a seed to initialize dropout.
|
seed2: The 2nd part of a seed to initialize dropout.
|
||||||
input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
|
input: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
|
[seq_length, batch_size, input_size]. If time_major is false, the shape is
|
||||||
num_units].
|
[batch_size, seq_length, input_size].
|
||||||
|
input_h: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
|
[num_layer * dir, batch_size, num_units]. If time_major is false, the shape
|
||||||
|
is [batch_size, num_layer * dir, num_units].
|
||||||
input_c: For LSTM, a 3-D tensor with the shape of
|
input_c: For LSTM, a 3-D tensor with the shape of
|
||||||
[num_layer * dir, batch, num_units]. For other models, it is ignored.
|
[num_layer * dir, batch, num_units]. For other models, it is ignored.
|
||||||
params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
||||||
@ -26,12 +29,15 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout.
|
|||||||
separately. Note that they might not be compatible across different
|
separately. Note that they might not be compatible across different
|
||||||
generations. So it is a good idea to save and restore
|
generations. So it is a good idea to save and restore
|
||||||
sequence_lengths: a vector of lengths of each input sequence.
|
sequence_lengths: a vector of lengths of each input sequence.
|
||||||
output: A 3-D tensor with the shape of [seq_length, batch_size,
|
output: If time_major is true, this is a 3-D tensor with the shape of
|
||||||
dir * num_units].
|
[seq_length, batch_size, dir * num_units]. If time_major is false, the
|
||||||
|
shape is [batch_size, seq_length, dir * num_units].
|
||||||
output_h: The same shape has input_h.
|
output_h: The same shape has input_h.
|
||||||
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
|
output_c: The same shape as input_c for LSTM. An empty tensor for other models.
|
||||||
is_training: Indicates whether this operation is used for inferenece or
|
is_training: Indicates whether this operation is used for inferenece or
|
||||||
training.
|
training.
|
||||||
|
time_major: Indicates whether the input/output format is time major or batch
|
||||||
|
major.
|
||||||
reserve_space: An opaque tensor that can be used in backprop calculation. It
|
reserve_space: An opaque tensor that can be used in backprop calculation. It
|
||||||
is only produced if is_training is true.
|
is only produced if is_training is true.
|
||||||
END
|
END
|
||||||
|
@ -559,7 +559,7 @@ struct RnnScratchSpace {
|
|||||||
// Extract and checks the forward input tensors, parameters, and shapes from the
|
// Extract and checks the forward input tensors, parameters, and shapes from the
|
||||||
// OpKernelContext.
|
// OpKernelContext.
|
||||||
Status ExtractForwardInput(OpKernelContext* context,
|
Status ExtractForwardInput(OpKernelContext* context,
|
||||||
const CudnnModelTypes& model_types,
|
const CudnnModelTypes& model_types, bool time_major,
|
||||||
const Tensor** input, const Tensor** input_h,
|
const Tensor** input, const Tensor** input_h,
|
||||||
const Tensor** input_c, const Tensor** params,
|
const Tensor** input_c, const Tensor** params,
|
||||||
CudnnRnnModelShapes* model_shapes) {
|
CudnnRnnModelShapes* model_shapes) {
|
||||||
@ -573,8 +573,13 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
if ((*input)->dims() != 3) {
|
if ((*input)->dims() != 3) {
|
||||||
return errors::InvalidArgument("RNN input must be a 3-D vector.");
|
return errors::InvalidArgument("RNN input must be a 3-D vector.");
|
||||||
}
|
}
|
||||||
|
if (time_major) {
|
||||||
model_shapes->max_seq_length = (*input)->dim_size(0);
|
model_shapes->max_seq_length = (*input)->dim_size(0);
|
||||||
model_shapes->batch_size = (*input)->dim_size(1);
|
model_shapes->batch_size = (*input)->dim_size(1);
|
||||||
|
} else {
|
||||||
|
model_shapes->max_seq_length = (*input)->dim_size(1);
|
||||||
|
model_shapes->batch_size = (*input)->dim_size(0);
|
||||||
|
}
|
||||||
model_shapes->input_size = (*input)->dim_size(2);
|
model_shapes->input_size = (*input)->dim_size(2);
|
||||||
model_shapes->input_shape = (*input)->shape();
|
model_shapes->input_shape = (*input)->shape();
|
||||||
model_shapes->dir_count =
|
model_shapes->dir_count =
|
||||||
@ -585,12 +590,25 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
if ((*input_h)->dims() != 3) {
|
if ((*input_h)->dims() != 3) {
|
||||||
return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
|
return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
|
||||||
}
|
}
|
||||||
model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
|
if (time_major) {
|
||||||
|
model_shapes->num_layers =
|
||||||
|
(*input_h)->dim_size(0) / model_shapes->dir_count;
|
||||||
|
} else {
|
||||||
|
model_shapes->num_layers =
|
||||||
|
(*input_h)->dim_size(1) / model_shapes->dir_count;
|
||||||
|
}
|
||||||
model_shapes->num_units = (*input_h)->dim_size(2);
|
model_shapes->num_units = (*input_h)->dim_size(2);
|
||||||
|
|
||||||
|
if (time_major) {
|
||||||
model_shapes->hidden_state_shape =
|
model_shapes->hidden_state_shape =
|
||||||
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
||||||
model_shapes->batch_size, model_shapes->num_units});
|
model_shapes->batch_size, model_shapes->num_units});
|
||||||
|
} else {
|
||||||
|
model_shapes->hidden_state_shape =
|
||||||
|
TensorShape({model_shapes->batch_size,
|
||||||
|
model_shapes->dir_count * model_shapes->num_layers,
|
||||||
|
model_shapes->num_units});
|
||||||
|
}
|
||||||
if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
|
if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
|
"Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
|
||||||
@ -604,23 +622,28 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
(*input_c)->shape().DebugString());
|
(*input_c)->shape().DebugString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (time_major) {
|
||||||
model_shapes->output_shape =
|
model_shapes->output_shape =
|
||||||
TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
|
TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
|
||||||
model_shapes->dir_count * model_shapes->num_units});
|
model_shapes->dir_count * model_shapes->num_units});
|
||||||
|
} else {
|
||||||
|
model_shapes->output_shape =
|
||||||
|
TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
|
||||||
|
model_shapes->dir_count * model_shapes->num_units});
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and checks the sequence_lengths, forward input tensors,
|
// Overloaded function to process the sequence_lengths
|
||||||
// parameters, and shapes from the OpKernelContext.
|
|
||||||
Status ExtractForwardInput(OpKernelContext* context,
|
Status ExtractForwardInput(OpKernelContext* context,
|
||||||
const CudnnModelTypes& model_types,
|
const CudnnModelTypes& model_types, bool time_major,
|
||||||
const Tensor** input, const Tensor** input_h,
|
const Tensor** input, const Tensor** input_h,
|
||||||
const Tensor** input_c, const Tensor** params,
|
const Tensor** input_c, const Tensor** params,
|
||||||
CudnnRnnModelShapes* model_shapes,
|
const Tensor** sequence_lengths,
|
||||||
const Tensor** sequence_lengths) {
|
CudnnRnnModelShapes* model_shapes) {
|
||||||
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
|
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
|
||||||
return ExtractForwardInput(context, model_types, input, input_h, input_c,
|
return ExtractForwardInput(context, model_types, time_major, input, input_h,
|
||||||
params, model_shapes);
|
input_c, params, model_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -629,7 +652,7 @@ Status CreateForwardAndBackwardIODescriptors(
|
|||||||
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
|
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
|
||||||
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
|
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
|
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
|
||||||
const absl::Span<const int>& seq_lengths) {
|
const absl::Span<const int>& seq_lengths, bool time_major) {
|
||||||
StreamExecutor* executor = context->op_device_context()->stream()->parent();
|
StreamExecutor* executor = context->op_device_context()->stream()->parent();
|
||||||
se::dnn::DataType data_type = ToDataType<T>::value;
|
se::dnn::DataType data_type = ToDataType<T>::value;
|
||||||
|
|
||||||
@ -639,11 +662,19 @@ Status CreateForwardAndBackwardIODescriptors(
|
|||||||
|
|
||||||
DCHECK_EQ(input_shape.dims(), 3);
|
DCHECK_EQ(input_shape.dims(), 3);
|
||||||
if (seq_lengths.data() != nullptr) {
|
if (seq_lengths.data() != nullptr) {
|
||||||
|
if (time_major) {
|
||||||
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
|
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
input_shape.dim_size(0), input_shape.dim_size(1),
|
input_shape.dim_size(0), input_shape.dim_size(1),
|
||||||
input_shape.dim_size(2), seq_lengths, data_type);
|
input_shape.dim_size(2), seq_lengths, time_major, data_type);
|
||||||
TF_RETURN_IF_ERROR(input_desc_s.status());
|
TF_RETURN_IF_ERROR(input_desc_s.status());
|
||||||
*input_desc = input_desc_s.ConsumeValueOrDie();
|
*input_desc = input_desc_s.ConsumeValueOrDie();
|
||||||
|
} else {
|
||||||
|
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
|
input_shape.dim_size(1), input_shape.dim_size(0),
|
||||||
|
input_shape.dim_size(2), seq_lengths, time_major, data_type);
|
||||||
|
TF_RETURN_IF_ERROR(input_desc_s.status());
|
||||||
|
*input_desc = input_desc_s.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
|
auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
input_shape.dim_size(0), input_shape.dim_size(1),
|
input_shape.dim_size(0), input_shape.dim_size(1),
|
||||||
@ -653,19 +684,35 @@ Status CreateForwardAndBackwardIODescriptors(
|
|||||||
}
|
}
|
||||||
|
|
||||||
DCHECK_EQ(hidden_state_shape.dims(), 3);
|
DCHECK_EQ(hidden_state_shape.dims(), 3);
|
||||||
|
if (time_major) {
|
||||||
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
|
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
|
||||||
hidden_state_shape.dim_size(2), data_type);
|
hidden_state_shape.dim_size(2), data_type);
|
||||||
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||||
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||||
|
} else {
|
||||||
|
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
|
hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
|
||||||
|
hidden_state_shape.dim_size(2), data_type);
|
||||||
|
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||||
|
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
DCHECK_EQ(output_shape.dims(), 3);
|
DCHECK_EQ(output_shape.dims(), 3);
|
||||||
if (seq_lengths.data() != nullptr) {
|
if (seq_lengths.data() != nullptr) {
|
||||||
|
if (time_major) {
|
||||||
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
|
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
output_shape.dim_size(0), output_shape.dim_size(1),
|
output_shape.dim_size(0), output_shape.dim_size(1),
|
||||||
output_shape.dim_size(2), seq_lengths, data_type);
|
output_shape.dim_size(2), seq_lengths, time_major, data_type);
|
||||||
TF_RETURN_IF_ERROR(output_desc_s.status());
|
TF_RETURN_IF_ERROR(output_desc_s.status());
|
||||||
*output_desc = output_desc_s.ConsumeValueOrDie();
|
*output_desc = output_desc_s.ConsumeValueOrDie();
|
||||||
|
} else {
|
||||||
|
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
|
output_shape.dim_size(1), output_shape.dim_size(0),
|
||||||
|
output_shape.dim_size(2), seq_lengths, time_major, data_type);
|
||||||
|
TF_RETURN_IF_ERROR(output_desc_s.status());
|
||||||
|
*output_desc = output_desc_s.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
|
auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
|
||||||
output_shape.dim_size(0), output_shape.dim_size(1),
|
output_shape.dim_size(0), output_shape.dim_size(1),
|
||||||
@ -687,7 +734,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
|||||||
const bool is_training,
|
const bool is_training,
|
||||||
/* forward outputs, outputs of the function */
|
/* forward outputs, outputs of the function */
|
||||||
Tensor* output, Tensor* output_h, Tensor* output_c,
|
Tensor* output, Tensor* output_h, Tensor* output_c,
|
||||||
const Tensor* sequence_lengths,
|
const Tensor* sequence_lengths, bool time_major,
|
||||||
ScratchAllocator* reserve_space_allocator,
|
ScratchAllocator* reserve_space_allocator,
|
||||||
ScratchAllocator* workspace_allocator,
|
ScratchAllocator* workspace_allocator,
|
||||||
ProfileResult* output_profile_result) {
|
ProfileResult* output_profile_result) {
|
||||||
@ -702,7 +749,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
|||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
||||||
seq_lengths));
|
seq_lengths, time_major));
|
||||||
|
|
||||||
auto input_data = AsDeviceMemory<T>(input);
|
auto input_data = AsDeviceMemory<T>(input);
|
||||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||||
@ -750,7 +797,7 @@ Status DoBackward(
|
|||||||
const Tensor* output_c_backprop, const Tensor* reserve_space,
|
const Tensor* output_c_backprop, const Tensor* reserve_space,
|
||||||
/* backprop outputs, output of the function */
|
/* backprop outputs, output of the function */
|
||||||
Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
|
Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
|
||||||
Tensor* params_backprop, const Tensor* sequence_lengths,
|
Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
|
||||||
ScratchAllocator* workspace_allocator,
|
ScratchAllocator* workspace_allocator,
|
||||||
ProfileResult* output_profile_result) {
|
ProfileResult* output_profile_result) {
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
||||||
@ -764,7 +811,7 @@ Status DoBackward(
|
|||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
||||||
seq_lengths));
|
seq_lengths, time_major));
|
||||||
|
|
||||||
auto input_data = AsDeviceMemory<T>(input);
|
auto input_data = AsDeviceMemory<T>(input);
|
||||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||||
@ -1216,13 +1263,15 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
AlgorithmConfig algo_config;
|
AlgorithmConfig algo_config;
|
||||||
ComputeAndReturnAlgorithm(context, &algo_config, false);
|
ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
|
||||||
|
/*time_major=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
|
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
|
||||||
AlgorithmConfig* output_algo_config,
|
AlgorithmConfig* output_algo_config,
|
||||||
bool var_seq_lengths) {
|
bool var_seq_lengths,
|
||||||
|
bool time_major) {
|
||||||
CHECK_NE(output_algo_config, nullptr);
|
CHECK_NE(output_algo_config, nullptr);
|
||||||
|
|
||||||
const Tensor* input = nullptr;
|
const Tensor* input = nullptr;
|
||||||
@ -1232,14 +1281,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
const Tensor* sequence_lengths = nullptr;
|
const Tensor* sequence_lengths = nullptr;
|
||||||
CudnnRnnModelShapes model_shapes;
|
CudnnRnnModelShapes model_shapes;
|
||||||
if (var_seq_lengths) {
|
if (var_seq_lengths) {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, ExtractForwardInput(context, model_types(), &input, &input_h,
|
ExtractForwardInput(context, model_types(), time_major,
|
||||||
&input_c, ¶ms, &model_shapes,
|
&input, &input_h, &input_c, ¶ms,
|
||||||
&sequence_lengths));
|
&sequence_lengths, &model_shapes));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||||
context, ExtractForwardInput(context, model_types(), &input, &input_h,
|
context, model_types(), time_major, &input,
|
||||||
&input_c, ¶ms, &model_shapes));
|
&input_h, &input_c, ¶ms, &model_shapes));
|
||||||
}
|
}
|
||||||
RnnInputMode input_mode;
|
RnnInputMode input_mode;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -1278,19 +1327,11 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
|
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
|
||||||
*output_algo_config,
|
*output_algo_config,
|
||||||
&rnn_state_cache_, &rnn_desc_ptr));
|
&rnn_state_cache_, &rnn_desc_ptr));
|
||||||
if (var_seq_lengths) {
|
|
||||||
launch_status = DoForward<T>(
|
launch_status = DoForward<T>(
|
||||||
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
||||||
input_c, params, is_training_, output, output_h, output_c,
|
input_c, params, is_training_, output, output_h, output_c,
|
||||||
sequence_lengths, &reserve_space_allocator, &workspace_allocator,
|
sequence_lengths, time_major, &reserve_space_allocator,
|
||||||
/*output_profile_result=*/nullptr);
|
&workspace_allocator, /*output_profile_result=*/nullptr);
|
||||||
} else {
|
|
||||||
launch_status = DoForward<T>(
|
|
||||||
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
|
||||||
input_c, params, is_training_, output, output_h, output_c, nullptr,
|
|
||||||
&reserve_space_allocator, &workspace_allocator,
|
|
||||||
/*output_profile_result=*/nullptr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
OP_REQUIRES_OK(context, launch_status);
|
OP_REQUIRES_OK(context, launch_status);
|
||||||
}
|
}
|
||||||
@ -1372,7 +1413,8 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
AlgorithmConfig best_algo_config;
|
AlgorithmConfig best_algo_config;
|
||||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||||
context, &best_algo_config, false);
|
context, &best_algo_config, /*var_seq_lengths=*/false,
|
||||||
|
/*time_major=*/true);
|
||||||
if (!context->status().ok()) {
|
if (!context->status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1490,10 +1532,11 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
|
|||||||
// Again use temp scratch allocator during profiling.
|
// Again use temp scratch allocator during profiling.
|
||||||
CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
|
CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
|
||||||
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
|
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
|
||||||
status = DoForward<T>(
|
status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
|
||||||
context, *rnn_desc, model_types(), model_shapes, input, input_h,
|
input, input_h, input_c, params, is_training(),
|
||||||
input_c, params, is_training(), output, output_h, output_c, nullptr,
|
output, output_h, output_c, nullptr, true,
|
||||||
&reserve_space_allocator, &workspace_allocator, &fwd_profile_result);
|
&reserve_space_allocator, &workspace_allocator,
|
||||||
|
&fwd_profile_result);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -1506,7 +1549,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
|
|||||||
input_c, params, output, output_h, output_c, &output_backprop,
|
input_c, params, output, output_h, output_c, &output_backprop,
|
||||||
&output_h_backprop, &output_c_backprop, &reserve_space,
|
&output_h_backprop, &output_c_backprop, &reserve_space,
|
||||||
&input_backprop, &input_h_backprop, &input_c_backprop,
|
&input_backprop, &input_h_backprop, &input_c_backprop,
|
||||||
¶ms_backprop, nullptr, &workspace_allocator,
|
¶ms_backprop, nullptr, true, &workspace_allocator,
|
||||||
&bak_profile_result);
|
&bak_profile_result);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
continue;
|
continue;
|
||||||
@ -1561,15 +1604,22 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
|
|||||||
using CudnnRNNKernelCommon::dropout;
|
using CudnnRNNKernelCommon::dropout;
|
||||||
using CudnnRNNKernelCommon::HasInputC;
|
using CudnnRNNKernelCommon::HasInputC;
|
||||||
using CudnnRNNKernelCommon::model_types;
|
using CudnnRNNKernelCommon::model_types;
|
||||||
|
bool time_major_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool time_major() { return time_major_; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
|
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
|
||||||
: CudnnRNNForwardOp<GPUDevice, T>(context) {}
|
: CudnnRNNForwardOp<GPUDevice, T>(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
AlgorithmConfig best_algo_config;
|
AlgorithmConfig best_algo_config;
|
||||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||||
context, &best_algo_config, true);
|
context, &best_algo_config, /*var_seq_lengths=*/true,
|
||||||
|
/*time_major=*/time_major());
|
||||||
if (!context->status().ok()) {
|
if (!context->status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1604,11 +1654,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
: CudnnRNNKernelCommon(context) {}
|
: CudnnRNNKernelCommon(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
ComputeImpl(context, false);
|
ComputeImpl(context, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths) {
|
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
|
||||||
|
bool time_major) {
|
||||||
const Tensor* input = nullptr;
|
const Tensor* input = nullptr;
|
||||||
const Tensor* input_h = nullptr;
|
const Tensor* input_h = nullptr;
|
||||||
const Tensor* input_c = nullptr;
|
const Tensor* input_c = nullptr;
|
||||||
@ -1616,14 +1667,14 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
const Tensor* sequence_lengths = nullptr;
|
const Tensor* sequence_lengths = nullptr;
|
||||||
CudnnRnnModelShapes model_shapes;
|
CudnnRnnModelShapes model_shapes;
|
||||||
if (var_seq_lengths) {
|
if (var_seq_lengths) {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context,
|
||||||
context, ExtractForwardInput(context, model_types(), &input, &input_h,
|
ExtractForwardInput(context, model_types(), time_major,
|
||||||
&input_c, ¶ms, &model_shapes,
|
&input, &input_h, &input_c, ¶ms,
|
||||||
&sequence_lengths));
|
&sequence_lengths, &model_shapes));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||||
context, ExtractForwardInput(context, model_types(), &input, &input_h,
|
context, model_types(), time_major, &input,
|
||||||
&input_c, ¶ms, &model_shapes));
|
&input_h, &input_c, ¶ms, &model_shapes));
|
||||||
}
|
}
|
||||||
RnnInputMode input_mode;
|
RnnInputMode input_mode;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -1665,22 +1716,13 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
|
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
|
||||||
algo_config, &rnn_state_cache_,
|
algo_config, &rnn_state_cache_,
|
||||||
&rnn_desc_ptr));
|
&rnn_desc_ptr));
|
||||||
if (var_seq_lengths) {
|
|
||||||
launch_status = DoBackward<T>(
|
launch_status = DoBackward<T>(
|
||||||
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
||||||
input_c, params, output, output_h, output_c, output_backprop,
|
input_c, params, output, output_h, output_c, output_backprop,
|
||||||
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
|
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
|
||||||
input_h_backprop, input_c_backprop, params_backprop,
|
input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
|
||||||
sequence_lengths, &workspace_allocator,
|
time_major, &workspace_allocator,
|
||||||
/*output_profile_result=*/nullptr);
|
/*output_profile_result=*/nullptr);
|
||||||
} else {
|
|
||||||
launch_status = DoBackward<T>(
|
|
||||||
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
|
|
||||||
input_c, params, output, output_h, output_c, output_backprop,
|
|
||||||
output_h_backprop, output_c_backprop, reserve_space, input_backprop,
|
|
||||||
input_h_backprop, input_c_backprop, params_backprop, nullptr,
|
|
||||||
&workspace_allocator, /*output_profile_result=*/nullptr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
OP_REQUIRES_OK(context, launch_status);
|
OP_REQUIRES_OK(context, launch_status);
|
||||||
}
|
}
|
||||||
@ -1827,12 +1869,20 @@ TF_CALL_double(REGISTER_GPU);
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
class CudnnRNNBackwardOpV3<GPUDevice, T>
|
class CudnnRNNBackwardOpV3<GPUDevice, T>
|
||||||
: public CudnnRNNBackwardOp<GPUDevice, T> {
|
: public CudnnRNNBackwardOp<GPUDevice, T> {
|
||||||
|
private:
|
||||||
|
bool time_major_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool time_major() { return time_major_; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
|
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
|
||||||
: CudnnRNNBackwardOp<GPUDevice, T>(context) {}
|
: CudnnRNNBackwardOp<GPUDevice, T>(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true);
|
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -167,6 +167,7 @@ REGISTER_OP("CudnnRNNV3")
|
|||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
|
.Attr("time_major: bool = true")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
auto input_shape = c->input(0);
|
auto input_shape = c->input(0);
|
||||||
auto input_h_shape = c->input(1);
|
auto input_h_shape = c->input(1);
|
||||||
@ -292,6 +293,7 @@ REGISTER_OP("CudnnRNNBackpropV3")
|
|||||||
.Attr("dropout: float = 0.0")
|
.Attr("dropout: float = 0.0")
|
||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("time_major: bool = true")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
auto input_shape = c->input(0);
|
auto input_shape = c->input(0);
|
||||||
auto input_h_shape = c->input(1);
|
auto input_h_shape = c->input(1);
|
||||||
|
@ -97,6 +97,7 @@ def _cudnn_rnn_backwardv3(op, *grads):
|
|||||||
dropout=op.get_attr("dropout"),
|
dropout=op.get_attr("dropout"),
|
||||||
seed=op.get_attr("seed"),
|
seed=op.get_attr("seed"),
|
||||||
seed2=op.get_attr("seed2"),
|
seed2=op.get_attr("seed2"),
|
||||||
|
time_major=op.get_attr("time_major"),
|
||||||
rnn_mode=op.get_attr("rnn_mode"),
|
rnn_mode=op.get_attr("rnn_mode"),
|
||||||
input_mode=op.get_attr("input_mode"),
|
input_mode=op.get_attr("input_mode"),
|
||||||
direction=op.get_attr("direction")) + (None,)
|
direction=op.get_attr("direction")) + (None,)
|
||||||
|
@ -1300,7 +1300,8 @@ class CudnnRnnSequenceTensorDescriptor
|
|||||||
|
|
||||||
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
|
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
|
||||||
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
|
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
|
||||||
const absl::Span<const int>& seq_lengths, cudnnDataType_t data_type) {
|
const absl::Span<const int>& seq_lengths, bool time_major,
|
||||||
|
cudnnDataType_t data_type) {
|
||||||
#if CUDNN_VERSION >= 7201
|
#if CUDNN_VERSION >= 7201
|
||||||
CHECK_GT(max_seq_length, 0);
|
CHECK_GT(max_seq_length, 0);
|
||||||
int dims[] = {batch_size, data_size, 1};
|
int dims[] = {batch_size, data_size, 1};
|
||||||
@ -1313,9 +1314,15 @@ class CudnnRnnSequenceTensorDescriptor
|
|||||||
const int* seq_lengths_array = seq_lengths.data();
|
const int* seq_lengths_array = seq_lengths.data();
|
||||||
RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
|
RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
|
||||||
float padding_fill = 0.0f;
|
float padding_fill = 0.0f;
|
||||||
|
cudnnRNNDataLayout_t layout;
|
||||||
|
if (time_major) {
|
||||||
|
layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
|
||||||
|
} else {
|
||||||
|
layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
|
||||||
|
}
|
||||||
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
|
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
|
||||||
/*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
|
/*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
|
||||||
/*layout=*/CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
|
/*layout=*/layout,
|
||||||
/*maxSeqLength=*/max_seq_length,
|
/*maxSeqLength=*/max_seq_length,
|
||||||
/*batchSize=*/batch_size, /*vectorSize=*/data_size,
|
/*batchSize=*/batch_size, /*vectorSize=*/data_size,
|
||||||
/*seqLengthArray=*/seq_lengths_array,
|
/*seqLengthArray=*/seq_lengths_array,
|
||||||
@ -1849,11 +1856,12 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
|
|||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
CudnnSupport::createRnnSequenceTensorDescriptor(
|
CudnnSupport::createRnnSequenceTensorDescriptor(
|
||||||
int max_seq_length, int batch_size, int data_size,
|
int max_seq_length, int batch_size, int data_size,
|
||||||
const absl::Span<const int>& seq_lengths, dnn::DataType data_type) {
|
const absl::Span<const int>& seq_lengths, bool time_major,
|
||||||
|
dnn::DataType data_type) {
|
||||||
SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
|
SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
|
||||||
CudnnRnnSequenceTensorDescriptor::Create(
|
CudnnRnnSequenceTensorDescriptor::Create(
|
||||||
parent_, max_seq_length, batch_size, data_size,
|
parent_, max_seq_length, batch_size, data_size,
|
||||||
seq_lengths, ToCudnnDataType(data_type)));
|
seq_lengths, time_major, ToCudnnDataType(data_type)));
|
||||||
return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
|
return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
|
||||||
new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
|
new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
|
||||||
}
|
}
|
||||||
|
@ -63,6 +63,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||||
int data_size,
|
int data_size,
|
||||||
const absl::Span<const int>& seq_lengths,
|
const absl::Span<const int>& seq_lengths,
|
||||||
|
bool time_major,
|
||||||
dnn::DataType data_type) override;
|
dnn::DataType data_type) override;
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
|
||||||
|
@ -2070,7 +2070,7 @@ class DnnSupport {
|
|||||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||||
int data_size,
|
int data_size,
|
||||||
const absl::Span<const int>& seq_lengths,
|
const absl::Span<const int>& seq_lengths,
|
||||||
dnn::DataType data_type) {
|
bool time_major, dnn::DataType data_type) {
|
||||||
return port::Status(port::error::UNIMPLEMENTED,
|
return port::Status(port::error::UNIMPLEMENTED,
|
||||||
"createRnnSequenceTensorDescriptor is unimplemented");
|
"createRnnSequenceTensorDescriptor is unimplemented");
|
||||||
}
|
}
|
||||||
|
@ -411,14 +411,16 @@ StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
|
|||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
StreamExecutor::createRnnSequenceTensorDescriptor(
|
StreamExecutor::createRnnSequenceTensorDescriptor(
|
||||||
int max_seq_length, int batch_size, int data_size,
|
int max_seq_length, int batch_size, int data_size,
|
||||||
const absl::Span<const int> &seq_lengths, dnn::DataType data_type) {
|
const absl::Span<const int> &seq_lengths, bool time_major,
|
||||||
|
dnn::DataType data_type) {
|
||||||
dnn::DnnSupport *dnn_support = AsDnn();
|
dnn::DnnSupport *dnn_support = AsDnn();
|
||||||
if (!dnn_support) {
|
if (!dnn_support) {
|
||||||
return port::Status(port::error::UNKNOWN,
|
return port::Status(port::error::UNKNOWN,
|
||||||
"Fail to find the dnn implementation.");
|
"Fail to find the dnn implementation.");
|
||||||
}
|
}
|
||||||
return dnn_support->createRnnSequenceTensorDescriptor(
|
return dnn_support->createRnnSequenceTensorDescriptor(
|
||||||
max_seq_length, batch_size, data_size, seq_lengths, data_type);
|
max_seq_length, batch_size, data_size, seq_lengths, time_major,
|
||||||
|
data_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
|
||||||
|
@ -421,7 +421,7 @@ class StreamExecutor {
|
|||||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||||
int data_size,
|
int data_size,
|
||||||
const absl::Span<const int> &seq_lengths,
|
const absl::Span<const int> &seq_lengths,
|
||||||
dnn::DataType data_type);
|
bool time_major, dnn::DataType data_type);
|
||||||
|
|
||||||
// Create an RNN state descriptor that specifies the input or hidden state.
|
// Create an RNN state descriptor that specifies the input or hidden state.
|
||||||
// The caller retains the ownership of the returned descriptor.
|
// The caller retains the ownership of the returned descriptor.
|
||||||
|
@ -738,7 +738,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNBackpropV3"
|
name: "CudnnRNNBackpropV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNCanonicalToParams"
|
name: "CudnnRNNCanonicalToParams"
|
||||||
@ -758,7 +758,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV3"
|
name: "CudnnRNNV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Cumprod"
|
name: "Cumprod"
|
||||||
|
@ -738,7 +738,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNBackpropV3"
|
name: "CudnnRNNBackpropV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNCanonicalToParams"
|
name: "CudnnRNNCanonicalToParams"
|
||||||
@ -758,7 +758,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV3"
|
name: "CudnnRNNV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Cumprod"
|
name: "Cumprod"
|
||||||
|
Loading…
Reference in New Issue
Block a user