From 8f0612c3a7a988ae693169b687dd8fc74e4045bf Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 4 Jan 2019 16:44:53 -0800 Subject: [PATCH 01/11] Add the new param 'time_major' for the new cuDNN RNN API --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 185 +++++++++++++----- .../cudnn_rnn/python/layers/cudnn_rnn.py | 26 ++- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 167 ++++++++++++---- tensorflow/core/kernels/cudnn_rnn_ops.cc | 170 +++++++++++----- tensorflow/core/ops/cudnn_rnn_ops.cc | 2 + tensorflow/python/ops/cudnn_rnn_grad.py | 1 + tensorflow/stream_executor/cuda/cuda_dnn.cc | 18 +- tensorflow/stream_executor/cuda/cuda_dnn.h | 1 + tensorflow/stream_executor/dnn.h | 1 + .../stream_executor/stream_executor_pimpl.cc | 6 +- .../stream_executor/stream_executor_pimpl.h | 1 + 11 files changed, 425 insertions(+), 153 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index f5219eb134d..55aeb7682d2 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -69,6 +69,7 @@ def RunLSTM(sess, time, num_layers=1, variable_seq_lengths=False, + time_major=True, is_training=True, dropout=0., num_dirs=True, @@ -84,11 +85,18 @@ def RunLSTM(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + if time_major: + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + else: + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(batch_size, time, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -127,7 +135,7 @@ def RunLSTM(sess, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, - time_major=True, + time_major=time_major, scope=None) # Convert to cudnn opaque param. @@ -135,21 +143,31 @@ def RunLSTM(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque([w, b]) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) - cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + if time_major: + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + else: + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=1) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=1) cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, cu_initial_c_op, opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) # Remove the trivial 1st dimension. - cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( - c=array_ops.squeeze(cu_c_op, axis=0), - h=array_ops.squeeze(cu_h_op, axis=0)) + if time_major: + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0), + h=array_ops.squeeze(cu_h_op, axis=0)) + else: + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=1), + h=array_ops.squeeze(cu_h_op, axis=1)) if is_training: (inp_grad_op, hgrad_op, @@ -161,9 +179,15 @@ def RunLSTM(sess, cu_outputs_op, [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + if time_major: + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + else: + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=1) # Remove the trivial 1st dimension - cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + if time_major: + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + else: + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -336,6 +360,7 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -347,7 +372,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -361,13 +387,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -375,18 +404,22 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -396,18 +429,22 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -417,7 +454,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(outputs, cu_outputs) # h @@ -428,13 +466,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -445,7 +486,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -459,14 +501,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") # Hand-picked dropouts are used below (0. and 1.) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -480,7 +525,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -493,7 +539,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -510,6 +557,7 @@ def RunGRU(sess, num_layers=1, is_training=True, variable_seq_lengths=False, + time_major=True, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -524,11 +572,18 @@ def RunGRU(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + if time_major: + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + else: + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(batch_size, time, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -577,7 +632,7 @@ def RunGRU(sess, sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, - time_major=True, + time_major=time_major, scope=None) ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] @@ -588,13 +643,17 @@ def RunGRU(sess, opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + if time_major: + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + else: + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=1) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -607,7 +666,10 @@ def RunGRU(sess, (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + if time_major: + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + else: + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -633,7 +695,10 @@ def RunGRU(sess, (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) ]) # Remove the trivial 1st dimension - cu_h = np.squeeze(cu_h, axis=0) + if time_major: + cu_h = np.squeeze(cu_h, axis=0) + else: + cu_h = np.squeeze(cu_h, axis=1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -653,7 +718,10 @@ def RunGRU(sess, outputs, h = sess.run([outputs_op, h_op]) cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) # Remove the trivial 1st dimension. - cu_h = np.squeeze(cu_h, axis=0) + if time_major: + cu_h = np.squeeze(cu_h, axis=0) + else: + cu_h = np.squeeze(cu_h, axis=1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -672,6 +740,7 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -683,7 +752,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -697,13 +767,16 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -711,18 +784,22 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -732,18 +809,22 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): + variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -753,20 +834,24 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -777,7 +862,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -786,15 +872,18 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], + "time_major": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): self.skipTest("No GPUs found") + if not variable_seq_lengths and not time_major: + self.skipTest("Batch major not supported") with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: # 1st time w/o dropout. @@ -807,7 +896,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -820,7 +910,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 86ad8ae8073..6cf7db7e0af 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -378,20 +378,33 @@ class _CudnnRNN(base_layer.Layer): inputs, initial_state=None, sequence_lengths=None, + time_major=True, training=True): """Runs the forward step for the RNN model. Args: - inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` if + 'time_major == True' (default) or `[batch_size, time_len, input_size]` + if 'time_major == False'. initial_state: a tuple of tensor(s) of shape - `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + `[num_layers * num_dirs, batch_size, num_units]` if + 'time_major == True' (default) or `[batch_size, num_layers * num_dirs, + num_units]` if 'time_major == False'. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. training: whether this operation will be used in training or inference. Returns: - output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]` + if 'time_major == True' (default) or `[batch_size, time_len, + num_dirs * num_units]` if 'time_major == False'. It is a `concat([fwd_output, bak_output], axis=2)`. output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. @@ -418,7 +431,8 @@ class _CudnnRNN(base_layer.Layer): # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - sequence_lengths, training) + sequence_lengths, time_major, + training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -482,7 +496,8 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, time_major, + training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -491,6 +506,7 @@ class _CudnnRNN(base_layer.Layer): training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 1facc83972f..e134b82593f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -956,6 +956,7 @@ def _cudnn_rnn(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -964,10 +965,12 @@ def _cudnn_rnn(inputs, """Cudnn RNN. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -977,6 +980,11 @@ def _cudnn_rnn(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1017,6 +1025,7 @@ def _cudnn_rnn(inputs, } if sequence_lengths is not None: args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) @@ -1031,6 +1040,7 @@ def cudnn_lstm(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1039,15 +1049,26 @@ def cudnn_lstm(inputs, """Cudnn LSTM. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference - input_mode: indicate whether there is a linear projection between the + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. + input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) always applies a linear projection of input @@ -1060,17 +1081,13 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. - sequence_lengths: an int32 array representing the variable sequence lengths - in a batch. The size of the array has to equal the batch_size. Default to - None, in which case sequences in the batch are assumed to have the same - length, which is inferred from inputs. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - sequence_lengths, input_mode, direction, dropout, seed, - name) + sequence_lengths, time_major, input_mode, direction, + dropout, seed, name) def _cudnn_rnn_no_input_c(inputs, @@ -1079,6 +1096,7 @@ def _cudnn_rnn_no_input_c(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1087,10 +1105,12 @@ def _cudnn_rnn_no_input_c(inputs, """Cudnn RNN w/o input_c. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1098,6 +1118,11 @@ def _cudnn_rnn_no_input_c(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1118,7 +1143,8 @@ def _cudnn_rnn_no_input_c(inputs, input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, - input_mode, direction, dropout, seed, name) + time_major, input_mode, direction, dropout, + seed, name) return outputs, output_h @@ -1127,6 +1153,7 @@ def cudnn_gru(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1135,10 +1162,12 @@ def cudnn_gru(inputs, """Cudnn GRU. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1153,6 +1182,11 @@ def cudnn_gru(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1163,8 +1197,8 @@ def cudnn_gru(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - sequence_lengths, input_mode, direction, dropout, - seed, name) + sequence_lengths, time_major, input_mode, + direction, dropout, seed, name) def cudnn_rnn_relu(inputs, @@ -1176,14 +1210,17 @@ def cudnn_rnn_relu(inputs, dropout=0., seed=0, sequence_lengths=None, + time_major=True, name=None): """Cudnn RNN Relu. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1201,14 +1238,19 @@ def cudnn_rnn_relu(inputs, sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. name: name of the operation. Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_RELU, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, @@ -1216,6 +1258,7 @@ def cudnn_rnn_tanh(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1224,10 +1267,12 @@ def cudnn_rnn_tanh(inputs, """Cudnn RNN Tanh. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + inputs: the input sequence to the RNN model. A Tensor of shape [max_time, + batch_size, input_size] if 'time_major == True' (default) or a Tensor + of shape [batch_size, max_time, input_size] if 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1242,6 +1287,11 @@ def cudnn_rnn_tanh(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1252,8 +1302,8 @@ def cudnn_rnn_tanh(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_TANH, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1537,14 +1587,18 @@ class _CudnnRNN(object): input_c, params, is_training=True, - sequence_lengths=None): + sequence_lengths=None, + time_major=True): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + input_data: the input sequence to the RNN model. A Tensor of shape + [max_time, batch_size, input_size] if 'time_major == True' (default) + or a Tensor of shape [batch_size, max_time, input_size] if + 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1553,6 +1607,12 @@ class _CudnnRNN(object): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. Returns: output: the output sequence. output_h: the final state for h. @@ -1566,6 +1626,7 @@ class _CudnnRNN(object): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1666,14 +1727,18 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. A Tensor of shape [?, - batch_size, input_size]. + input_data: the input sequence to the RNN model. A Tensor of shape + [max_time, batch_size, input_size] if 'time_major == True' (default) + or a Tensor of shape [batch_size, max_time, input_size] if + 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1681,6 +1746,12 @@ class CudnnLSTM(_CudnnRNN): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1693,6 +1764,7 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=sequence_lengths, + time_major=time_major, is_training=is_training) return (output, output_h, output_c) @@ -1752,19 +1824,29 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. + input_data: the input sequence to the RNN model. A Tensor of shape + [max_time, batch_size, input_size] if 'time_major == True' (default) + or a Tensor of shape [batch_size, max_time, input_size] if + 'time_major == False'. input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + batch_size, num_units] if 'time_major == True' (default) or a Tensor of + shape [batch_size, num_layers, num_units] params: the parameter buffer created for this model. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' + is used. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1777,6 +1859,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 196494cbcf8..1a25d983ea6 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -562,7 +562,8 @@ Status ExtractForwardInput(OpKernelContext* context, const CudnnModelTypes& model_types, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, - CudnnRnnModelShapes* model_shapes) { + CudnnRnnModelShapes* model_shapes, + bool time_major) { TF_RETURN_IF_ERROR(context->input("input", input)); TF_RETURN_IF_ERROR(context->input("input_h", input_h)); if (model_types.HasInputC()) { @@ -573,8 +574,13 @@ Status ExtractForwardInput(OpKernelContext* context, if ((*input)->dims() != 3) { return errors::InvalidArgument("RNN input must be a 3-D vector."); } - model_shapes->max_seq_length = (*input)->dim_size(0); - model_shapes->batch_size = (*input)->dim_size(1); + if (time_major) { + model_shapes->max_seq_length = (*input)->dim_size(0); + model_shapes->batch_size = (*input)->dim_size(1); + } else { + model_shapes->max_seq_length = (*input)->dim_size(1); + model_shapes->batch_size = (*input)->dim_size(0); + } model_shapes->input_size = (*input)->dim_size(2); model_shapes->input_shape = (*input)->shape(); model_shapes->dir_count = @@ -585,12 +591,23 @@ Status ExtractForwardInput(OpKernelContext* context, if ((*input_h)->dims() != 3) { return errors::InvalidArgument("RNN input_h must be a 3-D vector."); } - model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; + if (time_major) { + model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count; + } else { + model_shapes->num_layers = (*input_h)->dim_size(1) / model_shapes->dir_count; + } model_shapes->num_units = (*input_h)->dim_size(2); - model_shapes->hidden_state_shape = - TensorShape({model_shapes->dir_count * model_shapes->num_layers, - model_shapes->batch_size, model_shapes->num_units}); + if (time_major) { + model_shapes->hidden_state_shape = + TensorShape({model_shapes->dir_count * model_shapes->num_layers, + model_shapes->batch_size, model_shapes->num_units}); + } else { + model_shapes->hidden_state_shape = + TensorShape({model_shapes->batch_size, + model_shapes->dir_count * model_shapes->num_layers, + model_shapes->num_units}); + } if ((*input_h)->shape() != model_shapes->hidden_state_shape) { return errors::InvalidArgument( "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ", @@ -604,9 +621,15 @@ Status ExtractForwardInput(OpKernelContext* context, (*input_c)->shape().DebugString()); } } - model_shapes->output_shape = - TensorShape({model_shapes->max_seq_length, model_shapes->batch_size, - model_shapes->dir_count * model_shapes->num_units}); + if (time_major) { + model_shapes->output_shape = + TensorShape({model_shapes->max_seq_length, model_shapes->batch_size, + model_shapes->dir_count * model_shapes->num_units}); + } else { + model_shapes->output_shape = + TensorShape({model_shapes->batch_size, model_shapes->max_seq_length, + model_shapes->dir_count * model_shapes->num_units}); + } return Status::OK(); } @@ -617,10 +640,11 @@ Status ExtractForwardInput(OpKernelContext* context, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, CudnnRnnModelShapes* model_shapes, - const Tensor** sequence_lengths) { + const Tensor** sequence_lengths, + bool time_major) { TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); return ExtractForwardInput(context, model_types, input, input_h, input_c, - params, model_shapes); + params, model_shapes, time_major); } template @@ -629,7 +653,8 @@ Status CreateForwardAndBackwardIODescriptors( std::unique_ptr* input_desc, std::unique_ptr* state_desc, std::unique_ptr* output_desc, - const absl::Span& seq_lengths) { + const absl::Span& seq_lengths, + bool time_major) { StreamExecutor* executor = context->op_device_context()->stream()->parent(); se::dnn::DataType data_type = ToDataType::value; @@ -639,11 +664,19 @@ Status CreateForwardAndBackwardIODescriptors( DCHECK_EQ(input_shape.dims(), 3); if (seq_lengths.data() != nullptr) { - auto input_desc_s = executor->createRnnSequenceTensorDescriptor( - input_shape.dim_size(0), input_shape.dim_size(1), - input_shape.dim_size(2), seq_lengths, data_type); - TF_RETURN_IF_ERROR(input_desc_s.status()); - *input_desc = input_desc_s.ConsumeValueOrDie(); + if (time_major) { + auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + input_shape.dim_size(0), input_shape.dim_size(1), + input_shape.dim_size(2), seq_lengths, time_major, data_type); + TF_RETURN_IF_ERROR(input_desc_s.status()); + *input_desc = input_desc_s.ConsumeValueOrDie(); + } else { + auto input_desc_s = executor->createRnnSequenceTensorDescriptor( + input_shape.dim_size(1), input_shape.dim_size(0), + input_shape.dim_size(2), seq_lengths, time_major, data_type); + TF_RETURN_IF_ERROR(input_desc_s.status()); + *input_desc = input_desc_s.ConsumeValueOrDie(); + } } else { auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), @@ -653,19 +686,35 @@ Status CreateForwardAndBackwardIODescriptors( } DCHECK_EQ(hidden_state_shape.dims(), 3); - auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( - hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), - hidden_state_shape.dim_size(2), data_type); - TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); - *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + if (time_major) { + auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), + hidden_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); + *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + } else { + auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( + hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0), + hidden_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); + *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + } DCHECK_EQ(output_shape.dims(), 3); if (seq_lengths.data() != nullptr) { - auto output_desc_s = executor->createRnnSequenceTensorDescriptor( - output_shape.dim_size(0), output_shape.dim_size(1), - output_shape.dim_size(2), seq_lengths, data_type); - TF_RETURN_IF_ERROR(output_desc_s.status()); - *output_desc = output_desc_s.ConsumeValueOrDie(); + if (time_major) { + auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + output_shape.dim_size(0), output_shape.dim_size(1), + output_shape.dim_size(2), seq_lengths, time_major, data_type); + TF_RETURN_IF_ERROR(output_desc_s.status()); + *output_desc = output_desc_s.ConsumeValueOrDie(); + } else { + auto output_desc_s = executor->createRnnSequenceTensorDescriptor( + output_shape.dim_size(1), output_shape.dim_size(0), + output_shape.dim_size(2), seq_lengths, time_major, data_type); + TF_RETURN_IF_ERROR(output_desc_s.status()); + *output_desc = output_desc_s.ConsumeValueOrDie(); + } } else { auto output_desc_s = executor->createRnnSequenceTensorDescriptor( output_shape.dim_size(0), output_shape.dim_size(1), @@ -688,6 +737,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, /* forward outputs, outputs of the function */ Tensor* output, Tensor* output_h, Tensor* output_c, const Tensor* sequence_lengths, + bool time_major, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { @@ -702,7 +752,7 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, } TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors( context, model_shapes, &input_desc, &state_desc, &output_desc, - seq_lengths)); + seq_lengths, time_major)); auto input_data = AsDeviceMemory(input); auto input_h_data = AsDeviceMemory(input_h); @@ -750,7 +800,7 @@ Status DoBackward( const Tensor* output_c_backprop, const Tensor* reserve_space, /* backprop outputs, output of the function */ Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop, - Tensor* params_backprop, const Tensor* sequence_lengths, + Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major, ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { std::unique_ptr input_desc; @@ -764,7 +814,7 @@ Status DoBackward( } TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors( context, model_shapes, &input_desc, &state_desc, &output_desc, - seq_lengths)); + seq_lengths, time_major)); auto input_data = AsDeviceMemory(input); auto input_h_data = AsDeviceMemory(input_h); @@ -1216,13 +1266,14 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { void Compute(OpKernelContext* context) override { AlgorithmConfig algo_config; - ComputeAndReturnAlgorithm(context, &algo_config, false); + ComputeAndReturnAlgorithm(context, &algo_config, false, true); } protected: virtual void ComputeAndReturnAlgorithm(OpKernelContext* context, AlgorithmConfig* output_algo_config, - bool var_seq_lengths) { + bool var_seq_lengths, + bool time_major) { CHECK_NE(output_algo_config, nullptr); const Tensor* input = nullptr; @@ -1235,11 +1286,12 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), &input, &input_h, &input_c, ¶ms, &model_shapes, - &sequence_lengths)); + &sequence_lengths, time_major)); } else { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes)); + &input_c, ¶ms, &model_shapes, + time_major)); } RnnInputMode input_mode; OP_REQUIRES_OK(context, @@ -1282,13 +1334,13 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { launch_status = DoForward( context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, input_c, params, is_training_, output, output_h, output_c, - sequence_lengths, &reserve_space_allocator, &workspace_allocator, - /*output_profile_result=*/nullptr); + sequence_lengths, time_major, &reserve_space_allocator, + &workspace_allocator, /*output_profile_result=*/nullptr); } else { launch_status = DoForward( context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, input_c, params, is_training_, output, output_h, output_c, nullptr, - &reserve_space_allocator, &workspace_allocator, + true, &reserve_space_allocator, &workspace_allocator, /*output_profile_result=*/nullptr); } } @@ -1372,7 +1424,7 @@ class CudnnRNNForwardOpV2 void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp::ComputeAndReturnAlgorithm( - context, &best_algo_config, false); + context, &best_algo_config, false, true); if (!context->status().ok()) { return; } @@ -1493,7 +1545,8 @@ class CudnnRNNForwardOpV2 status = DoForward( context, *rnn_desc, model_types(), model_shapes, input, input_h, input_c, params, is_training(), output, output_h, output_c, nullptr, - &reserve_space_allocator, &workspace_allocator, &fwd_profile_result); + true, &reserve_space_allocator, &workspace_allocator, + &fwd_profile_result); if (!status.ok()) { continue; } @@ -1506,7 +1559,7 @@ class CudnnRNNForwardOpV2 input_c, params, output, output_h, output_c, &output_backprop, &output_h_backprop, &output_c_backprop, &reserve_space, &input_backprop, &input_h_backprop, &input_c_backprop, - ¶ms_backprop, nullptr, &workspace_allocator, + ¶ms_backprop, nullptr, true, &workspace_allocator, &bak_profile_result); if (!status.ok()) { continue; @@ -1561,15 +1614,19 @@ class CudnnRNNForwardOpV3 using CudnnRNNKernelCommon::dropout; using CudnnRNNKernelCommon::HasInputC; using CudnnRNNKernelCommon::model_types; - + bool time_major_; + protected: + bool time_major() { return time_major_; } public: explicit CudnnRNNForwardOpV3(OpKernelConstruction* context) - : CudnnRNNForwardOp(context) {} + : CudnnRNNForwardOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); + } void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp::ComputeAndReturnAlgorithm( - context, &best_algo_config, true); + context, &best_algo_config, true, time_major()); if (!context->status().ok()) { return; } @@ -1604,11 +1661,12 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { : CudnnRNNKernelCommon(context) {} void Compute(OpKernelContext* context) override { - ComputeImpl(context, false); + ComputeImpl(context, false, true); } protected: - virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths) { + virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths, + bool time_major) { const Tensor* input = nullptr; const Tensor* input_h = nullptr; const Tensor* input_c = nullptr; @@ -1619,11 +1677,12 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), &input, &input_h, &input_c, ¶ms, &model_shapes, - &sequence_lengths)); + &sequence_lengths, time_major)); } else { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes)); + &input_c, ¶ms, &model_shapes, + time_major)); } RnnInputMode input_mode; OP_REQUIRES_OK(context, @@ -1671,7 +1730,7 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, input_backprop, input_h_backprop, input_c_backprop, params_backprop, - sequence_lengths, &workspace_allocator, + sequence_lengths, time_major, &workspace_allocator, /*output_profile_result=*/nullptr); } else { launch_status = DoBackward( @@ -1679,7 +1738,8 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, input_backprop, input_h_backprop, input_c_backprop, params_backprop, nullptr, - &workspace_allocator, /*output_profile_result=*/nullptr); + true, &workspace_allocator, + /*output_profile_result=*/nullptr); } } OP_REQUIRES_OK(context, launch_status); @@ -1827,12 +1887,18 @@ TF_CALL_double(REGISTER_GPU); template class CudnnRNNBackwardOpV3 : public CudnnRNNBackwardOp { + private: + bool time_major_; + protected: + bool time_major() { return time_major_; } public: explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context) - : CudnnRNNBackwardOp(context) {} + : CudnnRNNBackwardOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); + } void Compute(OpKernelContext* context) override { - CudnnRNNBackwardOp::ComputeImpl(context, true); + CudnnRNNBackwardOp::ComputeImpl(context, true, time_major()); } }; diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index cd2e5c9d340..9b22ccdeeec 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -167,6 +167,7 @@ REGISTER_OP("CudnnRNNV3") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("is_training: bool = true") + .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); auto input_h_shape = c->input(1); @@ -292,6 +293,7 @@ REGISTER_OP("CudnnRNNBackpropV3") .Attr("dropout: float = 0.0") .Attr("seed: int = 0") .Attr("seed2: int = 0") + .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); auto input_h_shape = c->input(1); diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py index d4c182a802a..9ce906121f2 100644 --- a/tensorflow/python/ops/cudnn_rnn_grad.py +++ b/tensorflow/python/ops/cudnn_rnn_grad.py @@ -97,6 +97,7 @@ def _cudnn_rnn_backwardv3(op, *grads): dropout=op.get_attr("dropout"), seed=op.get_attr("seed"), seed2=op.get_attr("seed2"), + time_major=op.get_attr("time_major"), rnn_mode=op.get_attr("rnn_mode"), input_mode=op.get_attr("input_mode"), direction=op.get_attr("direction")) + (None,) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 7c05701895b..2abe617dd92 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1426,7 +1426,8 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr Create( CUDAExecutor* parent, int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, cudnnDataType_t data_type) { + const absl::Span& seq_lengths, bool time_major, + cudnnDataType_t data_type) { #if CUDNN_VERSION >= 7201 CHECK_GT(max_seq_length, 0); int dims[] = {batch_size, data_size, 1}; @@ -1439,9 +1440,15 @@ class CudnnRnnSequenceTensorDescriptor const int* seq_lengths_array = seq_lengths.data(); RNNDataDescriptor data_desc = CreateRNNDataDescriptor(); float padding_fill = 0.0f; + cudnnRNNDataLayout_t layout; + if (time_major) { + layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; + } else { + layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; + } RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor( /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type, - /*layout=*/CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, + /*layout=*/layout, /*maxSeqLength=*/max_seq_length, /*batchSize=*/batch_size, /*vectorSize=*/data_size, /*seqLengthArray=*/seq_lengths_array, @@ -1560,7 +1567,7 @@ port::StatusOr ExtractAndCheckRnnForward( model_dims.input_size = input_desc.data_size(); model_dims.dir_count = (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; - + // check parameters if (!(input_h_desc.num_layers() == model_dims.num_layers * model_dims.dir_count && @@ -1978,11 +1985,12 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, - const absl::Span& seq_lengths, dnn::DataType data_type) { + const absl::Span& seq_lengths, bool time_major, + dnn::DataType data_type) { SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, CudnnRnnSequenceTensorDescriptor::Create( parent_, max_seq_length, batch_size, data_size, - seq_lengths, ToCudnnDataType(data_type))); + seq_lengths, time_major, ToCudnnDataType(data_type))); return std::unique_ptr( new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 044ed545145..482098c5113 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -63,6 +63,7 @@ class CudnnSupport : public dnn::DnnSupport { createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, + bool time_major, dnn::DataType data_type) override; port::StatusOr> diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 1001824ed5f..2f784753dcc 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2107,6 +2107,7 @@ class DnnSupport { createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, + bool time_major, dnn::DataType data_type) { return port::Status(port::error::UNIMPLEMENTED, "createRnnSequenceTensorDescriptor is unimplemented"); diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 439c73ec8f6..e5a9cf03491 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -406,14 +406,16 @@ StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, port::StatusOr> StreamExecutor::createRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, - const absl::Span &seq_lengths, dnn::DataType data_type) { + const absl::Span &seq_lengths, bool time_major, + dnn::DataType data_type) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return port::Status(port::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnSequenceTensorDescriptor( - max_seq_length, batch_size, data_size, seq_lengths, data_type); + max_seq_length, batch_size, data_size, seq_lengths, time_major, + data_type); } port::StatusOr> diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index ad2bc3c733b..9f72648a2d1 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -420,6 +420,7 @@ class StreamExecutor { createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, const absl::Span &seq_lengths, + bool time_major, dnn::DataType data_type); // Create an RNN state descriptor that specifies the input or hidden state. From 7e5e2c8cdf496e7726c8508b580c4b61e40f5332 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 7 Feb 2019 18:37:36 -0800 Subject: [PATCH 02/11] Merged two callsites of ExtractForwardInput and other minor issues --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 88 ++++-------- .../cudnn_rnn/python/layers/cudnn_rnn.py | 14 +- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 129 +++++++++--------- tensorflow/core/framework/op_kernel.cc | 13 ++ tensorflow/core/framework/op_kernel.h | 4 + tensorflow/core/kernels/cudnn_rnn_ops.cc | 102 +++++--------- 6 files changed, 145 insertions(+), 205 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 55aeb7682d2..ac3bb72cb13 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -85,18 +85,12 @@ def RunLSTM(sess, random_seed.set_random_seed(0) np.random.seed(0) - if time_major: - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) - else: - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(batch_size, time, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] if time_major else + [batch_size, time, input_size]) + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(*shape).astype(dtype.as_numpy_dtype), + dtype=dtype) initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -143,12 +137,10 @@ def RunLSTM(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque([w, b]) - if time_major: - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) - cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) - else: - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=1) - cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=1) + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=(0 if time_major + else 1)) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=(0 if time_major + else 1)) cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, @@ -160,14 +152,9 @@ def RunLSTM(sess, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) # Remove the trivial 1st dimension. - if time_major: - cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( - c=array_ops.squeeze(cu_c_op, axis=0), - h=array_ops.squeeze(cu_h_op, axis=0)) - else: - cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( - c=array_ops.squeeze(cu_c_op, axis=1), - h=array_ops.squeeze(cu_h_op, axis=1)) + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), + h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1)) if is_training: (inp_grad_op, hgrad_op, @@ -179,15 +166,9 @@ def RunLSTM(sess, cu_outputs_op, [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) # Remove the trivial 1st dimension - if time_major: - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) - else: - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=1) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) # Remove the trivial 1st dimension - if time_major: - cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) - else: - cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=1) + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -572,18 +553,12 @@ def RunGRU(sess, random_seed.set_random_seed(0) np.random.seed(0) - if time_major: - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) - else: - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(batch_size, time, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] if time_major else + [batch_size, time, input_size]) + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(*shape).astype(dtype.as_numpy_dtype), + dtype=dtype) initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -643,10 +618,8 @@ def RunGRU(sess, opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - if time_major: - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) - else: - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=1) + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=(0 if time_major + else 1)) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, @@ -666,10 +639,7 @@ def RunGRU(sess, (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) # Remove the trivial 1st dimension - if time_major: - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) - else: - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=1) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -695,10 +665,7 @@ def RunGRU(sess, (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) ]) # Remove the trivial 1st dimension - if time_major: - cu_h = np.squeeze(cu_h, axis=0) - else: - cu_h = np.squeeze(cu_h, axis=1) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -718,10 +685,7 @@ def RunGRU(sess, outputs, h = sess.run([outputs_op, h_op]) cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) # Remove the trivial 1st dimension. - if time_major: - cu_h = np.squeeze(cu_h, axis=0) - else: - cu_h = np.squeeze(cu_h, axis=1) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 6cf7db7e0af..fdd1c0b694a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -384,17 +384,17 @@ class _CudnnRNN(base_layer.Layer): Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` if - 'time_major == True' (default) or `[batch_size, time_len, input_size]` - if 'time_major == False'. + `time_major` is True (default) or `[batch_size, time_len, input_size]` + if `time_major` is False. initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, batch_size, num_units]` if - 'time_major == True' (default) or `[batch_size, num_layers * num_dirs, - num_units]` if 'time_major == False'. If not provided, use + `time_major` is True (default) or `[batch_size, num_layers * num_dirs, + num_units]` if `time_major` is False. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in @@ -403,8 +403,8 @@ class _CudnnRNN(base_layer.Layer): training: whether this operation will be used in training or inference. Returns: output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]` - if 'time_major == True' (default) or `[batch_size, time_len, - num_dirs * num_units]` if 'time_major == False'. + if `time_major` is True (default) or `[batch_size, time_len, + num_dirs * num_units]` if `time_major` is False. It is a `concat([fwd_output, bak_output], axis=2)`. output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index e134b82593f..6ccba944de2 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -965,12 +965,12 @@ def _cudnn_rnn(inputs, """Cudnn RNN. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -980,7 +980,7 @@ def _cudnn_rnn(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1049,12 +1049,12 @@ def cudnn_lstm(inputs, """Cudnn LSTM. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1063,7 +1063,7 @@ def cudnn_lstm(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1105,12 +1105,12 @@ def _cudnn_rnn_no_input_c(inputs, """Cudnn RNN w/o input_c. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1118,7 +1118,7 @@ def _cudnn_rnn_no_input_c(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1162,12 +1162,12 @@ def cudnn_gru(inputs, """Cudnn GRU. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1182,7 +1182,7 @@ def cudnn_gru(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1215,12 +1215,12 @@ def cudnn_rnn_relu(inputs, """Cudnn RNN Relu. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1238,7 +1238,7 @@ def cudnn_rnn_relu(inputs, sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1267,12 +1267,12 @@ def cudnn_rnn_tanh(inputs, """Cudnn RNN Tanh. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [max_time, - batch_size, input_size] if 'time_major == True' (default) or a Tensor - of shape [batch_size, max_time, input_size] if 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1287,7 +1287,7 @@ def cudnn_rnn_tanh(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If true, + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major @@ -1592,13 +1592,12 @@ class _CudnnRNN(object): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape - [max_time, batch_size, input_size] if 'time_major == True' (default) - or a Tensor of shape [batch_size, max_time, input_size] if - 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1607,7 +1606,7 @@ class _CudnnRNN(object): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in @@ -1732,13 +1731,12 @@ class CudnnLSTM(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape - [max_time, batch_size, input_size] if 'time_major == True' (default) - or a Tensor of shape [batch_size, max_time, input_size] if - 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1746,7 +1744,7 @@ class CudnnLSTM(_CudnnRNN): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in @@ -1829,19 +1827,18 @@ class _CudnnRNNNoInputC(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape - [max_time, batch_size, input_size] if 'time_major == True' (default) - or a Tensor of shape [batch_size, max_time, input_size] if - 'time_major == False'. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units] if 'time_major == True' (default) or a Tensor of - shape [batch_size, num_layers, num_units] + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - time_major: The shape format of the 'inputs' and 'outputs' Tensors. If + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 789f0fda752..5fca492b131 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -133,6 +133,15 @@ const string& OpKernel::requested_input(int i) const { return def_->input(i); } return device->attributes().locality().numa_node(); } +bool OpKernel::HasInput(StringPiece input_name) const { + const auto result = input_name_map_.find(input_name); + if (result == input_name_map_.end()) { + return false; + } else { + return true; + } +} + Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { const auto result = input_name_map_.find(input_name); @@ -345,6 +354,10 @@ void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { referenced_tensors_->Add(tensor); } +bool OpKernelContext::has_input(StringPiece name) const { + return params_->op_kernel->HasInput(name); +} + Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index bccb2bf3c76..57a6e6c6059 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -169,6 +169,7 @@ class OpKernel { return output_memory_types_; } + bool HasInput(StringPiece input_name) const; Status InputRange(StringPiece input_name, int* start, int* stop) const; Status OutputRange(StringPiece output_name, int* start, int* stop) const; @@ -751,6 +752,9 @@ class OpKernelContext { // TODO(mrry): Convert this to return Status. bool has_input(int index) const; + // Return true if there is input under the given name. + bool has_input(StringPiece name) const; + // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. // Usage: if (!context->ValidateInputsAreSameShape(this)) return; diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 1a25d983ea6..17b10570c6a 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -560,10 +560,14 @@ struct RnnScratchSpace { // OpKernelContext. Status ExtractForwardInput(OpKernelContext* context, const CudnnModelTypes& model_types, + bool time_major, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, CudnnRnnModelShapes* model_shapes, - bool time_major) { + const Tensor** sequence_lengths) { + if (context->has_input("sequence_lengths")) { + TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); + } TF_RETURN_IF_ERROR(context->input("input", input)); TF_RETURN_IF_ERROR(context->input("input_h", input_h)); if (model_types.HasInputC()) { @@ -633,20 +637,6 @@ Status ExtractForwardInput(OpKernelContext* context, return Status::OK(); } -// Extract and checks the sequence_lengths, forward input tensors, -// parameters, and shapes from the OpKernelContext. -Status ExtractForwardInput(OpKernelContext* context, - const CudnnModelTypes& model_types, - const Tensor** input, const Tensor** input_h, - const Tensor** input_c, const Tensor** params, - CudnnRnnModelShapes* model_shapes, - const Tensor** sequence_lengths, - bool time_major) { - TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); - return ExtractForwardInput(context, model_types, input, input_h, input_c, - params, model_shapes, time_major); -} - template Status CreateForwardAndBackwardIODescriptors( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, @@ -1266,7 +1256,8 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { void Compute(OpKernelContext* context) override { AlgorithmConfig algo_config; - ComputeAndReturnAlgorithm(context, &algo_config, false, true); + ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false, + /*time_major=*/true); } protected: @@ -1282,17 +1273,10 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { const Tensor* params = nullptr; const Tensor* sequence_lengths = nullptr; CudnnRnnModelShapes model_shapes; - if (var_seq_lengths) { - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes, - &sequence_lengths, time_major)); - } else { - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes, - time_major)); - } + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &model_shapes, &sequence_lengths)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, @@ -1330,19 +1314,12 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { context, GetCachedRnnDescriptor(context, model_shapes, input_mode, *output_algo_config, &rnn_state_cache_, &rnn_desc_ptr)); - if (var_seq_lengths) { - launch_status = DoForward( - context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, - input_c, params, is_training_, output, output_h, output_c, - sequence_lengths, time_major, &reserve_space_allocator, - &workspace_allocator, /*output_profile_result=*/nullptr); - } else { - launch_status = DoForward( - context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, - input_c, params, is_training_, output, output_h, output_c, nullptr, - true, &reserve_space_allocator, &workspace_allocator, - /*output_profile_result=*/nullptr); - } + launch_status = DoForward( + context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, + input_c, params, is_training_, output, output_h, output_c, + sequence_lengths, time_major, &reserve_space_allocator, + &workspace_allocator, /*output_profile_result=*/nullptr); + } OP_REQUIRES_OK(context, launch_status); } @@ -1424,7 +1401,8 @@ class CudnnRNNForwardOpV2 void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp::ComputeAndReturnAlgorithm( - context, &best_algo_config, false, true); + context, &best_algo_config, /*var_seq_lengths=*/false, + /*time_major=*/true); if (!context->status().ok()) { return; } @@ -1626,7 +1604,8 @@ class CudnnRNNForwardOpV3 void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp::ComputeAndReturnAlgorithm( - context, &best_algo_config, true, time_major()); + context, &best_algo_config, /*var_seq_lengths=*/true, + /*time_major=*/time_major()); if (!context->status().ok()) { return; } @@ -1673,17 +1652,10 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { const Tensor* params = nullptr; const Tensor* sequence_lengths = nullptr; CudnnRnnModelShapes model_shapes; - if (var_seq_lengths) { - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes, - &sequence_lengths, time_major)); - } else { - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), &input, &input_h, - &input_c, ¶ms, &model_shapes, - time_major)); - } + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &model_shapes, &sequence_lengths)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, @@ -1724,23 +1696,13 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { context, GetCachedRnnDescriptor(context, model_shapes, input_mode, algo_config, &rnn_state_cache_, &rnn_desc_ptr)); - if (var_seq_lengths) { - launch_status = DoBackward( - context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, - input_c, params, output, output_h, output_c, output_backprop, - output_h_backprop, output_c_backprop, reserve_space, input_backprop, - input_h_backprop, input_c_backprop, params_backprop, - sequence_lengths, time_major, &workspace_allocator, - /*output_profile_result=*/nullptr); - } else { - launch_status = DoBackward( - context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, - input_c, params, output, output_h, output_c, output_backprop, - output_h_backprop, output_c_backprop, reserve_space, input_backprop, - input_h_backprop, input_c_backprop, params_backprop, nullptr, - true, &workspace_allocator, - /*output_profile_result=*/nullptr); - } + launch_status = DoBackward( + context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h, + input_c, params, output, output_h, output_c, output_backprop, + output_h_backprop, output_c_backprop, reserve_space, input_backprop, + input_h_backprop, input_c_backprop, params_backprop, + sequence_lengths, time_major, &workspace_allocator, + /*output_profile_result=*/nullptr); } OP_REQUIRES_OK(context, launch_status); } From 9c27326b98f975dd04908defd9c6925e9263aa78 Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 8 Feb 2019 13:54:43 -0800 Subject: [PATCH 03/11] Add support for time_major=false and sequence_lengths=None for rnnV3 --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 20 ------------------- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 9 +++++++++ 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index ac3bb72cb13..3776afb8f45 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -376,8 +376,6 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -399,8 +397,6 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -424,8 +420,6 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -455,8 +449,6 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -491,8 +483,6 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") # Hand-picked dropouts are used below (0. and 1.) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -739,8 +729,6 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -762,8 +750,6 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") self._test_training_helper( num_units, input_size, @@ -787,8 +773,6 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -814,8 +798,6 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, variable_seq_lengths, time_major): if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -846,8 +828,6 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): self.skipTest("No GPUs found") - if not variable_seq_lengths and not time_major: - self.skipTest("Batch major not supported") with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: # 1st time w/o dropout. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 6ccba944de2..86b48435d11 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import tracking as checkpointable_lib +import numpy as np CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -1027,6 +1028,14 @@ def _cudnn_rnn(inputs, args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif time_major is False: + max_time = inputs.get_shape().as_list()[1] + batch_size = inputs.get_shape().as_list()[0] + lengths = np.repeat(max_time, batch_size) + sequence_lengths = ops.convert_to_tensor(lengths.astype(np.int32)) + args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: From 0ef5f3fd6528f932b8482052dd3687edc0352876 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Tue, 19 Feb 2019 17:31:11 -0800 Subject: [PATCH 04/11] support dynamic shaped input for time_major rnn and other minor changes --- tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py | 6 +++--- tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 7 +++---- tensorflow/core/framework/op_kernel.cc | 6 +----- tensorflow/core/framework/op_kernel.h | 2 +- tensorflow/core/kernels/cudnn_rnn_ops.cc | 8 ++++---- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index fdd1c0b694a..c6ce631b792 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -383,9 +383,9 @@ class _CudnnRNN(base_layer.Layer): """Runs the forward step for the RNN model. Args: - inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` if - `time_major` is True (default) or `[batch_size, time_len, input_size]` - if `time_major` is False. + inputs: `3-D` tensor. If `time_major` is True (default), the Tensor shape + is [time_len, batch_size, input_size]. If `time_major` is False, the + shape is [batch_size, time_len, input_size]. initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, batch_size, num_units]` if `time_major` is True (default) or `[batch_size, num_layers * num_dirs, diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index aa99ebfd6b4..92f569a1678 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -1029,10 +1029,9 @@ def _cudnn_rnn(inputs, args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif time_major is False: - max_time = inputs.get_shape().as_list()[1] - batch_size = inputs.get_shape().as_list()[0] - lengths = np.repeat(max_time, batch_size) - sequence_lengths = ops.convert_to_tensor(lengths.astype(np.int32)) + batch_size = array_ops.shape(inputs)[0] + max_time = array_ops.shape(inputs)[1] + sequence_lengths = array_ops.fill([batch_size], max_time) args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 5fca492b131..9883e2bbf73 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -135,11 +135,7 @@ const string& OpKernel::requested_input(int i) const { return def_->input(i); } bool OpKernel::HasInput(StringPiece input_name) const { const auto result = input_name_map_.find(input_name); - if (result == input_name_map_.end()) { - return false; - } else { - return true; - } + return result != input_name_map_.end(); } Status OpKernel::InputRange(StringPiece input_name, int* start, diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 62589c35039..b87da60fc90 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -759,7 +759,7 @@ class OpKernelContext { // TODO(mrry): Convert this to return Status. bool has_input(int index) const; - // Return true if there is input under the given name. + // Return true if there exists an input. bool has_input(StringPiece name) const; // Returns true if all inputs are the same shape, otherwise sets the diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 17b10570c6a..75ceb655017 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -563,8 +563,8 @@ Status ExtractForwardInput(OpKernelContext* context, bool time_major, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, - CudnnRnnModelShapes* model_shapes, - const Tensor** sequence_lengths) { + const Tensor** sequence_lengths, + CudnnRnnModelShapes* model_shapes) { if (context->has_input("sequence_lengths")) { TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); } @@ -1276,7 +1276,7 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), time_major, &input, &input_h, &input_c, ¶ms, - &model_shapes, &sequence_lengths)); + &sequence_lengths, &model_shapes)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, @@ -1655,7 +1655,7 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK( context, ExtractForwardInput(context, model_types(), time_major, &input, &input_h, &input_c, ¶ms, - &model_shapes, &sequence_lengths)); + &sequence_lengths, &model_shapes)); RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, From 2b06ba7d602472b801b2c51e0afdf6709aa7eb9f Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 20 Feb 2019 11:37:33 -0800 Subject: [PATCH 05/11] Added test cases for the dynamic inputs for cudnn-rnn + time_major --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 116 ++++++++++++------ 1 file changed, 80 insertions(+), 36 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 3776afb8f45..b6162927225 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -70,6 +70,7 @@ def RunLSTM(sess, num_layers=1, variable_seq_lengths=False, time_major=True, + dynamic_shape_input=False, is_training=True, dropout=0., num_dirs=True, @@ -87,10 +88,14 @@ def RunLSTM(sess, shape = ([time, batch_size, input_size] if time_major else [batch_size, time, input_size]) - inputs = variable_scope.get_variable( + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( "inputs", - initializer=np.random.rand(*shape).astype(dtype.as_numpy_dtype), + initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder(dtype, shape=[None, None, None], + name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -124,7 +129,7 @@ def RunLSTM(sess, cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), @@ -159,7 +164,7 @@ def RunLSTM(sess, if is_training: (inp_grad_op, hgrad_op, cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) (cu_inp_grad_op, cu_hgrad_op, cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( @@ -191,7 +196,7 @@ def RunLSTM(sess, cu_bgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ]) + ], feed_dict={inputs:inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -210,7 +215,8 @@ def RunLSTM(sess, cu_bgrad) else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) - cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], + feed_dict={inputs:inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -342,6 +348,7 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): dtype, variable_seq_lengths, time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -354,7 +361,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -369,11 +377,12 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -384,17 +393,20 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float32, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -407,17 +419,19 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): rtol=5e-3, atol=5e-4, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -430,7 +444,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) # h @@ -442,11 +457,13 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -460,7 +477,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dtype=dtypes.float16, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -475,11 +493,13 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") @@ -497,7 +517,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dropout=0., variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -511,7 +532,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dropout=1., variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -529,6 +551,7 @@ def RunGRU(sess, is_training=True, variable_seq_lengths=False, time_major=True, + dynamic_shape_input=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -545,10 +568,14 @@ def RunGRU(sess, shape = ([time, batch_size, input_size] if time_major else [batch_size, time, input_size]) - inputs = variable_scope.get_variable( + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( "inputs", - initializer=np.random.rand(*shape).astype(dtype.as_numpy_dtype), + initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder(dtype, shape=[None, None, None], + name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -593,7 +620,7 @@ def RunGRU(sess, cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) outputs_op, h_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, @@ -624,7 +651,7 @@ def RunGRU(sess, if is_training: (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, cib_grad_op, chb_grad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op] + ws + bs) + outputs_op, [inputs_static, initial_h_op] + ws + bs) (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) @@ -653,7 +680,7 @@ def RunGRU(sess, cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) - ]) + ], feed_dict={inputs:inputs_np} if dynamic_shape_input else None) # Remove the trivial 1st dimension cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) @@ -673,7 +700,8 @@ def RunGRU(sess, cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) else: outputs, h = sess.run([outputs_op, h_op]) - cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op], + feed_dict={inputs:inputs_np} if dynamic_shape_input else None) # Remove the trivial 1st dimension. cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) @@ -695,6 +723,7 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): dtype, variable_seq_lengths, time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -707,7 +736,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -722,11 +752,12 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -737,17 +768,20 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float32, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -760,17 +794,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): rtol=5e-3, atol=5e-4, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major): + variable_seq_lengths, time_major, dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -783,7 +819,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) @@ -791,11 +828,13 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -809,7 +848,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dtype=dtypes.float16, variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -819,11 +859,13 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ "variable_seq_lengths": [True, False], "time_major": [True, False], + "dynamic_shape_input": [True, False], })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths, time_major): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): @@ -841,7 +883,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dropout=0., variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -855,7 +898,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): is_training=False, dropout=1., variable_seq_lengths=variable_seq_lengths, - time_major=time_major) + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) From a652026afa2f4452cd5306d95f815892a462a580 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 20 Feb 2019 17:34:48 -0800 Subject: [PATCH 06/11] changed some python formats to align with the requirements --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 10 ++++++---- .../contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 1 - 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index b6162927225..082741873e9 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -196,7 +196,7 @@ def RunLSTM(sess, cu_bgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ], feed_dict={inputs:inputs_np} if dynamic_shape_input else None) + ], feed_dict={inputs: inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -216,7 +216,8 @@ def RunLSTM(sess, else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], - feed_dict={inputs:inputs_np} if dynamic_shape_input else None) + feed_dict=({inputs: inputs_np} if + dynamic_shape_input else None)) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -680,7 +681,7 @@ def RunGRU(sess, cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) - ], feed_dict={inputs:inputs_np} if dynamic_shape_input else None) + ], feed_dict={inputs: inputs_np} if dynamic_shape_input else None) # Remove the trivial 1st dimension cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) @@ -701,7 +702,8 @@ def RunGRU(sess, else: outputs, h = sess.run([outputs_op, h_op]) cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op], - feed_dict={inputs:inputs_np} if dynamic_shape_input else None) + feed_dict=({inputs: inputs_np} if + dynamic_shape_input else None)) # Remove the trivial 1st dimension. cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index ab0e75eb28b..59d93f49837 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -34,7 +34,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver from tensorflow.python.training.tracking import tracking as trackable_lib -import numpy as np CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" From c3b912aede7878a50ed9baab4101346cbf83fd1e Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Thu, 21 Feb 2019 15:46:19 -0800 Subject: [PATCH 07/11] minor format updates --- .../cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 082741873e9..93bfdf8f05a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -217,7 +217,8 @@ def RunLSTM(sess, outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], feed_dict=({inputs: inputs_np} if - dynamic_shape_input else None)) + dynamic_shape_input else + None)) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -703,7 +704,7 @@ def RunGRU(sess, outputs, h = sess.run([outputs_op, h_op]) cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op], feed_dict=({inputs: inputs_np} if - dynamic_shape_input else None)) + dynamic_shape_input else None)) # Remove the trivial 1st dimension. cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) From 1d71d4c357bd848e28b2ca7b1143bce634e5869d Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 27 Feb 2019 14:17:16 -0800 Subject: [PATCH 08/11] Update the api def for cuDNNRNN v3 --- .../base_api/api_def_CudnnRNNBackpropV3.pbtxt | 16 +++++++++++----- .../api_def/base_api/api_def_CudnnRNNV3.pbtxt | 16 +++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt index 7967ca7c5d1..03dc530fc58 100644 --- a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV3.pbtxt @@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be dropout: Dropout probability. When set to 0., dropout is disabled. seed: The 1st part of a seed to initialize dropout. seed2: The 2nd part of a seed to initialize dropout. -input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, - num_units]. +input: If time_major is true, this is a 3-D tensor with the shape of + [seq_length, batch_size, input_size]. If time_major is false, the shape is + [batch_size, seq_length, input_size]. +input_h: If time_major is true, this is a 3-D tensor with the shape of + [num_layer * dir, batch_size, num_units]. If time_major is false, the shape + is [batch_size, num_layer * dir, num_units]. input_c: For LSTM, a 3-D tensor with the shape of [num_layer * dir, batch, num_units]. For other models, it is ignored. params: A 1-D tensor that contains the weights and biases in an opaque layout. @@ -26,8 +29,9 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout. separately. Note that they might not be compatible across different generations. So it is a good idea to save and restore sequence_lengths: a vector of lengths of each input sequence. -output: A 3-D tensor with the shape of [seq_length, batch_size, - dir * num_units]. +output: If time_major is true, this is a 3-D tensor with the shape of + [seq_length, batch_size, dir * num_units]. If time_major is false, the + shape is [batch_size, seq_length, dir * num_units]. output_h: The same shape has input_h. output_c: The same shape as input_c for LSTM. An empty tensor for other models. output_backprop: A 3-D tensor with the same shape as output in the forward pass. @@ -35,6 +39,8 @@ output_h_backprop: A 3-D tensor with the same shape as output_h in the forward pass. output_c_backprop: A 3-D tensor with the same shape as output_c in the forward pass. +time_major: Indicates whether the input/output format is time major or batch + major. reserve_space: The same reserve_space produced in the forward operation. input_backprop: The backprop to input in the forward pass. Has the same shape as input. diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV3.pbtxt index 9cde53684d0..e076d3cda28 100644 --- a/tensorflow/core/api_def/base_api/api_def_CudnnRNNV3.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV3.pbtxt @@ -16,9 +16,12 @@ direction: Indicates whether a bidirectional model will be used. Should be dropout: Dropout probability. When set to 0., dropout is disabled. seed: The 1st part of a seed to initialize dropout. seed2: The 2nd part of a seed to initialize dropout. -input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, - num_units]. +input: If time_major is true, this is a 3-D tensor with the shape of + [seq_length, batch_size, input_size]. If time_major is false, the shape is + [batch_size, seq_length, input_size]. +input_h: If time_major is true, this is a 3-D tensor with the shape of + [num_layer * dir, batch_size, num_units]. If time_major is false, the shape + is [batch_size, num_layer * dir, num_units]. input_c: For LSTM, a 3-D tensor with the shape of [num_layer * dir, batch, num_units]. For other models, it is ignored. params: A 1-D tensor that contains the weights and biases in an opaque layout. @@ -26,12 +29,15 @@ params: A 1-D tensor that contains the weights and biases in an opaque layout. separately. Note that they might not be compatible across different generations. So it is a good idea to save and restore sequence_lengths: a vector of lengths of each input sequence. -output: A 3-D tensor with the shape of [seq_length, batch_size, - dir * num_units]. +output: If time_major is true, this is a 3-D tensor with the shape of + [seq_length, batch_size, dir * num_units]. If time_major is false, the + shape is [batch_size, seq_length, dir * num_units]. output_h: The same shape has input_h. output_c: The same shape as input_c for LSTM. An empty tensor for other models. is_training: Indicates whether this operation is used for inferenece or training. +time_major: Indicates whether the input/output format is time major or batch + major. reserve_space: An opaque tensor that can be used in backprop calculation. It is only produced if is_training is true. END From 75857977892c2ca4d4b4c56876b8af49e0e37dbb Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Thu, 28 Feb 2019 11:05:43 -0800 Subject: [PATCH 09/11] remove the changes on op_kernel --- tensorflow/core/framework/op_kernel.cc | 9 ----- tensorflow/core/framework/op_kernel.h | 4 -- tensorflow/core/kernels/cudnn_rnn_ops.cc | 48 ++++++++++++++++++------ 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 8775dfc6df6..c0b81e8538f 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -133,11 +133,6 @@ const string& OpKernel::requested_input(int i) const { return def_->input(i); } return device->attributes().locality().numa_node(); } -bool OpKernel::HasInput(StringPiece input_name) const { - const auto result = input_name_map_.find(input_name); - return result != input_name_map_.end(); -} - Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { const auto result = input_name_map_.find(input_name); @@ -350,10 +345,6 @@ void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { referenced_tensors_->Add(tensor); } -bool OpKernelContext::has_input(StringPiece name) const { - return params_->op_kernel->HasInput(name); -} - Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 243711ecf58..ff0b44650ce 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -169,7 +169,6 @@ class OpKernel { return output_memory_types_; } - bool HasInput(StringPiece input_name) const; Status InputRange(StringPiece input_name, int* start, int* stop) const; Status OutputRange(StringPiece output_name, int* start, int* stop) const; @@ -790,9 +789,6 @@ class OpKernelContext { // TODO(mrry): Convert this to return Status. bool has_input(int index) const; - // Return true if there exists an input. - bool has_input(StringPiece name) const; - // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. // Usage: if (!context->ValidateInputsAreSameShape(this)) return; diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 75ceb655017..55eef0e87fd 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -563,11 +563,7 @@ Status ExtractForwardInput(OpKernelContext* context, bool time_major, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, - const Tensor** sequence_lengths, CudnnRnnModelShapes* model_shapes) { - if (context->has_input("sequence_lengths")) { - TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); - } TF_RETURN_IF_ERROR(context->input("input", input)); TF_RETURN_IF_ERROR(context->input("input_h", input_h)); if (model_types.HasInputC()) { @@ -637,6 +633,20 @@ Status ExtractForwardInput(OpKernelContext* context, return Status::OK(); } +// Overloaded function to process the sequence_lengths +Status ExtractForwardInput(OpKernelContext* context, + const CudnnModelTypes& model_types, + bool time_major, + const Tensor** input, const Tensor** input_h, + const Tensor** input_c, const Tensor** params, + const Tensor** sequence_lengths, + CudnnRnnModelShapes* model_shapes) { + TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); + return ExtractForwardInput(context, model_types, time_major, input, input_h, + input_c, params, model_shapes); +} + + template Status CreateForwardAndBackwardIODescriptors( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, @@ -1273,10 +1283,17 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { const Tensor* params = nullptr; const Tensor* sequence_lengths = nullptr; CudnnRnnModelShapes model_shapes; - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), time_major, - &input, &input_h, &input_c, ¶ms, - &sequence_lengths, &model_shapes)); + if (var_seq_lengths) { + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &sequence_lengths, &model_shapes)); + } else { + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &model_shapes)); + } RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, @@ -1652,10 +1669,17 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { const Tensor* params = nullptr; const Tensor* sequence_lengths = nullptr; CudnnRnnModelShapes model_shapes; - OP_REQUIRES_OK( - context, ExtractForwardInput(context, model_types(), time_major, - &input, &input_h, &input_c, ¶ms, - &sequence_lengths, &model_shapes)); + if (var_seq_lengths) { + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &sequence_lengths, &model_shapes)); + } else { + OP_REQUIRES_OK( + context, ExtractForwardInput(context, model_types(), time_major, + &input, &input_h, &input_c, ¶ms, + &model_shapes)); + } RnnInputMode input_mode; OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, From c6f7244ce8b11e4e2fa07b2c6232d76219ce51aa Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Thu, 28 Feb 2019 17:15:08 -0800 Subject: [PATCH 10/11] remove an empty line --- tensorflow/core/kernels/cudnn_rnn_ops.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 55eef0e87fd..a8eff634b67 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -646,7 +646,6 @@ Status ExtractForwardInput(OpKernelContext* context, input_c, params, model_shapes); } - template Status CreateForwardAndBackwardIODescriptors( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, From 40df70df7450a0d83645455d9db7d5d332ea192b Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 6 Mar 2019 20:45:51 -0800 Subject: [PATCH 11/11] goldens are changed to include the new APIs --- tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 4 ++-- tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 86c62b3a95c..0787c920e75 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -722,7 +722,7 @@ tf_module { } member_method { name: "CudnnRNNBackpropV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None" } member_method { name: "CudnnRNNCanonicalToParams" @@ -742,7 +742,7 @@ tf_module { } member_method { name: "CudnnRNNV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None" } member_method { name: "Cumprod" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 86c62b3a95c..0787c920e75 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -722,7 +722,7 @@ tf_module { } member_method { name: "CudnnRNNBackpropV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\'], varargs=None, keywords=None, defaults=None" } member_method { name: "CudnnRNNCanonicalToParams" @@ -742,7 +742,7 @@ tf_module { } member_method { name: "CudnnRNNV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\'], varargs=None, keywords=None, defaults=None" } member_method { name: "Cumprod"