Fix tokenization tests and update testing_utils to transfer state between layer creation.
PiperOrigin-RevId: 317379253 Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
This commit is contained in:
parent
23bb7b48a1
commit
9e7df609bc
|
@ -94,7 +94,8 @@ def layer_test(layer_cls,
|
|||
expected_output_shape=None,
|
||||
validate_training=True,
|
||||
adapt_data=None,
|
||||
custom_objects=None):
|
||||
custom_objects=None,
|
||||
test_harness=None):
|
||||
"""Test routine for a layer with a single input and single output.
|
||||
|
||||
Arguments:
|
||||
|
@ -114,6 +115,8 @@ def layer_test(layer_cls,
|
|||
be tested for this layer. This is only relevant for PreprocessingLayers.
|
||||
custom_objects: Optional dictionary mapping name strings to custom objects
|
||||
in the layer class. This is helpful for testing custom layers.
|
||||
test_harness: The Tensorflow test, if any, that this function is being
|
||||
called in.
|
||||
|
||||
Returns:
|
||||
The output data (Numpy array) returned by the layer, for additional
|
||||
|
@ -143,9 +146,15 @@ def layer_test(layer_cls,
|
|||
expected_output_dtype = input_dtype
|
||||
|
||||
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
|
||||
assert_equal = string_test
|
||||
if test_harness:
|
||||
assert_equal = test_harness.assertAllEqual
|
||||
else:
|
||||
assert_equal = string_test
|
||||
else:
|
||||
assert_equal = numeric_test
|
||||
if test_harness:
|
||||
assert_equal = test_harness.assertAllClose
|
||||
else:
|
||||
assert_equal = numeric_test
|
||||
|
||||
# instantiation
|
||||
kwargs = kwargs or {}
|
||||
|
@ -228,6 +237,7 @@ def layer_test(layer_cls,
|
|||
# test training mode (e.g. useful for dropout tests)
|
||||
# Rebuild the model to avoid the graph being reused between predict() and
|
||||
# See b/120160788 for more details. This should be mitigated after 2.0.
|
||||
layer_weights = layer.get_weights() # Get the layer weights BEFORE training.
|
||||
if validate_training:
|
||||
model = models.Model(x, layer(x))
|
||||
if _thread_local_data.run_eagerly is not None:
|
||||
|
@ -252,6 +262,8 @@ def layer_test(layer_cls,
|
|||
model = models.Sequential()
|
||||
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
|
||||
model.add(layer)
|
||||
|
||||
layer.set_weights(layer_weights)
|
||||
actual_output = model.predict(input_data)
|
||||
actual_output_shape = actual_output.shape
|
||||
for expected_dim, actual_dim in zip(computed_output_shape,
|
||||
|
|
Loading…
Reference in New Issue