Fix tokenization tests and update testing_utils to transfer state between layer creation.

PiperOrigin-RevId: 317379253
Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
This commit is contained in:
A. Unique TensorFlower 2020-06-19 14:21:31 -07:00 committed by TensorFlower Gardener
parent 23bb7b48a1
commit 9e7df609bc
1 changed files with 15 additions and 3 deletions

View File

@ -94,7 +94,8 @@ def layer_test(layer_cls,
expected_output_shape=None,
validate_training=True,
adapt_data=None,
custom_objects=None):
custom_objects=None,
test_harness=None):
"""Test routine for a layer with a single input and single output.
Arguments:
@ -114,6 +115,8 @@ def layer_test(layer_cls,
be tested for this layer. This is only relevant for PreprocessingLayers.
custom_objects: Optional dictionary mapping name strings to custom objects
in the layer class. This is helpful for testing custom layers.
test_harness: The Tensorflow test, if any, that this function is being
called in.
Returns:
The output data (Numpy array) returned by the layer, for additional
@ -143,7 +146,13 @@ def layer_test(layer_cls,
expected_output_dtype = input_dtype
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
if test_harness:
assert_equal = test_harness.assertAllEqual
else:
assert_equal = string_test
else:
if test_harness:
assert_equal = test_harness.assertAllClose
else:
assert_equal = numeric_test
@ -228,6 +237,7 @@ def layer_test(layer_cls,
# test training mode (e.g. useful for dropout tests)
# Rebuild the model to avoid the graph being reused between predict() and
# See b/120160788 for more details. This should be mitigated after 2.0.
layer_weights = layer.get_weights() # Get the layer weights BEFORE training.
if validate_training:
model = models.Model(x, layer(x))
if _thread_local_data.run_eagerly is not None:
@ -252,6 +262,8 @@ def layer_test(layer_cls,
model = models.Sequential()
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
model.add(layer)
layer.set_weights(layer_weights)
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(computed_output_shape,