Improvements to TFShouldUse
(1) Add an option to raise an exception if an object is not used by the end
of tf.function tracing.
(2) Clean up support in eager mode (by improving introspection).
This requires adding a slow path for checking if an object is an instance
of a subclass of EagerTensor created by TFShouldUseWrapper.
In this case, pull out the wrapped EagerTensor and operate on that.
I also had to move a method from func_graph.py to ops.py to avoid a circular
dependency between func_graph and ops/tensor_array_ops (ick!!)
PiperOrigin-RevId: 278627366
Change-Id: Idb4bb9bb0783f5d3ae66b568f101bd4b0be27b57
This commit is contained in:
parent
87c5874443
commit
f333c7affd
@ -1909,6 +1909,7 @@ class ConcreteFunction(object):
|
||||
|
||||
|
||||
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
|
||||
_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def g(x):
|
||||
old_values = list(values)
|
||||
func_graph.add_exit_callback_to_default_func_graph(append_1)
|
||||
ops.add_exit_callback_to_default_func_graph(append_1)
|
||||
self.assertEqual(old_values, values)
|
||||
return x + 1
|
||||
|
||||
@ -166,7 +166,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def f(x):
|
||||
old_values = list(values)
|
||||
func_graph.add_exit_callback_to_default_func_graph(append_2)
|
||||
ops.add_exit_callback_to_default_func_graph(append_2)
|
||||
self.assertEqual(old_values, values)
|
||||
return tf_g(x)
|
||||
|
||||
@ -179,7 +179,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testCannotAddExitCallbackWhenNotInFunctionScope(self):
|
||||
with self.assertRaisesRegexp(RuntimeError, 'when not building a function.'):
|
||||
func_graph.add_exit_callback_to_default_func_graph(lambda: None)
|
||||
ops.add_exit_callback_to_default_func_graph(lambda: None)
|
||||
|
||||
def testVariable(self):
|
||||
v1 = variables.Variable(1.0)
|
||||
|
||||
@ -1274,36 +1274,3 @@ def dismantle_func_graph(func_graph):
|
||||
"""
|
||||
func_graph.clear_captures()
|
||||
ops.dismantle_graph(func_graph)
|
||||
|
||||
|
||||
def add_exit_callback_to_default_func_graph(fn):
|
||||
"""Add a callback to run when the default function graph goes out of scope.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
@tf.function
|
||||
def fn(x, v):
|
||||
expensive = expensive_object(v)
|
||||
add_exit_callback_to_default_func_graph(lambda: expensive.release())
|
||||
return g(x, expensive)
|
||||
|
||||
fn(x=tf.constant(...), v=...)
|
||||
# `expensive` has been released.
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: A callable that takes no arguments and whose output is ignored.
|
||||
To be executed when exiting func graph scope.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If executed when the current defualt graph is not a FuncGraph,
|
||||
or not currently executing in function creation mode (e.g., if inside
|
||||
an init_scope).
|
||||
"""
|
||||
default_graph = ops.get_default_graph()
|
||||
if not default_graph._building_function: # pylint: disable=protected-access
|
||||
raise RuntimeError(
|
||||
"Cannot add scope exit callbacks when not building a function. "
|
||||
"Default graph: {}".format(default_graph))
|
||||
default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access
|
||||
|
||||
@ -6597,3 +6597,36 @@ def raise_from_not_ok_status(e, name):
|
||||
# pylint: disable=protected-access
|
||||
six.raise_from(core._status_to_exception(e.code, message), None)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def add_exit_callback_to_default_func_graph(fn):
|
||||
"""Add a callback to run when the default function graph goes out of scope.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
@tf.function
|
||||
def fn(x, v):
|
||||
expensive = expensive_object(v)
|
||||
add_exit_callback_to_default_func_graph(lambda: expensive.release())
|
||||
return g(x, expensive)
|
||||
|
||||
fn(x=tf.constant(...), v=...)
|
||||
# `expensive` has been released.
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: A callable that takes no arguments and whose output is ignored.
|
||||
To be executed when exiting func graph scope.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If executed when the current defualt graph is not a FuncGraph,
|
||||
or not currently executing in function creation mode (e.g., if inside
|
||||
an init_scope).
|
||||
"""
|
||||
default_graph = get_default_graph()
|
||||
if not default_graph._building_function: # pylint: disable=protected-access
|
||||
raise RuntimeError(
|
||||
"Cannot add scope exit callbacks when not building a function. "
|
||||
"Default graph: {}".format(default_graph))
|
||||
default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access
|
||||
|
||||
@ -376,6 +376,32 @@ static PyObject* TFE_ClearScalarCache();
|
||||
}
|
||||
if (EagerTensor_CheckExact(elem)) {
|
||||
(*$1)[i] = EagerTensor_Handle(elem);
|
||||
} else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
|
||||
// Use equivalent of object.__getattribute__ to get the underlying
|
||||
// tf wrapped EagerTensor (if there is one).
|
||||
tensorflow::Safe_PyObjectPtr tf_should_use_attr(
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyString_InternFromString("_tf_should_use_wrapped_value")
|
||||
#else
|
||||
PyUnicode_InternFromString("_tf_should_use_wrapped_value")
|
||||
#endif
|
||||
);
|
||||
tensorflow::Safe_PyObjectPtr value_attr(
|
||||
PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
|
||||
if (value_attr) {
|
||||
// This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
|
||||
(*$1)[i] = EagerTensor_Handle(value_attr.get());
|
||||
} else {
|
||||
// This is a subclass of EagerTensor that we don't support.
|
||||
PyErr_Clear();
|
||||
SWIG_exception_fail(
|
||||
SWIG_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Saw an object that is an instance of a strict subclass of "
|
||||
"EagerTensor, which is not supported. Item ",
|
||||
i, " is type: ", elem->ob_type->tp_name)
|
||||
.c_str());
|
||||
}
|
||||
} else if (tensorflow::swig::IsTensor(elem)) {
|
||||
// If it isnt an EagerTensor, but is still a Tensor, it must be a graph
|
||||
// tensor.
|
||||
@ -395,7 +421,7 @@ static PyObject* TFE_ClearScalarCache();
|
||||
" with tf.init_scope():\n",
|
||||
" added = my_constant * 2\n",
|
||||
"The graph tensor has name: ",
|
||||
TFE_GetPythonString(name_attr.get())
|
||||
name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>"
|
||||
).c_str());
|
||||
} else {
|
||||
SWIG_exception_fail(
|
||||
@ -403,7 +429,7 @@ static PyObject* TFE_ClearScalarCache();
|
||||
tensorflow::strings::StrCat(
|
||||
"provided list of inputs contains objects other "
|
||||
"than 'EagerTensor'. Item ",
|
||||
i, " is ", elem->ob_type->tp_name).c_str());
|
||||
i, " is type: ", elem->ob_type->tp_name).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,19 +40,25 @@ class _TFShouldUseHelper(object):
|
||||
main issues this wrapper warns about).
|
||||
"""
|
||||
|
||||
def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated,
|
||||
def __init__(self, type_, repr_, stack_frame, error_in_function,
|
||||
warn_in_eager):
|
||||
self._type = type_
|
||||
self._repr = repr_
|
||||
self._stack_frame = stack_frame
|
||||
self._fatal_error_if_unsated = fatal_error_if_unsated
|
||||
if warn_in_eager:
|
||||
self._sated = False
|
||||
self._error_in_function = error_in_function
|
||||
if context.executing_eagerly():
|
||||
# If warn_in_eager, sated == False. Otherwise true.
|
||||
self._sated = not warn_in_eager
|
||||
elif ops.get_default_graph()._building_function: # pylint: disable=protected-access
|
||||
if error_in_function:
|
||||
self._sated = False
|
||||
ops.add_exit_callback_to_default_func_graph(
|
||||
lambda: self._check_sated(raise_error=True))
|
||||
else:
|
||||
self._sated = True
|
||||
else:
|
||||
# If in eager mode or building a function with autodeps, we generally do
|
||||
# not need these warnings since behavior is eager-like.
|
||||
self._sated = (context.executing_eagerly()
|
||||
or ops.get_default_graph()._building_function) # pylint: disable=protected-access
|
||||
# TF1 graph building mode
|
||||
self._sated = False
|
||||
|
||||
def sate(self):
|
||||
self._sated = True
|
||||
@ -61,52 +67,62 @@ class _TFShouldUseHelper(object):
|
||||
self._stack_frame = None
|
||||
self._logging_module = None
|
||||
|
||||
def __del__(self):
|
||||
def _check_sated(self, raise_error):
|
||||
"""Check if the object has been sated."""
|
||||
if self._sated:
|
||||
return
|
||||
if self._fatal_error_if_unsated:
|
||||
logger = tf_logging.fatal
|
||||
else:
|
||||
logger = tf_logging.error
|
||||
creation_stack = ''.join(
|
||||
[line.rstrip()
|
||||
for line in traceback.format_stack(self._stack_frame, limit=5)])
|
||||
logger(
|
||||
'==================================\n'
|
||||
'Object was never used (type %s):\n%s\nIf you want to mark it as '
|
||||
'used call its "mark_used()" method.\nIt was originally created '
|
||||
'here:\n%s\n'
|
||||
'==================================' %
|
||||
(self._type, self._repr, creation_stack))
|
||||
if raise_error:
|
||||
try:
|
||||
raise RuntimeError(
|
||||
'Object was never used (type {}): {}. If you want to mark it as '
|
||||
'used call its "mark_used()" method. It was originally created '
|
||||
'here:\n{}'.format(self._type, self._repr, creation_stack))
|
||||
finally:
|
||||
self.sate()
|
||||
else:
|
||||
tf_logging.error(
|
||||
'==================================\n'
|
||||
'Object was never used (type {}):\n{}\nIf you want to mark it as '
|
||||
'used call its "mark_used()" method.\nIt was originally created '
|
||||
'here:\n{}\n'
|
||||
'=================================='
|
||||
.format(self._type, self._repr, creation_stack))
|
||||
|
||||
def __del__(self):
|
||||
self._check_sated(raise_error=False)
|
||||
|
||||
|
||||
def _new__init__(self, true_value, tf_should_use_helper):
|
||||
def _new__init__(self, wrapped_value, tf_should_use_helper):
|
||||
# pylint: disable=protected-access
|
||||
self._tf_should_use_helper = tf_should_use_helper
|
||||
self._true_value = true_value
|
||||
self._tf_should_use_wrapped_value = wrapped_value
|
||||
|
||||
|
||||
def _new__setattr__(self, key, value):
|
||||
if key in ('_tf_should_use_helper', '_true_value'):
|
||||
if key in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'):
|
||||
return object.__setattr__(self, key, value)
|
||||
return setattr(
|
||||
object.__getattribute__(self, '_true_value'),
|
||||
object.__getattribute__(self, '_tf_should_use_wrapped_value'),
|
||||
key, value)
|
||||
|
||||
|
||||
def _new__getattribute__(self, key):
|
||||
if key not in ('_tf_should_use_helper', '_true_value'):
|
||||
if key not in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'):
|
||||
object.__getattribute__(self, '_tf_should_use_helper').sate()
|
||||
if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'):
|
||||
return object.__getattribute__(self, key)
|
||||
return getattr(object.__getattribute__(self, '_true_value'), key)
|
||||
return getattr(
|
||||
object.__getattribute__(self, '_tf_should_use_wrapped_value'), key)
|
||||
|
||||
|
||||
def _new_mark_used(self, *args, **kwargs):
|
||||
object.__getattribute__(self, '_tf_should_use_helper').sate()
|
||||
try:
|
||||
mu = object.__getattribute__(
|
||||
object.__getattribute__(self, '_true_value'),
|
||||
object.__getattribute__(self, '_tf_should_use_wrapped_value'),
|
||||
'mark_used')
|
||||
return mu(*args, **kwargs)
|
||||
except AttributeError:
|
||||
@ -145,21 +161,29 @@ def _get_wrapper(x, tf_should_use_helper):
|
||||
return copy_tx(x, tf_should_use_helper)
|
||||
|
||||
|
||||
def _add_should_use_warning(x, fatal_error=False, warn_in_eager=False):
|
||||
def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False):
|
||||
"""Wraps object x so that if it is never used, a warning is logged.
|
||||
|
||||
Args:
|
||||
x: Python object.
|
||||
fatal_error: Python bool. If `True`, tf.compat.v1.logging.fatal is raised
|
||||
if the returned value is never used.
|
||||
error_in_function: Python bool. If `True`, a `RuntimeError` is raised
|
||||
if the returned value is never used when created during `tf.function`
|
||||
tracing.
|
||||
warn_in_eager: Python bool. If `True` raise warning if in Eager mode as well
|
||||
as graph.
|
||||
as graph mode.
|
||||
|
||||
Returns:
|
||||
An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
|
||||
and is a very shallow wrapper for `x` which logs access into `x`.
|
||||
"""
|
||||
if x is None or x == []: # pylint: disable=g-explicit-bool-comparison
|
||||
if x is None or (isinstance(x, list) and not x):
|
||||
return x
|
||||
|
||||
if context.executing_eagerly() and not warn_in_eager:
|
||||
return x
|
||||
|
||||
if ops.get_default_graph()._building_function and not error_in_function: # pylint: disable=protected-access
|
||||
# We don't currently log warnings in tf.function calls, so just skip it.
|
||||
return x
|
||||
|
||||
# Extract the current frame for later use by traceback printing.
|
||||
@ -172,22 +196,25 @@ def _add_should_use_warning(x, fatal_error=False, warn_in_eager=False):
|
||||
type_=type(x),
|
||||
repr_=repr(x),
|
||||
stack_frame=stack_frame,
|
||||
fatal_error_if_unsated=fatal_error,
|
||||
error_in_function=error_in_function,
|
||||
warn_in_eager=warn_in_eager)
|
||||
|
||||
return _get_wrapper(x, tf_should_use_helper)
|
||||
|
||||
|
||||
def should_use_result(fn=None, warn_in_eager=False):
|
||||
def should_use_result(fn=None, warn_in_eager=False, error_in_function=False):
|
||||
"""Function wrapper that ensures the function's output is used.
|
||||
|
||||
If the output is not used, a `tf.compat.v1.logging.error` is logged.
|
||||
If the output is not used, a `logging.error` is logged. If
|
||||
`error_in_function` is set, then a `RuntimeError` will be raised at the
|
||||
end of function tracing if the output is not used by that point.
|
||||
|
||||
An output is marked as used if any of its attributes are read, modified, or
|
||||
updated. Examples when the output is a `Tensor` include:
|
||||
|
||||
- Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`)
|
||||
- Accessing a property (e.g. getting `t.name` or `t.op`).
|
||||
- Calling `t.mark_used()`.
|
||||
|
||||
Note, certain behaviors cannot be tracked - for these the object may not
|
||||
be marked as used. Examples include:
|
||||
@ -198,6 +225,7 @@ def should_use_result(fn=None, warn_in_eager=False):
|
||||
Args:
|
||||
fn: The function to wrap.
|
||||
warn_in_eager: Whether to create warnings in Eager as well.
|
||||
error_in_function: Whether to raise an error when creating a tf.function.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
@ -205,7 +233,8 @@ def should_use_result(fn=None, warn_in_eager=False):
|
||||
def decorated(fn):
|
||||
def wrapped(*args, **kwargs):
|
||||
return _add_should_use_warning(fn(*args, **kwargs),
|
||||
warn_in_eager=warn_in_eager)
|
||||
warn_in_eager=warn_in_eager,
|
||||
error_in_function=error_in_function)
|
||||
return tf_decorator.make_decorator(
|
||||
target=fn,
|
||||
decorator_func=wrapped,
|
||||
@ -214,45 +243,10 @@ def should_use_result(fn=None, warn_in_eager=False):
|
||||
(fn.__doc__ or '') +
|
||||
('\n\n '
|
||||
'**NOTE** The output of this function should be used. If it is '
|
||||
'not, a warning will be logged. To mark the output as used, '
|
||||
'call its .mark_used() method.')))
|
||||
'not, a warning will be logged or an error may be raised. '
|
||||
'To mark the output as used, call its .mark_used() method.')))
|
||||
|
||||
if fn is not None:
|
||||
return decorated(fn)
|
||||
|
||||
else:
|
||||
return decorated
|
||||
|
||||
|
||||
def must_use_result_or_fatal(fn):
|
||||
"""Function wrapper that ensures the function's output is used.
|
||||
|
||||
If the output is not used, a `tf.compat.v1.logging.fatal` error is raised.
|
||||
|
||||
An output is marked as used if any of its attributes are read, modified, or
|
||||
updated. Examples when the output is a `Tensor` include:
|
||||
|
||||
- Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`)
|
||||
- Accessing a property (e.g. getting `t.name` or `t.op`).
|
||||
|
||||
Note, certain behaviors cannot be tracked - for these the object may not
|
||||
be marked as used. Examples include:
|
||||
|
||||
- `t != 0`. In this case, comparison is done on types / ids.
|
||||
- `isinstance(t, tf.Tensor)`. Similar to above.
|
||||
|
||||
Args:
|
||||
fn: The function to wrap.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
"""
|
||||
def wrapped(*args, **kwargs):
|
||||
return _add_should_use_warning(fn(*args, **kwargs), fatal_error=True)
|
||||
return tf_decorator.make_decorator(
|
||||
fn, wrapped, 'must_use_result_or_fatal',
|
||||
((fn.__doc__ or '') +
|
||||
('\n\n '
|
||||
'**NOTE** The output of this function must be used. If it is not, '
|
||||
'a fatal error will be raised. To mark the output as used, '
|
||||
'call its .mark_used() method.')))
|
||||
|
||||
@ -23,6 +23,8 @@ import contextlib
|
||||
import gc
|
||||
import sys
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -34,50 +36,53 @@ from tensorflow.python.util import tf_should_use
|
||||
def reroute_error():
|
||||
"""Temporarily reroute errors written to tf_logging.error into `captured`."""
|
||||
with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
|
||||
with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal:
|
||||
yield error, fatal
|
||||
yield error
|
||||
|
||||
|
||||
class TfShouldUseTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAddShouldUseWarningWhenNotUsed(self):
|
||||
c = constant_op.constant(0, name='blah0')
|
||||
def in_this_function():
|
||||
h = tf_should_use._add_should_use_warning(c)
|
||||
h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
|
||||
del h
|
||||
with reroute_error() as (error, _):
|
||||
with reroute_error() as error:
|
||||
in_this_function()
|
||||
msg = '\n'.join(error.call_args[0])
|
||||
self.assertIn('Object was never used', msg)
|
||||
self.assertIn('blah0:0', msg)
|
||||
if not context.executing_eagerly():
|
||||
self.assertIn('blah0:0', msg)
|
||||
self.assertIn('in_this_function', msg)
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAddShouldUseFatalWhenNotUsed(self):
|
||||
c = constant_op.constant(0, name='blah0')
|
||||
def testAddShouldUseExceptionInEagerAndFunction(self):
|
||||
def in_this_function():
|
||||
h = tf_should_use._add_should_use_warning(c, fatal_error=True)
|
||||
c = constant_op.constant(0, name='blah0')
|
||||
h = tf_should_use._add_should_use_warning(
|
||||
c, warn_in_eager=True, error_in_function=True)
|
||||
del h
|
||||
with reroute_error() as (_, fatal):
|
||||
in_this_function()
|
||||
msg = '\n'.join(fatal.call_args[0])
|
||||
self.assertIn('Object was never used', msg)
|
||||
self.assertIn('blah0:0', msg)
|
||||
self.assertIn('in_this_function', msg)
|
||||
if context.executing_eagerly():
|
||||
with reroute_error() as error:
|
||||
in_this_function()
|
||||
msg = '\n'.join(error.call_args[0])
|
||||
self.assertIn('Object was never used', msg)
|
||||
self.assertIn('in_this_function', msg)
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
tf_fn_in_this_function = def_function.function(in_this_function)
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, r'Object was never used.*blah0:0'):
|
||||
tf_fn_in_this_function()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
def _testAddShouldUseWarningWhenUsed(self, fn, name):
|
||||
c = constant_op.constant(0, name=name)
|
||||
with reroute_error() as (error, fatal):
|
||||
h = tf_should_use._add_should_use_warning(c)
|
||||
with reroute_error() as error:
|
||||
h = tf_should_use._add_should_use_warning(c, warn_in_eager=True)
|
||||
fn(h)
|
||||
del h
|
||||
error.assert_not_called()
|
||||
fatal.assert_not_called()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAddShouldUseWarningWhenUsedWithAdd(self):
|
||||
def add(h):
|
||||
_ = h + 1
|
||||
@ -85,34 +90,32 @@ class TfShouldUseTest(test.TestCase):
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAddShouldUseWarningWhenUsedWithGetName(self):
|
||||
def get_name(h):
|
||||
_ = h.name
|
||||
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
|
||||
def testAddShouldUseWarningWhenUsedWithGetShape(self):
|
||||
def get_shape(h):
|
||||
_ = h.shape
|
||||
self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name')
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShouldUseResult(self):
|
||||
@tf_should_use.should_use_result
|
||||
@tf_should_use.should_use_result(warn_in_eager=True)
|
||||
def return_const(value):
|
||||
return constant_op.constant(value, name='blah2')
|
||||
with reroute_error() as (error, _):
|
||||
with reroute_error() as error:
|
||||
return_const(0.0)
|
||||
msg = '\n'.join(error.call_args[0])
|
||||
self.assertIn('Object was never used', msg)
|
||||
self.assertIn('blah2:0', msg)
|
||||
if not context.executing_eagerly():
|
||||
self.assertIn('blah2:0', msg)
|
||||
self.assertIn('return_const', msg)
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShouldUseResultWhenNotReallyUsed(self):
|
||||
@tf_should_use.should_use_result
|
||||
@tf_should_use.should_use_result(warn_in_eager=True)
|
||||
def return_const(value):
|
||||
return constant_op.constant(value, name='blah3')
|
||||
with reroute_error() as (error, _):
|
||||
with reroute_error() as error:
|
||||
with self.cached_session():
|
||||
return_const(0.0)
|
||||
# Creating another op and executing it does not mark the
|
||||
@ -121,14 +124,15 @@ class TfShouldUseTest(test.TestCase):
|
||||
self.evaluate(v)
|
||||
msg = '\n'.join(error.call_args[0])
|
||||
self.assertIn('Object was never used', msg)
|
||||
self.assertIn('blah3:0', msg)
|
||||
if not context.executing_eagerly():
|
||||
self.assertIn('blah3:0', msg)
|
||||
self.assertIn('return_const', msg)
|
||||
gc.collect()
|
||||
self.assertFalse(gc.garbage)
|
||||
|
||||
# Tests that mark_used is available in the API.
|
||||
def testMarkUsed(self):
|
||||
@tf_should_use.should_use_result
|
||||
@tf_should_use.should_use_result(warn_in_eager=True)
|
||||
def return_const(value):
|
||||
return constant_op.constant(value, name='blah3')
|
||||
|
||||
|
||||
@ -277,6 +277,16 @@ int IsTensorHelper(PyObject* o) {
|
||||
return check_cache->CachedLookup(o);
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is an EagerTensor.
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred.
|
||||
int IsEagerTensorHelper(PyObject* o) {
|
||||
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
|
||||
return IsInstanceOfRegisteredType(to_check, "EagerTensor");
|
||||
});
|
||||
return check_cache->CachedLookup(o);
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is a ResourceVariable.
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred.
|
||||
@ -870,6 +880,7 @@ bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
|
||||
bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
|
||||
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
|
||||
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
|
||||
bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
|
||||
bool IsResourceVariable(PyObject* o) {
|
||||
return IsResourceVariableHelper(o) == 1;
|
||||
}
|
||||
|
||||
@ -131,6 +131,15 @@ bool IsAttrs(PyObject* o);
|
||||
// True if the object is a tensor.
|
||||
bool IsTensor(PyObject* o);
|
||||
|
||||
// Returns a true if its input is an eager.EagerTensor.
|
||||
//
|
||||
// Args:
|
||||
// o: the input to be checked.
|
||||
//
|
||||
// Returns:
|
||||
// True if the object is an eager tensor (or mimicking as one).
|
||||
bool IsEagerTensorSlow(PyObject* o);
|
||||
|
||||
// Returns a true if its input is a ResourceVariable.
|
||||
//
|
||||
// Args:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user