diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index 54d9e32fb25..15ba5f78d96 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -551,10 +551,68 @@ def preprocess_weights_for_loading(layer, if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + weights = _convert_unified_lstm_weights(layer, weights) + # convert CuDNN layers return _convert_rnn_weights(layer, weights) +def _convert_unified_lstm_weights(layer, weights): + """Converts weights for Unified LSTM layer. + + The input weights suppose to have 2, 3 or 4 items. + 1. kernel. (i, f, c, o gates concat among axis 1) + 2. recurrent_kernel. (i, f, c, o concat among axis 1) + 3. recurrent_bias. (optional, only available when use bias) + 4. input_bias (optional, only available when use bias and cudnn). + Kernel and recurrent_kernel does not need any conversion. During load(), + since the layer could be built with the parameter that does not support the + defun approach, it is possible that cudnn_bias variable is not created, or + even created but not used during actual run. Because of that, we sum up the + value of two biases, and give it to recurrent_bias only. Mathematically, the + LSTM is calculated as following formula: + + i_t = sigmoid(w_i * x_t + r_i * h_(t-1) + b_wi + b_ri) + f_t = sigmoid(w_f * x_t + r_f * h_(t-1) + b_wf + b_rf) + o_t = sigmoid(w_o * x_t + r_o * h_(t-1) + b_wo + b_ro) + c'_t = tanh(w_c * x_t + r_c * h_(t-1) + b_wc + b_rc) + c_t = f_t . c_(t-1) + i_t . c'_t + h_t = o_t . tanh(c_t) + + Note that b_w{x} is the input_bias, and b_r{x} is the recurrent_bias. + Since it is a linear add, it is fine to give b_r{x} 100% and b_w{x} 0%, as + long as the sum are the same. + + Args: + layer: The keras layer that will be loaded with weights. + weights: the list of numpy arrays which hold the weights to be loaded. + + Returns: + weights: the processed list of numpy arrays. + """ + if layer.__class__.__name__ == 'UnifiedLSTM': + if len(weights) not in [3, 4]: + # Only handles the bias conversion in this function, in the case when + # bias is not used or weights in unexpected length, do nothing and return. + return weights + + if len(weights) == 3: + recurrent_bias = weights[2] + else: + # Add all the bias value to recurrent_bias + recurrent_bias = weights[2] + weights[3] + + if len(layer.weights) == 3: + weights = weights[:2] + [recurrent_bias] + elif len(layer.weights) == 4: + # Create a zero filled input_bias, since all the weights have given + # to recurrent bias. + input_bias = np.zeros_like(recurrent_bias) + weights = weights[:2] + [recurrent_bias, input_bias] + + return weights + + def _convert_rnn_weights(layer, weights): """Converts weights for RNN layers between native and CuDNN format. diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 6d9d9a2fcae..8fcefce748f 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -221,6 +221,70 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): for (x, y) in zip(weights1, weights2) ] + @parameterized.named_parameters( + # test_name, use_bias, bias_initializer, activation + ('normal', True, 'zeros', 'tanh'), + ('no_bias', False, 'zeros', 'tanh'), + # TODO(scottzhu): Reenable this test case when the approach is decided. + # ('random_bias', True, 'random_uniform', 'tanh'), + ('no_cudnn_bias', True, 'zeros', 'relu') + ) + def test_process_weights_for_loading_unified_lstm( + self, use_bias, bias_initializer, activation): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + batch = 10 + timestep = 3 + input_dim = 5 + units = 2 + + x = np.random.random((batch, timestep, input_dim)) + + def build_model(): + inputs = keras.layers.Input( + shape=[timestep, input_dim], dtype=dtypes.float32) + layer = keras.layers.UnifiedLSTM( + units, + activation=activation, + use_bias=use_bias, + bias_initializer=bias_initializer) + output = layer(inputs) + return keras.models.Model(inputs, output), layer + + with self.cached_session(): + model, layer = build_model() + y_ref = model.predict(x) + model.save_weights(h5_path) + + cloned_model, new_layer = build_model() + cloned_model.load_weights(h5_path) + y = cloned_model.predict(x) + + self.assertAllClose(y, y_ref) + + # Test the individual layer weights. + weights1 = layer.get_weights() + weights2 = new_layer.get_weights() + self.assertLen(weights1, len(weights2)) + # kernel and current kernel should be the same. + self.assertAllClose(weights1[:2], weights2[:2]) + + if len(weights2) >= 3: + # Test recurrent bias + expected_recurrent_bias = weights1[2] + if len(weights1) == 4: + expected_recurrent_bias += weights1[3] + self.assertAllClose(weights2[2], expected_recurrent_bias) + + if len(weights2) == 4: + # Test recovered input_gate_bias to be always zero + self.assertAllClose(weights2[3], np.zeros_like(weights1[3])) + def test_sequential_weight_loading(self): if h5py is None: return diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 49990b6bf4f..df7571e5d5f 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -149,6 +149,7 @@ from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell from tensorflow.python.keras.layers.recurrent import SimpleRNN from tensorflow.python.keras.layers.recurrent import GRU from tensorflow.python.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import UnifiedLSTM # Convolutional-recurrent layers. from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D