From 9ff27787893f76d6971dcd1552eb5270d254f31b Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 10 Apr 2019 21:43:35 -0700 Subject: [PATCH 01/17] add cudnn lstm projection (lstmp) --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 291 +++++++++++---- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 298 +++++++++++---- tensorflow/core/kernels/cudnn_rnn_ops.cc | 350 +++++++++++++----- tensorflow/core/ops/cudnn_rnn_ops.cc | 65 +++- tensorflow/python/ops/cudnn_rnn_grad.py | 1 + tensorflow/stream_executor/cuda/cuda_dnn.cc | 110 +++++- tensorflow/stream_executor/cuda/cuda_dnn.h | 10 +- tensorflow/stream_executor/dnn.h | 2 +- .../stream_executor/stream_executor_pimpl.cc | 4 +- .../stream_executor/stream_executor_pimpl.h | 10 +- .../tools/api/golden/v1/tensorflow.pbtxt | 8 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 14 +- .../tools/api/golden/v2/tensorflow.pbtxt | 8 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 14 +- .../api/golden/v2/tensorflow.summary.pbtxt | 10 +- 15 files changed, 914 insertions(+), 281 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 5c63ee7a97b..47ae2765822 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -71,7 +71,8 @@ def RunLSTM(sess, is_training=True, dropout=0., num_dirs=True, - dtype=dtypes.float32): + dtype=dtypes.float32, + num_proj=None): # TODO(jamesqin): add multi-layer tests. # TODO(jamesqin): add multi-dir tests assert num_layers == 1 @@ -94,7 +95,8 @@ def RunLSTM(sess, initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, - num_units).astype(dtype.as_numpy_dtype), + num_proj if num_proj else num_units) + .astype(dtype.as_numpy_dtype), dtype=dtype) initial_c_op = variable_scope.get_variable( "initial_c_op", @@ -115,13 +117,19 @@ def RunLSTM(sess, with variable_scope.variable_scope("test", initializer=initializer): w = variable_scope.get_variable( "rnn/lstm_cell/kernel", - shape=[input_size + num_units, num_units * 4], + shape=[input_size + (num_proj if num_proj else num_units), + num_units * 4], dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + if num_proj: + pw = variable_scope.get_variable( + "rnn/lstm_cell/projection/kernel", + shape=[num_units, num_proj], dtype=dtype) # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. - cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True, + num_proj=num_proj if num_proj else None) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs_static, @@ -134,8 +142,12 @@ def RunLSTM(sess, # Convert to cudnn opaque param. format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( - num_layers, num_units, input_size) - opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + num_layers, num_units, input_size, + num_proj=num_proj if num_proj else None) + if num_proj: + opaque_params = format_converter.tf_canonical_to_opaque([w, b], [pw,]) + else: + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) cu_initial_h_op = array_ops.expand_dims( initial_h_op, axis=(0 if time_major else 1)) @@ -150,16 +162,22 @@ def RunLSTM(sess, time_major=time_major, dropout=dropout, is_training=is_training, - rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM, + num_proj=num_proj if num_proj else None) # Remove the trivial 1st dimension. cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1)) if is_training: - (inp_grad_op, hgrad_op, - cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( - outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) + if num_proj: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op, pwgrad_op) = gradients_impl.gradients( + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b, pw]) + else: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) (cu_inp_grad_op, cu_hgrad_op, cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( @@ -170,10 +188,16 @@ def RunLSTM(sess, # Remove the trivial 1st dimension cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1) - cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( - opaque_grad_op) + if num_proj: + cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op = \ + format_converter.opaque_to_tf_canonical(opaque_grad_op) + else: + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) cu_wgrad_op = cu_wgrad_op[0] cu_bgrad_op = cu_bgrad_op[0] + if num_proj: + cu_pwgrad_op = cu_pwgrad_op[0] # cudnn lstm has 2 biases each gate. When converting to tf canonical format, # the two biases are summed into one. Thus here bias gradient should be # halved when comparing with tf lstm. @@ -183,18 +207,30 @@ def RunLSTM(sess, sess.run(init_op) if is_training: - outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ - outputs_op, state_tuple_op, inp_grad_op, - (hgrad_op, cgrad_op), wgrad_op, bgrad_op - ]) - (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, - cu_bgrad) = sess.run( - [ - cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) - + if num_proj: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad, pwgrad = \ + sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op, pwgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad, cu_pwgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + else: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) logging.vlog(1, "state_tuple: %s" % str(state_tuple)) @@ -205,11 +241,20 @@ def RunLSTM(sess, logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) logging.vlog(1, "wgrad: %s" % str(wgrad)) logging.vlog(1, "bgrad: %s" % str(bgrad)) + if num_proj: + logging.vlog(1, "pwgrad: %s" % str(bgrad)) logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) - return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, - cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, - cu_bgrad) + if num_proj: + logging.vlog(1, "cu_pwgrad: %s" % str(cu_bgrad)) + if num_proj: + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, + cu_wgrad, cu_bgrad, cu_pwgrad) + else: + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], @@ -226,35 +271,43 @@ def RunLSTM(sess, # Basic set of RNN configs to test. They can be further extended in relevant # test (e.g. adding num_dirs). +#NAMED_RNN_TESTCASES = ({ +# "testcase_name": "xsmall", +# "num_units": 1, +# "input_size": 1, +# "batch_size": 1, +# "time": 1, +# "num_layers": 1, +#}, { +# "testcase_name": "small", +# "num_units": 4, +# "input_size": 4, +# "batch_size": 4, +# "time": 4, +# "num_layers": 1, +#}, { +# "testcase_name": "medium", +# "num_units": 128, +# "input_size": 64, +# "batch_size": 8, +# "time": 16, +# "num_layers": 1, +#}, { +# "testcase_name": "large", +# "num_units": 128, +# "input_size": 128, +# "batch_size": 16, +# "time": 32, +# "num_layers": 1, +#}) NAMED_RNN_TESTCASES = ({ - "testcase_name": "xsmall", - "num_units": 1, - "input_size": 1, - "batch_size": 1, - "time": 1, - "num_layers": 1, -}, { "testcase_name": "small", "num_units": 4, "input_size": 4, "batch_size": 4, "time": 4, "num_layers": 1, -}, { - "testcase_name": "medium", - "num_units": 128, - "input_size": 64, - "batch_size": 8, - "time": 16, - "num_layers": 1, -}, { - "testcase_name": "large", - "num_units": 128, - "input_size": 128, - "batch_size": 16, - "time": 32, - "num_layers": 1, -}) +}, ) def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): @@ -349,19 +402,22 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): time_major, dynamic_shape_input=False, rtol=3e-6, - atol=3e-6): + atol=3e-6, + num_proj=None): with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, - state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + if num_proj is not None and num_proj != 0: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, cu_wgrad, + cu_bgrad, cu_pwgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths=variable_seq_lengths, + dynamic_shape_input=dynamic_shape_input, num_proj=num_proj) + else: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) = RunLSTM(sess, num_units, input_size, batch_size, time, + num_layers, variable_seq_lengths=variable_seq_lengths, + dynamic_shape_input=dynamic_shape_input, num_proj=num_proj) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -371,6 +427,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) + if num_proj is not None and num_proj != 0: + self.assertAllClose(pwgrad, cu_pwgrad, rtol=rtol, atol=atol) @parameterized.named_parameters( ExpandNamedTestCases( @@ -378,10 +436,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major, dynamic_shape_input): + variable_seq_lengths, time_major, dynamic_shape_input, + use_proj): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") self._test_training_helper( num_units, input_size, @@ -391,7 +454,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): dtypes.float32, variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -399,11 +463,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_training_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): + dynamic_shape_input, use_proj): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") self._test_training_helper( num_units, input_size, @@ -415,7 +483,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): atol=5e-4, variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -423,10 +492,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major, dynamic_shape_input): + variable_seq_lengths, time_major, dynamic_shape_input, + use_proj): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -438,7 +512,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): is_training=False, variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) self.assertAllClose(outputs, cu_outputs) # h @@ -452,11 +527,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_inference_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): + dynamic_shape_input, use_proj): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -469,7 +548,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): dtype=dtypes.float16, variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -486,12 +566,16 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_inference_with_dropout(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): + dynamic_shape_input, use_proj): """Validates that dropout does not affect Cudnn Rnn inference.""" + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") # Hand-picked dropouts are used below (0. and 1.) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -507,7 +591,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): dropout=0., variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -522,7 +607,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): dropout=1., variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -890,40 +976,61 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, parameterized.TestCase): """Class for testing various format converters.""" - def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + def _test_lstm_helper(self, num_units, input_size, num_layers, direction, + num_proj=None): with self.session(use_gpu=True) as sess: random_seed.set_random_seed(0) np.random.seed(0) num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( - num_layers, num_units, input_size, direction=direction) + num_layers, num_units, input_size, direction=direction, + num_proj=num_proj if num_proj else None) - ws, bs = [], [] + ws, bs, pws = [], [], [] for _ in range(num_layers * num_dirs): w = constant_op.constant( - np.random.rand(input_size + num_units, 4 * num_units), + np.random.rand(input_size + (num_proj if num_proj else num_units), + 4 * num_units), dtype=dtypes.float32) b = constant_op.constant( np.random.rand(4 * num_units), dtype=dtypes.float32) ws.append(w) bs.append(b) + if num_proj: + pw = constant_op.constant( + np.random.rand(num_units, num_proj), dtype=dtypes.float32) + pws.append(pw) + + if num_proj: + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs, pws) + else: + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( cudnn_rnn_ops.CUDNN_LSTM, num_layers, num_units, input_size, - direction=direction) + direction=direction, + num_proj=num_proj if num_proj else None) - ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + if num_proj: + ws_r, bs_r, pws_r = format_converter.opaque_to_tf_canonical( + opaque_params) + ws, ws_r, pws, bs, bs_r, pws_r = sess.run([ws, ws_r, pws, bs, bs_r, + pws_r]) + else: + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() # returns the original input. - ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) for w, w_r in zip(ws, ws_r): self.assertAllClose(w, w_r) + if num_proj: + for pw, pw_r in zip(pws, pws_r): + self.assertAllClose(pw, pw_r) for b, b_r in zip(bs, bs_r): self.assertAllClose(b, b_r) @@ -942,6 +1049,18 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @test_util.run_gpu_only + def test_lstmp(self, num_units, input_size, num_layers): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, + num_proj=num_proj) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) @@ -950,6 +1069,18 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @test_util.run_gpu_only + def test_lstmp_bidi(self, num_units, input_size, num_layers): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION, + num_proj=num_proj) + def _test_gru_helper(self, num_units, input_size, num_layers, direction): with self.session(use_gpu=True) as sess: random_seed.set_random_seed(0) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index fd3dc975779..c3bcd02fdda 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -184,6 +184,7 @@ class CudnnParamsFormatConverter(object): num_layers, num_units, input_size, + num_proj=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION): """Constructor. @@ -193,6 +194,8 @@ class CudnnParamsFormatConverter(object): num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the num_units. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be one of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' @@ -207,14 +210,16 @@ class CudnnParamsFormatConverter(object): self._input_size = input_size self._num_units = num_units self._input_mode = input_mode + self._num_proj = num_proj self._direction = direction self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - def tf_canonical_to_opaque(self, tf_canonicals): + def tf_canonical_to_opaque(self, tf_canonicals, weights_proj=None): r"""Converts tf canonical weights to cudnn opaque param.""" - cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals, + weights_proj) cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) return opaque_params @@ -222,8 +227,14 @@ class CudnnParamsFormatConverter(object): def opaque_to_tf_canonical(self, opaque_param): r"""Converts cudnn opaque param to tf canonical weights.""" cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) - weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) - return weights, biases + if self._num_proj: + weights, biases, weights_proj = self._cu_canonical_to_tf_canonical( + cu_weights, cu_biases) + return weights, biases, weights_proj + else: + weights, biases = self._cu_canonical_to_tf_canonical( + cu_weights, cu_biases) + return weights, biases def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. @@ -235,15 +246,31 @@ class CudnnParamsFormatConverter(object): 2 list for weights and biases respectively. """ with ops.device("/gpu:0"): - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if self._num_proj: + num_params_weights = (self._num_params + + 1 * self._num_layers * self._num_dirs) + num_params_biases = self._num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=self._num_proj) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) return (weights, biases) def _cu_canonical_to_opaque(self, cu_weights, cu_biases): @@ -256,16 +283,28 @@ class CudnnParamsFormatConverter(object): a single opaque tensor. """ with ops.device("/gpu:0"): - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) - + if self._num_proj: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + num_proj=self._num_proj, + direction=self._direction) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) + def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. @@ -289,9 +328,11 @@ class CudnnParamsFormatConverter(object): 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] + tf_weights_proj = [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num + layer_weights_num += (1 * self._num_dirs) if self._num_proj else 0 for i in range(self._num_layers): layer_weights = cu_weights[i * layer_weights_num:(i + 1) * @@ -299,7 +340,7 @@ class CudnnParamsFormatConverter(object): layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: self._cu_canonical_to_tf_canonical_single_layer( - layer_weights, layer_biases, tf_weights, tf_biases) + layer_weights, layer_biases, tf_weights, tf_biases, tf_weights_proj) else: fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] @@ -311,6 +352,7 @@ class CudnnParamsFormatConverter(object): fw_biases, tf_weights, tf_biases, + tf_weights_proj, ) self._cu_canonical_to_tf_canonical_single_layer( @@ -318,11 +360,16 @@ class CudnnParamsFormatConverter(object): bw_biases, tf_weights, tf_biases, + tf_weights_proj, ) - return (tf_weights, tf_biases) + if self._num_proj: + return (tf_weights, tf_biases, tf_weights_proj) + else: + return (tf_weights, tf_biases) def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): + tf_weights, tf_biases, + tf_weigths_proj=None): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -337,7 +384,7 @@ class CudnnParamsFormatConverter(object): """ raise NotImplementedError("Abstract method") - def _tf_canonical_to_cu_canonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals, weights_proj): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -356,6 +403,7 @@ class CudnnParamsFormatConverter(object): --------------- |fwd |bak | --------------- + weights_proj: (optional) weights matrices for projection Returns: 2 lists: the recovered cudnn canonical weights and biases. """ @@ -370,6 +418,9 @@ class CudnnParamsFormatConverter(object): layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights)) + if weights_proj is not None: + pw = array_ops.transpose(weights_proj[i]) + cu_weights.append(pw) cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases)) else: fw_weights, bw_weights = layer_weights[:len( @@ -377,9 +428,15 @@ class CudnnParamsFormatConverter(object): fw_biases, bw_biases = layer_biases[:len( layer_biases) // 2], layer_biases[len(layer_biases) // 2:] cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights)) + if weights_proj is not None: + pw0 = array_ops.transpose(weights_proj[2*i+0]) + cu_weights.append(pw0) cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases)) cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights)) + if weights_proj is not None: + pw1 = array_ops.transpose(weights_proj[2*i+1]) + cu_weights.append(pw1) cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases)) return cu_weights, cu_biases @@ -415,7 +472,10 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" - w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights + if self._num_proj: + w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o, pw = cu_weights + else: + w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights # pylint: disable=invalid-name W_i = array_ops.concat([w_i, r_i], axis=1) @@ -425,7 +485,11 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): # pylint: enable=invalid-name # Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order. reordered = self._cudnn_to_tf_gate_params(* [W_i, W_f, W_c, W_o]) - return (array_ops.transpose(array_ops.concat(reordered, axis=0)),) + if self._num_proj: + return (array_ops.transpose(array_ops.concat(reordered, axis=0)), + array_ops.transpose(pw)) + else: + return (array_ops.transpose(array_ops.concat(reordered, axis=0)),) def _tf_to_cudnn_weights(self, layer, *tf_weights): r"""Reverse the operations in StitchWeights().""" @@ -434,7 +498,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): if layer == 0: input_weight_width = input_size else: - input_weight_width = num_units + input_weight_width = self._num_proj if self._num_proj else num_units if self._direction == CUDNN_RNN_BIDIRECTION: input_weight_width *= 2 @@ -444,10 +508,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(*array_ops.split( w, 4, axis=0)) - w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1) - w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1) - w_f, r_f = array_ops.split(W_f, [input_weight_width, num_units], axis=1) - w_o, r_o = array_ops.split(W_o, [input_weight_width, num_units], axis=1) + hidden_state_width = self._num_proj if self._num_proj else num_units + w_i, r_i = array_ops.split(W_i, [input_weight_width, hidden_state_width], + axis=1) + w_c, r_c = array_ops.split(W_c, [input_weight_width, hidden_state_width], + axis=1) + w_f, r_f = array_ops.split(W_f, [input_weight_width, hidden_state_width], + axis=1) + w_o, r_o = array_ops.split(W_o, [input_weight_width, hidden_state_width], + axis=1) return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o # pylint: enable=invalid-name @@ -483,10 +552,16 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): - (w,) = self._cudnn_to_tf_weights(*cu_weights) + tf_weights, tf_biases, + tf_weights_proj=None): + if self._num_proj: + (w, pw) = self._cudnn_to_tf_weights(*cu_weights) + tf_weights.append(w) + tf_weights_proj.append(pw) + else: + (w,) = self._cudnn_to_tf_weights(*cu_weights) + tf_weights.append(w) (b,) = self._cudnn_to_tf_biases(*cu_biases) - tf_weights.append(w) tf_biases.append(b) @@ -554,7 +629,8 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): + tf_weights, tf_biases, + tf_weights_proj=None): # pylint: disable=invalid-name W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) @@ -727,8 +803,9 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): def format_converter(self): if self._format_converter is None: self._format_converter = self._format_converter_cls( - self._num_layers, self._num_units, self._input_size, self._input_mode, - self._direction) + self._num_layers, self._num_units, self._input_size, + input_mode=self._input_mode, + direction=self._direction) return self._format_converter def restore(self, restored_tensors, restored_shapes): @@ -962,6 +1039,7 @@ def _cudnn_rnn(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + num_proj=None, name=None): """Cudnn RNN. @@ -999,6 +1077,8 @@ def _cudnn_rnn(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: outputs, output_h, output_c @@ -1027,13 +1107,15 @@ def _cudnn_rnn(inputs, if sequence_lengths is not None: args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major + args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) - elif time_major is False: + elif time_major is False or num_proj: batch_size = array_ops.shape(inputs)[0] max_time = array_ops.shape(inputs)[1] sequence_lengths = array_ops.fill([batch_size], max_time) args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major + args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) @@ -1053,6 +1135,7 @@ def cudnn_lstm(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + num_proj=None, name=None): """Cudnn LSTM. @@ -1089,13 +1172,15 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, sequence_lengths, time_major, input_mode, direction, - dropout, seed, name) + dropout, seed, num_proj, name) def _cudnn_rnn_no_input_c(inputs, @@ -1151,7 +1236,7 @@ def _cudnn_rnn_no_input_c(inputs, input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn( inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, - time_major, input_mode, direction, dropout, seed, name) + time_major, input_mode, direction, dropout, seed, None, name) return outputs, output_h @@ -1322,6 +1407,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, direction=CUDNN_RNN_UNIDIRECTION, dropout=0, seed=0, + num_proj=None, name=None): """Convert cudnn opaque params to canonical. @@ -1346,6 +1432,8 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: weights list and bias list @@ -1358,19 +1446,39 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - rnn_mode=rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - params=params, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - num_params=num_params, - name=name) + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + if num_proj is not None and num_proj != 0: + num_params_weights = (num_params + 1 * num_layers * num_dirs) + num_params_biases = num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + params=params, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=num_proj, + name=name) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + params=params, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_params=num_params, + name=name) return weights, biases @@ -1384,6 +1492,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, direction=CUDNN_RNN_UNIDIRECTION, dropout=0, seed=0, + num_proj=None, name=None): """Converts params from the canonical format to a specific format of cuDNN. @@ -1409,6 +1518,8 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: an opaque Cudnn param. @@ -1419,20 +1530,35 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, check_direction(direction) check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - rnn_mode=rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - weights=weights, - biases=biases, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - name=name) - + if num_proj is not None and num_proj != 0: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + weights=weights, + biases=biases, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_proj=num_proj, + name=name) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + weights=weights, + biases=biases, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + name=name) def cudnn_rnn_opaque_params_size(rnn_mode, num_layers, @@ -1443,6 +1569,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, dtype=dtypes.float32, dropout=0, seed=0, + num_proj=None, name=None): """Returns opaque params size for specific Cudnn config. @@ -1467,6 +1594,8 @@ def cudnn_rnn_opaque_params_size(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: a int, size of Cudnn opaque params. @@ -1482,6 +1611,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, num_layers=num_layers, num_units=num_units, input_size=input_size, + num_proj=num_proj, T=dtype, S=dtypes.int32, dropout=dropout, @@ -1510,7 +1640,8 @@ class _CudnnRNN(object): direction=CUDNN_RNN_UNIDIRECTION, dtype=dtypes.float32, dropout=0., - seed=0): + seed=0, + num_proj=None): """Creates a CudnnRNN model from model spec. Args: @@ -1534,6 +1665,8 @@ class _CudnnRNN(object): dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. Raises: ValueError: if direction is invalid. """ @@ -1546,6 +1679,7 @@ class _CudnnRNN(object): self._dtype = dtype self._dropout = dropout self._seed = seed + self._num_proj = num_proj @property def input_mode(self): @@ -1571,6 +1705,10 @@ class _CudnnRNN(object): def direction(self): return self._direction + @property + def num_proj(self): + return self._num_proj + def params_size(self): """Calculates the size of the opaque parameter buffer needed for this model. @@ -1582,6 +1720,7 @@ class _CudnnRNN(object): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, + num_proj=self._num_proj, dtype=self._dtype, dropout=self._dropout, seed=self._seed, @@ -1637,7 +1776,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) def params_to_canonical(self, params): """Converts params from a specific format of cuDNN to the canonical format. @@ -1657,7 +1797,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) def canonical_to_params(self, weights, biases): """Converts params from the canonical format to a specific format of cuDNN. @@ -1679,7 +1820,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) class CudnnLSTM(_CudnnRNN): @@ -1697,7 +1839,8 @@ class CudnnLSTM(_CudnnRNN): direction=CUDNN_RNN_UNIDIRECTION, dtype=dtypes.float32, dropout=0., - seed=0): + seed=0, + num_proj=None): """Creates a Cudnn LSTM model from model spec. Args: @@ -1716,6 +1859,8 @@ class CudnnLSTM(_CudnnRNN): dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. """ super(CudnnLSTM, self).__init__( CUDNN_LSTM, @@ -1726,7 +1871,8 @@ class CudnnLSTM(_CudnnRNN): direction=direction, dtype=dtype, dropout=dropout, - seed=seed) + seed=seed, + num_proj=num_proj) def __call__(self, input_data, diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index d43fe747333..e70fa88b2cf 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -502,20 +502,23 @@ struct CudnnRnnModelShapes { int dir_count; int max_seq_length; int batch_size; + int c_num_units; TensorShape input_shape; TensorShape output_shape; TensorShape hidden_state_shape; + TensorShape c_state_shape; // At present only fields related to cached RnnDescriptor are concerned. bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { return num_layers == rhs.num_layers && input_size == rhs.input_size && - num_units == rhs.num_units && dir_count == rhs.dir_count; + num_units == rhs.num_units && dir_count == rhs.dir_count && + c_num_units == rhs.c_num_units; } string DebugString() const { return strings::Printf( "[num_layers, input_size, num_units, dir_count, max_seq_length, " - "batch_size]: [%d, %d, %d, %d, %d, %d] ", + "batch_size, c_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", num_layers, input_size, num_units, dir_count, max_seq_length, - batch_size); + batch_size, c_num_units); } }; @@ -562,6 +565,7 @@ Status ExtractForwardInput(OpKernelContext* context, const CudnnModelTypes& model_types, bool time_major, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, + const int num_proj, CudnnRnnModelShapes* model_shapes) { TF_RETURN_IF_ERROR(context->input("input", input)); TF_RETURN_IF_ERROR(context->input("input_h", input_h)); @@ -615,12 +619,48 @@ Status ExtractForwardInput(OpKernelContext* context, model_shapes->hidden_state_shape.DebugString()); } if (model_types.HasInputC()) { - if ((*input_h)->shape() != (*input_c)->shape()) { - return errors::InvalidArgument( - "input_h and input_c must have the same shape: ", - (*input_h)->shape().DebugString(), " ", - (*input_c)->shape().DebugString()); + model_shapes->c_num_units = (*input_c)->dim_size(2); + if (time_major) { + model_shapes->c_state_shape = + TensorShape({model_shapes->dir_count * model_shapes->num_layers, + model_shapes->batch_size, model_shapes->c_num_units}); + } else { + model_shapes->c_state_shape = + TensorShape({model_shapes->batch_size, + model_shapes->dir_count * model_shapes->num_layers, + model_shapes->c_num_units}); } + if (num_proj == 0) { + if ((*input_h)->shape() != (*input_c)->shape()) { + return errors::InvalidArgument( + "input_h and input_c must have the same shape w/o projection: ", + (*input_h)->shape().DebugString(), " ", + (*input_c)->shape().DebugString()); + } + } else { + if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) || + num_proj != (*input_h)->dim_size(2) || + (*input_h)->dim_size(0) != (*input_c)->dim_size(0) || + (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) { + return errors::InvalidArgument( + "Invalid input_h and input_c w/ projection size: ", num_proj, " ", + (*input_h)->shape().DebugString(), " ", + (*input_c)->shape().DebugString()); + } + } + } else { + // dummy c_state_shape TODO(kaixih): remove the time_major branch + if (time_major) { + model_shapes->c_state_shape = + TensorShape({model_shapes->dir_count * model_shapes->num_layers, + model_shapes->batch_size, model_shapes->num_units}); + } else { + model_shapes->c_state_shape = + TensorShape({model_shapes->batch_size, + model_shapes->dir_count * model_shapes->num_layers, + model_shapes->num_units}); + } + model_shapes->c_num_units = 0; } if (time_major) { model_shapes->output_shape = @@ -639,18 +679,19 @@ Status ExtractForwardInput(OpKernelContext* context, const CudnnModelTypes& model_types, bool time_major, const Tensor** input, const Tensor** input_h, const Tensor** input_c, const Tensor** params, - const Tensor** sequence_lengths, + const Tensor** sequence_lengths, const int num_proj, CudnnRnnModelShapes* model_shapes) { TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths)); return ExtractForwardInput(context, model_types, time_major, input, input_h, - input_c, params, model_shapes); + input_c, params, num_proj, model_shapes); } template <typename T> Status CreateForwardAndBackwardIODescriptors( OpKernelContext* context, const CudnnRnnModelShapes& model_shapes, std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc, - std::unique_ptr<RnnStateTensorDescriptor>* state_desc, + std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc, + std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc, std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc, const absl::Span<const int>& seq_lengths, bool time_major) { StreamExecutor* executor = context->op_device_context()->stream()->parent(); @@ -658,6 +699,7 @@ Status CreateForwardAndBackwardIODescriptors( const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; + const TensorShape& c_state_shape = model_shapes.c_state_shape; const TensorShape& output_shape = model_shapes.output_shape; DCHECK_EQ(input_shape.dims(), 3); @@ -689,13 +731,28 @@ Status CreateForwardAndBackwardIODescriptors( hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); - *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); } else { auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor( hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(2), data_type); TF_RETURN_IF_ERROR(hidden_state_desc_s.status()); - *state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); + } + + DCHECK_EQ(c_state_shape.dims(), 3); + if (time_major) { + auto c_state_desc_s = executor->createRnnStateTensorDescriptor( + c_state_shape.dim_size(0), c_state_shape.dim_size(1), + c_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(c_state_desc_s.status()); + *c_state_desc = c_state_desc_s.ConsumeValueOrDie(); + } else { + auto c_state_desc_s = executor->createRnnStateTensorDescriptor( + c_state_shape.dim_size(1), c_state_shape.dim_size(0), + c_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(c_state_desc_s.status()); + *c_state_desc = c_state_desc_s.ConsumeValueOrDie(); } DCHECK_EQ(output_shape.dims(), 3); @@ -739,7 +796,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; - std::unique_ptr<RnnStateTensorDescriptor> state_desc; + std::unique_ptr<RnnStateTensorDescriptor> h_state_desc; + std::unique_ptr<RnnStateTensorDescriptor> c_state_desc; std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; absl::Span<const int> seq_lengths; @@ -748,8 +806,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, sequence_lengths->template flat<int>().data(), model_shapes.batch_size); } TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( - context, model_shapes, &input_desc, &state_desc, &output_desc, - seq_lengths, time_major)); + context, model_shapes, &input_desc, &h_state_desc, &c_state_desc, + &output_desc, seq_lengths, time_major)); auto input_data = AsDeviceMemory<T>(input); auto input_h_data = AsDeviceMemory<T>(input_h); @@ -769,11 +827,11 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, Stream* stream = context->op_device_context()->stream(); bool launch_success = stream - ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc, - input_h_data, *state_desc, input_c_data, params_data, - *output_desc, &output_data, *state_desc, - &output_h_data, *state_desc, &output_c_data, - is_training, reserve_space_allocator, + ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, + params_data, *output_desc, &output_data, + *h_state_desc, &output_h_data, *c_state_desc, + &output_c_data, is_training, reserve_space_allocator, workspace_allocator, output_profile_result) .ok(); return launch_success @@ -801,7 +859,8 @@ Status DoBackward( ScratchAllocator* workspace_allocator, ProfileResult* output_profile_result) { std::unique_ptr<RnnSequenceTensorDescriptor> input_desc; - std::unique_ptr<RnnStateTensorDescriptor> state_desc; + std::unique_ptr<RnnStateTensorDescriptor> h_state_desc; + std::unique_ptr<RnnStateTensorDescriptor> c_state_desc; std::unique_ptr<RnnSequenceTensorDescriptor> output_desc; absl::Span<const int> seq_lengths; @@ -810,8 +869,8 @@ Status DoBackward( sequence_lengths->template flat<int>().data(), model_shapes.batch_size); } TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>( - context, model_shapes, &input_desc, &state_desc, &output_desc, - seq_lengths, time_major)); + context, model_shapes, &input_desc, &h_state_desc, &c_state_desc, + &output_desc, seq_lengths, time_major)); auto input_data = AsDeviceMemory<T>(input); auto input_h_data = AsDeviceMemory<T>(input_h); @@ -847,15 +906,16 @@ Status DoBackward( Stream* stream = context->op_device_context()->stream(); bool launch_success = stream - ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc, - input_h_data, *state_desc, input_c_data, - params_data, *output_desc, output_data, *state_desc, - output_h_data, *state_desc, output_c_data, - output_backprop_data, output_h_backprop_data, - output_c_backprop_data, &input_backprop_data, - &input_h_backprop_data, &input_c_backprop_data, - ¶ms_backprop_data, &reserve_space_uint8, - workspace_allocator, output_profile_result) + ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, + params_data, *output_desc, output_data, + *h_state_desc, output_h_data, *c_state_desc, + output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, workspace_allocator, + output_profile_result) .ok(); return launch_success ? Status::OK() @@ -932,7 +992,7 @@ class CudnnRNNKernelCommon : public OpKernel { bool ResetRndGenState() { return reset_rnd_gen_state_; } template <typename T> - Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, + Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj, std::unique_ptr<RnnDescriptor>* rnn_desc) { const Tensor* num_layers_t = nullptr; TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t)); @@ -953,6 +1013,9 @@ class CudnnRNNKernelCommon : public OpKernel { } int input_size = input_size_t->scalar<int>()(); + int h_num_units = (num_proj == 0 ? num_units : num_proj); + int c_num_units = (num_proj == 0 ? 0 : num_units); + RnnInputMode input_mode; TF_RETURN_IF_ERROR( ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode)); @@ -962,9 +1025,10 @@ class CudnnRNNKernelCommon : public OpKernel { // random number generator, therefore set state_allocator to nullptr. const AlgorithmConfig algo_config; auto rnn_desc_s = stream->parent()->createRnnDescriptor( - num_layers, num_units, input_size, /*batch_size=*/0, input_mode, - rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config, - dropout(), seed(), /* state_allocator=*/nullptr); + num_layers, h_num_units, input_size, /*c_size=*/c_num_units, + /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), + ToDataType<T>::value, algo_config, dropout(), seed(), + /* state_allocator=*/nullptr); if (!rnn_desc_s.ok()) { return FromExecutorStatus(rnn_desc_s); } @@ -983,9 +1047,9 @@ class CudnnRNNKernelCommon : public OpKernel { se::dnn::DataType data_type = ToDataType<T>::value; auto rnn_desc_s = executor->createRnnDescriptor( model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, model_shapes.batch_size, input_mode, - rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(), - seed(), dropout_state_allocator); + model_shapes.input_size, model_shapes.c_num_units, + model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), + data_type, algo_config, dropout(), seed(), dropout_state_allocator); TF_RETURN_IF_ERROR(rnn_desc_s.status()); *rnn_desc = rnn_desc_s.ConsumeValueOrDie(); @@ -1035,11 +1099,18 @@ template <typename T, typename Index> class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { public: explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) {} + : CudnnRNNKernelCommon(context) { + if (context->HasAttr("num_proj")) { + OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); + } else { + num_proj_ = 0; + } + } void Compute(OpKernelContext* context) override { std::unique_ptr<RnnDescriptor> rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); + OP_REQUIRES_OK(context, + ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; @@ -1049,6 +1120,9 @@ class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t)); *output_t->template flat<Index>().data() = params_size; } + + private: + int num_proj_; }; #define REGISTER_GPU(T) \ @@ -1074,7 +1148,33 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { public: explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context) : CudnnRNNKernelCommon(context) { - OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); + if (context->HasAttr("num_params")) { + OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_)); + } else { + num_params_ = 0; + } + if (context->HasAttr("num_params_weights")) { + OP_REQUIRES_OK(context, context->GetAttr("num_params_weights", + &num_params_weights_)); + } else { + num_params_weights_ = 0; + } + if (context->HasAttr("num_params_biases")) { + OP_REQUIRES_OK(context, + context->GetAttr("num_params_biases", + &num_params_biases_)); + } else { + num_params_biases_ = 0; + } + if (context->HasAttr("num_proj")) { + OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); + } else { + num_proj_ = 0; + } + if (num_proj_ == 0) { + num_params_weights_ = num_params_; + num_params_biases_ = num_params_; + } } void Compute(OpKernelContext* context) override { @@ -1083,7 +1183,8 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { Stream* stream = context->op_device_context()->stream(); std::unique_ptr<RnnDescriptor> rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); + OP_REQUIRES_OK(context, + ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; @@ -1109,25 +1210,38 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) { num_dirs = 2; } - const int num_params_per_layer = num_params_ / num_layers / num_dirs; + const int num_params_weights_per_layer = + num_params_weights_ / num_layers / num_dirs; // Number of params applied on inputs. The rest are applied on recurrent // hidden states. - const int num_params_input_state = num_params_per_layer / 2; - CHECK(num_params_ % (num_layers * num_dirs) == 0) - << "Number of params is not a multiple of num_layers * num_dirs."; - CHECK(num_params_per_layer % 2 == 0) - << "Number of params per layer is not a even number."; + const int num_params_input_state = num_params_weights_per_layer / 2; + CHECK(num_params_weights_ % (num_layers * num_dirs) == 0) + << "Number of params (weights) is not a multiple of num_layers * " + "num_dirs."; + CHECK(num_params_biases_ % (num_layers * num_dirs) == 0) + << "Number of params (bias) is not a multiple of num_layers * " + "num_dirs."; + if (num_proj_ == 0) { + CHECK(num_params_weights_per_layer % 2 == 0) + << "Number of params per layer is not a even number w/o projection."; + } else { + CHECK(num_params_weights_per_layer % 2 != 0) + << "Number of params per layer is not a odd number w/ projection."; + } - CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size()) - << "Number of params mismatch. Expected " << num_params_ << ", got " - << rnn_desc->ParamsWeightRegions().size(); + CHECK(num_params_weights_ == rnn_desc->ParamsWeightRegions().size()) + << "C Number of params mismatch. Expected " << num_params_weights_ + << ", got " << rnn_desc->ParamsWeightRegions().size(); + int h_num_units = (num_proj_ == 0 ? num_units : num_proj_); + int c_num_units = (num_proj_ == 0 ? 0 : num_units); for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size; int64 size = size_in_bytes / sizeof(T); - const int layer_idx = i / num_params_per_layer; - const int index_within_layer = i % num_params_per_layer; - int width = 0, height = num_units; - // In CuDNN layout, each layer has num_params_per_layer params, with the + const int layer_idx = i / num_params_weights_per_layer; + const int index_within_layer = i % num_params_weights_per_layer; + int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units); + // In CuDNN layout, each layer has num_params_weights_per_layer params, + // with the // first half a.k.a num_params_input_state params applied on the inputs, // and the second half on the recurrent hidden states. bool apply_on_input_state = index_within_layer < num_params_input_state; @@ -1135,7 +1249,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { if (layer_idx == 0 && apply_on_input_state) { width = input_size; } else { - width = num_units; + width = h_num_units; } } else { if (apply_on_input_state) { @@ -1145,15 +1259,19 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { } else { // Following layers, cell inputs are concatenated outputs of // its prior layer. - width = 2 * num_units; + width = 2 * h_num_units; } } else { - width = num_units; + width = h_num_units; } } CHECK(size == width * height) << "Params size mismatch. Expected " << width * height << ", got " << size; Tensor* output = nullptr; + int id_in_layer = i % num_params_weights_per_layer; + if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer-1) { + std::swap(height, width); + } OP_REQUIRES_OK(context, context->allocate_output( i, TensorShape({height, width}), &output)); DeviceMemoryBase data_src_ptr = SliceDeviceMemory( @@ -1162,10 +1280,11 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes); } - OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(), - errors::InvalidArgument("Number of params mismatch. Expected ", - num_params_, ", got ", - rnn_desc->ParamsBiasRegions().size())); + OP_REQUIRES( + context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(), + errors::InvalidArgument("A Number of params mismatch. Expected ", + num_params_biases_, ", got ", + rnn_desc->ParamsBiasRegions().size())); for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) { int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size; int64 size = size_in_bytes / sizeof(T); @@ -1175,7 +1294,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { Tensor* output = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(num_params_ + i, + context->allocate_output(num_params_weights_ + i, TensorShape({size}), &output)); DeviceMemoryBase data_src_ptr = SliceDeviceMemory( input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes); @@ -1186,6 +1305,9 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { private: int num_params_; + int num_params_weights_; + int num_params_biases_; + int num_proj_; }; #define REGISTER_GPU(T) \ @@ -1201,17 +1323,37 @@ TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_layers") \ + .HostMemory("num_units") \ + .HostMemory("input_size") \ + .TypeConstraint<T>("T"), \ + CudnnRNNParamsToCanonical<GPUDevice, T>); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +#undef REGISTER_GPU + // Convert weight and bias params from the canonical form to a // platform-specific layout. template <typename T> class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon { public: explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context) - : CudnnRNNKernelCommon(context) {} + : CudnnRNNKernelCommon(context) { + if (context->HasAttr("num_proj")) { + OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); + } else { + num_proj_ = 0; + } + } void Compute(OpKernelContext* context) override { std::unique_ptr<RnnDescriptor> rnn_desc; - OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc)); + OP_REQUIRES_OK(context, + ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc)); int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes(); CHECK(params_size_in_bytes % sizeof(T) == 0) << "params_size_in_bytes must be multiple of element size"; @@ -1232,6 +1374,9 @@ class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon { RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr, stream); } + + private: + int num_proj_; }; #define REGISTER_GPU(T) \ @@ -1247,6 +1392,19 @@ TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("num_layers") \ + .HostMemory("num_units") \ + .HostMemory("input_size") \ + .TypeConstraint<T>("T"), \ + CudnnRNNCanonicalToParams<GPUDevice, T>); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +#undef REGISTER_GPU + // Run the forward operation of the RNN model. template <typename T> class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { @@ -1264,14 +1422,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { void Compute(OpKernelContext* context) override { AlgorithmConfig algo_config; ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false, - /*time_major=*/true); + /*time_major=*/true, /*num_proj=*/0); } protected: virtual void ComputeAndReturnAlgorithm(OpKernelContext* context, AlgorithmConfig* output_algo_config, bool var_seq_lengths, - bool time_major) { + bool time_major, int num_proj) { CHECK_NE(output_algo_config, nullptr); const Tensor* input = nullptr; @@ -1284,11 +1442,13 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ExtractForwardInput(context, model_types(), time_major, &input, &input_h, &input_c, ¶ms, - &sequence_lengths, &model_shapes)); + &sequence_lengths, num_proj, + &model_shapes)); } else { OP_REQUIRES_OK(context, ExtractForwardInput( context, model_types(), time_major, &input, - &input_h, &input_c, ¶ms, &model_shapes)); + &input_h, &input_c, ¶ms, num_proj, + &model_shapes)); } RnnInputMode input_mode; OP_REQUIRES_OK(context, @@ -1362,13 +1522,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { Tensor** output_c) { const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; + const TensorShape& c_state_shape = model_shapes.c_state_shape; TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output)); TF_RETURN_IF_ERROR( context->allocate_output(1, hidden_state_shape, output_h)); if (HasInputC()) { TF_RETURN_IF_ERROR( - context->allocate_output(2, hidden_state_shape, output_c)); + context->allocate_output(2, c_state_shape, output_c)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. @@ -1414,7 +1575,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T> AlgorithmConfig best_algo_config; CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( context, &best_algo_config, /*var_seq_lengths=*/false, - /*time_major=*/true); + /*time_major=*/true, /*num_proj=*/0); if (!context->status().ok()) { return; } @@ -1613,13 +1774,18 @@ class CudnnRNNForwardOpV3<GPUDevice, T> explicit CudnnRNNForwardOpV3(OpKernelConstruction* context) : CudnnRNNForwardOp<GPUDevice, T>(context) { OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); + if (context->HasAttr("num_proj")) { + OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); + } else { + num_proj_ = 0; + } } void Compute(OpKernelContext* context) override { AlgorithmConfig best_algo_config; CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm( context, &best_algo_config, /*var_seq_lengths=*/true, - /*time_major=*/time_major()); + /*time_major=*/time_major(), num_proj_); if (!context->status().ok()) { return; } @@ -1631,6 +1797,9 @@ class CudnnRNNForwardOpV3<GPUDevice, T> OP_REQUIRES_OK(context, context->allocate_output(4, {}, &output_host_reserved)); } + + private: + int num_proj_; }; #define REGISTER_GPU(T) \ @@ -1654,12 +1823,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { : CudnnRNNKernelCommon(context) {} void Compute(OpKernelContext* context) override { - ComputeImpl(context, false, true); + ComputeImpl(context, false, true, 0); } protected: virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths, - bool time_major) { + bool time_major, int num_proj) { const Tensor* input = nullptr; const Tensor* input_h = nullptr; const Tensor* input_c = nullptr; @@ -1670,11 +1839,13 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ExtractForwardInput(context, model_types(), time_major, &input, &input_h, &input_c, ¶ms, - &sequence_lengths, &model_shapes)); + &sequence_lengths, num_proj, + &model_shapes)); } else { OP_REQUIRES_OK(context, ExtractForwardInput( context, model_types(), time_major, &input, - &input_h, &input_c, ¶ms, &model_shapes)); + &input_h, &input_c, ¶ms, num_proj, + &model_shapes)); } RnnInputMode input_mode; OP_REQUIRES_OK(context, @@ -1757,6 +1928,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space)); const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; + const TensorShape& c_state_shape = model_shapes.c_state_shape; if (output_shape != (*output)->shape()) { return errors::InvalidArgument( @@ -1782,16 +1954,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { } if (model_types.HasInputC()) { - if (hidden_state_shape != (*output_c)->shape()) { + if (c_state_shape != (*output_c)->shape()) { return errors::InvalidArgument( "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ", - hidden_state_shape.DebugString()); + c_state_shape.DebugString()); } - if (hidden_state_shape != (*output_c_backprop)->shape()) { + if (c_state_shape != (*output_c_backprop)->shape()) { return errors::InvalidArgument( "Invalid output_c_backprop shape: ", (*output_c_backprop)->shape().DebugString(), " ", - hidden_state_shape.DebugString()); + c_state_shape.DebugString()); } } return Status::OK(); @@ -1804,6 +1976,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { Tensor** input_c_backprop, Tensor** params_backprop) { const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; + const TensorShape& c_state_shape = model_shapes.c_state_shape; TF_RETURN_IF_ERROR( context->allocate_output(0, input_shape, input_backprop)); @@ -1811,7 +1984,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { context->allocate_output(1, hidden_state_shape, input_h_backprop)); if (HasInputC()) { TF_RETURN_IF_ERROR( - context->allocate_output(2, hidden_state_shape, input_c_backprop)); + context->allocate_output(2, c_state_shape, input_c_backprop)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. @@ -1879,11 +2052,20 @@ class CudnnRNNBackwardOpV3<GPUDevice, T> explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context) : CudnnRNNBackwardOp<GPUDevice, T>(context) { OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_)); + if (context->HasAttr("num_proj")) { + OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_)); + } else { + num_proj_ = 0; + } } void Compute(OpKernelContext* context) override { - CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major()); + CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(), + num_proj_); } + + private: + int num_proj_; }; #define REGISTER_GPU(T) \ diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index 9b22ccdeeec..1dd7659e137 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -49,6 +49,7 @@ REGISTER_OP("CudnnRNNParamsSize") .Attr("dropout: float = 0.0") .Attr("seed: int = 0") .Attr("seed2: int = 0") + .Attr("num_proj: int = 0") .Output("params_size: S") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; @@ -166,11 +167,13 @@ REGISTER_OP("CudnnRNNV3") .Attr("dropout: float = 0.0") .Attr("seed: int = 0") .Attr("seed2: int = 0") + .Attr("num_proj: int = 0") .Attr("is_training: bool = true") .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); auto input_h_shape = c->input(1); + auto input_c_shape = c->input(2); auto max_seq_length = c->Dim(input_shape, 0); auto batch_size = c->Dim(input_shape, 1); auto num_units = c->Dim(input_h_shape, 2); @@ -185,7 +188,7 @@ REGISTER_OP("CudnnRNNV3") c->MakeShape({max_seq_length, batch_size, output_size}); auto output_h_shape = input_h_shape; auto output_c_shape TF_ATTRIBUTE_UNUSED = - (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({}); + (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({}); c->set_output(0, output_shape); c->set_output(1, output_h_shape); c->set_output(2, output_c_shape); @@ -293,6 +296,7 @@ REGISTER_OP("CudnnRNNBackpropV3") .Attr("dropout: float = 0.0") .Attr("seed: int = 0") .Attr("seed2: int = 0") + .Attr("num_proj: int = 0") .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); @@ -338,6 +342,43 @@ REGISTER_OP("CudnnRNNParamsToCanonical") return Status::OK(); }); +REGISTER_OP("CudnnRNNParamsToCanonicalV2") + .Input("num_layers: int32") + .Input("num_units: int32") + .Input("input_size: int32") + .Input("params: T") + .Output("weights: num_params_weights * T") + .Output("biases: num_params_biases * T") + .Attr("T: {float16, float32, float64}") + .Attr("num_params_weights: int") + .Attr("num_params_biases: int") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Attr("dropout: float = 0.0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("num_proj: int = 0") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused)); + int num_params_weights; + int num_params_biases; + TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights)); + TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases)); + // Set shape for weight matrices + for (int i = 0; i < num_params_weights; i++) { + c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + } + // Set shape for bias vectors + for (int i = 0; i < num_params_biases; i++) { + c->set_output(num_params_weights + i, + c->Vector(InferenceContext::kUnknownDim)); + } + return Status::OK(); + }); + REGISTER_OP("CudnnRNNCanonicalToParams") .Input("num_layers: int32") .Input("num_units: int32") @@ -358,4 +399,26 @@ REGISTER_OP("CudnnRNNCanonicalToParams") return Status::OK(); }); +REGISTER_OP("CudnnRNNCanonicalToParamsV2") + .Input("num_layers: int32") + .Input("num_units: int32") + .Input("input_size: int32") + .Input("weights: num_params_weights * T") + .Input("biases: num_params_biases * T") + .Output("params: T") + .Attr("T: {float16, float32, float64}") + .Attr("num_params_weights: int") + .Attr("num_params_biases: int") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Attr("dropout: float = 0.0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("num_proj: int = 0") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py index 9ce906121f2..3c93e0b8ec9 100644 --- a/tensorflow/python/ops/cudnn_rnn_grad.py +++ b/tensorflow/python/ops/cudnn_rnn_grad.py @@ -98,6 +98,7 @@ def _cudnn_rnn_backwardv3(op, *grads): seed=op.get_attr("seed"), seed2=op.get_attr("seed2"), time_major=op.get_attr("time_major"), + num_proj=op.get_attr("num_proj"), rnn_mode=op.get_attr("rnn_mode"), input_mode=op.get_attr("input_mode"), direction=op.get_attr("direction")) + (None,) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index c0cc00c7208..4e9d569f669 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1002,8 +1002,8 @@ class CudnnRnnParamsDescriptor { class CudnnRnnDescriptor : public dnn::RnnDescriptor { CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc, PersistentRnnPlan rnn_plan, int num_layers, - int hidden_size, int input_size, int batch_size, - cudnnRNNInputMode_t input_mode, + int hidden_size, int input_size, int c_size, + int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, @@ -1015,6 +1015,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { num_layers_(num_layers), hidden_size_(hidden_size), input_size_(input_size), + c_size_(c_size), batch_size_(batch_size), rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())), input_mode_(input_mode), @@ -1031,7 +1032,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { static port::StatusOr<CudnnRnnDescriptor> Create( const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size, - int batch_size, cudnnRNNInputMode_t input_mode, + int c_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, @@ -1044,12 +1045,29 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm()); // TODO: allow the user to choose an algorithm. - RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), - /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*algo=*/rnn_algo, - /*dataType=*/compute_type)); + if (c_size != 0 && hidden_size < c_size) { + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/c_size, + /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), + /*inputMode=*/input_mode, /*direction=*/direction_mode, + /*mode=*/rnn_mode, /*algo=*/rnn_algo, /*dataType=*/compute_type)); +#if CUDNN_VERSION >= 7101 + RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), + /*recProjSize=*/hidden_size, /*outProjSize=*/0)); +#else + return port::Status(port::error::INVALID_ARGUMENT, + "No supported cudnnSetRNNProjectionLayers when " + "CUDNN_VERSION < 7.1.1"); +#endif + } else { + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), + /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers, + /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, + /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, + /*dataType=*/compute_type)); + } // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs. // But in the future if these APIs are used to process full length arrays, @@ -1098,9 +1116,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { #endif return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), - num_layers, hidden_size, input_size, batch_size, - input_mode, direction_mode, rnn_mode, data_type, - compute_type, algorithm_config, + num_layers, hidden_size, input_size, c_size, + batch_size, input_mode, direction_mode, rnn_mode, + data_type, compute_type, algorithm_config, std::move(dropout_desc), std::move(params_desc)); } @@ -1108,6 +1126,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int num_layers() const { return num_layers_; } int hidden_size() const { return hidden_size_; } int input_size() const { return input_size_; } + int c_size() const { return c_size_; } int batch_size() const { return batch_size_; } cudnnRNNInputMode_t input_mode() const { return input_mode_; } cudnnDirectionMode_t direction_mode() const { return direction_mode_; } @@ -1136,6 +1155,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int num_layers_; int hidden_size_; int input_size_; + int c_size_; // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC // algorithm. int batch_size_; @@ -1240,6 +1260,62 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create( (type == 0 ? weights : biases).push_back(region); } } + int hidden_size_v; + int num_layers_v; + cudnnDropoutDescriptor_t dropout_desc; + cudnnRNNInputMode_t input_mode; + cudnnDirectionMode_t direction; + cudnnRNNMode_t mode; + cudnnRNNAlgo_t algo; + cudnnDataType_t data_dype; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*hiddenSize=*/&hidden_size_v, + /*numLayers=*/&num_layers_v, + /*dropoutDesc=*/&dropout_desc, + /*inputMode=*/&input_mode, + /*direction=*/&direction, + /*mode=*/&mode, + /*algo=*/&algo, + /*dataType=*/&data_type)); + int rec_proj_size_v; + int out_proj_size_v; +#if CUDNN_VERSION >= 7101 + RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc, + /*recProjSize*/ &rec_proj_size_v, + /*outProjSize*/ &out_proj_size_v)); +#else + return port::Status(port::error::INVALID_ARGUMENT, + "No supported cudnnGetRNNProjectionLayers when " + "CUDNN_VERSION < 7.1.1"); +#endif + if (rec_proj_size_v != hidden_size_v) { + void* offset = nullptr; + int region_id = 8; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*layer=*/layer, /*xDesc=*/input_desc.get(), + /*wDesc=*/filter_desc.get(), + /*w=*/nullptr, /*linLayerID=*/region_id, + /*linLayerMatDesc=*/region_desc_handle.get(), + /*linLayerMat or linLayerBias=*/&offset)); + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims)); + int64 size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + dnn::RnnDescriptor::ParamsRegion region = { + reinterpret_cast<int64>(offset), size}; + weights.push_back(region); + } } return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, @@ -1404,6 +1480,7 @@ struct RnnModelDims { int max_seq_length = 0; int hidden_size = 0; int input_size = 0; + int c_size = 0; int dir_count = 0; }; @@ -1429,6 +1506,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( model_dims.max_seq_length = input_desc.max_seq_length(); model_dims.hidden_size = rnn_desc.hidden_size(); model_dims.input_size = input_desc.data_size(); + model_dims.c_size = rnn_desc.c_size(); model_dims.dir_count = (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; @@ -1441,7 +1519,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( } if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && - input_h_desc.data_size() == input_c_desc.data_size())) { + input_h_desc.data_size() <= input_c_desc.data_size())) { return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape"); } if (!(output_desc.max_seq_length() == model_dims.max_seq_length && @@ -1458,7 +1536,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( } if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && input_h_desc.batch_size() == output_c_desc.batch_size() && - input_h_desc.data_size() == output_c_desc.data_size())) { + input_h_desc.data_size() <= output_c_desc.data_size())) { return port::Status(port::error::INVALID_ARGUMENT, "Invalid output_c shape"); } @@ -1814,7 +1892,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> CudnnSupport::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int batch_size, + int num_layers, int hidden_size, int input_size, int c_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, @@ -1825,7 +1903,7 @@ CudnnSupport::createRnnDescriptor( SE_ASSIGN_OR_RETURN( CudnnRnnDescriptor rnn_desc, CudnnRnnDescriptor::Create( - cudnn, num_layers, hidden_size, input_size, batch_size, + cudnn, num_layers, hidden_size, input_size, c_size, batch_size, ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), GetRnnComputeType(data_type), diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 3a49469651c..b0da5cbd3b9 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -48,11 +48,11 @@ class CudnnSupport : public dnn::DnnSupport { port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override; port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int batch_size, - dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, - const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, - ScratchAllocator* state_allocator) override; + int num_layers, int hidden_size, int input_size, int c_size, + int batch_size, dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, + dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, + float dropout, uint64 seed, ScratchAllocator* state_allocator) override; port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 34de1512ee5..295bedba233 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2052,7 +2052,7 @@ class DnnSupport { // is no longer in use. virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(int num_layers, int hidden_size, int input_size, - int batch_size, dnn::RnnInputMode input_mode, + int c_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 2870c3883e2..6541a138f73 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -379,7 +379,7 @@ bool StreamExecutor::GetBlasGemmAlgorithms( port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> StreamExecutor::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int batch_size, + int num_layers, int hidden_size, int input_size, int c_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, @@ -390,7 +390,7 @@ StreamExecutor::createRnnDescriptor( "Fail to find the dnn implementation."); } return dnn_support->createRnnDescriptor( - num_layers, hidden_size, input_size, batch_size, input_mode, + num_layers, hidden_size, input_size, c_size, batch_size, input_mode, direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed, state_allocator); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 7ded071467f..533213bd194 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -405,11 +405,11 @@ class StreamExecutor { // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int batch_size, - dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, - dnn::RnnMode rnn_mode, dnn::DataType data_type, - const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, - ScratchAllocator *state_allocator); + int num_layers, int hidden_size, int input_size, int c_size, + int batch_size, dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, + dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, + float dropout, uint64 seed, ScratchAllocator *state_allocator); // Create a RNN sequence descriptor that specifies either the input or output // sequence. The caller retains the ownership of the returned descriptor. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index a809e92d3a5..c4f9210aae2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1036,6 +1036,14 @@ tf_module { name: "cross" argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "cudnn_rnn_canonical_to_params_v2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } + member_method { + name: "cudnn_rnn_params_to_canonical_v2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "cumprod" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index a2fa54f0214..2fdb508eda5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -766,27 +766,35 @@ tf_module { } member_method { name: "CudnnRNNBackpropV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], " + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], " } member_method { name: "CudnnRNNCanonicalToParams" argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " } + member_method { + name: "CudnnRNNCanonicalToParamsV2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "CudnnRNNParamsSize" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " } member_method { name: "CudnnRNNParamsToCanonical" argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " } + member_method { + name: "CudnnRNNParamsToCanonicalV2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "CudnnRNNV2" argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], " } member_method { name: "CudnnRNNV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], " + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], " } member_method { name: "Cumprod" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 6e7cc51a285..39d8cbaea81 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -540,6 +540,14 @@ tf_module { name: "cosh" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "cudnn_rnn_canonical_to_params_v2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } + member_method { + name: "cudnn_rnn_params_to_canonical_v2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "cumsum" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index a2fa54f0214..2fdb508eda5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -766,27 +766,35 @@ tf_module { } member_method { name: "CudnnRNNBackpropV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], " + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], " } member_method { name: "CudnnRNNCanonicalToParams" argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " } + member_method { + name: "CudnnRNNCanonicalToParamsV2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "CudnnRNNParamsSize" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " } member_method { name: "CudnnRNNParamsToCanonical" argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], " } + member_method { + name: "CudnnRNNParamsToCanonicalV2" + argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " + } member_method { name: "CudnnRNNV2" argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], " } member_method { name: "CudnnRNNV3" - argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], " + argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], " } member_method { name: "Cumprod" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt index a4d728c6976..1db8d32de3c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "audio" - argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " } member_method { name: "create_file_writer" @@ -26,11 +26,11 @@ tf_module { } member_method { name: "histogram" - argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "image" - argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], " } member_method { name: "import_event" @@ -42,7 +42,7 @@ tf_module { } member_method { name: "scalar" - argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "summary_scope" @@ -50,7 +50,7 @@ tf_module { } member_method { name: "text" - argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "trace_export" From dbe02b821930faeb1a6f3e5954f37104be9aa6c7 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Mon, 6 May 2019 17:36:22 -0700 Subject: [PATCH 02/17] fixed some conflicts --- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 421 +++++++++--------- 1 file changed, 214 insertions(+), 207 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index c3bcd02fdda..bb19935de6a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -68,15 +68,19 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, cell_clip=None, use_peephole=False, - reuse=reuse, name="cudnn_compatible_lstm_cell") + num_units, + forget_bias=0, + cell_clip=None, + use_peephole=False, + reuse=reuse, + name="cudnn_compatible_lstm_cell") self._names.update({"scope": "cudnn_compatible_lstm_cell"}) class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): r"""Cudnn Compatible GRUCell. - A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with + A GRU impl akin to `tf.compat.v1.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by it seamlessly. @@ -97,7 +101,8 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` - Other GRU (see `tf.nn.rnn_cell.GRUCell` and `tf.contrib.rnn.GRUBlockCell`): + Other GRU (see `tf.compat.v1.nn.rnn_cell.GRUCell` and + `tf.contrib.rnn.GRUBlockCell`): ```python # new memory gate \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) @@ -117,8 +122,8 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): def build(self, inputs_shape): if inputs_shape[1].value is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + inputs_shape) input_depth = inputs_shape[1].value self._gate_kernel = self.add_variable( @@ -128,10 +133,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): self._gate_bias = self.add_variable( "gates/%s" % _BIAS_VARIABLE_NAME, shape=[2 * self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.constant_initializer(1.0, dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) self._candidate_input_kernel = self.add_variable( "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, @@ -145,17 +149,15 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): self._candidate_input_bias = self.add_variable( "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, shape=[self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.zeros_initializer(dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) self._candidate_hidden_bias = self.add_variable( "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, shape=[self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.zeros_initializer(dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" @@ -173,7 +175,7 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): math_ops.matmul(state, self._candidate_hidden_kernel), self._candidate_hidden_bias) candidate = self._activation(candidate) - new_h = (1-u) * candidate + u * state + new_h = (1 - u) * candidate + u * state return new_h, new_h @@ -242,6 +244,7 @@ class CudnnParamsFormatConverter(object): Args: opaque_param: An opaque tensor storing cudnn rnn params (weights and biases). + Returns: 2 list for weights and biases respectively. """ @@ -279,6 +282,7 @@ class CudnnParamsFormatConverter(object): Args: cu_weights: a list of tensors, Cudnn canonical weights. cu_biases: a list of tensors, Cudnn canonical biases. + Returns: a single opaque tensor. """ @@ -324,6 +328,7 @@ class CudnnParamsFormatConverter(object): Args: cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. + Returns: 1 tuple, tf canonical weights and biases. """ @@ -339,8 +344,10 @@ class CudnnParamsFormatConverter(object): layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - self._cu_canonical_to_tf_canonical_single_layer( - layer_weights, layer_biases, tf_weights, tf_biases, tf_weights_proj) + self._cu_canonical_to_tf_canonical_single_layer(layer_weights, + layer_biases, + tf_weights, tf_biases, + tf_weights_proj) else: fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] @@ -423,10 +430,12 @@ class CudnnParamsFormatConverter(object): cu_weights.append(pw) cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases)) else: - fw_weights, bw_weights = layer_weights[:len( - layer_weights) // 2], layer_weights[len(layer_weights) // 2:] - fw_biases, bw_biases = layer_biases[:len( - layer_biases) // 2], layer_biases[len(layer_biases) // 2:] + fw_weights, bw_weights = layer_weights[:len(layer_weights) // + 2], layer_weights[ + len(layer_weights) // 2:] + fw_biases, bw_biases = layer_biases[:len(layer_biases) // + 2], layer_biases[len(layer_biases + ) // 2:] cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights)) if weights_proj is not None: pw0 = array_ops.transpose(weights_proj[2*i+0]) @@ -484,7 +493,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): W_o = array_ops.concat([w_o, r_o], axis=1) # pylint: enable=invalid-name # Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order. - reordered = self._cudnn_to_tf_gate_params(* [W_i, W_f, W_c, W_o]) + reordered = self._cudnn_to_tf_gate_params(*[W_i, W_f, W_c, W_o]) if self._num_proj: return (array_ops.transpose(array_ops.concat(reordered, axis=0)), array_ops.transpose(pw)) @@ -505,8 +514,8 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): (tf_weight,) = tf_weights w = array_ops.transpose(tf_weight) # pylint: disable=invalid-name - W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(*array_ops.split( - w, 4, axis=0)) + W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params( + *array_ops.split(w, 4, axis=0)) hidden_state_width = self._num_proj if self._num_proj else num_units w_i, r_i = array_ops.split(W_i, [input_weight_width, hidden_state_width], @@ -532,15 +541,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): B_c = b_wc + b_rc B_o = b_wo + b_ro # pylint: enable=invalid-name - reordered = self._cudnn_to_tf_gate_params(* [B_i, B_f, B_c, B_o]) + reordered = self._cudnn_to_tf_gate_params(*[B_i, B_f, B_c, B_o]) return (array_ops.concat(reordered, axis=0),) def _tf_to_cudnn_biases(self, *tf_biases): r"""Reverse the operations in StitchBiases().""" (tf_bias,) = tf_biases # pylint: disable=invalid-name - B_i, B_f, B_c, B_o = self._tf_to_cudnn_gate_params(*array_ops.split( - tf_bias, 4, axis=0)) + B_i, B_f, B_c, B_o = self._tf_to_cudnn_gate_params( + *array_ops.split(tf_bias, 4, axis=0)) # pylint: enable=invalid-name # pylint: disable=unbalanced-tuple-unpacking b_wi, b_ri = (B_i * 0.5,) * 2 @@ -614,8 +623,8 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): # return two biases each with half the value. Since RNN does not # regularize by weight decay, it has no side effect in training or # inference. - array_ops.concat([b_wi, b_wr], axis=0) + array_ops.concat( - [b_ri, b_rr], axis=0), + array_ops.concat([b_wi, b_wr], axis=0) + + array_ops.concat([b_ri, b_rr], axis=0), b_wh, b_rh) @@ -796,8 +805,8 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) for param, param_name in zip(params, prefixed_param_names) ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) + super(CudnnOpaqueParamsSaveable, + self).__init__(array_ops.identity(self._variables), specs, name) @property def format_converter(self): @@ -837,8 +846,8 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cuDNN-compatible cells. Args: - trackable: An object inheriting from `Trackable` to add - dependencies too (typically the cuDNN `Layer`). + trackable: An object inheriting from `Trackable` to add dependencies too + (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( @@ -982,9 +991,9 @@ _cudnn_rnn_common_doc_string = """ def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): - raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % - (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, - CUDNN_RNN_RELU)) + raise ValueError( + "Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % + (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU)) def _get_seed(seed): @@ -1045,13 +1054,13 @@ def _cudnn_rnn(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1064,22 +1073,22 @@ def _cudnn_rnn(inputs, false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major form. This param is only effective when 'sequence_lengths' is used. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. If None or 0, no projection is performed. name: name of the operation. + Returns: outputs, output_h, output_c """ @@ -1141,13 +1150,13 @@ def cudnn_lstm(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference sequence_lengths: an int32 array representing the variable sequence lengths @@ -1159,22 +1168,22 @@ def cudnn_lstm(inputs, false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major form. This param is only effective when 'sequence_lengths' is used. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. - If None or 0, no projection is performed. + If None or 0, no projection is performed. name: name of the operation. + Returns: outputs, output_h, output_c """ @@ -1199,11 +1208,11 @@ def _cudnn_rnn_no_input_c(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1216,27 +1225,28 @@ def _cudnn_rnn_no_input_c(inputs, false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in time-major form. This param is only effective when 'sequence_lengths' is used. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ input_c = array_ops.constant([], dtype=input_h.dtype) - outputs, output_h, _ = _cudnn_rnn( - inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, - time_major, input_mode, direction, dropout, seed, None, name) + outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, + is_training, rnn_mode, sequence_lengths, + time_major, input_mode, direction, dropout, + seed, None, name) return outputs, output_h @@ -1255,21 +1265,20 @@ def cudnn_gru(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same @@ -1280,11 +1289,12 @@ def cudnn_gru(inputs, By default this function accepts input and emits output in time-major form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ @@ -1308,11 +1318,11 @@ def cudnn_rnn_relu(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1325,8 +1335,8 @@ def cudnn_rnn_relu(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. @@ -1360,21 +1370,20 @@ def cudnn_rnn_tanh(inputs, Args: inputs: the input sequence to the RNN model. If `time_major` is True - (default), the Tensor shape is [max_time, batch_size, input_size]. If - `time_major` is False, the shape is [batch_size, max_time, input_size]. - input_h: the initial hidden state for h. If `time_major` is True - (default), the Tensor shape is [num_layers, batch_size, num_units]. If - `time_major` is False, the shape is [batch_size, num_layers, num_units]. + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same @@ -1385,11 +1394,12 @@ def cudnn_rnn_tanh(inputs, By default this function accepts input and emits output in time-major form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ @@ -1413,28 +1423,27 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. + input_size: the size of the input, it could be different from the num_units. params: opaque cudnn params var. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. - If None or 0, no projection is performed. + If None or 0, no projection is performed. name: name of the operation. + Returns: weights list and bias list Raises: @@ -1498,29 +1507,28 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. + input_size: the size of the input, it could be different from the num_units. weights: a Tensor for weight parameters. biases: a Tensor for bias parameters. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. - If None or 0, no projection is performed. + If None or 0, no projection is performed. name: name of the operation. + Returns: an opaque Cudnn param. Raises: @@ -1575,28 +1583,27 @@ def cudnn_rnn_opaque_params_size(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_size: the size of the input, it could be different from the num_units. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: one of tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. - If None or 0, no projection is performed. + If None or 0, no projection is performed. name: name of the operation. + Returns: a int, size of Cudnn opaque params. Raises: @@ -1646,27 +1653,27 @@ class _CudnnRNN(object): Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. num_proj: The output dimensionality for the projection matrices. - If None or 0, no projection is performed. + If None or 0, no projection is performed. + Raises: ValueError: if direction is invalid. """ @@ -1847,15 +1854,14 @@ class CudnnLSTM(_CudnnRNN): num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and The actual computation before the first layer. It could be - 'skip_input', 'linear_input' or 'auto_select'. - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. 'skip_input' is only + allowed when input_size == num_units; 'auto_select' implies 'skip_input' + when input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. @@ -1902,9 +1908,10 @@ class CudnnLSTM(_CudnnRNN): true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in - time-major form. This param is only effective when 'sequence_lengths' - is used. + time-major form. This param is only effective when 'sequence_lengths' is + used. is_training: whether this operation will be used in training or inference. + Returns: output: the output sequence. output_h: the final state for h. @@ -1940,15 +1947,14 @@ class _CudnnRNNNoInputC(_CudnnRNN): num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and The actual computation before the first layer. It could be - 'skip_input', 'linear_input' or 'auto_select'. - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. 'skip_input' is only + allowed when input_size == num_units; 'auto_select' implies 'skip_input' + when input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. @@ -1996,9 +2002,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. By default this function accepts input and emits output in - time-major form. This param is only effective when 'sequence_lengths' - is used. + time-major form. This param is only effective when 'sequence_lengths' is + used. is_training: whether this operation will be used in training or inference. + Returns: output: the output sequence. output_h: the final state for h. From e6904ebd23e1a3722bdf084489859f468467e69e Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 8 May 2019 15:01:47 -0700 Subject: [PATCH 03/17] removed some duplicated codes --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 144 +++++++++--------- tensorflow/stream_executor/dnn.h | 1 + .../api/golden/v2/tensorflow.summary.pbtxt | 10 +- 3 files changed, 82 insertions(+), 73 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index d35e8340013..067e83d7a47 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1045,12 +1045,17 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm()); // TODO: allow the user to choose an algorithm. + int unified_size = hidden_size; + if (c_size != 0 && hidden_size < c_size) { + unified_size = c_size; + } + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), + /*hiddenSize=*/unified_size, /*numLayers=*/num_layers, + /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, + /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, + /*dataType=*/compute_type)); if (c_size != 0 && hidden_size < c_size) { - RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/c_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), - /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*algo=*/rnn_algo, /*dataType=*/compute_type)); #if CUDNN_VERSION >= 7101 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers( cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), @@ -1060,13 +1065,6 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { "No supported cudnnSetRNNProjectionLayers when " "CUDNN_VERSION < 7.1.1"); #endif - } else { - RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), - /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers, - /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, - /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, - /*dataType=*/compute_type)); } // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs. @@ -1155,6 +1153,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int num_layers_; int hidden_size_; int input_size_; + // c_size_ is the size of cell state, which will be different from + // hidden_size_ if the projection is used. int c_size_; // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC // algorithm. @@ -1173,6 +1173,66 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { namespace { +port::Status CheckAndFetchProjectionWeights( + const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, int layer, + TensorDescriptor& input_desc, FilterDescriptor& filter_desc, + FilterDescriptor& region_desc_handle, + dnn::RnnDescriptor::ParamsRegions& weights) { +#if CUDNN_VERSION >= 7101 + int hidden_size_v; + int num_layers_v; + cudnnDropoutDescriptor_t dropout_desc; + cudnnRNNInputMode_t input_mode; + cudnnDirectionMode_t direction; + cudnnRNNMode_t mode; + cudnnRNNAlgo_t algo; + cudnnDataType_t data_type; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*hiddenSize=*/&hidden_size_v, + /*numLayers=*/&num_layers_v, + /*dropoutDesc=*/&dropout_desc, + /*inputMode=*/&input_mode, + /*direction=*/&direction, + /*mode=*/&mode, + /*algo=*/&algo, + /*dataType=*/&data_type)); + int rec_proj_size_v; + int out_proj_size_v; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc, + /*recProjSize*/ &rec_proj_size_v, + /*outProjSize*/ &out_proj_size_v)); + if (rec_proj_size_v != hidden_size_v) { + void* offset = nullptr; + int region_id = 8; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*layer=*/layer, /*xDesc=*/input_desc.get(), + /*wDesc=*/filter_desc.get(), + /*w=*/nullptr, /*linLayerID=*/region_id, + /*linLayerMatDesc=*/region_desc_handle.get(), + /*linLayerMat or linLayerBias=*/&offset)); + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims)); + int64 size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + dnn::RnnDescriptor::ParamsRegion region = { + reinterpret_cast<int64>(offset), size}; + weights.push_back(region); + } +#endif // CUDNN_VERSION >= 7101 + return port::Status::OK(); +} + port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create( const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, @@ -1260,62 +1320,10 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create( (type == 0 ? weights : biases).push_back(region); } } - int hidden_size_v; - int num_layers_v; - cudnnDropoutDescriptor_t dropout_desc; - cudnnRNNInputMode_t input_mode; - cudnnDirectionMode_t direction; - cudnnRNNMode_t mode; - cudnnRNNAlgo_t algo; - cudnnDataType_t data_dype; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*hiddenSize=*/&hidden_size_v, - /*numLayers=*/&num_layers_v, - /*dropoutDesc=*/&dropout_desc, - /*inputMode=*/&input_mode, - /*direction=*/&direction, - /*mode=*/&mode, - /*algo=*/&algo, - /*dataType=*/&data_type)); - int rec_proj_size_v; - int out_proj_size_v; -#if CUDNN_VERSION >= 7101 - RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( - /*handle=*/cudnn.handle(), - /*rnnDesc=*/rnn_desc, - /*recProjSize*/ &rec_proj_size_v, - /*outProjSize*/ &out_proj_size_v)); -#else - return port::Status(port::error::INVALID_ARGUMENT, - "No supported cudnnGetRNNProjectionLayers when " - "CUDNN_VERSION < 7.1.1"); -#endif - if (rec_proj_size_v != hidden_size_v) { - void* offset = nullptr; - int region_id = 8; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*layer=*/layer, /*xDesc=*/input_desc.get(), - /*wDesc=*/filter_desc.get(), - /*w=*/nullptr, /*linLayerID=*/region_id, - /*linLayerMatDesc=*/region_desc_handle.get(), - /*linLayerMat or linLayerBias=*/&offset)); - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle.get(), - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims)); - int64 size = - dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast<int64>(offset), size}; - weights.push_back(region); - } + TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(cudnn, rnn_desc, layer, + input_desc, filter_desc, + region_desc_handle, + weights)); } return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index ec225fdd01c..70395d7e70c 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2052,6 +2052,7 @@ class DnnSupport { // num_layers: the number of layers for a RNN model. // hidden_size: the size of the hidden state. // input_size: the size of the input state. + // c_size: the size of the cell state // input_mode: an enum to specify whether a linear transformation is added // after the input state. If input_size is different from hidden_size, this // is required. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt index 9365fe9ebdd..a81480f5c38 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt @@ -10,7 +10,7 @@ tf_module { } member_method { name: "audio" - argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\', \'None\'], " } member_method { name: "create_file_writer" @@ -26,11 +26,11 @@ tf_module { } member_method { name: "histogram" - argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "image" - argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\'], " } member_method { name: "record_if" @@ -38,11 +38,11 @@ tf_module { } member_method { name: "scalar" - argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "text" - argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "trace_export" From bc109124a10286a59580ab7e790ad07e9d643d42 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Thu, 30 May 2019 18:07:53 -0700 Subject: [PATCH 04/17] Use cell_ instead of c_; Add back commented-out tests --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 78 +++++----- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 6 +- tensorflow/core/kernels/cudnn_rnn_ops.cc | 74 +++++----- tensorflow/stream_executor/cuda/cuda_dnn.cc | 139 +++++++++--------- tensorflow/stream_executor/dnn.h | 5 +- .../stream_executor/stream_executor_pimpl.cc | 7 +- .../stream_executor/stream_executor_pimpl.h | 2 +- 7 files changed, 156 insertions(+), 155 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 47ae2765822..e08b50f5fd4 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -92,10 +92,11 @@ def RunLSTM(sess, inputs_dynamic = array_ops.placeholder( dtype, shape=[None, None, None], name="inputs") inputs = inputs_dynamic if dynamic_shape_input else inputs_static + unified_num_units = num_proj if num_proj else num_units + unified_num_proj = num_proj if num_proj else None initial_h_op = variable_scope.get_variable( "initial_h_op", - initializer=np.random.rand(batch_size, - num_proj if num_proj else num_units) + initializer=np.random.rand(batch_size, unified_num_units) .astype(dtype.as_numpy_dtype), dtype=dtype) initial_c_op = variable_scope.get_variable( @@ -117,8 +118,7 @@ def RunLSTM(sess, with variable_scope.variable_scope("test", initializer=initializer): w = variable_scope.get_variable( "rnn/lstm_cell/kernel", - shape=[input_size + (num_proj if num_proj else num_units), - num_units * 4], + shape=[input_size + unified_num_units, num_units * 4], dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) @@ -129,7 +129,7 @@ def RunLSTM(sess, # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True, - num_proj=num_proj if num_proj else None) + num_proj=unified_num_proj) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs_static, @@ -142,8 +142,7 @@ def RunLSTM(sess, # Convert to cudnn opaque param. format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( - num_layers, num_units, input_size, - num_proj=num_proj if num_proj else None) + num_layers, num_units, input_size, num_proj=unified_num_proj) if num_proj: opaque_params = format_converter.tf_canonical_to_opaque([w, b], [pw,]) else: @@ -163,7 +162,7 @@ def RunLSTM(sess, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM, - num_proj=num_proj if num_proj else None) + num_proj=unified_num_proj) # Remove the trivial 1st dimension. cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), @@ -271,44 +270,35 @@ def RunLSTM(sess, # Basic set of RNN configs to test. They can be further extended in relevant # test (e.g. adding num_dirs). -#NAMED_RNN_TESTCASES = ({ -# "testcase_name": "xsmall", -# "num_units": 1, -# "input_size": 1, -# "batch_size": 1, -# "time": 1, -# "num_layers": 1, -#}, { -# "testcase_name": "small", -# "num_units": 4, -# "input_size": 4, -# "batch_size": 4, -# "time": 4, -# "num_layers": 1, -#}, { -# "testcase_name": "medium", -# "num_units": 128, -# "input_size": 64, -# "batch_size": 8, -# "time": 16, -# "num_layers": 1, -#}, { -# "testcase_name": "large", -# "num_units": 128, -# "input_size": 128, -# "batch_size": 16, -# "time": 32, -# "num_layers": 1, -#}) NAMED_RNN_TESTCASES = ({ + "testcase_name": "xsmall", + "num_units": 1, + "input_size": 1, + "batch_size": 1, + "time": 1, + "num_layers": 1, +}, { "testcase_name": "small", "num_units": 4, "input_size": 4, "batch_size": 4, "time": 4, "num_layers": 1, -}, ) - +}, { + "testcase_name": "medium", + "num_units": 128, + "input_size": 64, + "batch_size": 8, + "time": 16, + "num_layers": 1, +}, { + "testcase_name": "large", + "num_units": 128, + "input_size": 128, + "batch_size": 16, + "time": 32, + "num_layers": 1, +}) def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): """Expands testcase with new config dimensions. @@ -484,7 +474,7 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths=variable_seq_lengths, time_major=time_major, dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -513,7 +503,7 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths=variable_seq_lengths, time_major=time_major, dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + num_proj=num_proj if use_proj else None) self.assertAllClose(outputs, cu_outputs) # h @@ -549,7 +539,7 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths=variable_seq_lengths, time_major=time_major, dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + num_proj=num_proj if use_proj else None) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -592,7 +582,7 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths=variable_seq_lengths, time_major=time_major, dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + num_proj=num_proj if use_proj else None) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -608,7 +598,7 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): variable_seq_lengths=variable_seq_lengths, time_major=time_major, dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + num_proj=num_proj if use_proj else None) self.assertAllClose(cu_outputs, cu_outputs2) # h diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index bb19935de6a..6190ec31f70 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -1119,8 +1119,10 @@ def _cudnn_rnn(inputs, args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif time_major is False or num_proj: - batch_size = array_ops.shape(inputs)[0] - max_time = array_ops.shape(inputs)[1] + batch_id = 1 if time_major else 0 + time_id = 0 if time_major else 1 + batch_size = array_ops.shape(inputs)[batch_id] + max_time = array_ops.shape(inputs)[time_id] sequence_lengths = array_ops.fill([batch_size], max_time) args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index e70fa88b2cf..61201d01b72 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -502,23 +502,23 @@ struct CudnnRnnModelShapes { int dir_count; int max_seq_length; int batch_size; - int c_num_units; + int cell_num_units = 0; TensorShape input_shape; TensorShape output_shape; TensorShape hidden_state_shape; - TensorShape c_state_shape; + TensorShape cell_state_shape; // At present only fields related to cached RnnDescriptor are concerned. bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { return num_layers == rhs.num_layers && input_size == rhs.input_size && num_units == rhs.num_units && dir_count == rhs.dir_count && - c_num_units == rhs.c_num_units; + cell_num_units == rhs.cell_num_units; } string DebugString() const { return strings::Printf( "[num_layers, input_size, num_units, dir_count, max_seq_length, " - "batch_size, c_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", + "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", num_layers, input_size, num_units, dir_count, max_seq_length, - batch_size, c_num_units); + batch_size, cell_num_units); } }; @@ -619,16 +619,16 @@ Status ExtractForwardInput(OpKernelContext* context, model_shapes->hidden_state_shape.DebugString()); } if (model_types.HasInputC()) { - model_shapes->c_num_units = (*input_c)->dim_size(2); + model_shapes->cell_num_units = (*input_c)->dim_size(2); if (time_major) { - model_shapes->c_state_shape = + model_shapes->cell_state_shape = TensorShape({model_shapes->dir_count * model_shapes->num_layers, - model_shapes->batch_size, model_shapes->c_num_units}); + model_shapes->batch_size, model_shapes->cell_num_units}); } else { - model_shapes->c_state_shape = + model_shapes->cell_state_shape = TensorShape({model_shapes->batch_size, model_shapes->dir_count * model_shapes->num_layers, - model_shapes->c_num_units}); + model_shapes->cell_num_units}); } if (num_proj == 0) { if ((*input_h)->shape() != (*input_c)->shape()) { @@ -649,18 +649,18 @@ Status ExtractForwardInput(OpKernelContext* context, } } } else { - // dummy c_state_shape TODO(kaixih): remove the time_major branch + // dummy cell_state_shape TODO(kaixih): remove the time_major branch if (time_major) { - model_shapes->c_state_shape = + model_shapes->cell_state_shape = TensorShape({model_shapes->dir_count * model_shapes->num_layers, model_shapes->batch_size, model_shapes->num_units}); } else { - model_shapes->c_state_shape = + model_shapes->cell_state_shape = TensorShape({model_shapes->batch_size, model_shapes->dir_count * model_shapes->num_layers, model_shapes->num_units}); } - model_shapes->c_num_units = 0; + model_shapes->cell_num_units = 0; } if (time_major) { model_shapes->output_shape = @@ -699,7 +699,7 @@ Status CreateForwardAndBackwardIODescriptors( const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; - const TensorShape& c_state_shape = model_shapes.c_state_shape; + const TensorShape& cell_state_shape = model_shapes.cell_state_shape; const TensorShape& output_shape = model_shapes.output_shape; DCHECK_EQ(input_shape.dims(), 3); @@ -740,19 +740,19 @@ Status CreateForwardAndBackwardIODescriptors( *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie(); } - DCHECK_EQ(c_state_shape.dims(), 3); + DCHECK_EQ(cell_state_shape.dims(), 3); if (time_major) { - auto c_state_desc_s = executor->createRnnStateTensorDescriptor( - c_state_shape.dim_size(0), c_state_shape.dim_size(1), - c_state_shape.dim_size(2), data_type); - TF_RETURN_IF_ERROR(c_state_desc_s.status()); - *c_state_desc = c_state_desc_s.ConsumeValueOrDie(); + auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( + cell_state_shape.dim_size(0), cell_state_shape.dim_size(1), + cell_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(cell_state_desc_s.status()); + *c_state_desc = cell_state_desc_s.ConsumeValueOrDie(); } else { - auto c_state_desc_s = executor->createRnnStateTensorDescriptor( - c_state_shape.dim_size(1), c_state_shape.dim_size(0), - c_state_shape.dim_size(2), data_type); - TF_RETURN_IF_ERROR(c_state_desc_s.status()); - *c_state_desc = c_state_desc_s.ConsumeValueOrDie(); + auto cell_state_desc_s = executor->createRnnStateTensorDescriptor( + cell_state_shape.dim_size(1), cell_state_shape.dim_size(0), + cell_state_shape.dim_size(2), data_type); + TF_RETURN_IF_ERROR(cell_state_desc_s.status()); + *c_state_desc = cell_state_desc_s.ConsumeValueOrDie(); } DCHECK_EQ(output_shape.dims(), 3); @@ -1025,7 +1025,7 @@ class CudnnRNNKernelCommon : public OpKernel { // random number generator, therefore set state_allocator to nullptr. const AlgorithmConfig algo_config; auto rnn_desc_s = stream->parent()->createRnnDescriptor( - num_layers, h_num_units, input_size, /*c_size=*/c_num_units, + num_layers, h_num_units, input_size, /*cell_size=*/c_num_units, /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config, dropout(), seed(), /* state_allocator=*/nullptr); @@ -1047,7 +1047,7 @@ class CudnnRNNKernelCommon : public OpKernel { se::dnn::DataType data_type = ToDataType<T>::value; auto rnn_desc_s = executor->createRnnDescriptor( model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, model_shapes.c_num_units, + model_shapes.input_size, model_shapes.cell_num_units, model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(), seed(), dropout_state_allocator); TF_RETURN_IF_ERROR(rnn_desc_s.status()); @@ -1522,14 +1522,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { Tensor** output_c) { const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; - const TensorShape& c_state_shape = model_shapes.c_state_shape; + const TensorShape& cell_state_shape = model_shapes.cell_state_shape; TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output)); TF_RETURN_IF_ERROR( context->allocate_output(1, hidden_state_shape, output_h)); if (HasInputC()) { TF_RETURN_IF_ERROR( - context->allocate_output(2, c_state_shape, output_c)); + context->allocate_output(2, cell_state_shape, output_c)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. @@ -1928,7 +1928,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space)); const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; const TensorShape& output_shape = model_shapes.output_shape; - const TensorShape& c_state_shape = model_shapes.c_state_shape; + const TensorShape& cell_state_shape = model_shapes.cell_state_shape; if (output_shape != (*output)->shape()) { return errors::InvalidArgument( @@ -1954,16 +1954,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { } if (model_types.HasInputC()) { - if (c_state_shape != (*output_c)->shape()) { + if (cell_state_shape != (*output_c)->shape()) { return errors::InvalidArgument( "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ", - c_state_shape.DebugString()); + cell_state_shape.DebugString()); } - if (c_state_shape != (*output_c_backprop)->shape()) { + if (cell_state_shape != (*output_c_backprop)->shape()) { return errors::InvalidArgument( "Invalid output_c_backprop shape: ", (*output_c_backprop)->shape().DebugString(), " ", - c_state_shape.DebugString()); + cell_state_shape.DebugString()); } } return Status::OK(); @@ -1976,7 +1976,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { Tensor** input_c_backprop, Tensor** params_backprop) { const TensorShape& input_shape = model_shapes.input_shape; const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape; - const TensorShape& c_state_shape = model_shapes.c_state_shape; + const TensorShape& cell_state_shape = model_shapes.cell_state_shape; TF_RETURN_IF_ERROR( context->allocate_output(0, input_shape, input_backprop)); @@ -1984,7 +1984,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon { context->allocate_output(1, hidden_state_shape, input_h_backprop)); if (HasInputC()) { TF_RETURN_IF_ERROR( - context->allocate_output(2, c_state_shape, input_c_backprop)); + context->allocate_output(2, cell_state_shape, input_c_backprop)); } else { // Only LSTM uses input_c and output_c. So for all other models, we only // need to create dummy outputs. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 067e83d7a47..b18d246dc65 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1002,7 +1002,7 @@ class CudnnRnnParamsDescriptor { class CudnnRnnDescriptor : public dnn::RnnDescriptor { CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc, PersistentRnnPlan rnn_plan, int num_layers, - int hidden_size, int input_size, int c_size, + int hidden_size, int input_size, int cell_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, @@ -1015,7 +1015,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { num_layers_(num_layers), hidden_size_(hidden_size), input_size_(input_size), - c_size_(c_size), + cell_size_(cell_size), batch_size_(batch_size), rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())), input_mode_(input_mode), @@ -1032,7 +1032,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { static port::StatusOr<CudnnRnnDescriptor> Create( const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size, - int c_size, int batch_size, cudnnRNNInputMode_t input_mode, + int cell_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, @@ -1046,8 +1046,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // TODO: allow the user to choose an algorithm. int unified_size = hidden_size; - if (c_size != 0 && hidden_size < c_size) { - unified_size = c_size; + bool use_projection = cell_size != 0 && hidden_size < cell_size; + if (use_projection) { + unified_size = cell_size; } RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), @@ -1055,7 +1056,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, /*dataType=*/compute_type)); - if (c_size != 0 && hidden_size < c_size) { + if (use_projection) { #if CUDNN_VERSION >= 7101 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers( cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), @@ -1114,7 +1115,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { #endif return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), - num_layers, hidden_size, input_size, c_size, + num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, direction_mode, rnn_mode, data_type, compute_type, algorithm_config, std::move(dropout_desc), std::move(params_desc)); @@ -1124,7 +1125,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int num_layers() const { return num_layers_; } int hidden_size() const { return hidden_size_; } int input_size() const { return input_size_; } - int c_size() const { return c_size_; } + int cell_size() const { return cell_size_; } int batch_size() const { return batch_size_; } cudnnRNNInputMode_t input_mode() const { return input_mode_; } cudnnDirectionMode_t direction_mode() const { return direction_mode_; } @@ -1153,9 +1154,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int num_layers_; int hidden_size_; int input_size_; - // c_size_ is the size of cell state, which will be different from + // cell_size_ is the size of cell state, which will be different from // hidden_size_ if the projection is used. - int c_size_; + int cell_size_; // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC // algorithm. int batch_size_; @@ -1173,62 +1174,65 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { namespace { +// Check if the LSTM projection is used. If yes, an additional weigth matrix +// (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will +// be done. port::Status CheckAndFetchProjectionWeights( const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, int layer, TensorDescriptor& input_desc, FilterDescriptor& filter_desc, FilterDescriptor& region_desc_handle, - dnn::RnnDescriptor::ParamsRegions& weights) { + dnn::RnnDescriptor::ParamsRegions* weights) { #if CUDNN_VERSION >= 7101 - int hidden_size_v; - int num_layers_v; - cudnnDropoutDescriptor_t dropout_desc; - cudnnRNNInputMode_t input_mode; - cudnnDirectionMode_t direction; - cudnnRNNMode_t mode; - cudnnRNNAlgo_t algo; - cudnnDataType_t data_type; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( + int hidden_size_v; + int num_layers_v; + cudnnDropoutDescriptor_t dropout_desc; + cudnnRNNInputMode_t input_mode; + cudnnDirectionMode_t direction; + cudnnRNNMode_t mode; + cudnnRNNAlgo_t algo; + cudnnDataType_t data_type; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*hiddenSize=*/&hidden_size_v, + /*numLayers=*/&num_layers_v, + /*dropoutDesc=*/&dropout_desc, + /*inputMode=*/&input_mode, + /*direction=*/&direction, + /*mode=*/&mode, + /*algo=*/&algo, + /*dataType=*/&data_type)); + int rec_proj_size_v; + int out_proj_size_v; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc, + /*recProjSize*/ &rec_proj_size_v, + /*outProjSize*/ &out_proj_size_v)); + if (rec_proj_size_v != hidden_size_v) { + void* offset = nullptr; + int region_id = 8; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*hiddenSize=*/&hidden_size_v, - /*numLayers=*/&num_layers_v, - /*dropoutDesc=*/&dropout_desc, - /*inputMode=*/&input_mode, - /*direction=*/&direction, - /*mode=*/&mode, - /*algo=*/&algo, - /*dataType=*/&data_type)); - int rec_proj_size_v; - int out_proj_size_v; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers( - /*handle=*/cudnn.handle(), - /*rnnDesc=*/rnn_desc, - /*recProjSize*/ &rec_proj_size_v, - /*outProjSize*/ &out_proj_size_v)); - if (rec_proj_size_v != hidden_size_v) { - void* offset = nullptr; - int region_id = 8; - RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, - /*layer=*/layer, /*xDesc=*/input_desc.get(), - /*wDesc=*/filter_desc.get(), - /*w=*/nullptr, /*linLayerID=*/region_id, - /*linLayerMatDesc=*/region_desc_handle.get(), - /*linLayerMat or linLayerBias=*/&offset)); - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle.get(), - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims)); - int64 size = - dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast<int64>(offset), size}; - weights.push_back(region); - } + /*layer=*/layer, /*xDesc=*/input_desc.get(), + /*wDesc=*/filter_desc.get(), + /*w=*/nullptr, /*linLayerID=*/region_id, + /*linLayerMatDesc=*/region_desc_handle.get(), + /*linLayerMat or linLayerBias=*/&offset)); + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims)); + int64 size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + dnn::RnnDescriptor::ParamsRegion region = { + reinterpret_cast<int64>(offset), size}; + (*weights).push_back(region); + } #endif // CUDNN_VERSION >= 7101 return port::Status::OK(); } @@ -1323,7 +1327,7 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create( TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle, - weights)); + &weights)); } return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, @@ -1488,7 +1492,7 @@ struct RnnModelDims { int max_seq_length = 0; int hidden_size = 0; int input_size = 0; - int c_size = 0; + int cell_size = 0; int dir_count = 0; }; @@ -1514,7 +1518,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( model_dims.max_seq_length = input_desc.max_seq_length(); model_dims.hidden_size = rnn_desc.hidden_size(); model_dims.input_size = input_desc.data_size(); - model_dims.c_size = rnn_desc.c_size(); + model_dims.cell_size = rnn_desc.cell_size(); model_dims.dir_count = (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; @@ -1525,6 +1529,8 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward( input_h_desc.data_size() == model_dims.hidden_size)) { return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape"); } + // The LSTM projection will be used if input_h_desc.data_size() < + // input_c_desc.data_size() if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && input_h_desc.data_size() <= input_c_desc.data_size())) { @@ -1900,8 +1906,9 @@ port::Status CudnnSupport::DoRnnBackwardImpl( port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> CudnnSupport::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int c_size, int batch_size, - dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + int num_layers, int hidden_size, int input_size, int cell_size, + int batch_size, dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, ScratchAllocator* state_allocator) { @@ -1911,7 +1918,7 @@ CudnnSupport::createRnnDescriptor( SE_ASSIGN_OR_RETURN( CudnnRnnDescriptor rnn_desc, CudnnRnnDescriptor::Create( - cudnn, num_layers, hidden_size, input_size, c_size, batch_size, + cudnn, num_layers, hidden_size, input_size, cell_size, batch_size, ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), GetRnnComputeType(data_type), diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 70395d7e70c..11a514c1abd 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2052,7 +2052,7 @@ class DnnSupport { // num_layers: the number of layers for a RNN model. // hidden_size: the size of the hidden state. // input_size: the size of the input state. - // c_size: the size of the cell state + // cell_size: the size of the cell state // input_mode: an enum to specify whether a linear transformation is added // after the input state. If input_size is different from hidden_size, this // is required. @@ -2068,7 +2068,8 @@ class DnnSupport { // is no longer in use. virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(int num_layers, int hidden_size, int input_size, - int c_size, int batch_size, dnn::RnnInputMode input_mode, + int cell_size, int batch_size, + dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 88587ec635c..e71b199a358 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -336,8 +336,9 @@ bool StreamExecutor::GetBlasGemmAlgorithms( port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> StreamExecutor::createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int c_size, int batch_size, - dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, + int num_layers, int hidden_size, int input_size, int cell_size, + int batch_size, dnn::RnnInputMode input_mode, + dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, ScratchAllocator *state_allocator) { @@ -347,7 +348,7 @@ StreamExecutor::createRnnDescriptor( "Fail to find the dnn implementation."); } return dnn_support->createRnnDescriptor( - num_layers, hidden_size, input_size, c_size, batch_size, input_mode, + num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed, state_allocator); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index ec02e305dc3..35e1e0778c3 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -392,7 +392,7 @@ class StreamExecutor { // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int c_size, + int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, From 291d39f5e19b615d52ae1a1c0eeb7618d509c881 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Fri, 31 May 2019 10:08:51 -0700 Subject: [PATCH 05/17] Added apidef files for canon <-> cudnnparams V2 --- .../python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt | 4 ++++ .../python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt | 4 ++++ 2 files changed, 8 insertions(+) create mode 100644 tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt new file mode 100644 index 00000000000..d953e7ccbae --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "CudnnRNNCanonicalToParamsV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt new file mode 100644 index 00000000000..a3ca3dfbca3 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "CudnnRNNParamsToCanonicalV2" + visibility: HIDDEN +} From f38dd432f4300de2a34374caab2616d1f82e5ce6 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Fri, 31 May 2019 13:55:50 -0700 Subject: [PATCH 06/17] use a tuple for batch/time_id --- tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 6190ec31f70..5a6c2924a10 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -1119,8 +1119,7 @@ def _cudnn_rnn(inputs, args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif time_major is False or num_proj: - batch_id = 1 if time_major else 0 - time_id = 0 if time_major else 1 + batch_id, time_id = (1, 0) if time_major else (0, 1) batch_size = array_ops.shape(inputs)[batch_id] max_time = array_ops.shape(inputs)[time_id] sequence_lengths = array_ops.fill([batch_size], max_time) From f4e1bc69baf1e562eb74ecf272edd88768179ea8 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Tue, 4 Jun 2019 10:25:29 -0700 Subject: [PATCH 07/17] minor changes --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 2 +- tensorflow/stream_executor/cuda/cuda_dnn.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index b18d246dc65..6061dce2f68 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1231,7 +1231,7 @@ port::Status CheckAndFetchProjectionWeights( dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); dnn::RnnDescriptor::ParamsRegion region = { reinterpret_cast<int64>(offset), size}; - (*weights).push_back(region); + weights->push_back(region); } #endif // CUDNN_VERSION >= 7101 return port::Status::OK(); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 7041353aa2d..6432622a959 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -47,7 +47,7 @@ class CudnnSupport : public dnn::DnnSupport { port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override; port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( - int num_layers, int hidden_size, int input_size, int c_size, + int num_layers, int hidden_size, int input_size, int cell_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, From f4c2a0c510d5b23d405877d6425aa7c25b18ee2a Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Tue, 4 Jun 2019 11:04:11 -0700 Subject: [PATCH 08/17] change input params as const --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 6061dce2f68..e3d048c8c13 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1178,9 +1178,9 @@ namespace { // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will // be done. port::Status CheckAndFetchProjectionWeights( - const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, int layer, - TensorDescriptor& input_desc, FilterDescriptor& filter_desc, - FilterDescriptor& region_desc_handle, + const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer, + const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc, + const FilterDescriptor& region_desc_handle, dnn::RnnDescriptor::ParamsRegions* weights) { #if CUDNN_VERSION >= 7101 int hidden_size_v; From f65f560adbe4e57585ff2e2029c4d631ea61cda5 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Tue, 4 Jun 2019 12:28:36 -0700 Subject: [PATCH 09/17] change CHECK to OP_REQUIRES --- tensorflow/core/kernels/cudnn_rnn_ops.cc | 30 ++++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 61201d01b72..7f1f952ca9f 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -1215,23 +1215,27 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon { // Number of params applied on inputs. The rest are applied on recurrent // hidden states. const int num_params_input_state = num_params_weights_per_layer / 2; - CHECK(num_params_weights_ % (num_layers * num_dirs) == 0) - << "Number of params (weights) is not a multiple of num_layers * " - "num_dirs."; - CHECK(num_params_biases_ % (num_layers * num_dirs) == 0) - << "Number of params (bias) is not a multiple of num_layers * " - "num_dirs."; + OP_REQUIRES(context, num_params_weights_ % (num_layers * num_dirs) == 0, + errors::InvalidArgument("Number of params (weights) is not a multiple" + "of num_layers * num_dirs.")); + OP_REQUIRES(context, num_params_biases_ % (num_layers * num_dirs) == 0, + errors::InvalidArgument("Number of params (biases) is not a multiple" + "of num_layers * num_dirs.")); if (num_proj_ == 0) { - CHECK(num_params_weights_per_layer % 2 == 0) - << "Number of params per layer is not a even number w/o projection."; + OP_REQUIRES(context, num_params_weights_per_layer % 2 == 0, + errors::InvalidArgument("Number of params (weights) per layer is not" + "an even number with no projection.")); } else { - CHECK(num_params_weights_per_layer % 2 != 0) - << "Number of params per layer is not a odd number w/ projection."; + OP_REQUIRES(context, num_params_weights_per_layer % 2 != 0, + errors::InvalidArgument("Number of params (weights) per layer is not" + "an odl number with projection.")); } - CHECK(num_params_weights_ == rnn_desc->ParamsWeightRegions().size()) - << "C Number of params mismatch. Expected " << num_params_weights_ - << ", got " << rnn_desc->ParamsWeightRegions().size(); + OP_REQUIRES(context, + num_params_weights_ == rnn_desc->ParamsWeightRegions().size(), + errors::InvalidArgument("C Number of params mismatch. Expected ", + num_params_weights_, ", got ", + rnn_desc->ParamsWeightRegions().size())); int h_num_units = (num_proj_ == 0 ? num_units : num_proj_); int c_num_units = (num_proj_ == 0 ? 0 : num_units); for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) { From 6a3d554a582a13e5f4ab977da332625f847beb7f Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 5 Jun 2019 10:29:35 -0700 Subject: [PATCH 10/17] fixed a shapeinference for cudnn rnn --- tensorflow/core/ops/cudnn_rnn_ops_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc index 25121c6484f..ed1e003c897 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc @@ -119,9 +119,9 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) { }; string input_shapes_desc = strings::StrCat( shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";", - shape_to_str(input_h_shape), ";", "[?]", ";", + shape_to_str(input_c_shape), ";", "[?]", ";", shape_to_str(seq_lengths_shape)); - string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?;?"; + string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in2;?;?"; ShapeInferenceTestOp op("CudnnRNNV3"); TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV3") From f7f31ea38bae737e0f802b63611280d3ebeb5ad5 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 5 Jun 2019 10:52:45 -0700 Subject: [PATCH 11/17] Format changes --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index e08b50f5fd4..143acd3212a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -215,20 +215,17 @@ def RunLSTM(sess, (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, cu_bgrad, cu_pwgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op - ], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) else: outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ outputs_op, state_tuple_op, inp_grad_op, - (hgrad_op, cgrad_op), wgrad_op, bgrad_op - ]) + (hgrad_op, cgrad_op), wgrad_op, bgrad_op]) (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, cu_bgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -405,8 +402,9 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): else: (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, - cu_bgrad) = RunLSTM(sess, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths=variable_seq_lengths, + cu_bgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths=variable_seq_lengths, dynamic_shape_input=dynamic_shape_input, num_proj=num_proj) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) From 9d7c1edba17dcb4dc1bbb078ba92c0724557568a Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 5 Jun 2019 13:12:41 -0700 Subject: [PATCH 12/17] declare of input_c_shape --- tensorflow/core/ops/cudnn_rnn_ops_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc index ed1e003c897..90e8c85917e 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc @@ -111,6 +111,8 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) { std::vector<int> input_shape = {max_seq_length, batch_size, num_units}; std::vector<int> input_h_shape = {num_layers * dir_count, batch_size, num_units}; + std::vector<int> input_c_shape = {num_layers * dir_count, batch_size, + num_units}; std::vector<int> output_shape = {max_seq_length, batch_size, num_units * dir_count}; std::vector<int> seq_lengths_shape = {batch_size}; From b64d3dd126ce13cdd03e5d1cf0f870e4d5a30648 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 5 Jun 2019 15:23:54 -0700 Subject: [PATCH 13/17] forward compat --- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 116 +++++++++++------- 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 5a6c2924a10..b6250d04c35 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -249,31 +250,43 @@ class CudnnParamsFormatConverter(object): 2 list for weights and biases respectively. """ with ops.device("/gpu:0"): - if self._num_proj: - num_params_weights = (self._num_params + - 1 * self._num_layers * self._num_dirs) - num_params_biases = self._num_params - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction, - num_params_weights=num_params_weights, - num_params_biases=num_params_biases, - num_proj=self._num_proj) - else: - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26): + if self._num_proj: + num_params_weights = (self._num_params + + 1 * self._num_layers * self._num_dirs) + num_params_biases = self._num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=self._num_proj) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) + return (weights, biases) + + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) return (weights, biases) def _cu_canonical_to_opaque(self, cu_weights, cu_biases): @@ -287,27 +300,38 @@ class CudnnParamsFormatConverter(object): a single opaque tensor. """ with ops.device("/gpu:0"): - if self._num_proj: - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - num_proj=self._num_proj, - direction=self._direction) - else: - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26): + if self._num_proj: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + num_proj=self._num_proj, + direction=self._direction) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) + + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. From f3bc96be5d00883a0e5f1af856d19a1eba6ff2dd Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Wed, 5 Jun 2019 16:17:25 -0700 Subject: [PATCH 14/17] forward compat test --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 268 +++++++++--------- 1 file changed, 138 insertions(+), 130 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 143acd3212a..40ea0e7f228 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -430,20 +431,21 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_training(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, dynamic_shape_input, use_proj): - num_proj = num_units // 2 - if use_proj and num_proj == 0: - self.skipTest("num_proj cannot be 0") - self._test_training_helper( - num_units, - input_size, - batch_size, - time, - num_layers, - dtypes.float32, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -457,22 +459,23 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_training_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, dynamic_shape_input, use_proj): - num_proj = num_units // 2 - if use_proj and num_proj == 0: - self.skipTest("num_proj cannot be 0") - self._test_training_helper( - num_units, - input_size, - batch_size, - time, - num_layers, - dtypes.float16, - rtol=5e-3, - atol=5e-4, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -486,28 +489,29 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_inference(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, dynamic_shape_input, use_proj): - num_proj = num_units // 2 - if use_proj and num_proj == 0: - self.skipTest("num_proj cannot be 0") - with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - self.assertAllClose(outputs, cu_outputs) - # h - self.assertAllClose(state_tuple.h, cu_state_tuple.h) - # c - self.assertAllClose(state_tuple.c, cu_state_tuple.c) + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) @parameterized.named_parameters( ExpandNamedTestCases( @@ -521,32 +525,33 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): def test_inference_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, dynamic_shape_input, use_proj): - num_proj = num_units // 2 - if use_proj and num_proj == 0: - self.skipTest("num_proj cannot be 0") - with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - rtol, atol = 5e-3, 5e-4 - self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) - # h - self.assertAllClose( - state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) - # c - self.assertAllClose( - state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) @parameterized.named_parameters( ExpandNamedTestCases( @@ -561,48 +566,49 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): num_layers, variable_seq_lengths, time_major, dynamic_shape_input, use_proj): """Validates that dropout does not affect Cudnn Rnn inference.""" - num_proj = num_units // 2 - if use_proj and num_proj == 0: - self.skipTest("num_proj cannot be 0") - # Hand-picked dropouts are used below (0. and 1.) - with ops.Graph().as_default() as g: - with self.session(use_gpu=True, graph=g) as sess: - # 1st time w/o dropout. - (_, cu_outputs, _, cu_state_tuple) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dropout=0., - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0., + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - with ops.Graph().as_default() as g: - with self.session(use_gpu=True, graph=g) as sess: - (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dropout=1., - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input, - num_proj=num_proj if use_proj else None) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1., + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - self.assertAllClose(cu_outputs, cu_outputs2) - # h - self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) - # c - self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) def RunGRU(sess, @@ -1042,12 +1048,13 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, for c in NAMED_RNN_TESTCASES) @test_util.run_gpu_only def test_lstmp(self, num_units, input_size, num_layers): - num_proj = num_units // 2 - if num_proj == 0: - self.skipTest("num_proj cannot be 0") - self._test_lstm_helper(num_units, input_size, num_layers, - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - num_proj=num_proj) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, + num_proj=num_proj) @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) @@ -1062,12 +1069,13 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, for c in NAMED_RNN_TESTCASES) @test_util.run_gpu_only def test_lstmp_bidi(self, num_units, input_size, num_layers): - num_proj = num_units // 2 - if num_proj == 0: - self.skipTest("num_proj cannot be 0") - self._test_lstm_helper(num_units, input_size, num_layers, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION, - num_proj=num_proj) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION, + num_proj=num_proj) def _test_gru_helper(self, num_units, input_size, num_layers, direction): with self.session(use_gpu=True) as sess: @@ -1126,8 +1134,8 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, for c in NAMED_RNN_TESTCASES) @test_util.run_gpu_only def test_gru(self, num_units, input_size, num_layers): - self._test_gru_helper(num_units, input_size, num_layers, - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) From ef8517bf4bf28a8ffdbbbf67acc380092bd4d7b6 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Thu, 6 Jun 2019 10:07:58 -0700 Subject: [PATCH 15/17] format changes --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 4 +- .../cudnn_rnn/python/ops/cudnn_rnn_ops.py | 115 +++++++----------- 2 files changed, 48 insertions(+), 71 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 40ea0e7f228..2fc8268b8f4 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -1134,8 +1134,8 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, for c in NAMED_RNN_TESTCASES) @test_util.run_gpu_only def test_gru(self, num_units, input_size, num_layers): - self._test_gru_helper(num_units, input_size, num_layers, - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index b6250d04c35..b2d162cfce5 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -250,43 +250,31 @@ class CudnnParamsFormatConverter(object): 2 list for weights and biases respectively. """ with ops.device("/gpu:0"): - if compat.forward_compatible(2019, 6, 26): - if self._num_proj: - num_params_weights = (self._num_params + - 1 * self._num_layers * self._num_dirs) - num_params_biases = self._num_params - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction, - num_params_weights=num_params_weights, - num_params_biases=num_params_biases, - num_proj=self._num_proj) - else: - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) - return (weights, biases) - - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26) and self._num_proj: + num_params_weights = (self._num_params + + 1 * self._num_layers * self._num_dirs) + num_params_biases = self._num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=self._num_proj) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) return (weights, biases) def _cu_canonical_to_opaque(self, cu_weights, cu_biases): @@ -300,39 +288,28 @@ class CudnnParamsFormatConverter(object): a single opaque tensor. """ with ops.device("/gpu:0"): - if compat.forward_compatible(2019, 6, 26): - if self._num_proj: - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - num_proj=self._num_proj, - direction=self._direction) - else: - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26) and self._num_proj: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + num_proj=self._num_proj, + direction=self._direction) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) - def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. From df2b9b790e5269626daab10a3e3d22a045473c24 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Thu, 6 Jun 2019 20:06:29 -0700 Subject: [PATCH 16/17] updated the goldens --- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 8 -------- tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 8 -------- 2 files changed, 16 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index f2105d3f5db..a5c4fb03e26 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1032,14 +1032,6 @@ tf_module { name: "cross" argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "cudnn_rnn_canonical_to_params_v2" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " - } - member_method { - name: "cudnn_rnn_params_to_canonical_v2" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " - } member_method { name: "cumprod" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 78d9127ad33..fed938e67bb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -540,14 +540,6 @@ tf_module { name: "cosh" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "cudnn_rnn_canonical_to_params_v2" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " - } - member_method { - name: "cudnn_rnn_params_to_canonical_v2" - argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], " - } member_method { name: "cumsum" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " From c0ea11b15886b62ce7184bf9ab4a5e2502afe985 Mon Sep 17 00:00:00 2001 From: Kaixi Hou <kaixih@nvidia.com> Date: Fri, 7 Jun 2019 10:16:34 -0700 Subject: [PATCH 17/17] added two api def files and solved some python format issues --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 19 +++++----- .../api_def_CudnnRNNCanonicalToParamsV2.pbtxt | 36 +++++++++++++++++++ .../api_def_CudnnRNNParamsToCanonicalV2.pbtxt | 36 +++++++++++++++++++ 3 files changed, 81 insertions(+), 10 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 2fc8268b8f4..e516ec8ce9b 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -208,16 +208,15 @@ def RunLSTM(sess, if is_training: if num_proj: - outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad, pwgrad = \ - sess.run([ - outputs_op, state_tuple_op, inp_grad_op, - (hgrad_op, cgrad_op), wgrad_op, bgrad_op, pwgrad_op - ]) + (outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad, + pwgrad) = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op, pwgrad_op]) (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, cu_bgrad, cu_pwgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op + ], feed_dict={inputs: inputs_np} if dynamic_shape_input else None) else: outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ outputs_op, state_tuple_op, inp_grad_op, @@ -225,9 +224,9 @@ def RunLSTM(sess, (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, cu_bgrad) = sess.run([ cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) - + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ], feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) logging.vlog(1, "state_tuple: %s" % str(state_tuple)) diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt new file mode 100644 index 00000000000..cac3674c1ae --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "CudnnRNNCanonicalToParamsV2" + summary: "Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM." + description: <<END +Writes a set of weights into the opaque params buffer so they can be used in +upcoming training or inferences. + +Note that the params buffer may not be compatible across different GPUs. So any +save and restoration should be converted to and from the canonical weights and +biases. + +num_layers: Specifies the number of layers in the RNN model. +num_units: Specifies the size of the hidden state. +input_size: Specifies the size of the input state. +weights: the canonical form of weights that can be used for saving + and restoration. They are more likely to be compatible across different + generations. +biases: the canonical form of biases that can be used for saving + and restoration. They are more likely to be compatible across different + generations. +num_params_weigths: number of weight parameter matrix for all layers. +num_params_biases: number of bias parameter vector for all layers. +rnn_mode: Indicates the type of the RNN model. +input_mode: Indicate whether there is a linear projection between the input and + The actual computation before the first layer. 'skip_input' is only allowed + when input_size == num_units; 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. +direction: Indicates whether a bidirectional model will be used. + dir = (direction == bidirectional) ? 2 : 1 +dropout: dropout probability. When set to 0., dropout is disabled. +seed: the 1st part of a seed to initialize dropout. +seed2: the 2nd part of a seed to initialize dropout. +num_proj: The output dimensionality for the projection matrices. If None or 0, + no projection is performed. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt new file mode 100644 index 00000000000..aa51414ba23 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "CudnnRNNParamsToCanonicalV2" + summary: "Retrieves CudnnRNN params in canonical form. It supports the projection in LSTM." + description: <<END +Retrieves a set of weights from the opaque params buffer that can be saved and +restored in a way compatible with future runs. + +Note that the params buffer may not be compatible across different GPUs. So any +save and restoration should be converted to and from the canonical weights and +biases. + +num_layers: Specifies the number of layers in the RNN model. +num_units: Specifies the size of the hidden state. +input_size: Specifies the size of the input state. +num_params_weigths: number of weight parameter matrix for all layers. +num_params_biases: number of bias parameter vector for all layers. +weights: the canonical form of weights that can be used for saving + and restoration. They are more likely to be compatible across different + generations. +biases: the canonical form of biases that can be used for saving + and restoration. They are more likely to be compatible across different + generations. +rnn_mode: Indicates the type of the RNN model. +input_mode: Indicate whether there is a linear projection between the input and + The actual computation before the first layer. 'skip_input' is only allowed + when input_size == num_units; 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. +direction: Indicates whether a bidirectional model will be used. + dir = (direction == bidirectional) ? 2 : 1 +dropout: dropout probability. When set to 0., dropout is disabled. +seed: the 1st part of a seed to initialize dropout. +seed2: the 2nd part of a seed to initialize dropout. +num_proj: The output dimensionality for the projection matrices. If None or 0, + no projection is performed. +END +}