Fix conv_lstm and test for TF 2.0.
PiperOrigin-RevId: 235622678
This commit is contained in:
		
							parent
							
								
									82cf60d9d7
								
							
						
					
					
						commit
						391bee7364
					
				| @ -563,7 +563,7 @@ tf_py_test( | ||||
| 
 | ||||
| tf_py_test( | ||||
|     name = "convolutional_recurrent_test", | ||||
|     size = "large", | ||||
|     size = "medium", | ||||
|     srcs = ["layers/convolutional_recurrent_test.py"], | ||||
|     additional_deps = [ | ||||
|         ":keras", | ||||
| @ -571,7 +571,7 @@ tf_py_test( | ||||
|         "//third_party/py/numpy", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|     ], | ||||
|     shard_count = 2, | ||||
|     shard_count = 4, | ||||
| ) | ||||
| 
 | ||||
| cuda_py_test( | ||||
|  | ||||
| @ -34,6 +34,7 @@ from tensorflow.python.keras.layers.recurrent import RNN | ||||
| from tensorflow.python.keras.utils import conv_utils | ||||
| from tensorflow.python.keras.utils import generic_utils | ||||
| from tensorflow.python.keras.utils import tf_utils | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.util.tf_export import keras_export | ||||
| 
 | ||||
| 
 | ||||
| @ -272,7 +273,7 @@ class ConvRNN2D(RNN): | ||||
|     shape = list(self.cell.kernel_shape) | ||||
|     shape[-1] = self.cell.filters | ||||
|     initial_state = self.cell.input_conv(initial_state, | ||||
|                                          K.zeros(tuple(shape)), | ||||
|                                          array_ops.zeros(tuple(shape)), | ||||
|                                          padding=self.cell.padding) | ||||
| 
 | ||||
|     if hasattr(self.cell.state_size, '__len__'): | ||||
| @ -625,31 +626,8 @@ class ConvLSTM2DCell(Layer): | ||||
|           initializer=bias_initializer, | ||||
|           regularizer=self.bias_regularizer, | ||||
|           constraint=self.bias_constraint) | ||||
| 
 | ||||
|     else: | ||||
|       self.bias = None | ||||
| 
 | ||||
|     self.kernel_i = self.kernel[:, :, :, :self.filters] | ||||
|     self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters] | ||||
|     self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2] | ||||
|     self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters: | ||||
|                                                     self.filters * 2] | ||||
|     self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3] | ||||
|     self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2: | ||||
|                                                     self.filters * 3] | ||||
|     self.kernel_o = self.kernel[:, :, :, self.filters * 3:] | ||||
|     self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:] | ||||
| 
 | ||||
|     if self.use_bias: | ||||
|       self.bias_i = self.bias[:self.filters] | ||||
|       self.bias_f = self.bias[self.filters: self.filters * 2] | ||||
|       self.bias_c = self.bias[self.filters * 2: self.filters * 3] | ||||
|       self.bias_o = self.bias[self.filters * 3:] | ||||
|     else: | ||||
|       self.bias_i = None | ||||
|       self.bias_f = None | ||||
|       self.bias_c = None | ||||
|       self.bias_o = None | ||||
|     self.built = True | ||||
| 
 | ||||
|   def call(self, inputs, states, training=None): | ||||
| @ -697,22 +675,26 @@ class ConvLSTM2DCell(Layer): | ||||
|       h_tm1_c = h_tm1 | ||||
|       h_tm1_o = h_tm1 | ||||
| 
 | ||||
|     x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i, | ||||
|                           padding=self.padding) | ||||
|     x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f, | ||||
|                           padding=self.padding) | ||||
|     x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c, | ||||
|                           padding=self.padding) | ||||
|     x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o, | ||||
|                           padding=self.padding) | ||||
|     h_i = self.recurrent_conv(h_tm1_i, | ||||
|                               self.recurrent_kernel_i) | ||||
|     h_f = self.recurrent_conv(h_tm1_f, | ||||
|                               self.recurrent_kernel_f) | ||||
|     h_c = self.recurrent_conv(h_tm1_c, | ||||
|                               self.recurrent_kernel_c) | ||||
|     h_o = self.recurrent_conv(h_tm1_o, | ||||
|                               self.recurrent_kernel_o) | ||||
|     (kernel_i, kernel_f, | ||||
|      kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3) | ||||
|     (recurrent_kernel_i, | ||||
|      recurrent_kernel_f, | ||||
|      recurrent_kernel_c, | ||||
|      recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3) | ||||
| 
 | ||||
|     if self.use_bias: | ||||
|       bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4) | ||||
|     else: | ||||
|       bias_i, bias_f, bias_c, bias_o = None, None, None, None | ||||
| 
 | ||||
|     x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) | ||||
|     x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) | ||||
|     x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) | ||||
|     x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) | ||||
|     h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) | ||||
|     h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) | ||||
|     h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) | ||||
|     h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) | ||||
| 
 | ||||
|     i = self.recurrent_activation(x_i + h_i) | ||||
|     f = self.recurrent_activation(x_f + h_f) | ||||
|  | ||||
| @ -18,16 +18,24 @@ from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from absl.testing import parameterized | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.python import keras | ||||
| from tensorflow.python.framework import test_util | ||||
| from tensorflow.python.keras import keras_parameterized | ||||
| from tensorflow.python.keras import testing_utils | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| 
 | ||||
| class ConvLSTMTest(test.TestCase): | ||||
| @keras_parameterized.run_all_keras_modes | ||||
| class ConvLSTMTest(keras_parameterized.TestCase): | ||||
| 
 | ||||
|   def test_conv_lstm(self): | ||||
|   @parameterized.named_parameters( | ||||
|       *test_util.generate_combinations_with_testcase_name( | ||||
|           data_format=['channels_first', 'channels_last'], | ||||
|           return_sequences=[True, False])) | ||||
|   def test_conv_lstm(self, data_format, return_sequences): | ||||
|     num_row = 3 | ||||
|     num_col = 3 | ||||
|     filters = 2 | ||||
| @ -36,7 +44,6 @@ class ConvLSTMTest(test.TestCase): | ||||
|     input_num_row = 5 | ||||
|     input_num_col = 5 | ||||
|     sequence_len = 2 | ||||
|     for data_format in ['channels_first', 'channels_last']: | ||||
|     if data_format == 'channels_first': | ||||
|       inputs = np.random.rand(num_samples, sequence_len, | ||||
|                               input_channel, | ||||
| @ -46,8 +53,6 @@ class ConvLSTMTest(test.TestCase): | ||||
|                               input_num_row, input_num_col, | ||||
|                               input_channel) | ||||
| 
 | ||||
|       for return_sequences in [True, False]: | ||||
|         with self.cached_session(): | ||||
|     # test for return state: | ||||
|     x = keras.Input(batch_shape=inputs.shape) | ||||
|     kwargs = {'data_format': data_format, | ||||
| @ -167,6 +172,8 @@ class ConvLSTMTest(test.TestCase): | ||||
|       self.assertEqual(len(layer.losses), 4) | ||||
| 
 | ||||
|   def test_conv_lstm_dropout(self): | ||||
|     if testing_utils.should_run_eagerly(): | ||||
|       self.skipTest('Skip test due to b/126246383.') | ||||
|     # check dropout | ||||
|     with self.cached_session(): | ||||
|       testing_utils.layer_test( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user