Fix conv_lstm and test for TF 2.0.
PiperOrigin-RevId: 235622678
This commit is contained in:
parent
82cf60d9d7
commit
391bee7364
@ -563,7 +563,7 @@ tf_py_test(
|
||||
|
||||
tf_py_test(
|
||||
name = "convolutional_recurrent_test",
|
||||
size = "large",
|
||||
size = "medium",
|
||||
srcs = ["layers/convolutional_recurrent_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
@ -571,7 +571,7 @@ tf_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
shard_count = 2,
|
||||
shard_count = 4,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras.layers.recurrent import RNN
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -272,7 +273,7 @@ class ConvRNN2D(RNN):
|
||||
shape = list(self.cell.kernel_shape)
|
||||
shape[-1] = self.cell.filters
|
||||
initial_state = self.cell.input_conv(initial_state,
|
||||
K.zeros(tuple(shape)),
|
||||
array_ops.zeros(tuple(shape)),
|
||||
padding=self.cell.padding)
|
||||
|
||||
if hasattr(self.cell.state_size, '__len__'):
|
||||
@ -625,31 +626,8 @@ class ConvLSTM2DCell(Layer):
|
||||
initializer=bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.kernel_i = self.kernel[:, :, :, :self.filters]
|
||||
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
|
||||
self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
|
||||
self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
|
||||
self.filters * 2]
|
||||
self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
|
||||
self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
|
||||
self.filters * 3]
|
||||
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
|
||||
self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:]
|
||||
|
||||
if self.use_bias:
|
||||
self.bias_i = self.bias[:self.filters]
|
||||
self.bias_f = self.bias[self.filters: self.filters * 2]
|
||||
self.bias_c = self.bias[self.filters * 2: self.filters * 3]
|
||||
self.bias_o = self.bias[self.filters * 3:]
|
||||
else:
|
||||
self.bias_i = None
|
||||
self.bias_f = None
|
||||
self.bias_c = None
|
||||
self.bias_o = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
@ -697,22 +675,26 @@ class ConvLSTM2DCell(Layer):
|
||||
h_tm1_c = h_tm1
|
||||
h_tm1_o = h_tm1
|
||||
|
||||
x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
|
||||
padding=self.padding)
|
||||
x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
|
||||
padding=self.padding)
|
||||
x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
|
||||
padding=self.padding)
|
||||
x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
|
||||
padding=self.padding)
|
||||
h_i = self.recurrent_conv(h_tm1_i,
|
||||
self.recurrent_kernel_i)
|
||||
h_f = self.recurrent_conv(h_tm1_f,
|
||||
self.recurrent_kernel_f)
|
||||
h_c = self.recurrent_conv(h_tm1_c,
|
||||
self.recurrent_kernel_c)
|
||||
h_o = self.recurrent_conv(h_tm1_o,
|
||||
self.recurrent_kernel_o)
|
||||
(kernel_i, kernel_f,
|
||||
kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
|
||||
(recurrent_kernel_i,
|
||||
recurrent_kernel_f,
|
||||
recurrent_kernel_c,
|
||||
recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
|
||||
|
||||
if self.use_bias:
|
||||
bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
|
||||
else:
|
||||
bias_i, bias_f, bias_c, bias_o = None, None, None, None
|
||||
|
||||
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
|
||||
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
|
||||
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
|
||||
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
|
||||
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
|
||||
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
|
||||
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
|
||||
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
|
||||
|
||||
i = self.recurrent_activation(x_i + h_i)
|
||||
f = self.recurrent_activation(x_f + h_f)
|
||||
|
@ -18,16 +18,24 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ConvLSTMTest(test.TestCase):
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class ConvLSTMTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_conv_lstm(self):
|
||||
@parameterized.named_parameters(
|
||||
*test_util.generate_combinations_with_testcase_name(
|
||||
data_format=['channels_first', 'channels_last'],
|
||||
return_sequences=[True, False]))
|
||||
def test_conv_lstm(self, data_format, return_sequences):
|
||||
num_row = 3
|
||||
num_col = 3
|
||||
filters = 2
|
||||
@ -36,47 +44,44 @@ class ConvLSTMTest(test.TestCase):
|
||||
input_num_row = 5
|
||||
input_num_col = 5
|
||||
sequence_len = 2
|
||||
for data_format in ['channels_first', 'channels_last']:
|
||||
if data_format == 'channels_first':
|
||||
inputs = np.random.rand(num_samples, sequence_len,
|
||||
input_channel,
|
||||
input_num_row, input_num_col)
|
||||
else:
|
||||
inputs = np.random.rand(num_samples, sequence_len,
|
||||
input_num_row, input_num_col,
|
||||
input_channel)
|
||||
if data_format == 'channels_first':
|
||||
inputs = np.random.rand(num_samples, sequence_len,
|
||||
input_channel,
|
||||
input_num_row, input_num_col)
|
||||
else:
|
||||
inputs = np.random.rand(num_samples, sequence_len,
|
||||
input_num_row, input_num_col,
|
||||
input_channel)
|
||||
|
||||
for return_sequences in [True, False]:
|
||||
with self.cached_session():
|
||||
# test for return state:
|
||||
x = keras.Input(batch_shape=inputs.shape)
|
||||
kwargs = {'data_format': data_format,
|
||||
'return_sequences': return_sequences,
|
||||
'return_state': True,
|
||||
'stateful': True,
|
||||
'filters': filters,
|
||||
'kernel_size': (num_row, num_col),
|
||||
'padding': 'valid'}
|
||||
layer = keras.layers.ConvLSTM2D(**kwargs)
|
||||
layer.build(inputs.shape)
|
||||
outputs = layer(x)
|
||||
_, states = outputs[0], outputs[1:]
|
||||
self.assertEqual(len(states), 2)
|
||||
model = keras.models.Model(x, states[0])
|
||||
state = model.predict(inputs)
|
||||
# test for return state:
|
||||
x = keras.Input(batch_shape=inputs.shape)
|
||||
kwargs = {'data_format': data_format,
|
||||
'return_sequences': return_sequences,
|
||||
'return_state': True,
|
||||
'stateful': True,
|
||||
'filters': filters,
|
||||
'kernel_size': (num_row, num_col),
|
||||
'padding': 'valid'}
|
||||
layer = keras.layers.ConvLSTM2D(**kwargs)
|
||||
layer.build(inputs.shape)
|
||||
outputs = layer(x)
|
||||
_, states = outputs[0], outputs[1:]
|
||||
self.assertEqual(len(states), 2)
|
||||
model = keras.models.Model(x, states[0])
|
||||
state = model.predict(inputs)
|
||||
|
||||
self.assertAllClose(
|
||||
keras.backend.eval(layer.states[0]), state, atol=1e-4)
|
||||
self.assertAllClose(
|
||||
keras.backend.eval(layer.states[0]), state, atol=1e-4)
|
||||
|
||||
# test for output shape:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.ConvLSTM2D,
|
||||
kwargs={'data_format': data_format,
|
||||
'return_sequences': return_sequences,
|
||||
'filters': filters,
|
||||
'kernel_size': (num_row, num_col),
|
||||
'padding': 'valid'},
|
||||
input_shape=inputs.shape)
|
||||
# test for output shape:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.ConvLSTM2D,
|
||||
kwargs={'data_format': data_format,
|
||||
'return_sequences': return_sequences,
|
||||
'filters': filters,
|
||||
'kernel_size': (num_row, num_col),
|
||||
'padding': 'valid'},
|
||||
input_shape=inputs.shape)
|
||||
|
||||
def test_conv_lstm_statefulness(self):
|
||||
# Tests for statefulness
|
||||
@ -167,6 +172,8 @@ class ConvLSTMTest(test.TestCase):
|
||||
self.assertEqual(len(layer.losses), 4)
|
||||
|
||||
def test_conv_lstm_dropout(self):
|
||||
if testing_utils.should_run_eagerly():
|
||||
self.skipTest('Skip test due to b/126246383.')
|
||||
# check dropout
|
||||
with self.cached_session():
|
||||
testing_utils.layer_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user