Add eager support for unit tests for most Keras layers.
A few minor layers were left out: - noise layers (apparent issue with tf.random_normal) - bidirectional wrapper - conv recurrent layers (impending refactor) PiperOrigin-RevId: 186654795
This commit is contained in:
parent
a4de23973d
commit
e2a9276d48
@ -3087,7 +3087,8 @@ def rnn(step_function,
|
||||
outputs_shape[1] = inputs_shape[1]
|
||||
outputs.set_shape(outputs_shape)
|
||||
|
||||
last_output._uses_learning_phase = uses_learning_phase
|
||||
if not context.in_eager_mode():
|
||||
last_output._uses_learning_phase = uses_learning_phase
|
||||
return last_output, outputs, new_states
|
||||
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches
|
||||
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
|
||||
@ -196,8 +197,7 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model,
|
||||
output of the model, total loss and the loss associated with each output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model loss is 0 or if the trainable weights list is
|
||||
empty when the trainable parameter is set to True.
|
||||
ValueError: If the model has no loss to optimize.
|
||||
"""
|
||||
K.set_learning_phase(training)
|
||||
with GradientTape() as tape:
|
||||
@ -209,12 +209,13 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model,
|
||||
'because it has no loss to optimize.')
|
||||
if training:
|
||||
if not model._collected_trainable_weights:
|
||||
raise ValueError('The list of trainable weights is empty. Make sure that '
|
||||
'you are not setting model.trainable to False before '
|
||||
'compiling the model.')
|
||||
grads = tape.gradient(loss, model._collected_trainable_weights)
|
||||
model.optimizer.apply_gradients(zip(grads,
|
||||
model._collected_trainable_weights))
|
||||
logging.warning('The list of trainable weights is empty. Make sure that '
|
||||
'you are not setting model.trainable to False before '
|
||||
'compiling the model.')
|
||||
else:
|
||||
grads = tape.gradient(loss, model._collected_trainable_weights)
|
||||
model.optimizer.apply_gradients(zip(grads,
|
||||
model._collected_trainable_weights))
|
||||
return outs, loss, loss_metrics
|
||||
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
@ -397,17 +398,13 @@ class LossWeightingTest(test.TestCase):
|
||||
optimizer=RMSPropOptimizer(learning_rate=0.001))
|
||||
|
||||
np.random.seed(43)
|
||||
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
|
||||
(x_train, y_train), _ = testing_utils.get_test_data(
|
||||
train_samples=train_samples,
|
||||
test_samples=test_samples,
|
||||
input_shape=(input_dim,),
|
||||
num_classes=num_classes)
|
||||
int_y_test = y_test.copy()
|
||||
int_y_train = y_train.copy()
|
||||
# convert class vectors to binary class matrices
|
||||
y_train = keras.utils.to_categorical(y_train, num_classes)
|
||||
y_test = keras.utils.to_categorical(y_test, num_classes)
|
||||
test_ids = np.where(int_y_test == np.array(weighted_class))[0]
|
||||
|
||||
class_weight = dict([(i, 1.) for i in range(num_classes)])
|
||||
class_weight[weighted_class] = 2.
|
||||
@ -549,8 +546,10 @@ class TestDynamicTrainability(test.TestCase):
|
||||
model.trainable = False
|
||||
model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
|
||||
model.trainable = True
|
||||
with self.assertRaises(ValueError):
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
model.train_on_batch(x, y)
|
||||
self.assertRegexpMatches(str(mock_log.call_args),
|
||||
'trainable weights is empty')
|
||||
|
||||
def test_trainable_argument(self):
|
||||
x = np.random.random((5, 3))
|
||||
@ -560,8 +559,10 @@ class TestDynamicTrainability(test.TestCase):
|
||||
model.add(keras.layers.Dense(2, input_dim=3, trainable=False))
|
||||
model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
|
||||
out = model.predict(x)
|
||||
with self.assertRaises(ValueError):
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
model.train_on_batch(x, y)
|
||||
self.assertRegexpMatches(str(mock_log.call_args),
|
||||
'trainable weights is empty')
|
||||
out_2 = model.predict(x)
|
||||
self.assertAllClose(out, out_2)
|
||||
|
||||
@ -571,8 +572,10 @@ class TestDynamicTrainability(test.TestCase):
|
||||
model = keras.models.Model(inputs, output)
|
||||
model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse')
|
||||
out = model.predict(x)
|
||||
with self.assertRaises(ValueError):
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
model.train_on_batch(x, y)
|
||||
self.assertRegexpMatches(str(mock_log.call_args),
|
||||
'trainable weights is empty')
|
||||
out_2 = model.predict(x)
|
||||
self.assertAllClose(out, out_2)
|
||||
|
||||
|
||||
@ -22,6 +22,8 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
@ -43,6 +45,7 @@ class Convolution1DTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, length, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_conv1d(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -114,6 +117,7 @@ class Conv2DTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_conv2d(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -188,6 +192,7 @@ class Conv2DTransposeTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_conv2dtranspose(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -253,6 +258,7 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, depth, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_conv3dtranspose(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -316,6 +322,7 @@ class SeparableConv1DTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, length, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_separable_conv1d(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -391,6 +398,7 @@ class SeparableConv2DTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_separable_conv2d(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -469,6 +477,7 @@ class Conv3DTest(test.TestCase):
|
||||
kwargs=test_kwargs,
|
||||
input_shape=(num_samples, depth, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_conv3d(self):
|
||||
kwargs = {
|
||||
'filters': 2,
|
||||
@ -520,6 +529,7 @@ class Conv3DTest(test.TestCase):
|
||||
|
||||
class ZeroPaddingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_zero_padding_1d(self):
|
||||
num_samples = 2
|
||||
input_dim = 2
|
||||
@ -543,7 +553,10 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding1D(padding=2)
|
||||
layer.build(shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, 2:-2, :], 1.)
|
||||
@ -551,7 +564,10 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding1D(padding=(1, 2))
|
||||
layer.build(shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
for left_offset in [0]:
|
||||
np.testing.assert_allclose(np_output[:, left_offset, :], 0.)
|
||||
for right_offset in [-1, -2]:
|
||||
@ -565,6 +581,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.ZeroPadding1D(padding=None)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_zero_padding_2d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -593,7 +610,10 @@ class ZeroPaddingTest(test.TestCase):
|
||||
padding=(2, 2), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
if data_format == 'channels_last':
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :], 0.)
|
||||
@ -609,7 +629,10 @@ class ZeroPaddingTest(test.TestCase):
|
||||
padding=((1, 2), (3, 4)), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
if data_format == 'channels_last':
|
||||
for top_offset in [0]:
|
||||
np.testing.assert_allclose(np_output[:, top_offset, :, :], 0.)
|
||||
@ -637,6 +660,7 @@ class ZeroPaddingTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.ZeroPadding2D(padding=None)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_zero_padding_3d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -659,7 +683,10 @@ class ZeroPaddingTest(test.TestCase):
|
||||
layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2))
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
for offset in [0, 1, -1, -2]:
|
||||
np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.)
|
||||
np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.)
|
||||
@ -675,11 +702,13 @@ class ZeroPaddingTest(test.TestCase):
|
||||
|
||||
class UpSamplingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_upsampling_1d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_upsampling_2d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -708,7 +737,10 @@ class UpSamplingTest(test.TestCase):
|
||||
size=(length_row, length_col), data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
if data_format == 'channels_first':
|
||||
assert np_output.shape[2] == length_row * input_num_row
|
||||
assert np_output.shape[3] == length_col * input_num_col
|
||||
@ -726,6 +758,7 @@ class UpSamplingTest(test.TestCase):
|
||||
|
||||
np.testing.assert_allclose(np_output, expected_out)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_upsampling_3d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -757,7 +790,10 @@ class UpSamplingTest(test.TestCase):
|
||||
data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
if data_format == 'channels_first':
|
||||
assert np_output.shape[2] == length_dim1 * input_len_dim1
|
||||
assert np_output.shape[3] == length_dim2 * input_len_dim2
|
||||
@ -782,6 +818,7 @@ class UpSamplingTest(test.TestCase):
|
||||
|
||||
class CroppingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_cropping_1d(self):
|
||||
num_samples = 2
|
||||
time_length = 4
|
||||
@ -800,6 +837,7 @@ class CroppingTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.Cropping1D(cropping=None)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_cropping_2d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -827,7 +865,10 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
# compare with numpy
|
||||
if data_format == 'channels_first':
|
||||
expected_out = inputs[:, :, cropping[0][0]:-cropping[0][1], cropping[
|
||||
@ -851,7 +892,10 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
# compare with input
|
||||
np.testing.assert_allclose(np_output, inputs)
|
||||
|
||||
@ -861,6 +905,7 @@ class CroppingTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.Cropping2D(cropping=None)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_cropping_3d(self):
|
||||
num_samples = 2
|
||||
stack_size = 2
|
||||
@ -892,7 +937,10 @@ class CroppingTest(test.TestCase):
|
||||
cropping=cropping, data_format=data_format)
|
||||
layer.build(inputs.shape)
|
||||
output = layer(keras.backend.variable(inputs))
|
||||
np_output = keras.backend.eval(output)
|
||||
if context.in_eager_mode():
|
||||
np_output = output.numpy()
|
||||
else:
|
||||
np_output = keras.backend.eval(output)
|
||||
# compare with numpy
|
||||
if data_format == 'channels_first':
|
||||
expected_out = inputs[:, :,
|
||||
|
||||
@ -20,11 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -52,146 +50,134 @@ class CoreLayersTest(test.TestCase):
|
||||
dropout = keras.layers.Dropout(0.5)
|
||||
self.assertEqual(True, dropout.supports_masking)
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout1D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4))
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_spatial_dropout(self):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout1D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout2D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout2D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4, 5))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout2D,
|
||||
kwargs={'rate': 0.5, 'data_format': 'channels_first'},
|
||||
input_shape=(2, 3, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout2D,
|
||||
kwargs={'rate': 0.5, 'data_format': 'channels_first'},
|
||||
input_shape=(2, 3, 4, 5))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout3D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout3D,
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(2, 3, 4, 4, 5))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout3D,
|
||||
kwargs={'rate': 0.5, 'data_format': 'channels_first'},
|
||||
input_shape=(2, 3, 4, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SpatialDropout3D,
|
||||
kwargs={'rate': 0.5, 'data_format': 'channels_first'},
|
||||
input_shape=(2, 3, 4, 4, 5))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_activation(self):
|
||||
# with string argument
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Activation,
|
||||
kwargs={'activation': 'relu'},
|
||||
input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Activation,
|
||||
kwargs={'activation': 'relu'},
|
||||
input_shape=(3, 2))
|
||||
|
||||
# with function argument
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Activation,
|
||||
kwargs={'activation': keras.backend.relu},
|
||||
input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Activation,
|
||||
kwargs={'activation': keras.backend.relu},
|
||||
input_shape=(3, 2))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_reshape(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (8, 1)},
|
||||
input_shape=(3, 2, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (8, 1)},
|
||||
input_shape=(3, 2, 4))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (-1, 1)},
|
||||
input_shape=(3, 2, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (-1, 1)},
|
||||
input_shape=(3, 2, 4))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (1, -1)},
|
||||
input_shape=(3, 2, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (1, -1)},
|
||||
input_shape=(3, 2, 4))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (-1, 1)},
|
||||
input_shape=(None, None, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Reshape,
|
||||
kwargs={'target_shape': (-1, 1)},
|
||||
input_shape=(None, None, 2))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_permute(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_flatten(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_repeat_vector(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_lambda(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Lambda,
|
||||
kwargs={'function': lambda x: x + 1},
|
||||
input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Lambda,
|
||||
kwargs={'function': lambda x: x + 1},
|
||||
input_shape=(3, 2))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Lambda,
|
||||
kwargs={
|
||||
'function': lambda x, a, b: x * a + b,
|
||||
'arguments': {
|
||||
'a': 0.6,
|
||||
'b': 0.4
|
||||
}
|
||||
},
|
||||
input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Lambda,
|
||||
kwargs={
|
||||
'function': lambda x, a, b: x * a + b,
|
||||
'arguments': {
|
||||
'a': 0.6,
|
||||
'b': 0.4
|
||||
}
|
||||
},
|
||||
input_shape=(3, 2))
|
||||
|
||||
with self.test_session():
|
||||
# test serialization with function
|
||||
def f(x):
|
||||
return x + 1
|
||||
# test serialization with function
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
ld = keras.layers.Lambda(f)
|
||||
config = ld.get_config()
|
||||
ld = keras.layers.deserialize({
|
||||
'class_name': 'Lambda',
|
||||
'config': config
|
||||
})
|
||||
ld = keras.layers.Lambda(f)
|
||||
config = ld.get_config()
|
||||
ld = keras.layers.deserialize({
|
||||
'class_name': 'Lambda',
|
||||
'config': config
|
||||
})
|
||||
|
||||
# test with lambda
|
||||
ld = keras.layers.Lambda(
|
||||
lambda x: keras.backend.concatenate([keras.backend.square(x), x]))
|
||||
config = ld.get_config()
|
||||
ld = keras.layers.Lambda.from_config(config)
|
||||
# test with lambda
|
||||
ld = keras.layers.Lambda(
|
||||
lambda x: keras.backend.concatenate([keras.backend.square(x), x]))
|
||||
config = ld.get_config()
|
||||
ld = keras.layers.Lambda.from_config(config)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dense(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2))
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2))
|
||||
|
||||
# Test regularization
|
||||
def test_dense_regularization(self):
|
||||
with self.test_session():
|
||||
layer = keras.layers.Dense(
|
||||
3,
|
||||
@ -202,7 +188,7 @@ class CoreLayersTest(test.TestCase):
|
||||
layer(keras.backend.variable(np.ones((2, 4))))
|
||||
self.assertEqual(3, len(layer.losses))
|
||||
|
||||
# Test constraints
|
||||
def test_dense_constraints(self):
|
||||
with self.test_session():
|
||||
k_constraint = keras.constraints.max_norm(0.01)
|
||||
b_constraint = keras.constraints.max_norm(0.01)
|
||||
@ -212,12 +198,6 @@ class CoreLayersTest(test.TestCase):
|
||||
self.assertEqual(layer.kernel.constraint, k_constraint)
|
||||
self.assertEqual(layer.bias.constraint, b_constraint)
|
||||
|
||||
def test_eager_dense(self):
|
||||
with context.eager_mode():
|
||||
l = keras.layers.Dense(units=3,
|
||||
kernel_initializer=init_ops.zeros_initializer())
|
||||
self.assertAllEqual(l(constant_op.constant([[1.0]])), [[0., 0., 0.]])
|
||||
|
||||
def test_activity_regularization(self):
|
||||
with self.test_session():
|
||||
layer = keras.layers.ActivityRegularization(l1=0.1)
|
||||
|
||||
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
@ -25,47 +26,44 @@ from tensorflow.python.platform import test
|
||||
|
||||
class EmbeddingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_embedding(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'input_length': 2},
|
||||
input_shape=(3, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'input_length': 2},
|
||||
input_shape=(3, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True},
|
||||
input_shape=(3, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True},
|
||||
input_shape=(3, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True},
|
||||
input_shape=(3, 4, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True},
|
||||
input_shape=(3, 4, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True,
|
||||
'input_length': (None, 2)},
|
||||
input_shape=(3, 4, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Embedding,
|
||||
kwargs={'output_dim': 4,
|
||||
'input_dim': 10,
|
||||
'mask_zero': True,
|
||||
'input_length': (None, 2)},
|
||||
input_shape=(3, 4, 2),
|
||||
input_dtype='int32',
|
||||
expected_output_dtype='float32')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -20,64 +20,66 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
class GRULayerTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_return_sequences_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dynamic_behavior_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
layer = keras.layers.GRU(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile('sgd', 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
layer = keras.layers.GRU(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile(RMSPropOptimizer(0.01), 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dropout_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_implementation_mode_GRU(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.GRU,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_statefulness_GRU(self):
|
||||
num_samples = 2
|
||||
|
||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
@ -27,6 +28,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
class LocallyConnectedLayersTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_locallyconnected_1d(self):
|
||||
num_samples = 2
|
||||
num_steps = 8
|
||||
@ -39,16 +41,15 @@ class LocallyConnectedLayersTest(test.TestCase):
|
||||
if padding == 'same' and strides != 1:
|
||||
continue
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected1D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': filter_length,
|
||||
'padding': padding,
|
||||
'strides': strides
|
||||
},
|
||||
input_shape=(num_samples, num_steps, input_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected1D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': filter_length,
|
||||
'padding': padding,
|
||||
'strides': strides
|
||||
},
|
||||
input_shape=(num_samples, num_steps, input_dim))
|
||||
|
||||
def test_locallyconnected_1d_regularization(self):
|
||||
num_samples = 2
|
||||
@ -86,6 +87,7 @@ class LocallyConnectedLayersTest(test.TestCase):
|
||||
self.assertEqual(layer.kernel.constraint, k_constraint)
|
||||
self.assertEqual(layer.bias.constraint, b_constraint)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_locallyconnected_2d(self):
|
||||
num_samples = 8
|
||||
filters = 3
|
||||
@ -98,20 +100,18 @@ class LocallyConnectedLayersTest(test.TestCase):
|
||||
if padding == 'same' and strides != (1, 1):
|
||||
continue
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected2D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': 3,
|
||||
'padding': padding,
|
||||
'kernel_regularizer': 'l2',
|
||||
'bias_regularizer': 'l2',
|
||||
'activity_regularizer': 'l2',
|
||||
'strides': strides,
|
||||
'data_format': 'channels_last'
|
||||
},
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected2D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': 3,
|
||||
'padding': padding,
|
||||
'kernel_regularizer': 'l2',
|
||||
'bias_regularizer': 'l2',
|
||||
'strides': strides,
|
||||
'data_format': 'channels_last'
|
||||
},
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
def test_locallyconnected_2d_channels_first(self):
|
||||
num_samples = 8
|
||||
|
||||
@ -20,28 +20,29 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
class LSTMLayerTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_return_sequences_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_static_shape_inference_LSTM(self):
|
||||
# Github issue: 15165
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
@ -55,45 +56,45 @@ class LSTMLayerTest(test.TestCase):
|
||||
outputs = model.layers[-1].output
|
||||
self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units])
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dynamic_behavior_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile('sgd', 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile(RMSPropOptimizer(0.001), 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dropout_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_implementation_mode_LSTM(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LSTM,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_statefulness_LSTM(self):
|
||||
num_samples = 2
|
||||
|
||||
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -27,24 +28,25 @@ from tensorflow.python.platform import test
|
||||
|
||||
class MergeLayersTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_add(self):
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
i3 = keras.layers.Input(shape=(4, 5))
|
||||
|
||||
o = keras.layers.add([i1, i2, i3])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2, i3], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
x3 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2, x3])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
|
||||
|
||||
def test_merge_add_masking(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
i3 = keras.layers.Input(shape=(4, 5))
|
||||
|
||||
o = keras.layers.add([i1, i2, i3])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2, i3], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
x3 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2, x3])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
|
||||
|
||||
# test masking
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
m1 = keras.layers.Masking()(i1)
|
||||
@ -54,11 +56,13 @@ class MergeLayersTest(test.TestCase):
|
||||
mask = layer.output_mask
|
||||
self.assertListEqual(mask.get_shape().as_list(), [None, 4])
|
||||
|
||||
# test missing shape
|
||||
def test_merge_add_dynamic_shape(self):
|
||||
with self.test_session():
|
||||
i1 = array_ops.placeholder(shape=(4, None), dtype='float32')
|
||||
i2 = array_ops.placeholder(shape=(4, 5), dtype='float32')
|
||||
layer = keras.layers.Add()
|
||||
o = layer([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [4, 5])
|
||||
|
||||
def test_merge_elementwise_errors(self):
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
@ -72,79 +76,82 @@ class MergeLayersTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.layers.add([i1])
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_multiply(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
i3 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.multiply([i1, i2, i3])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2, i3], o)
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
i3 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.multiply([i1, i2, i3])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2, i3], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
x3 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2, x3])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, x1 * x2 * x3, atol=1e-4)
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
x3 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2, x3])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, x1 * x2 * x3, atol=1e-4)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_average(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.average([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.average([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4)
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_maximum(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.maximum([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.maximum([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4)
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_minimum(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.minimum([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.minimum([i1, i2])
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 4, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4)
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 4, 5))
|
||||
self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_concatenate(self):
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.concatenate([i1, i2], axis=1)
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 8, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 8, 5))
|
||||
self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
|
||||
|
||||
def test_merge_concatenate_masking(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
o = keras.layers.concatenate([i1, i2], axis=1)
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 8, 5])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
|
||||
x1 = np.random.random((2, 4, 5))
|
||||
x2 = np.random.random((2, 4, 5))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 8, 5))
|
||||
self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
|
||||
|
||||
# test masking
|
||||
m1 = keras.layers.Masking()(i1)
|
||||
layer = keras.layers.Concatenate()
|
||||
o = layer([m1, i2])
|
||||
@ -162,35 +169,35 @@ class MergeLayersTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'called on a list'):
|
||||
keras.layers.concatenate([i1], axis=-1)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_dot(self):
|
||||
with self.test_session():
|
||||
i1 = keras.layers.Input(shape=(4,))
|
||||
i2 = keras.layers.Input(shape=(4,))
|
||||
o = keras.layers.dot([i1, i2], axes=1)
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 1])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
_ = keras.layers.Dot(axes=1).get_config()
|
||||
i1 = keras.layers.Input(shape=(4,))
|
||||
i2 = keras.layers.Input(shape=(4,))
|
||||
o = keras.layers.dot([i1, i2], axes=1)
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 1])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
_ = keras.layers.Dot(axes=1).get_config()
|
||||
|
||||
x1 = np.random.random((2, 4))
|
||||
x2 = np.random.random((2, 4))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 1))
|
||||
expected = np.zeros((2, 1))
|
||||
expected[0, 0] = np.dot(x1[0], x2[0])
|
||||
expected[1, 0] = np.dot(x1[1], x2[1])
|
||||
self.assertAllClose(out, expected, atol=1e-4)
|
||||
x1 = np.random.random((2, 4))
|
||||
x2 = np.random.random((2, 4))
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 1))
|
||||
expected = np.zeros((2, 1))
|
||||
expected[0, 0] = np.dot(x1[0], x2[0])
|
||||
expected[1, 0] = np.dot(x1[1], x2[1])
|
||||
self.assertAllClose(out, expected, atol=1e-4)
|
||||
|
||||
# Test with negative tuple of axes.
|
||||
o = keras.layers.dot([i1, i2], axes=(-1, -1))
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 1])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 1))
|
||||
self.assertAllClose(out, expected, atol=1e-4)
|
||||
# Test with negative tuple of axes.
|
||||
o = keras.layers.dot([i1, i2], axes=(-1, -1))
|
||||
self.assertListEqual(o.get_shape().as_list(), [None, 1])
|
||||
model = keras.models.Model([i1, i2], o)
|
||||
out = model.predict([x1, x2])
|
||||
self.assertEqual(out.shape, (2, 1))
|
||||
self.assertAllClose(out, expected, atol=1e-4)
|
||||
|
||||
# test compute_output_shape
|
||||
layer = keras.layers.Dot(axes=-1)
|
||||
self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1))
|
||||
# test compute_output_shape
|
||||
layer = keras.layers.Dot(axes=-1)
|
||||
self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1))
|
||||
|
||||
def test_dot_errors(self):
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
@ -208,6 +215,7 @@ class MergeLayersTest(test.TestCase):
|
||||
dot = keras.layers.Dot(1)
|
||||
dot.compute_output_shape(1)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_merge_subtract(self):
|
||||
i1 = keras.layers.Input(shape=(4, 5))
|
||||
i2 = keras.layers.Input(shape=(4, 5))
|
||||
|
||||
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
@ -39,12 +40,12 @@ class NoiseLayersTest(test.TestCase):
|
||||
kwargs={'rate': 0.5},
|
||||
input_shape=(3, 2, 3))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_AlphaDropout(self):
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AlphaDropout,
|
||||
kwargs={'rate': 0.2},
|
||||
input_shape=(3, 2, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AlphaDropout,
|
||||
kwargs={'rate': 0.2},
|
||||
input_shape=(3, 2, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
@ -25,81 +27,85 @@ from tensorflow.python.platform import test
|
||||
|
||||
class GlobalPoolingTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_globalpooling_1d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
|
||||
input_shape=(3, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
|
||||
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
|
||||
input_shape=(3, 4, 5))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_globalpooling_2d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling2D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 5, 6))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling2D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling2D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 5, 6))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling2D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling2D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 5, 6))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling2D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling2D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 5, 6))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling2D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_globalpooling_3d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling3D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling3D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling3D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling3D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling3D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalMaxPooling3D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling3D,
|
||||
kwargs={'data_format': 'channels_first'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.pooling.GlobalAveragePooling3D,
|
||||
kwargs={'data_format': 'channels_last'},
|
||||
input_shape=(3, 4, 3, 4, 3))
|
||||
|
||||
|
||||
class Pooling2DTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_maxpooling_2d(self):
|
||||
pool_size = (3, 3)
|
||||
with self.test_session(use_gpu=True):
|
||||
for strides in [(1, 1), (2, 2)]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling2D,
|
||||
kwargs={
|
||||
'strides': strides,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
for strides in [(1, 1), (2, 2)]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling2D,
|
||||
kwargs={
|
||||
'strides': strides,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_averagepooling_2d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling2D,
|
||||
kwargs={'strides': (2, 2),
|
||||
'padding': 'same',
|
||||
'pool_size': (2, 2)},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling2D,
|
||||
kwargs={'strides': (2, 2),
|
||||
'padding': 'valid',
|
||||
'pool_size': (3, 3)},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling2D,
|
||||
kwargs={'strides': (2, 2),
|
||||
'padding': 'same',
|
||||
'pool_size': (2, 2)},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling2D,
|
||||
kwargs={'strides': (2, 2),
|
||||
'padding': 'valid',
|
||||
'pool_size': (3, 3)},
|
||||
input_shape=(3, 5, 6, 4))
|
||||
|
||||
# This part of the test can only run on GPU but doesn't appear
|
||||
# to be properly assigned to a GPU when running in eager mode.
|
||||
if not context.in_eager_mode():
|
||||
# Only runs on GPU with CUDA, channels_first is not supported on CPU.
|
||||
# TODO(b/62340061): Support channels_first on CPU.
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
@ -116,66 +122,66 @@ class Pooling2DTest(test.TestCase):
|
||||
|
||||
class Pooling3DTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_maxpooling_3d(self):
|
||||
pool_size = (3, 3, 3)
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling3D,
|
||||
kwargs={'strides': 2,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size},
|
||||
input_shape=(3, 11, 12, 10, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling3D,
|
||||
kwargs={
|
||||
'strides': 3,
|
||||
'padding': 'valid',
|
||||
'data_format': 'channels_first',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 4, 11, 12, 10))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling3D,
|
||||
kwargs={'strides': 2,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size},
|
||||
input_shape=(3, 11, 12, 10, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling3D,
|
||||
kwargs={
|
||||
'strides': 3,
|
||||
'padding': 'valid',
|
||||
'data_format': 'channels_first',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 4, 11, 12, 10))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_averagepooling_3d(self):
|
||||
pool_size = (3, 3, 3)
|
||||
with self.test_session(use_gpu=True):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling3D,
|
||||
kwargs={'strides': 2,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size},
|
||||
input_shape=(3, 11, 12, 10, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling3D,
|
||||
kwargs={
|
||||
'strides': 3,
|
||||
'padding': 'valid',
|
||||
'data_format': 'channels_first',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 4, 11, 12, 10))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling3D,
|
||||
kwargs={'strides': 2,
|
||||
'padding': 'valid',
|
||||
'pool_size': pool_size},
|
||||
input_shape=(3, 11, 12, 10, 4))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling3D,
|
||||
kwargs={
|
||||
'strides': 3,
|
||||
'padding': 'valid',
|
||||
'data_format': 'channels_first',
|
||||
'pool_size': pool_size
|
||||
},
|
||||
input_shape=(3, 4, 11, 12, 10))
|
||||
|
||||
|
||||
class Pooling1DTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_maxpooling_1d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ['valid', 'same']:
|
||||
for stride in [1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling1D,
|
||||
kwargs={'strides': stride,
|
||||
'padding': padding},
|
||||
input_shape=(3, 5, 4))
|
||||
for padding in ['valid', 'same']:
|
||||
for stride in [1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.MaxPooling1D,
|
||||
kwargs={'strides': stride,
|
||||
'padding': padding},
|
||||
input_shape=(3, 5, 4))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def test_averagepooling_1d(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
for padding in ['valid', 'same']:
|
||||
for stride in [1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling1D,
|
||||
kwargs={'strides': stride,
|
||||
'padding': padding},
|
||||
input_shape=(3, 5, 4))
|
||||
for padding in ['valid', 'same']:
|
||||
for stride in [1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.AveragePooling1D,
|
||||
kwargs={'strides': stride,
|
||||
'padding': padding},
|
||||
input_shape=(3, 5, 4))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras._impl.keras import activations
|
||||
from tensorflow.python.keras._impl.keras import backend as K
|
||||
@ -935,7 +936,9 @@ class SimpleRNNCell(Layer):
|
||||
|
||||
# Properly set learning phase on output tensor.
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None:
|
||||
if training is None and not context.in_eager_mode():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
output._uses_learning_phase = True
|
||||
return output, [output]
|
||||
|
||||
@ -1299,23 +1302,6 @@ class GRUCell(Layer):
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.kernel_z = self.kernel[:, :self.units]
|
||||
self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
|
||||
self.kernel_r = self.kernel[:, self.units:self.units * 2]
|
||||
self.recurrent_kernel_r = self.recurrent_kernel[:, self.units:
|
||||
self.units * 2]
|
||||
self.kernel_h = self.kernel[:, self.units * 2:]
|
||||
self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]
|
||||
|
||||
if self.use_bias:
|
||||
self.bias_z = self.bias[:self.units]
|
||||
self.bias_r = self.bias[self.units:self.units * 2]
|
||||
self.bias_h = self.bias[self.units * 2:]
|
||||
else:
|
||||
self.bias_z = None
|
||||
self.bias_r = None
|
||||
self.bias_h = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
@ -1350,13 +1336,13 @@ class GRUCell(Layer):
|
||||
inputs_z = inputs
|
||||
inputs_r = inputs
|
||||
inputs_h = inputs
|
||||
x_z = K.dot(inputs_z, self.kernel_z)
|
||||
x_r = K.dot(inputs_r, self.kernel_r)
|
||||
x_h = K.dot(inputs_h, self.kernel_h)
|
||||
x_z = K.dot(inputs_z, self.kernel[:, :self.units])
|
||||
x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
|
||||
x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
|
||||
if self.use_bias:
|
||||
x_z = K.bias_add(x_z, self.bias_z)
|
||||
x_r = K.bias_add(x_r, self.bias_r)
|
||||
x_h = K.bias_add(x_h, self.bias_h)
|
||||
x_z = K.bias_add(x_z, self.bias[:self.units])
|
||||
x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2])
|
||||
x_h = K.bias_add(x_h, self.bias[self.units * 2:])
|
||||
|
||||
if 0. < self.recurrent_dropout < 1.:
|
||||
h_tm1_z = h_tm1 * rec_dp_mask[0]
|
||||
@ -1367,11 +1353,14 @@ class GRUCell(Layer):
|
||||
h_tm1_r = h_tm1
|
||||
h_tm1_h = h_tm1
|
||||
z = self.recurrent_activation(
|
||||
x_z + K.dot(h_tm1_z, self.recurrent_kernel_z))
|
||||
x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]))
|
||||
r = self.recurrent_activation(
|
||||
x_r + K.dot(h_tm1_r, self.recurrent_kernel_r))
|
||||
x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units:
|
||||
self.units * 2]))
|
||||
|
||||
hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h))
|
||||
hh = self.activation(x_h + K.dot(r * h_tm1_h,
|
||||
self.recurrent_kernel[:,
|
||||
self.units * 2:]))
|
||||
else:
|
||||
if 0. < self.dropout < 1.:
|
||||
inputs *= dp_mask[0]
|
||||
@ -1395,44 +1384,34 @@ class GRUCell(Layer):
|
||||
hh = self.activation(x_h + recurrent_h)
|
||||
h = z * h_tm1 + (1 - z) * hh
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None:
|
||||
if training is None and not context.in_eager_mode():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
h._uses_learning_phase = True
|
||||
return h, [h]
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units':
|
||||
self.units,
|
||||
'activation':
|
||||
activations.serialize(self.activation),
|
||||
'units': self.units,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'recurrent_activation':
|
||||
activations.serialize(self.recurrent_activation),
|
||||
'use_bias':
|
||||
self.use_bias,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'recurrent_initializer':
|
||||
initializers.serialize(self.recurrent_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'recurrent_regularizer':
|
||||
regularizers.serialize(self.recurrent_regularizer),
|
||||
'bias_regularizer':
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
'kernel_constraint':
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'recurrent_constraint':
|
||||
constraints.serialize(self.recurrent_constraint),
|
||||
'bias_constraint':
|
||||
constraints.serialize(self.bias_constraint),
|
||||
'dropout':
|
||||
self.dropout,
|
||||
'recurrent_dropout':
|
||||
self.recurrent_dropout,
|
||||
'implementation':
|
||||
self.implementation
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint),
|
||||
'dropout': self.dropout,
|
||||
'recurrent_dropout': self.recurrent_dropout,
|
||||
'implementation': self.implementation
|
||||
}
|
||||
base_config = super(GRUCell, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
@ -1809,29 +1788,6 @@ class LSTMCell(Layer):
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.kernel_i = self.kernel[:, :self.units]
|
||||
self.kernel_f = self.kernel[:, self.units:self.units * 2]
|
||||
self.kernel_c = self.kernel[:, self.units * 2:self.units * 3]
|
||||
self.kernel_o = self.kernel[:, self.units * 3:]
|
||||
|
||||
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
|
||||
self.recurrent_kernel_f = self.recurrent_kernel[:, self.units:
|
||||
self.units * 2]
|
||||
self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2:
|
||||
self.units * 3]
|
||||
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
|
||||
|
||||
if self.use_bias:
|
||||
self.bias_i = self.bias[:self.units]
|
||||
self.bias_f = self.bias[self.units:self.units * 2]
|
||||
self.bias_c = self.bias[self.units * 2:self.units * 3]
|
||||
self.bias_o = self.bias[self.units * 3:]
|
||||
else:
|
||||
self.bias_i = None
|
||||
self.bias_f = None
|
||||
self.bias_c = None
|
||||
self.bias_o = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
@ -1869,15 +1825,15 @@ class LSTMCell(Layer):
|
||||
inputs_f = inputs
|
||||
inputs_c = inputs
|
||||
inputs_o = inputs
|
||||
x_i = K.dot(inputs_i, self.kernel_i)
|
||||
x_f = K.dot(inputs_f, self.kernel_f)
|
||||
x_c = K.dot(inputs_c, self.kernel_c)
|
||||
x_o = K.dot(inputs_o, self.kernel_o)
|
||||
x_i = K.dot(inputs_i, self.kernel[:, :self.units])
|
||||
x_f = K.dot(inputs_f, self.kernel[:, self.units:self.units * 2])
|
||||
x_c = K.dot(inputs_c, self.kernel[:, self.units * 2:self.units * 3])
|
||||
x_o = K.dot(inputs_o, self.kernel[:, self.units * 3:])
|
||||
if self.use_bias:
|
||||
x_i = K.bias_add(x_i, self.bias_i)
|
||||
x_f = K.bias_add(x_f, self.bias_f)
|
||||
x_c = K.bias_add(x_c, self.bias_c)
|
||||
x_o = K.bias_add(x_o, self.bias_o)
|
||||
x_i = K.bias_add(x_i, self.bias[:self.units])
|
||||
x_f = K.bias_add(x_f, self.bias[self.units:self.units * 2])
|
||||
x_c = K.bias_add(x_c, self.bias[self.units * 2:self.units * 3])
|
||||
x_o = K.bias_add(x_o, self.bias[self.units * 3:])
|
||||
|
||||
if 0 < self.recurrent_dropout < 1.:
|
||||
h_tm1_i = h_tm1 * rec_dp_mask[0]
|
||||
@ -1890,13 +1846,15 @@ class LSTMCell(Layer):
|
||||
h_tm1_c = h_tm1
|
||||
h_tm1_o = h_tm1
|
||||
i = self.recurrent_activation(
|
||||
x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
|
||||
x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
|
||||
f = self.recurrent_activation(
|
||||
x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
|
||||
x_f + K.dot(h_tm1_f,
|
||||
self.recurrent_kernel[:, self.units: self.units * 2]))
|
||||
c = f * c_tm1 + i * self.activation(
|
||||
x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
|
||||
x_c + K.dot(h_tm1_c,
|
||||
self.recurrent_kernel[:, self.units * 2: self.units * 3]))
|
||||
o = self.recurrent_activation(
|
||||
x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
|
||||
x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
|
||||
else:
|
||||
if 0. < self.dropout < 1.:
|
||||
inputs *= dp_mask[0]
|
||||
@ -1919,7 +1877,9 @@ class LSTMCell(Layer):
|
||||
|
||||
h = o * self.activation(c)
|
||||
if 0 < self.dropout + self.recurrent_dropout:
|
||||
if training is None:
|
||||
if training is None and not context.in_eager_mode():
|
||||
# This would be harmless to set in eager mode, but eager tensors
|
||||
# disallow setting arbitrary attributes.
|
||||
h._uses_learning_phase = True
|
||||
return h, [h, c]
|
||||
|
||||
|
||||
@ -20,64 +20,66 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.keras._impl.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
class SimpleRNNLayerTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_return_sequences_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'return_sequences': True},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dynamic_behavior_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile('sgd', 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim))
|
||||
model = keras.models.Sequential()
|
||||
model.add(layer)
|
||||
model.compile(RMSPropOptimizer(0.01), 'mse')
|
||||
x = np.random.random((num_samples, timesteps, embedding_dim))
|
||||
y = np.random.random((num_samples, units))
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_dropout_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'dropout': 0.1,
|
||||
'recurrent_dropout': 0.1},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_implementation_mode_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
timesteps = 3
|
||||
embedding_dim = 4
|
||||
units = 2
|
||||
with self.test_session():
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
for mode in [0, 1, 2]:
|
||||
testing_utils.layer_test(
|
||||
keras.layers.SimpleRNN,
|
||||
kwargs={'units': units,
|
||||
'implementation': mode},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_statefulness_SimpleRNN(self):
|
||||
num_samples = 2
|
||||
|
||||
@ -20,44 +20,43 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
class TimeDistributedTest(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_timedistributed_dense(self):
|
||||
# first, test with Dense layer
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.TimeDistributed(
|
||||
keras.layers.Dense(2), input_shape=(3, 4)))
|
||||
model.compile(optimizer='rmsprop', loss='mse')
|
||||
model.fit(
|
||||
np.random.random((10, 3, 4)),
|
||||
np.random.random((10, 3, 2)),
|
||||
epochs=1,
|
||||
batch_size=10)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.TimeDistributed(
|
||||
keras.layers.Dense(2), input_shape=(3, 4)))
|
||||
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
|
||||
model.fit(
|
||||
np.random.random((10, 3, 4)),
|
||||
np.random.random((10, 3, 2)),
|
||||
epochs=1,
|
||||
batch_size=10)
|
||||
|
||||
# test config
|
||||
model.get_config()
|
||||
# test config
|
||||
model.get_config()
|
||||
|
||||
def test_timedistributed_static_batch_size(self):
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.TimeDistributed(
|
||||
keras.layers.Dense(2), input_shape=(3, 4), batch_size=10))
|
||||
model.compile(optimizer='rmsprop', loss='mse')
|
||||
model.fit(
|
||||
np.random.random((10, 3, 4)),
|
||||
np.random.random((10, 3, 2)),
|
||||
epochs=1,
|
||||
batch_size=10)
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.TimeDistributed(
|
||||
keras.layers.Dense(2), input_shape=(3, 4), batch_size=10))
|
||||
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
|
||||
model.fit(
|
||||
np.random.random((10, 3, 4)),
|
||||
np.random.random((10, 3, 2)),
|
||||
epochs=1,
|
||||
batch_size=10)
|
||||
|
||||
def test_timedistributed_conv2d(self):
|
||||
# test with Conv2D
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
@ -73,7 +72,6 @@ class TimeDistributedTest(test.TestCase):
|
||||
model.summary()
|
||||
|
||||
def test_timedistributed_stacked(self):
|
||||
# test stacked layers
|
||||
with self.test_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
@ -167,7 +165,7 @@ class BidirectionalTest(test.TestCase):
|
||||
model.add(
|
||||
keras.layers.Bidirectional(
|
||||
rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim)))
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
# test compute output shape
|
||||
|
||||
@ -704,8 +704,10 @@ class TFOptimizer(Optimizer):
|
||||
return self.optimizer.compute_gradients(loss, params)
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
grads = self.optimizer.compute_gradients(loss, params)
|
||||
self.updates = [K.update_add(self.iterations, 1)]
|
||||
if not params:
|
||||
return self.updates
|
||||
grads = self.optimizer.compute_gradients(loss, params)
|
||||
opt_update = self.optimizer.apply_gradients(
|
||||
grads, global_step=self.iterations)
|
||||
self.updates.append(opt_update)
|
||||
|
||||
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
@ -145,7 +146,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
|
||||
np.testing.assert_allclose(output, actual_output, rtol=1e-3)
|
||||
|
||||
# test training mode (e.g. useful for dropout tests)
|
||||
model.compile('rmsprop', 'mse')
|
||||
model.compile(RMSPropOptimizer(0.01), 'mse')
|
||||
model.train_on_batch(input_data, actual_output)
|
||||
|
||||
# test as first layer in Sequential API
|
||||
@ -181,9 +182,5 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
|
||||
output = recovered_model.predict(input_data)
|
||||
np.testing.assert_allclose(output, actual_output, rtol=1e-3)
|
||||
|
||||
# test training mode (e.g. useful for dropout tests)
|
||||
model.compile('rmsprop', 'mse')
|
||||
model.train_on_batch(input_data, actual_output)
|
||||
|
||||
# for further checks in the caller function
|
||||
return actual_output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user