diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 5c63ee7a97b..7cba7937f1b 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -71,7 +72,8 @@ def RunLSTM(sess, is_training=True, dropout=0., num_dirs=True, - dtype=dtypes.float32): + dtype=dtypes.float32, + num_proj=None): # TODO(jamesqin): add multi-layer tests. # TODO(jamesqin): add multi-dir tests assert num_layers == 1 @@ -91,10 +93,12 @@ def RunLSTM(sess, inputs_dynamic = array_ops.placeholder( dtype, shape=[None, None, None], name="inputs") inputs = inputs_dynamic if dynamic_shape_input else inputs_static + unified_num_units = num_proj if num_proj else num_units + unified_num_proj = num_proj if num_proj else None initial_h_op = variable_scope.get_variable( "initial_h_op", - initializer=np.random.rand(batch_size, - num_units).astype(dtype.as_numpy_dtype), + initializer=np.random.rand(batch_size, unified_num_units).astype( + dtype.as_numpy_dtype), dtype=dtype) initial_c_op = variable_scope.get_variable( "initial_c_op", @@ -115,13 +119,19 @@ def RunLSTM(sess, with variable_scope.variable_scope("test", initializer=initializer): w = variable_scope.get_variable( "rnn/lstm_cell/kernel", - shape=[input_size + num_units, num_units * 4], + shape=[input_size + unified_num_units, num_units * 4], dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + if num_proj: + pw = variable_scope.get_variable( + "rnn/lstm_cell/projection/kernel", + shape=[num_units, num_proj], + dtype=dtype) # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. - cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + cell = rnn_cell_impl.LSTMCell( + num_units, forget_bias=0., reuse=True, num_proj=unified_num_proj) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs_static, @@ -134,8 +144,13 @@ def RunLSTM(sess, # Convert to cudnn opaque param. format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( - num_layers, num_units, input_size) - opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + num_layers, num_units, input_size, num_proj=unified_num_proj) + if num_proj: + opaque_params = format_converter.tf_canonical_to_opaque([w, b], [ + pw, + ]) + else: + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) cu_initial_h_op = array_ops.expand_dims( initial_h_op, axis=(0 if time_major else 1)) @@ -150,16 +165,22 @@ def RunLSTM(sess, time_major=time_major, dropout=dropout, is_training=is_training, - rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM, + num_proj=unified_num_proj) # Remove the trivial 1st dimension. cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1)) if is_training: - (inp_grad_op, hgrad_op, - cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( - outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) + if num_proj: + (inp_grad_op, hgrad_op, cgrad_op, + wgrad_op, bgrad_op, pwgrad_op) = gradients_impl.gradients( + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b, pw]) + else: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) (cu_inp_grad_op, cu_hgrad_op, cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( @@ -170,10 +191,16 @@ def RunLSTM(sess, # Remove the trivial 1st dimension cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1) - cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( - opaque_grad_op) + if num_proj: + cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op = \ + format_converter.opaque_to_tf_canonical(opaque_grad_op) + else: + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) cu_wgrad_op = cu_wgrad_op[0] cu_bgrad_op = cu_bgrad_op[0] + if num_proj: + cu_pwgrad_op = cu_pwgrad_op[0] # cudnn lstm has 2 biases each gate. When converting to tf canonical format, # the two biases are summed into one. Thus here bias gradient should be # halved when comparing with tf lstm. @@ -183,17 +210,32 @@ def RunLSTM(sess, sess.run(init_op) if is_training: - outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ - outputs_op, state_tuple_op, inp_grad_op, - (hgrad_op, cgrad_op), wgrad_op, bgrad_op - ]) - (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, - cu_bgrad) = sess.run( - [ - cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ], - feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + if num_proj: + (outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad, + pwgrad) = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op), + wgrad_op, bgrad_op, pwgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad, cu_pwgrad) = sess.run( + [ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, + cu_pwgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) + else: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op), + wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run( + [ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -205,11 +247,20 @@ def RunLSTM(sess, logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) logging.vlog(1, "wgrad: %s" % str(wgrad)) logging.vlog(1, "bgrad: %s" % str(bgrad)) + if num_proj: + logging.vlog(1, "pwgrad: %s" % str(bgrad)) logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) - return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, - cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, - cu_bgrad) + if num_proj: + logging.vlog(1, "cu_pwgrad: %s" % str(cu_bgrad)) + if num_proj: + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, + cu_wgrad, cu_bgrad, cu_pwgrad) + else: + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], @@ -256,7 +307,6 @@ NAMED_RNN_TESTCASES = ({ "num_layers": 1, }) - def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): """Expands testcase with new config dimensions. @@ -349,19 +399,35 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): time_major, dynamic_shape_input=False, rtol=3e-6, - atol=3e-6): + atol=3e-6, + num_proj=None): with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, - state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + if num_proj is not None and num_proj != 0: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, cu_wgrad, + cu_bgrad, cu_pwgrad) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj) + else: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -371,6 +437,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) + if num_proj is not None and num_proj != 0: + self.assertAllClose(pwgrad, cu_pwgrad, rtol=rtol, atol=atol) @parameterized.named_parameters( ExpandNamedTestCases( @@ -378,20 +446,27 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major, dynamic_shape_input): - self._test_training_helper( - num_units, - input_size, - batch_size, - time, - num_layers, - dtypes.float32, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + variable_seq_lengths, time_major, dynamic_shape_input, + use_proj): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -399,52 +474,29 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_training_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): - self._test_training_helper( - num_units, - input_size, - batch_size, - time, - num_layers, - dtypes.float16, - rtol=5e-3, - atol=5e-4, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) - - @parameterized.named_parameters( - ExpandNamedTestCases( - NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - "time_major": [True, False], - "dynamic_shape_input": [True, False], - })) - @test_util.run_gpu_only - def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths, time_major, dynamic_shape_input): - with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( - sess, + dynamic_shape_input, use_proj): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_training_helper( num_units, input_size, batch_size, time, num_layers, - is_training=False, + dtypes.float16, + rtol=5e-3, + atol=5e-4, variable_seq_lengths=variable_seq_lengths, time_major=time_major, - dynamic_shape_input=dynamic_shape_input) - - self.assertAllClose(outputs, cu_outputs) - # h - self.assertAllClose(state_tuple.h, cu_state_tuple.h) - # c - self.assertAllClose(state_tuple.c, cu_state_tuple.c) + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) @parameterized.named_parameters( ExpandNamedTestCases( @@ -452,33 +504,75 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], + })) + @test_util.run_gpu_only + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths, time_major, dynamic_shape_input, + use_proj): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) + + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_inference_fp16(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): - with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + dynamic_shape_input, use_proj): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - rtol, atol = 5e-3, 5e-4 - self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) - # h - self.assertAllClose( - state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) - # c - self.assertAllClose( - state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) @parameterized.named_parameters( ExpandNamedTestCases( @@ -486,49 +580,56 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): "variable_seq_lengths": [True, False], "time_major": [True, False], "dynamic_shape_input": [True, False], + "use_proj": [True, False], })) @test_util.run_gpu_only def test_inference_with_dropout(self, num_units, input_size, batch_size, time, num_layers, variable_seq_lengths, time_major, - dynamic_shape_input): + dynamic_shape_input, use_proj): """Validates that dropout does not affect Cudnn Rnn inference.""" - # Hand-picked dropouts are used below (0. and 1.) - with ops.Graph().as_default() as g: - with self.session(use_gpu=True, graph=g) as sess: - # 1st time w/o dropout. - (_, cu_outputs, _, cu_state_tuple) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dropout=0., - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if use_proj and num_proj == 0: + self.skipTest("num_proj cannot be 0") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0., + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - with ops.Graph().as_default() as g: - with self.session(use_gpu=True, graph=g) as sess: - (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( - sess, - num_units, - input_size, - batch_size, - time, - num_layers, - is_training=False, - dropout=1., - variable_seq_lengths=variable_seq_lengths, - time_major=time_major, - dynamic_shape_input=dynamic_shape_input) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1., + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input, + num_proj=num_proj if use_proj else None) - self.assertAllClose(cu_outputs, cu_outputs2) - # h - self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) - # c - self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) def RunGRU(sess, @@ -890,40 +991,68 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, parameterized.TestCase): """Class for testing various format converters.""" - def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + def _test_lstm_helper(self, + num_units, + input_size, + num_layers, + direction, + num_proj=None): with self.session(use_gpu=True) as sess: random_seed.set_random_seed(0) np.random.seed(0) num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( - num_layers, num_units, input_size, direction=direction) + num_layers, + num_units, + input_size, + direction=direction, + num_proj=num_proj if num_proj else None) - ws, bs = [], [] + ws, bs, pws = [], [], [] for _ in range(num_layers * num_dirs): w = constant_op.constant( - np.random.rand(input_size + num_units, 4 * num_units), + np.random.rand(input_size + (num_proj if num_proj else num_units), + 4 * num_units), dtype=dtypes.float32) b = constant_op.constant( np.random.rand(4 * num_units), dtype=dtypes.float32) ws.append(w) bs.append(b) + if num_proj: + pw = constant_op.constant( + np.random.rand(num_units, num_proj), dtype=dtypes.float32) + pws.append(pw) + + if num_proj: + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs, pws) + else: + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( cudnn_rnn_ops.CUDNN_LSTM, num_layers, num_units, input_size, - direction=direction) + direction=direction, + num_proj=num_proj if num_proj else None) - ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + if num_proj: + ws_r, bs_r, pws_r = format_converter.opaque_to_tf_canonical( + opaque_params) + ws, ws_r, pws, bs, bs_r, pws_r = sess.run( + [ws, ws_r, pws, bs, bs_r, pws_r]) + else: + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() # returns the original input. - ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) for w, w_r in zip(ws, ws_r): self.assertAllClose(w, w_r) + if num_proj: + for pw, pw_r in zip(pws, pws_r): + self.assertAllClose(pw, pw_r) for b, b_r in zip(bs, bs_r): self.assertAllClose(b, b_r) @@ -942,6 +1071,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + @parameterized.named_parameters( + (c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @test_util.run_gpu_only + def test_lstmp(self, num_units, input_size, num_layers): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper( + num_units, + input_size, + num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, + num_proj=num_proj) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) @@ -950,6 +1095,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + @parameterized.named_parameters( + (c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @test_util.run_gpu_only + def test_lstmp_bidi(self, num_units, input_size, num_layers): + with compat.forward_compatibility_horizon(2019, 6, 27): + num_proj = num_units // 2 + if num_proj == 0: + self.skipTest("num_proj cannot be 0") + self._test_lstm_helper( + num_units, + input_size, + num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION, + num_proj=num_proj) + def _test_gru_helper(self, num_units, input_size, num_layers, direction): with self.session(use_gpu=True) as sess: random_seed.set_random_seed(0) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 3694d112ce4..7ee00ade6d2 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -186,6 +187,7 @@ class CudnnParamsFormatConverter(object): num_layers, num_units, input_size, + num_proj=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION): """Constructor. @@ -195,6 +197,8 @@ class CudnnParamsFormatConverter(object): num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the num_units. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be one of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' @@ -209,14 +213,16 @@ class CudnnParamsFormatConverter(object): self._input_size = input_size self._num_units = num_units self._input_mode = input_mode + self._num_proj = num_proj self._direction = direction self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - def tf_canonical_to_opaque(self, tf_canonicals): + def tf_canonical_to_opaque(self, tf_canonicals, weights_proj=None): r"""Converts tf canonical weights to cudnn opaque param.""" - cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical( + tf_canonicals, weights_proj) cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) return opaque_params @@ -224,8 +230,14 @@ class CudnnParamsFormatConverter(object): def opaque_to_tf_canonical(self, opaque_param): r"""Converts cudnn opaque param to tf canonical weights.""" cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) - weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) - return weights, biases + if self._num_proj: + weights, biases, weights_proj = self._cu_canonical_to_tf_canonical( + cu_weights, cu_biases) + return weights, biases, weights_proj + else: + weights, biases = self._cu_canonical_to_tf_canonical( + cu_weights, cu_biases) + return weights, biases def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. @@ -238,15 +250,31 @@ class CudnnParamsFormatConverter(object): 2 list for weights and biases respectively. """ with ops.device("/gpu:0"): - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=opaque_param, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26) and self._num_proj: + num_params_weights = ( + self._num_params + 1 * self._num_layers * self._num_dirs) + num_params_biases = self._num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=self._num_proj) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=opaque_param, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) return (weights, biases) def _cu_canonical_to_opaque(self, cu_weights, cu_biases): @@ -260,15 +288,27 @@ class CudnnParamsFormatConverter(object): a single opaque tensor. """ with ops.device("/gpu:0"): - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + if compat.forward_compatible(2019, 6, 26) and self._num_proj: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + num_proj=self._num_proj, + direction=self._direction) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. @@ -294,9 +334,11 @@ class CudnnParamsFormatConverter(object): 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] + tf_weights_proj = [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num + layer_weights_num += (1 * self._num_dirs) if self._num_proj else 0 for i in range(self._num_layers): layer_weights = cu_weights[i * layer_weights_num:(i + 1) * @@ -305,7 +347,8 @@ class CudnnParamsFormatConverter(object): if self._direction == CUDNN_RNN_UNIDIRECTION: self._cu_canonical_to_tf_canonical_single_layer(layer_weights, layer_biases, - tf_weights, tf_biases) + tf_weights, tf_biases, + tf_weights_proj) else: fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] @@ -317,6 +360,7 @@ class CudnnParamsFormatConverter(object): fw_biases, tf_weights, tf_biases, + tf_weights_proj, ) self._cu_canonical_to_tf_canonical_single_layer( @@ -324,11 +368,19 @@ class CudnnParamsFormatConverter(object): bw_biases, tf_weights, tf_biases, + tf_weights_proj, ) - return (tf_weights, tf_biases) + if self._num_proj: + return (tf_weights, tf_biases, tf_weights_proj) + else: + return (tf_weights, tf_biases) - def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): + def _cu_canonical_to_tf_canonical_single_layer(self, + cu_weights, + cu_biases, + tf_weights, + tf_biases, + tf_weigths_proj=None): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -343,7 +395,7 @@ class CudnnParamsFormatConverter(object): """ raise NotImplementedError("Abstract method") - def _tf_canonical_to_cu_canonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals, weights_proj): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -362,6 +414,7 @@ class CudnnParamsFormatConverter(object): --------------- |fwd |bak | --------------- + weights_proj: (optional) weights matrices for projection Returns: 2 lists: the recovered cudnn canonical weights and biases. """ @@ -376,6 +429,9 @@ class CudnnParamsFormatConverter(object): layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights)) + if weights_proj is not None: + pw = array_ops.transpose(weights_proj[i]) + cu_weights.append(pw) cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases)) else: fw_weights, bw_weights = layer_weights[:len(layer_weights) // @@ -385,9 +441,15 @@ class CudnnParamsFormatConverter(object): 2], layer_biases[len(layer_biases ) // 2:] cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights)) + if weights_proj is not None: + pw0 = array_ops.transpose(weights_proj[2 * i + 0]) + cu_weights.append(pw0) cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases)) cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights)) + if weights_proj is not None: + pw1 = array_ops.transpose(weights_proj[2 * i + 1]) + cu_weights.append(pw1) cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases)) return cu_weights, cu_biases @@ -423,7 +485,10 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" - w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights + if self._num_proj: + w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o, pw = cu_weights + else: + w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights # pylint: disable=invalid-name W_i = array_ops.concat([w_i, r_i], axis=1) @@ -433,7 +498,11 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): # pylint: enable=invalid-name # Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order. reordered = self._cudnn_to_tf_gate_params(*[W_i, W_f, W_c, W_o]) - return (array_ops.transpose(array_ops.concat(reordered, axis=0)),) + if self._num_proj: + return (array_ops.transpose(array_ops.concat(reordered, axis=0)), + array_ops.transpose(pw)) + else: + return (array_ops.transpose(array_ops.concat(reordered, axis=0)),) def _tf_to_cudnn_weights(self, layer, *tf_weights): r"""Reverse the operations in StitchWeights().""" @@ -442,7 +511,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): if layer == 0: input_weight_width = input_size else: - input_weight_width = num_units + input_weight_width = self._num_proj if self._num_proj else num_units if self._direction == CUDNN_RNN_BIDIRECTION: input_weight_width *= 2 @@ -452,10 +521,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params( *array_ops.split(w, 4, axis=0)) - w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1) - w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1) - w_f, r_f = array_ops.split(W_f, [input_weight_width, num_units], axis=1) - w_o, r_o = array_ops.split(W_o, [input_weight_width, num_units], axis=1) + hidden_state_width = self._num_proj if self._num_proj else num_units + w_i, r_i = array_ops.split( + W_i, [input_weight_width, hidden_state_width], axis=1) + w_c, r_c = array_ops.split( + W_c, [input_weight_width, hidden_state_width], axis=1) + w_f, r_f = array_ops.split( + W_f, [input_weight_width, hidden_state_width], axis=1) + w_o, r_o = array_ops.split( + W_o, [input_weight_width, hidden_state_width], axis=1) return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o # pylint: enable=invalid-name @@ -490,11 +564,20 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): # Return ifco order for Cudnn LSTM. return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro - def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): - (w,) = self._cudnn_to_tf_weights(*cu_weights) + def _cu_canonical_to_tf_canonical_single_layer(self, + cu_weights, + cu_biases, + tf_weights, + tf_biases, + tf_weights_proj=None): + if self._num_proj: + (w, pw) = self._cudnn_to_tf_weights(*cu_weights) + tf_weights.append(w) + tf_weights_proj.append(pw) + else: + (w,) = self._cudnn_to_tf_weights(*cu_weights) + tf_weights.append(w) (b,) = self._cudnn_to_tf_biases(*cu_biases) - tf_weights.append(w) tf_biases.append(b) @@ -561,8 +644,12 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): b_ri, b_rr = array_ops.split(br, 2, axis=0) return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh - def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, - tf_weights, tf_biases): + def _cu_canonical_to_tf_canonical_single_layer(self, + cu_weights, + cu_biases, + tf_weights, + tf_biases, + tf_weights_proj=None): # pylint: disable=invalid-name W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) @@ -735,8 +822,11 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): def format_converter(self): if self._format_converter is None: self._format_converter = self._format_converter_cls( - self._num_layers, self._num_units, self._input_size, self._input_mode, - self._direction) + self._num_layers, + self._num_units, + self._input_size, + input_mode=self._input_mode, + direction=self._direction) return self._format_converter def restore(self, restored_tensors, restored_shapes): @@ -970,6 +1060,7 @@ def _cudnn_rnn(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + num_proj=None, name=None): """Cudnn RNN. @@ -1006,6 +1097,8 @@ def _cudnn_rnn(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: @@ -1035,13 +1128,16 @@ def _cudnn_rnn(inputs, if sequence_lengths is not None: args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major + args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) - elif time_major is False: - batch_size = array_ops.shape(inputs)[0] - max_time = array_ops.shape(inputs)[1] + elif time_major is False or num_proj: + batch_id, time_id = (1, 0) if time_major else (0, 1) + batch_size = array_ops.shape(inputs)[batch_id] + max_time = array_ops.shape(inputs)[time_id] sequence_lengths = array_ops.fill([batch_size], max_time) args["sequence_lengths"] = sequence_lengths args["time_major"] = time_major + args["num_proj"] = 0 if num_proj is None else num_proj outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) @@ -1061,6 +1157,7 @@ def cudnn_lstm(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + num_proj=None, name=None): """Cudnn LSTM. @@ -1096,6 +1193,8 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: @@ -1103,7 +1202,7 @@ def cudnn_lstm(inputs, """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, sequence_lengths, time_major, input_mode, direction, - dropout, seed, name) + dropout, seed, num_proj, name) def _cudnn_rnn_no_input_c(inputs, @@ -1160,7 +1259,7 @@ def _cudnn_rnn_no_input_c(inputs, outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, time_major, input_mode, direction, dropout, - seed, name) + seed, None, name) return outputs, output_h @@ -1331,6 +1430,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, direction=CUDNN_RNN_UNIDIRECTION, dropout=0, seed=0, + num_proj=None, name=None): """Convert cudnn opaque params to canonical. @@ -1353,6 +1453,8 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: @@ -1366,19 +1468,39 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - rnn_mode=rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - params=params, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - num_params=num_params, - name=name) + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + if num_proj is not None and num_proj != 0: + num_params_weights = (num_params + 1 * num_layers * num_dirs) + num_params_biases = num_params + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + params=params, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_params_weights=num_params_weights, + num_params_biases=num_params_biases, + num_proj=num_proj, + name=name) + else: + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + params=params, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_params=num_params, + name=name) return weights, biases @@ -1392,6 +1514,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, direction=CUDNN_RNN_UNIDIRECTION, dropout=0, seed=0, + num_proj=None, name=None): """Converts params from the canonical format to a specific format of cuDNN. @@ -1415,6 +1538,8 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: @@ -1426,19 +1551,35 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, check_direction(direction) check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - rnn_mode=rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - weights=weights, - biases=biases, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - name=name) + if num_proj is not None and num_proj != 0: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + weights=weights, + biases=biases, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_proj=num_proj, + name=name) + else: + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + weights=weights, + biases=biases, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + name=name) def cudnn_rnn_opaque_params_size(rnn_mode, @@ -1450,6 +1591,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, dtype=dtypes.float32, dropout=0, seed=0, + num_proj=None, name=None): """Returns opaque params size for specific Cudnn config. @@ -1472,6 +1614,8 @@ def cudnn_rnn_opaque_params_size(rnn_mode, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. name: name of the operation. Returns: @@ -1488,6 +1632,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, num_layers=num_layers, num_units=num_units, input_size=input_size, + num_proj=num_proj, T=dtype, S=dtypes.int32, dropout=dropout, @@ -1516,7 +1661,8 @@ class _CudnnRNN(object): direction=CUDNN_RNN_UNIDIRECTION, dtype=dtypes.float32, dropout=0., - seed=0): + seed=0, + num_proj=None): """Creates a CudnnRNN model from model spec. Args: @@ -1539,6 +1685,8 @@ class _CudnnRNN(object): dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.compat.v1.set_random_seed` for behavior. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. Raises: ValueError: if direction is invalid. @@ -1552,6 +1700,7 @@ class _CudnnRNN(object): self._dtype = dtype self._dropout = dropout self._seed = seed + self._num_proj = num_proj @property def input_mode(self): @@ -1577,6 +1726,10 @@ class _CudnnRNN(object): def direction(self): return self._direction + @property + def num_proj(self): + return self._num_proj + def params_size(self): """Calculates the size of the opaque parameter buffer needed for this model. @@ -1588,6 +1741,7 @@ class _CudnnRNN(object): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, + num_proj=self._num_proj, dtype=self._dtype, dropout=self._dropout, seed=self._seed, @@ -1643,7 +1797,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) def params_to_canonical(self, params): """Converts params from a specific format of cuDNN to the canonical format. @@ -1663,7 +1818,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) def canonical_to_params(self, weights, biases): """Converts params from the canonical format to a specific format of cuDNN. @@ -1685,7 +1841,8 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed) + seed=self._seed, + num_proj=self._num_proj) class CudnnLSTM(_CudnnRNN): @@ -1703,7 +1860,8 @@ class CudnnLSTM(_CudnnRNN): direction=CUDNN_RNN_UNIDIRECTION, dtype=dtypes.float32, dropout=0., - seed=0): + seed=0, + num_proj=None): """Creates a Cudnn LSTM model from model spec. Args: @@ -1721,6 +1879,8 @@ class CudnnLSTM(_CudnnRNN): dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. + num_proj: The output dimensionality for the projection matrices. + If None or 0, no projection is performed. """ super(CudnnLSTM, self).__init__( CUDNN_LSTM, @@ -1731,7 +1891,8 @@ class CudnnLSTM(_CudnnRNN): direction=direction, dtype=dtype, dropout=dropout, - seed=seed) + seed=seed, + num_proj=num_proj) def __call__(self, input_data, diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt new file mode 100644 index 00000000000..cac3674c1ae --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "CudnnRNNCanonicalToParamsV2" + summary: "Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM." + description: <