Fix tokenization tests and update testing_utils to transfer state between layer creation.

PiperOrigin-RevId: 317379253
Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
This commit is contained in:
A. Unique TensorFlower 2020-06-19 14:21:31 -07:00 committed by TensorFlower Gardener
parent 23bb7b48a1
commit 9e7df609bc
1 changed files with 15 additions and 3 deletions

View File

@ -94,7 +94,8 @@ def layer_test(layer_cls,
expected_output_shape=None, expected_output_shape=None,
validate_training=True, validate_training=True,
adapt_data=None, adapt_data=None,
custom_objects=None): custom_objects=None,
test_harness=None):
"""Test routine for a layer with a single input and single output. """Test routine for a layer with a single input and single output.
Arguments: Arguments:
@ -114,6 +115,8 @@ def layer_test(layer_cls,
be tested for this layer. This is only relevant for PreprocessingLayers. be tested for this layer. This is only relevant for PreprocessingLayers.
custom_objects: Optional dictionary mapping name strings to custom objects custom_objects: Optional dictionary mapping name strings to custom objects
in the layer class. This is helpful for testing custom layers. in the layer class. This is helpful for testing custom layers.
test_harness: The Tensorflow test, if any, that this function is being
called in.
Returns: Returns:
The output data (Numpy array) returned by the layer, for additional The output data (Numpy array) returned by the layer, for additional
@ -143,9 +146,15 @@ def layer_test(layer_cls,
expected_output_dtype = input_dtype expected_output_dtype = input_dtype
if dtypes.as_dtype(expected_output_dtype) == dtypes.string: if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
assert_equal = string_test if test_harness:
assert_equal = test_harness.assertAllEqual
else:
assert_equal = string_test
else: else:
assert_equal = numeric_test if test_harness:
assert_equal = test_harness.assertAllClose
else:
assert_equal = numeric_test
# instantiation # instantiation
kwargs = kwargs or {} kwargs = kwargs or {}
@ -228,6 +237,7 @@ def layer_test(layer_cls,
# test training mode (e.g. useful for dropout tests) # test training mode (e.g. useful for dropout tests)
# Rebuild the model to avoid the graph being reused between predict() and # Rebuild the model to avoid the graph being reused between predict() and
# See b/120160788 for more details. This should be mitigated after 2.0. # See b/120160788 for more details. This should be mitigated after 2.0.
layer_weights = layer.get_weights() # Get the layer weights BEFORE training.
if validate_training: if validate_training:
model = models.Model(x, layer(x)) model = models.Model(x, layer(x))
if _thread_local_data.run_eagerly is not None: if _thread_local_data.run_eagerly is not None:
@ -252,6 +262,8 @@ def layer_test(layer_cls,
model = models.Sequential() model = models.Sequential()
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
model.add(layer) model.add(layer)
layer.set_weights(layer_weights)
actual_output = model.predict(input_data) actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(computed_output_shape, for expected_dim, actual_dim in zip(computed_output_shape,