Fix conv_lstm and test for TF 2.0.

PiperOrigin-RevId: 235622678
This commit is contained in:
Scott Zhu 2019-02-25 16:21:01 -08:00 committed by TensorFlower Gardener
parent 82cf60d9d7
commit 391bee7364
3 changed files with 71 additions and 82 deletions

View File

@ -563,7 +563,7 @@ tf_py_test(
tf_py_test( tf_py_test(
name = "convolutional_recurrent_test", name = "convolutional_recurrent_test",
size = "large", size = "medium",
srcs = ["layers/convolutional_recurrent_test.py"], srcs = ["layers/convolutional_recurrent_test.py"],
additional_deps = [ additional_deps = [
":keras", ":keras",
@ -571,7 +571,7 @@ tf_py_test(
"//third_party/py/numpy", "//third_party/py/numpy",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
], ],
shard_count = 2, shard_count = 4,
) )
cuda_py_test( cuda_py_test(

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -272,7 +273,7 @@ class ConvRNN2D(RNN):
shape = list(self.cell.kernel_shape) shape = list(self.cell.kernel_shape)
shape[-1] = self.cell.filters shape[-1] = self.cell.filters
initial_state = self.cell.input_conv(initial_state, initial_state = self.cell.input_conv(initial_state,
K.zeros(tuple(shape)), array_ops.zeros(tuple(shape)),
padding=self.cell.padding) padding=self.cell.padding)
if hasattr(self.cell.state_size, '__len__'): if hasattr(self.cell.state_size, '__len__'):
@ -625,31 +626,8 @@ class ConvLSTM2DCell(Layer):
initializer=bias_initializer, initializer=bias_initializer,
regularizer=self.bias_regularizer, regularizer=self.bias_regularizer,
constraint=self.bias_constraint) constraint=self.bias_constraint)
else: else:
self.bias = None self.bias = None
self.kernel_i = self.kernel[:, :, :, :self.filters]
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
self.filters * 2]
self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
self.filters * 3]
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:]
if self.use_bias:
self.bias_i = self.bias[:self.filters]
self.bias_f = self.bias[self.filters: self.filters * 2]
self.bias_c = self.bias[self.filters * 2: self.filters * 3]
self.bias_o = self.bias[self.filters * 3:]
else:
self.bias_i = None
self.bias_f = None
self.bias_c = None
self.bias_o = None
self.built = True self.built = True
def call(self, inputs, states, training=None): def call(self, inputs, states, training=None):
@ -697,22 +675,26 @@ class ConvLSTM2DCell(Layer):
h_tm1_c = h_tm1 h_tm1_c = h_tm1
h_tm1_o = h_tm1 h_tm1_o = h_tm1
x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i, (kernel_i, kernel_f,
padding=self.padding) kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f, (recurrent_kernel_i,
padding=self.padding) recurrent_kernel_f,
x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c, recurrent_kernel_c,
padding=self.padding) recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
padding=self.padding) if self.use_bias:
h_i = self.recurrent_conv(h_tm1_i, bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
self.recurrent_kernel_i) else:
h_f = self.recurrent_conv(h_tm1_f, bias_i, bias_f, bias_c, bias_o = None, None, None, None
self.recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
self.recurrent_kernel_c) x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
h_o = self.recurrent_conv(h_tm1_o, x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
self.recurrent_kernel_o) x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i) i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f) f = self.recurrent_activation(x_f + h_f)

View File

@ -18,16 +18,24 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
class ConvLSTMTest(test.TestCase): @keras_parameterized.run_all_keras_modes
class ConvLSTMTest(keras_parameterized.TestCase):
def test_conv_lstm(self): @parameterized.named_parameters(
*test_util.generate_combinations_with_testcase_name(
data_format=['channels_first', 'channels_last'],
return_sequences=[True, False]))
def test_conv_lstm(self, data_format, return_sequences):
num_row = 3 num_row = 3
num_col = 3 num_col = 3
filters = 2 filters = 2
@ -36,47 +44,44 @@ class ConvLSTMTest(test.TestCase):
input_num_row = 5 input_num_row = 5
input_num_col = 5 input_num_col = 5
sequence_len = 2 sequence_len = 2
for data_format in ['channels_first', 'channels_last']: if data_format == 'channels_first':
if data_format == 'channels_first': inputs = np.random.rand(num_samples, sequence_len,
inputs = np.random.rand(num_samples, sequence_len, input_channel,
input_channel, input_num_row, input_num_col)
input_num_row, input_num_col) else:
else: inputs = np.random.rand(num_samples, sequence_len,
inputs = np.random.rand(num_samples, sequence_len, input_num_row, input_num_col,
input_num_row, input_num_col, input_channel)
input_channel)
for return_sequences in [True, False]: # test for return state:
with self.cached_session(): x = keras.Input(batch_shape=inputs.shape)
# test for return state: kwargs = {'data_format': data_format,
x = keras.Input(batch_shape=inputs.shape) 'return_sequences': return_sequences,
kwargs = {'data_format': data_format, 'return_state': True,
'return_sequences': return_sequences, 'stateful': True,
'return_state': True, 'filters': filters,
'stateful': True, 'kernel_size': (num_row, num_col),
'filters': filters, 'padding': 'valid'}
'kernel_size': (num_row, num_col), layer = keras.layers.ConvLSTM2D(**kwargs)
'padding': 'valid'} layer.build(inputs.shape)
layer = keras.layers.ConvLSTM2D(**kwargs) outputs = layer(x)
layer.build(inputs.shape) _, states = outputs[0], outputs[1:]
outputs = layer(x) self.assertEqual(len(states), 2)
_, states = outputs[0], outputs[1:] model = keras.models.Model(x, states[0])
self.assertEqual(len(states), 2) state = model.predict(inputs)
model = keras.models.Model(x, states[0])
state = model.predict(inputs)
self.assertAllClose( self.assertAllClose(
keras.backend.eval(layer.states[0]), state, atol=1e-4) keras.backend.eval(layer.states[0]), state, atol=1e-4)
# test for output shape: # test for output shape:
testing_utils.layer_test( testing_utils.layer_test(
keras.layers.ConvLSTM2D, keras.layers.ConvLSTM2D,
kwargs={'data_format': data_format, kwargs={'data_format': data_format,
'return_sequences': return_sequences, 'return_sequences': return_sequences,
'filters': filters, 'filters': filters,
'kernel_size': (num_row, num_col), 'kernel_size': (num_row, num_col),
'padding': 'valid'}, 'padding': 'valid'},
input_shape=inputs.shape) input_shape=inputs.shape)
def test_conv_lstm_statefulness(self): def test_conv_lstm_statefulness(self):
# Tests for statefulness # Tests for statefulness
@ -167,6 +172,8 @@ class ConvLSTMTest(test.TestCase):
self.assertEqual(len(layer.losses), 4) self.assertEqual(len(layer.losses), 4)
def test_conv_lstm_dropout(self): def test_conv_lstm_dropout(self):
if testing_utils.should_run_eagerly():
self.skipTest('Skip test due to b/126246383.')
# check dropout # check dropout
with self.cached_session(): with self.cached_session():
testing_utils.layer_test( testing_utils.layer_test(