Fix tokenization tests and update testing_utils to transfer state between layer creation.
PiperOrigin-RevId: 317379253 Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
This commit is contained in:
parent
23bb7b48a1
commit
9e7df609bc
|
@ -94,7 +94,8 @@ def layer_test(layer_cls,
|
||||||
expected_output_shape=None,
|
expected_output_shape=None,
|
||||||
validate_training=True,
|
validate_training=True,
|
||||||
adapt_data=None,
|
adapt_data=None,
|
||||||
custom_objects=None):
|
custom_objects=None,
|
||||||
|
test_harness=None):
|
||||||
"""Test routine for a layer with a single input and single output.
|
"""Test routine for a layer with a single input and single output.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -114,6 +115,8 @@ def layer_test(layer_cls,
|
||||||
be tested for this layer. This is only relevant for PreprocessingLayers.
|
be tested for this layer. This is only relevant for PreprocessingLayers.
|
||||||
custom_objects: Optional dictionary mapping name strings to custom objects
|
custom_objects: Optional dictionary mapping name strings to custom objects
|
||||||
in the layer class. This is helpful for testing custom layers.
|
in the layer class. This is helpful for testing custom layers.
|
||||||
|
test_harness: The Tensorflow test, if any, that this function is being
|
||||||
|
called in.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output data (Numpy array) returned by the layer, for additional
|
The output data (Numpy array) returned by the layer, for additional
|
||||||
|
@ -143,7 +146,13 @@ def layer_test(layer_cls,
|
||||||
expected_output_dtype = input_dtype
|
expected_output_dtype = input_dtype
|
||||||
|
|
||||||
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
|
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
|
||||||
|
if test_harness:
|
||||||
|
assert_equal = test_harness.assertAllEqual
|
||||||
|
else:
|
||||||
assert_equal = string_test
|
assert_equal = string_test
|
||||||
|
else:
|
||||||
|
if test_harness:
|
||||||
|
assert_equal = test_harness.assertAllClose
|
||||||
else:
|
else:
|
||||||
assert_equal = numeric_test
|
assert_equal = numeric_test
|
||||||
|
|
||||||
|
@ -228,6 +237,7 @@ def layer_test(layer_cls,
|
||||||
# test training mode (e.g. useful for dropout tests)
|
# test training mode (e.g. useful for dropout tests)
|
||||||
# Rebuild the model to avoid the graph being reused between predict() and
|
# Rebuild the model to avoid the graph being reused between predict() and
|
||||||
# See b/120160788 for more details. This should be mitigated after 2.0.
|
# See b/120160788 for more details. This should be mitigated after 2.0.
|
||||||
|
layer_weights = layer.get_weights() # Get the layer weights BEFORE training.
|
||||||
if validate_training:
|
if validate_training:
|
||||||
model = models.Model(x, layer(x))
|
model = models.Model(x, layer(x))
|
||||||
if _thread_local_data.run_eagerly is not None:
|
if _thread_local_data.run_eagerly is not None:
|
||||||
|
@ -252,6 +262,8 @@ def layer_test(layer_cls,
|
||||||
model = models.Sequential()
|
model = models.Sequential()
|
||||||
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
|
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
|
||||||
model.add(layer)
|
model.add(layer)
|
||||||
|
|
||||||
|
layer.set_weights(layer_weights)
|
||||||
actual_output = model.predict(input_data)
|
actual_output = model.predict(input_data)
|
||||||
actual_output_shape = actual_output.shape
|
actual_output_shape = actual_output.shape
|
||||||
for expected_dim, actual_dim in zip(computed_output_shape,
|
for expected_dim, actual_dim in zip(computed_output_shape,
|
||||||
|
|
Loading…
Reference in New Issue