diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 1928588fea1..cceaabe37a5 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -94,7 +94,8 @@ def layer_test(layer_cls, expected_output_shape=None, validate_training=True, adapt_data=None, - custom_objects=None): + custom_objects=None, + test_harness=None): """Test routine for a layer with a single input and single output. Arguments: @@ -114,6 +115,8 @@ def layer_test(layer_cls, be tested for this layer. This is only relevant for PreprocessingLayers. custom_objects: Optional dictionary mapping name strings to custom objects in the layer class. This is helpful for testing custom layers. + test_harness: The Tensorflow test, if any, that this function is being + called in. Returns: The output data (Numpy array) returned by the layer, for additional @@ -143,9 +146,15 @@ def layer_test(layer_cls, expected_output_dtype = input_dtype if dtypes.as_dtype(expected_output_dtype) == dtypes.string: - assert_equal = string_test + if test_harness: + assert_equal = test_harness.assertAllEqual + else: + assert_equal = string_test else: - assert_equal = numeric_test + if test_harness: + assert_equal = test_harness.assertAllClose + else: + assert_equal = numeric_test # instantiation kwargs = kwargs or {} @@ -228,6 +237,7 @@ def layer_test(layer_cls, # test training mode (e.g. useful for dropout tests) # Rebuild the model to avoid the graph being reused between predict() and # See b/120160788 for more details. This should be mitigated after 2.0. + layer_weights = layer.get_weights() # Get the layer weights BEFORE training. if validate_training: model = models.Model(x, layer(x)) if _thread_local_data.run_eagerly is not None: @@ -252,6 +262,8 @@ def layer_test(layer_cls, model = models.Sequential() model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) model.add(layer) + + layer.set_weights(layer_weights) actual_output = model.predict(input_data) actual_output_shape = actual_output.shape for expected_dim, actual_dim in zip(computed_output_shape,