Fix conv_lstm and test for TF 2.0.
PiperOrigin-RevId: 235622678
This commit is contained in:
parent
82cf60d9d7
commit
391bee7364
@ -563,7 +563,7 @@ tf_py_test(
|
|||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "convolutional_recurrent_test",
|
name = "convolutional_recurrent_test",
|
||||||
size = "large",
|
size = "medium",
|
||||||
srcs = ["layers/convolutional_recurrent_test.py"],
|
srcs = ["layers/convolutional_recurrent_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":keras",
|
":keras",
|
||||||
@ -571,7 +571,7 @@ tf_py_test(
|
|||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
shard_count = 2,
|
shard_count = 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from tensorflow.python.keras.layers.recurrent import RNN
|
|||||||
from tensorflow.python.keras.utils import conv_utils
|
from tensorflow.python.keras.utils import conv_utils
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
@ -272,7 +273,7 @@ class ConvRNN2D(RNN):
|
|||||||
shape = list(self.cell.kernel_shape)
|
shape = list(self.cell.kernel_shape)
|
||||||
shape[-1] = self.cell.filters
|
shape[-1] = self.cell.filters
|
||||||
initial_state = self.cell.input_conv(initial_state,
|
initial_state = self.cell.input_conv(initial_state,
|
||||||
K.zeros(tuple(shape)),
|
array_ops.zeros(tuple(shape)),
|
||||||
padding=self.cell.padding)
|
padding=self.cell.padding)
|
||||||
|
|
||||||
if hasattr(self.cell.state_size, '__len__'):
|
if hasattr(self.cell.state_size, '__len__'):
|
||||||
@ -625,31 +626,8 @@ class ConvLSTM2DCell(Layer):
|
|||||||
initializer=bias_initializer,
|
initializer=bias_initializer,
|
||||||
regularizer=self.bias_regularizer,
|
regularizer=self.bias_regularizer,
|
||||||
constraint=self.bias_constraint)
|
constraint=self.bias_constraint)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
self.kernel_i = self.kernel[:, :, :, :self.filters]
|
|
||||||
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
|
|
||||||
self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
|
|
||||||
self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
|
|
||||||
self.filters * 2]
|
|
||||||
self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
|
|
||||||
self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
|
|
||||||
self.filters * 3]
|
|
||||||
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
|
|
||||||
self.recurrent_kernel_o = self.recurrent_kernel[:, :, :, self.filters * 3:]
|
|
||||||
|
|
||||||
if self.use_bias:
|
|
||||||
self.bias_i = self.bias[:self.filters]
|
|
||||||
self.bias_f = self.bias[self.filters: self.filters * 2]
|
|
||||||
self.bias_c = self.bias[self.filters * 2: self.filters * 3]
|
|
||||||
self.bias_o = self.bias[self.filters * 3:]
|
|
||||||
else:
|
|
||||||
self.bias_i = None
|
|
||||||
self.bias_f = None
|
|
||||||
self.bias_c = None
|
|
||||||
self.bias_o = None
|
|
||||||
self.built = True
|
self.built = True
|
||||||
|
|
||||||
def call(self, inputs, states, training=None):
|
def call(self, inputs, states, training=None):
|
||||||
@ -697,22 +675,26 @@ class ConvLSTM2DCell(Layer):
|
|||||||
h_tm1_c = h_tm1
|
h_tm1_c = h_tm1
|
||||||
h_tm1_o = h_tm1
|
h_tm1_o = h_tm1
|
||||||
|
|
||||||
x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
|
(kernel_i, kernel_f,
|
||||||
padding=self.padding)
|
kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
|
||||||
x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
|
(recurrent_kernel_i,
|
||||||
padding=self.padding)
|
recurrent_kernel_f,
|
||||||
x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
|
recurrent_kernel_c,
|
||||||
padding=self.padding)
|
recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
|
||||||
x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
|
|
||||||
padding=self.padding)
|
if self.use_bias:
|
||||||
h_i = self.recurrent_conv(h_tm1_i,
|
bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
|
||||||
self.recurrent_kernel_i)
|
else:
|
||||||
h_f = self.recurrent_conv(h_tm1_f,
|
bias_i, bias_f, bias_c, bias_o = None, None, None, None
|
||||||
self.recurrent_kernel_f)
|
|
||||||
h_c = self.recurrent_conv(h_tm1_c,
|
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
|
||||||
self.recurrent_kernel_c)
|
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
|
||||||
h_o = self.recurrent_conv(h_tm1_o,
|
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
|
||||||
self.recurrent_kernel_o)
|
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
|
||||||
|
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
|
||||||
|
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
|
||||||
|
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
|
||||||
|
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
|
||||||
|
|
||||||
i = self.recurrent_activation(x_i + h_i)
|
i = self.recurrent_activation(x_i + h_i)
|
||||||
f = self.recurrent_activation(x_f + h_f)
|
f = self.recurrent_activation(x_f + h_f)
|
||||||
|
|||||||
@ -18,16 +18,24 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ConvLSTMTest(test.TestCase):
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class ConvLSTMTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def test_conv_lstm(self):
|
@parameterized.named_parameters(
|
||||||
|
*test_util.generate_combinations_with_testcase_name(
|
||||||
|
data_format=['channels_first', 'channels_last'],
|
||||||
|
return_sequences=[True, False]))
|
||||||
|
def test_conv_lstm(self, data_format, return_sequences):
|
||||||
num_row = 3
|
num_row = 3
|
||||||
num_col = 3
|
num_col = 3
|
||||||
filters = 2
|
filters = 2
|
||||||
@ -36,47 +44,44 @@ class ConvLSTMTest(test.TestCase):
|
|||||||
input_num_row = 5
|
input_num_row = 5
|
||||||
input_num_col = 5
|
input_num_col = 5
|
||||||
sequence_len = 2
|
sequence_len = 2
|
||||||
for data_format in ['channels_first', 'channels_last']:
|
if data_format == 'channels_first':
|
||||||
if data_format == 'channels_first':
|
inputs = np.random.rand(num_samples, sequence_len,
|
||||||
inputs = np.random.rand(num_samples, sequence_len,
|
input_channel,
|
||||||
input_channel,
|
input_num_row, input_num_col)
|
||||||
input_num_row, input_num_col)
|
else:
|
||||||
else:
|
inputs = np.random.rand(num_samples, sequence_len,
|
||||||
inputs = np.random.rand(num_samples, sequence_len,
|
input_num_row, input_num_col,
|
||||||
input_num_row, input_num_col,
|
input_channel)
|
||||||
input_channel)
|
|
||||||
|
|
||||||
for return_sequences in [True, False]:
|
# test for return state:
|
||||||
with self.cached_session():
|
x = keras.Input(batch_shape=inputs.shape)
|
||||||
# test for return state:
|
kwargs = {'data_format': data_format,
|
||||||
x = keras.Input(batch_shape=inputs.shape)
|
'return_sequences': return_sequences,
|
||||||
kwargs = {'data_format': data_format,
|
'return_state': True,
|
||||||
'return_sequences': return_sequences,
|
'stateful': True,
|
||||||
'return_state': True,
|
'filters': filters,
|
||||||
'stateful': True,
|
'kernel_size': (num_row, num_col),
|
||||||
'filters': filters,
|
'padding': 'valid'}
|
||||||
'kernel_size': (num_row, num_col),
|
layer = keras.layers.ConvLSTM2D(**kwargs)
|
||||||
'padding': 'valid'}
|
layer.build(inputs.shape)
|
||||||
layer = keras.layers.ConvLSTM2D(**kwargs)
|
outputs = layer(x)
|
||||||
layer.build(inputs.shape)
|
_, states = outputs[0], outputs[1:]
|
||||||
outputs = layer(x)
|
self.assertEqual(len(states), 2)
|
||||||
_, states = outputs[0], outputs[1:]
|
model = keras.models.Model(x, states[0])
|
||||||
self.assertEqual(len(states), 2)
|
state = model.predict(inputs)
|
||||||
model = keras.models.Model(x, states[0])
|
|
||||||
state = model.predict(inputs)
|
|
||||||
|
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
keras.backend.eval(layer.states[0]), state, atol=1e-4)
|
keras.backend.eval(layer.states[0]), state, atol=1e-4)
|
||||||
|
|
||||||
# test for output shape:
|
# test for output shape:
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
keras.layers.ConvLSTM2D,
|
keras.layers.ConvLSTM2D,
|
||||||
kwargs={'data_format': data_format,
|
kwargs={'data_format': data_format,
|
||||||
'return_sequences': return_sequences,
|
'return_sequences': return_sequences,
|
||||||
'filters': filters,
|
'filters': filters,
|
||||||
'kernel_size': (num_row, num_col),
|
'kernel_size': (num_row, num_col),
|
||||||
'padding': 'valid'},
|
'padding': 'valid'},
|
||||||
input_shape=inputs.shape)
|
input_shape=inputs.shape)
|
||||||
|
|
||||||
def test_conv_lstm_statefulness(self):
|
def test_conv_lstm_statefulness(self):
|
||||||
# Tests for statefulness
|
# Tests for statefulness
|
||||||
@ -167,6 +172,8 @@ class ConvLSTMTest(test.TestCase):
|
|||||||
self.assertEqual(len(layer.losses), 4)
|
self.assertEqual(len(layer.losses), 4)
|
||||||
|
|
||||||
def test_conv_lstm_dropout(self):
|
def test_conv_lstm_dropout(self):
|
||||||
|
if testing_utils.should_run_eagerly():
|
||||||
|
self.skipTest('Skip test due to b/126246383.')
|
||||||
# check dropout
|
# check dropout
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user