Removes a should_use decorator from tf.Assert to prevent extraneous warning inside functions.

PiperOrigin-RevId: 224061693
This commit is contained in:
A. Unique TensorFlower 2018-12-04 15:54:54 -08:00 committed by TensorFlower Gardener
parent 0918ff3de1
commit 43a1df9ac3
3 changed files with 14 additions and 2 deletions

View File

@ -3250,6 +3250,7 @@ cuda_py_test(
":util",
":variable_scope",
":variables",
"//tensorflow/python/eager:def_function",
],
)

View File

@ -56,7 +56,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
@ -113,7 +112,6 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
@tf_export("debugging.Assert", "Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -1014,6 +1015,18 @@ class AssertTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(c)
@test_util.run_in_graph_and_eager_modes
def testAssertInFunction(self):
@def_function.function
def whiny(value):
control_flow_ops.Assert(value, ["Raised false"])
return constant_op.constant(5)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(whiny(False))
self.assertAllEqual(whiny(True), 5)
if __name__ == "__main__":
googletest.main()