Fix conv_lstm and test for TF 2.0.

PiperOrigin-RevId: 235622678
This commit is contained in:
Scott Zhu 2019-02-25 16:21:01 -08:00 committed by TensorFlower Gardener
parent 82cf60d9d7
commit 391bee7364
3 changed files with 71 additions and 82 deletions

View File

@ -563,7 +563,7 @@ tf_py_test(
tf_py_test(
name = "convolutional_recurrent_test",
size = "large",
size = "medium",
srcs = ["layers/convolutional_recurrent_test.py"],
additional_deps = [
":keras",
@ -571,7 +571,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 2,
shard_count = 4,
)
cuda_py_test(

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import keras_export
@ -272,7 +273,7 @@ class ConvRNN2D(RNN):
shape = list(self.cell.kernel_shape)
shape[-1] = self.cell.filters
initial_state = self.cell.input_conv(initial_state,
K.zeros(tuple(shape)),
array_ops.zeros(tuple(shape)),
padding=self.cell.padding)
if hasattr(self.cell.state_size, '__len__'):
@ -625,31 +626,8 @@ class ConvLSTM2DCell(Layer):
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.kernel_i = self.kernel[:, :, :, :self.filters]
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
self.filters * 2]
self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
self.filters * 3]
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:]
if self.use_bias:
self.bias_i = self.bias[:self.filters]
self.bias_f = self.bias[self.filters: self.filters * 2]
self.bias_c = self.bias[self.filters * 2: self.filters * 3]
self.bias_o = self.bias[self.filters * 3:]
else:
self.bias_i = None
self.bias_f = None
self.bias_c = None
self.bias_o = None
self.built = True
def call(self, inputs, states, training=None):
@ -697,22 +675,26 @@ class ConvLSTM2DCell(Layer):
h_tm1_c = h_tm1
h_tm1_o = h_tm1
x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
padding=self.padding)
x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
padding=self.padding)
x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
padding=self.padding)
x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i,
self.recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f,
self.recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c,
self.recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o,
self.recurrent_kernel_o)
(kernel_i, kernel_f,
kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
(recurrent_kernel_i,
recurrent_kernel_f,
recurrent_kernel_c,
recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
if self.use_bias:
bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
else:
bias_i, bias_f, bias_c, bias_o = None, None, None, None
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f)

View File

@ -18,16 +18,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
class ConvLSTMTest(test.TestCase):
@keras_parameterized.run_all_keras_modes
class ConvLSTMTest(keras_parameterized.TestCase):
def test_conv_lstm(self):
@parameterized.named_parameters(
*test_util.generate_combinations_with_testcase_name(
data_format=['channels_first', 'channels_last'],
return_sequences=[True, False]))
def test_conv_lstm(self, data_format, return_sequences):
num_row = 3
num_col = 3
filters = 2
@ -36,47 +44,44 @@ class ConvLSTMTest(test.TestCase):
input_num_row = 5
input_num_col = 5
sequence_len = 2
for data_format in ['channels_first', 'channels_last']:
if data_format == 'channels_first':
inputs = np.random.rand(num_samples, sequence_len,
input_channel,
input_num_row, input_num_col)
else:
inputs = np.random.rand(num_samples, sequence_len,
input_num_row, input_num_col,
input_channel)
if data_format == 'channels_first':
inputs = np.random.rand(num_samples, sequence_len,
input_channel,
input_num_row, input_num_col)
else:
inputs = np.random.rand(num_samples, sequence_len,
input_num_row, input_num_col,
input_channel)
for return_sequences in [True, False]:
with self.cached_session():
# test for return state:
x = keras.Input(batch_shape=inputs.shape)
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
'return_state': True,
'stateful': True,
'filters': filters,
'kernel_size': (num_row, num_col),
'padding': 'valid'}
layer = keras.layers.ConvLSTM2D(**kwargs)
layer.build(inputs.shape)
outputs = layer(x)
_, states = outputs[0], outputs[1:]
self.assertEqual(len(states), 2)
model = keras.models.Model(x, states[0])
state = model.predict(inputs)
# test for return state:
x = keras.Input(batch_shape=inputs.shape)
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
'return_state': True,
'stateful': True,
'filters': filters,
'kernel_size': (num_row, num_col),
'padding': 'valid'}
layer = keras.layers.ConvLSTM2D(**kwargs)
layer.build(inputs.shape)
outputs = layer(x)
_, states = outputs[0], outputs[1:]
self.assertEqual(len(states), 2)
model = keras.models.Model(x, states[0])
state = model.predict(inputs)
self.assertAllClose(
keras.backend.eval(layer.states[0]), state, atol=1e-4)
self.assertAllClose(
keras.backend.eval(layer.states[0]), state, atol=1e-4)
# test for output shape:
testing_utils.layer_test(
keras.layers.ConvLSTM2D,
kwargs={'data_format': data_format,
'return_sequences': return_sequences,
'filters': filters,
'kernel_size': (num_row, num_col),
'padding': 'valid'},
input_shape=inputs.shape)
# test for output shape:
testing_utils.layer_test(
keras.layers.ConvLSTM2D,
kwargs={'data_format': data_format,
'return_sequences': return_sequences,
'filters': filters,
'kernel_size': (num_row, num_col),
'padding': 'valid'},
input_shape=inputs.shape)
def test_conv_lstm_statefulness(self):
# Tests for statefulness
@ -167,6 +172,8 @@ class ConvLSTMTest(test.TestCase):
self.assertEqual(len(layer.losses), 4)
def test_conv_lstm_dropout(self):
if testing_utils.should_run_eagerly():
self.skipTest('Skip test due to b/126246383.')
# check dropout
with self.cached_session():
testing_utils.layer_test(