Removes a should_use decorator from tf.Assert to prevent extraneous warning inside functions.
PiperOrigin-RevId: 224061693
This commit is contained in:
parent
0918ff3de1
commit
43a1df9ac3
@ -3250,6 +3250,7 @@ cuda_py_test(
|
|||||||
":util",
|
":util",
|
||||||
":variable_scope",
|
":variable_scope",
|
||||||
":variables",
|
":variables",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,7 +56,6 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_should_use
|
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -113,7 +112,6 @@ def _summarize_eager(tensor, summarize=None):
|
|||||||
# Assert and Print are special symbols in python, so we must
|
# Assert and Print are special symbols in python, so we must
|
||||||
# use an upper-case version of them.
|
# use an upper-case version of them.
|
||||||
@tf_export("debugging.Assert", "Assert")
|
@tf_export("debugging.Assert", "Assert")
|
||||||
@tf_should_use.should_use_result
|
|
||||||
def Assert(condition, data, summarize=None, name=None):
|
def Assert(condition, data, summarize=None, name=None):
|
||||||
"""Asserts that the given condition is true.
|
"""Asserts that the given condition is true.
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import numpy as np
|
|||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -1014,6 +1015,18 @@ class AssertTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.evaluate(c)
|
self.evaluate(c)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testAssertInFunction(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def whiny(value):
|
||||||
|
control_flow_ops.Assert(value, ["Raised false"])
|
||||||
|
return constant_op.constant(5)
|
||||||
|
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self.evaluate(whiny(False))
|
||||||
|
|
||||||
|
self.assertAllEqual(whiny(True), 5)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user