From 4b914607fb8384deb22e5b822dca2624e1758718 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 4 Dec 2018 15:48:05 -0800 Subject: [PATCH] Update the condition for using defun backend in LSTM. Currently defun backend is disabled when the cudnn implementation doesn't support certain behavior, which end up different results between different implementation. 1. The cudnn backend only support 'tanh' as activation. 2. Cudnn does not support recurrent dropout. 3. Cudnn LSTM cannot unroll. 4. Cudnn always use bias gate. 5. If a bias regularizer is specified, it will cause some mathematical difference when save/reload the weight. We disable defun in this case as well. PiperOrigin-RevId: 224060654 --- tensorflow/python/keras/layers/recurrent.py | 53 ++++++++++++++--- .../python/keras/layers/unified_lstm_test.py | 57 +++++++++++++++++++ 2 files changed, 102 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index d9502dfc5b7..189ad987942 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -2661,9 +2661,10 @@ class UnifiedLSTM(LSTM): ] self._num_constants = None self._num_inputs = None + self._dropout_mask = None self.could_use_cudnn = ( - activation == 'tanh' and dropout == 0 and not unroll and use_bias and - unit_forget_bias) + activation == 'tanh' and recurrent_dropout == 0 and + not unroll and use_bias and bias_regularizer is None) def build(self, input_shape): super(UnifiedLSTM, self).build(input_shape) @@ -2722,6 +2723,16 @@ class UnifiedLSTM(LSTM): combined_bias = array_ops.concat([self.cudnn_bias, self.cell.bias], 0) + if 0 < self.dropout < 1: + if self._dropout_mask is None: + self._dropout_mask = _generate_dropout_mask( + array_ops.ones_like(inputs), + self.dropout, + training=training, + count=4) + + inputs *= self._dropout_mask[0] + # Each time a defun function is called, we will give a unique identifiable # API name, so that the grappler won't get confused when it sees multiple # LSTM layer added into same graph, and it will be able to pair up the @@ -2835,9 +2846,33 @@ class UnifiedLSTM(LSTM): K.batch_set_value(tuples) -def _canonical_to_params(weights, biases, shape): - """Utility function convert variable to CuDNN compatible parameter.""" - weights = [array_ops.reshape(x, shape) for x in weights] +def _canonical_to_params(weights, biases, shape, transpose_weights=False): + """Utility function convert variable to CuDNN compatible parameter. + + Note that Keras weights for kernels are different from the CuDNN format. Eg.: + + ``` + Keras CuDNN + [[0, 1, 2], <---> [[0, 2, 4], + [3, 4, 5]] [1, 3, 5]] + ``` + + If the input weights need to be in a unified format, then set + `transpose_weights=True` to convert the weights. + + Args: + weights: list of weights for the individual kernels and recurrent kernels. + biases: list of biases for individual gate. + shape: the shape for the converted variables that will be feed to CuDNN. + transpose_weights: boolean, whether to transpose the weights. + + Returns: + The converted weights that can be feed to CuDNN ops as param. + """ + def convert(w): + return array_ops.transpose(w) if transpose_weights else w + + weights = [array_ops.reshape(convert(x), shape) for x in weights] biases = [array_ops.reshape(x, shape) for x in biases] return array_ops.concat(weights + biases, axis=0) @@ -2930,15 +2965,17 @@ def cudnn_lstm(inputs, input_h, input_c, kernel, recurrent_kernel, bias, params = _canonical_to_params( weights=weights, biases=array_ops.split(bias, 8), - shape=constant_op.constant([-1])) + shape=constant_op.constant([-1]), + transpose_weights=True) outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn( - inputs, input_h=input_h, input_c=input_c, params=params) + inputs, input_h=input_h, input_c=input_c, params=params, is_training=True) + last_output = outputs[-1] if not time_major: outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) h = h[0] c = c[0] - last_output = outputs[:, -1, :] + return last_output, outputs, h, c, constant_op.constant( 'cudnn', dtype=dtypes.string, name='runtime') diff --git a/tensorflow/python/keras/layers/unified_lstm_test.py b/tensorflow/python/keras/layers/unified_lstm_test.py index d229d14312f..b004284140c 100644 --- a/tensorflow/python/keras/layers/unified_lstm_test.py +++ b/tensorflow/python/keras/layers/unified_lstm_test.py @@ -157,6 +157,63 @@ class UnifiedLSTMTest(test.TestCase, parameterized.TestCase): self.assertNotEqual(existing_loss, loss_value) existing_loss = loss_value + @parameterized.named_parameters( + ('_non_tan_activation', 'relu', 0, False, True, None), + ('_use_recurrent_dropout', 'tanh', 0.1, False, True, None), + ('_unroll', 'tanh', 0, True, True, None), + ('_not_use_bias', 'tanh', 0, False, False, None), + ('_use_bias_regularizer', 'tanh', 0, False, True, 'l2') + ) + @test_util.run_in_graph_and_eager_modes(config=_config) + def test_could_use_defun_backend(self, activation, recurrent_dropout, + unroll, use_bias, bias_regularizer): + layer = UnifiedLSTM(1, + activation=activation, + recurrent_dropout=recurrent_dropout, + unroll=unroll, + use_bias=use_bias, + bias_regularizer=bias_regularizer) + self.assertFalse(layer.could_use_cudnn) + + @test_util.run_in_graph_and_eager_modes(config=_config) + def test_unified_lstm_output_on_multiple_kernel(self): + input_shape = 10 + rnn_state_size = 8 + timestep = 4 + batch = 100 + + x_train = np.random.random((batch, timestep, input_shape)) + + inputs = keras.layers.Input( + shape=[timestep, input_shape], dtype=dtypes.float32) + with test_util.device(use_gpu=False): + # Note that CuDNN use 'sigmoid' as activation. Force the CPU + # implementation to use 'sigmoid' so that it will generate same output as + # CuDNN implementation. + layer = UnifiedLSTM(rnn_state_size, recurrent_activation='sigmoid') + output = layer(inputs) + cpu_model = keras.models.Model(inputs, output) + weights = cpu_model.get_weights() + y_1 = cpu_model.predict(x_train) + + with test_util.device(use_gpu=True): + layer = UnifiedLSTM(rnn_state_size, recurrent_activation='sigmoid') + output = layer(inputs) + gpu_model = keras.models.Model(inputs, output) + gpu_model.set_weights(weights) + y_2 = gpu_model.predict(x_train) + + with test_util.device(use_gpu=True): + layer = keras.layers.LSTM(rnn_state_size, recurrent_activation='sigmoid') + output = layer(inputs) + canonical_model = keras.models.Model(inputs, output) + # Remove the extra cudnn bias since canonical lstm will not use it. + canonical_model.set_weights(weights[:3]) + y_3 = canonical_model.predict(x_train) + + self.assertAllClose(y_1, y_2) + self.assertAllClose(y_2, y_3) + @test_util.run_in_graph_and_eager_modes(config=_config) def test_keras_model_with_lstm(self): input_shape = 10