Removes a should_use decorator from tf.Assert to prevent extraneous warning inside functions.

PiperOrigin-RevId: 224061693
This commit is contained in:
A. Unique TensorFlower 2018-12-04 15:54:54 -08:00 committed by TensorFlower Gardener
parent 0918ff3de1
commit 43a1df9ac3
3 changed files with 14 additions and 2 deletions

View File

@ -3250,6 +3250,7 @@ cuda_py_test(
":util", ":util",
":variable_scope", ":variable_scope",
":variables", ":variables",
"//tensorflow/python/eager:def_function",
], ],
) )

View File

@ -56,7 +56,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -113,7 +112,6 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must # Assert and Print are special symbols in python, so we must
# use an upper-case version of them. # use an upper-case version of them.
@tf_export("debugging.Assert", "Assert") @tf_export("debugging.Assert", "Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None): def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true. """Asserts that the given condition is true.

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -1014,6 +1015,18 @@ class AssertTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(c) self.evaluate(c)
@test_util.run_in_graph_and_eager_modes
def testAssertInFunction(self):
@def_function.function
def whiny(value):
control_flow_ops.Assert(value, ["Raised false"])
return constant_op.constant(5)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(whiny(False))
self.assertAllEqual(whiny(True), 5)
if __name__ == "__main__": if __name__ == "__main__":
googletest.main() googletest.main()