From 43a1df9ac3df076e92655152500c9e9f5ba399be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Dec 2018 15:54:54 -0800 Subject: [PATCH] Removes a should_use decorator from tf.Assert to prevent extraneous warning inside functions. PiperOrigin-RevId: 224061693 --- tensorflow/python/BUILD | 1 + tensorflow/python/ops/control_flow_ops.py | 2 -- tensorflow/python/ops/control_flow_ops_test.py | 13 +++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a558045e4af..1384a2ea738 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3250,6 +3250,7 @@ cuda_py_test( ":util", ":variable_scope", ":variables", + "//tensorflow/python/eager:def_function", ], ) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index b7e50c1dae5..75c081801a4 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -56,7 +56,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest -from tensorflow.python.util import tf_should_use from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export @@ -113,7 +112,6 @@ def _summarize_eager(tensor, summarize=None): # Assert and Print are special symbols in python, so we must # use an upper-case version of them. @tf_export("debugging.Assert", "Assert") -@tf_should_use.should_use_result def Assert(condition, data, summarize=None, name=None): """Asserts that the given condition is true. diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index c020189ad63..b19ec4bd61b 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1014,6 +1015,18 @@ class AssertTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(c) + @test_util.run_in_graph_and_eager_modes + def testAssertInFunction(self): + + @def_function.function + def whiny(value): + control_flow_ops.Assert(value, ["Raised false"]) + return constant_op.constant(5) + + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(whiny(False)) + + self.assertAllEqual(whiny(True), 5) if __name__ == "__main__": googletest.main()