diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index c87260e1c23..6c807e61746 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1909,6 +1909,7 @@ class ConcreteFunction(object): _pywrap_utils.RegisterType("Tensor", ops.Tensor) +_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) _pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index ca41f833625..20b21a478e4 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -158,7 +158,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def g(x): old_values = list(values) - func_graph.add_exit_callback_to_default_func_graph(append_1) + ops.add_exit_callback_to_default_func_graph(append_1) self.assertEqual(old_values, values) return x + 1 @@ -166,7 +166,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def f(x): old_values = list(values) - func_graph.add_exit_callback_to_default_func_graph(append_2) + ops.add_exit_callback_to_default_func_graph(append_2) self.assertEqual(old_values, values) return tf_g(x) @@ -179,7 +179,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def testCannotAddExitCallbackWhenNotInFunctionScope(self): with self.assertRaisesRegexp(RuntimeError, 'when not building a function.'): - func_graph.add_exit_callback_to_default_func_graph(lambda: None) + ops.add_exit_callback_to_default_func_graph(lambda: None) def testVariable(self): v1 = variables.Variable(1.0) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 8bca72c0823..57ee6d19cce 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -1274,36 +1274,3 @@ def dismantle_func_graph(func_graph): """ func_graph.clear_captures() ops.dismantle_graph(func_graph) - - -def add_exit_callback_to_default_func_graph(fn): - """Add a callback to run when the default function graph goes out of scope. - - Usage: - - ```python - @tf.function - def fn(x, v): - expensive = expensive_object(v) - add_exit_callback_to_default_func_graph(lambda: expensive.release()) - return g(x, expensive) - - fn(x=tf.constant(...), v=...) - # `expensive` has been released. - ``` - - Args: - fn: A callable that takes no arguments and whose output is ignored. - To be executed when exiting func graph scope. - - Raises: - RuntimeError: If executed when the current defualt graph is not a FuncGraph, - or not currently executing in function creation mode (e.g., if inside - an init_scope). - """ - default_graph = ops.get_default_graph() - if not default_graph._building_function: # pylint: disable=protected-access - raise RuntimeError( - "Cannot add scope exit callbacks when not building a function. " - "Default graph: {}".format(default_graph)) - default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 9f0ddef75f5..8a273e834be 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -6597,3 +6597,36 @@ def raise_from_not_ok_status(e, name): # pylint: disable=protected-access six.raise_from(core._status_to_exception(e.code, message), None) # pylint: enable=protected-access + + +def add_exit_callback_to_default_func_graph(fn): + """Add a callback to run when the default function graph goes out of scope. + + Usage: + + ```python + @tf.function + def fn(x, v): + expensive = expensive_object(v) + add_exit_callback_to_default_func_graph(lambda: expensive.release()) + return g(x, expensive) + + fn(x=tf.constant(...), v=...) + # `expensive` has been released. + ``` + + Args: + fn: A callable that takes no arguments and whose output is ignored. + To be executed when exiting func graph scope. + + Raises: + RuntimeError: If executed when the current defualt graph is not a FuncGraph, + or not currently executing in function creation mode (e.g., if inside + an init_scope). + """ + default_graph = get_default_graph() + if not default_graph._building_function: # pylint: disable=protected-access + raise RuntimeError( + "Cannot add scope exit callbacks when not building a function. " + "Default graph: {}".format(default_graph)) + default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 7a36452396a..0473c342c26 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -376,6 +376,32 @@ static PyObject* TFE_ClearScalarCache(); } if (EagerTensor_CheckExact(elem)) { (*$1)[i] = EagerTensor_Handle(elem); + } else if (tensorflow::swig::IsEagerTensorSlow(elem)) { + // Use equivalent of object.__getattribute__ to get the underlying + // tf wrapped EagerTensor (if there is one). + tensorflow::Safe_PyObjectPtr tf_should_use_attr( +#if PY_MAJOR_VERSION < 3 + PyString_InternFromString("_tf_should_use_wrapped_value") +#else + PyUnicode_InternFromString("_tf_should_use_wrapped_value") +#endif + ); + tensorflow::Safe_PyObjectPtr value_attr( + PyObject_GenericGetAttr(elem, tf_should_use_attr.get())); + if (value_attr) { + // This is an EagerTensor wrapped inside a TFShouldUse wrapped object. + (*$1)[i] = EagerTensor_Handle(value_attr.get()); + } else { + // This is a subclass of EagerTensor that we don't support. + PyErr_Clear(); + SWIG_exception_fail( + SWIG_TypeError, + tensorflow::strings::StrCat( + "Saw an object that is an instance of a strict subclass of " + "EagerTensor, which is not supported. Item ", + i, " is type: ", elem->ob_type->tp_name) + .c_str()); + } } else if (tensorflow::swig::IsTensor(elem)) { // If it isnt an EagerTensor, but is still a Tensor, it must be a graph // tensor. @@ -395,7 +421,7 @@ static PyObject* TFE_ClearScalarCache(); " with tf.init_scope():\n", " added = my_constant * 2\n", "The graph tensor has name: ", - TFE_GetPythonString(name_attr.get()) + name_attr ? TFE_GetPythonString(name_attr.get()) : "" ).c_str()); } else { SWIG_exception_fail( @@ -403,7 +429,7 @@ static PyObject* TFE_ClearScalarCache(); tensorflow::strings::StrCat( "provided list of inputs contains objects other " "than 'EagerTensor'. Item ", - i, " is ", elem->ob_type->tp_name).c_str()); + i, " is type: ", elem->ob_type->tp_name).c_str()); } } } diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index f923b3403d0..f8c480ed1ac 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -40,19 +40,25 @@ class _TFShouldUseHelper(object): main issues this wrapper warns about). """ - def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated, + def __init__(self, type_, repr_, stack_frame, error_in_function, warn_in_eager): self._type = type_ self._repr = repr_ self._stack_frame = stack_frame - self._fatal_error_if_unsated = fatal_error_if_unsated - if warn_in_eager: - self._sated = False + self._error_in_function = error_in_function + if context.executing_eagerly(): + # If warn_in_eager, sated == False. Otherwise true. + self._sated = not warn_in_eager + elif ops.get_default_graph()._building_function: # pylint: disable=protected-access + if error_in_function: + self._sated = False + ops.add_exit_callback_to_default_func_graph( + lambda: self._check_sated(raise_error=True)) + else: + self._sated = True else: - # If in eager mode or building a function with autodeps, we generally do - # not need these warnings since behavior is eager-like. - self._sated = (context.executing_eagerly() - or ops.get_default_graph()._building_function) # pylint: disable=protected-access + # TF1 graph building mode + self._sated = False def sate(self): self._sated = True @@ -61,52 +67,62 @@ class _TFShouldUseHelper(object): self._stack_frame = None self._logging_module = None - def __del__(self): + def _check_sated(self, raise_error): + """Check if the object has been sated.""" if self._sated: return - if self._fatal_error_if_unsated: - logger = tf_logging.fatal - else: - logger = tf_logging.error creation_stack = ''.join( [line.rstrip() for line in traceback.format_stack(self._stack_frame, limit=5)]) - logger( - '==================================\n' - 'Object was never used (type %s):\n%s\nIf you want to mark it as ' - 'used call its "mark_used()" method.\nIt was originally created ' - 'here:\n%s\n' - '==================================' % - (self._type, self._repr, creation_stack)) + if raise_error: + try: + raise RuntimeError( + 'Object was never used (type {}): {}. If you want to mark it as ' + 'used call its "mark_used()" method. It was originally created ' + 'here:\n{}'.format(self._type, self._repr, creation_stack)) + finally: + self.sate() + else: + tf_logging.error( + '==================================\n' + 'Object was never used (type {}):\n{}\nIf you want to mark it as ' + 'used call its "mark_used()" method.\nIt was originally created ' + 'here:\n{}\n' + '==================================' + .format(self._type, self._repr, creation_stack)) + + def __del__(self): + self._check_sated(raise_error=False) -def _new__init__(self, true_value, tf_should_use_helper): +def _new__init__(self, wrapped_value, tf_should_use_helper): # pylint: disable=protected-access self._tf_should_use_helper = tf_should_use_helper - self._true_value = true_value + self._tf_should_use_wrapped_value = wrapped_value def _new__setattr__(self, key, value): - if key in ('_tf_should_use_helper', '_true_value'): + if key in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'): return object.__setattr__(self, key, value) return setattr( - object.__getattribute__(self, '_true_value'), + object.__getattribute__(self, '_tf_should_use_wrapped_value'), key, value) def _new__getattribute__(self, key): - if key not in ('_tf_should_use_helper', '_true_value'): + if key not in ('_tf_should_use_helper', '_tf_should_use_wrapped_value'): object.__getattribute__(self, '_tf_should_use_helper').sate() if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'): return object.__getattribute__(self, key) - return getattr(object.__getattribute__(self, '_true_value'), key) + return getattr( + object.__getattribute__(self, '_tf_should_use_wrapped_value'), key) def _new_mark_used(self, *args, **kwargs): object.__getattribute__(self, '_tf_should_use_helper').sate() try: mu = object.__getattribute__( - object.__getattribute__(self, '_true_value'), + object.__getattribute__(self, '_tf_should_use_wrapped_value'), 'mark_used') return mu(*args, **kwargs) except AttributeError: @@ -145,21 +161,29 @@ def _get_wrapper(x, tf_should_use_helper): return copy_tx(x, tf_should_use_helper) -def _add_should_use_warning(x, fatal_error=False, warn_in_eager=False): +def _add_should_use_warning(x, error_in_function=False, warn_in_eager=False): """Wraps object x so that if it is never used, a warning is logged. Args: x: Python object. - fatal_error: Python bool. If `True`, tf.compat.v1.logging.fatal is raised - if the returned value is never used. + error_in_function: Python bool. If `True`, a `RuntimeError` is raised + if the returned value is never used when created during `tf.function` + tracing. warn_in_eager: Python bool. If `True` raise warning if in Eager mode as well - as graph. + as graph mode. Returns: An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ - if x is None or x == []: # pylint: disable=g-explicit-bool-comparison + if x is None or (isinstance(x, list) and not x): + return x + + if context.executing_eagerly() and not warn_in_eager: + return x + + if ops.get_default_graph()._building_function and not error_in_function: # pylint: disable=protected-access + # We don't currently log warnings in tf.function calls, so just skip it. return x # Extract the current frame for later use by traceback printing. @@ -172,22 +196,25 @@ def _add_should_use_warning(x, fatal_error=False, warn_in_eager=False): type_=type(x), repr_=repr(x), stack_frame=stack_frame, - fatal_error_if_unsated=fatal_error, + error_in_function=error_in_function, warn_in_eager=warn_in_eager) return _get_wrapper(x, tf_should_use_helper) -def should_use_result(fn=None, warn_in_eager=False): +def should_use_result(fn=None, warn_in_eager=False, error_in_function=False): """Function wrapper that ensures the function's output is used. - If the output is not used, a `tf.compat.v1.logging.error` is logged. + If the output is not used, a `logging.error` is logged. If + `error_in_function` is set, then a `RuntimeError` will be raised at the + end of function tracing if the output is not used by that point. An output is marked as used if any of its attributes are read, modified, or updated. Examples when the output is a `Tensor` include: - Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`) - Accessing a property (e.g. getting `t.name` or `t.op`). + - Calling `t.mark_used()`. Note, certain behaviors cannot be tracked - for these the object may not be marked as used. Examples include: @@ -198,6 +225,7 @@ def should_use_result(fn=None, warn_in_eager=False): Args: fn: The function to wrap. warn_in_eager: Whether to create warnings in Eager as well. + error_in_function: Whether to raise an error when creating a tf.function. Returns: The wrapped function. @@ -205,7 +233,8 @@ def should_use_result(fn=None, warn_in_eager=False): def decorated(fn): def wrapped(*args, **kwargs): return _add_should_use_warning(fn(*args, **kwargs), - warn_in_eager=warn_in_eager) + warn_in_eager=warn_in_eager, + error_in_function=error_in_function) return tf_decorator.make_decorator( target=fn, decorator_func=wrapped, @@ -214,45 +243,10 @@ def should_use_result(fn=None, warn_in_eager=False): (fn.__doc__ or '') + ('\n\n ' '**NOTE** The output of this function should be used. If it is ' - 'not, a warning will be logged. To mark the output as used, ' - 'call its .mark_used() method.'))) + 'not, a warning will be logged or an error may be raised. ' + 'To mark the output as used, call its .mark_used() method.'))) if fn is not None: return decorated(fn) - else: return decorated - - -def must_use_result_or_fatal(fn): - """Function wrapper that ensures the function's output is used. - - If the output is not used, a `tf.compat.v1.logging.fatal` error is raised. - - An output is marked as used if any of its attributes are read, modified, or - updated. Examples when the output is a `Tensor` include: - - - Using it in any capacity (e.g. `y = t + 0`, `sess.run(t)`) - - Accessing a property (e.g. getting `t.name` or `t.op`). - - Note, certain behaviors cannot be tracked - for these the object may not - be marked as used. Examples include: - - - `t != 0`. In this case, comparison is done on types / ids. - - `isinstance(t, tf.Tensor)`. Similar to above. - - Args: - fn: The function to wrap. - - Returns: - The wrapped function. - """ - def wrapped(*args, **kwargs): - return _add_should_use_warning(fn(*args, **kwargs), fatal_error=True) - return tf_decorator.make_decorator( - fn, wrapped, 'must_use_result_or_fatal', - ((fn.__doc__ or '') + - ('\n\n ' - '**NOTE** The output of this function must be used. If it is not, ' - 'a fatal error will be raised. To mark the output as used, ' - 'call its .mark_used() method.'))) diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 65d848cf2a5..bb50edfa857 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -23,6 +23,8 @@ import contextlib import gc import sys +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -34,50 +36,53 @@ from tensorflow.python.util import tf_should_use def reroute_error(): """Temporarily reroute errors written to tf_logging.error into `captured`.""" with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: - with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal: - yield error, fatal + yield error class TfShouldUseTest(test.TestCase): - @test_util.run_deprecated_v1 def testAddShouldUseWarningWhenNotUsed(self): c = constant_op.constant(0, name='blah0') def in_this_function(): - h = tf_should_use._add_should_use_warning(c) + h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) del h - with reroute_error() as (error, _): + with reroute_error() as error: in_this_function() msg = '\n'.join(error.call_args[0]) self.assertIn('Object was never used', msg) - self.assertIn('blah0:0', msg) + if not context.executing_eagerly(): + self.assertIn('blah0:0', msg) self.assertIn('in_this_function', msg) self.assertFalse(gc.garbage) - @test_util.run_deprecated_v1 - def testAddShouldUseFatalWhenNotUsed(self): - c = constant_op.constant(0, name='blah0') + def testAddShouldUseExceptionInEagerAndFunction(self): def in_this_function(): - h = tf_should_use._add_should_use_warning(c, fatal_error=True) + c = constant_op.constant(0, name='blah0') + h = tf_should_use._add_should_use_warning( + c, warn_in_eager=True, error_in_function=True) del h - with reroute_error() as (_, fatal): - in_this_function() - msg = '\n'.join(fatal.call_args[0]) - self.assertIn('Object was never used', msg) - self.assertIn('blah0:0', msg) - self.assertIn('in_this_function', msg) + if context.executing_eagerly(): + with reroute_error() as error: + in_this_function() + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('in_this_function', msg) + self.assertFalse(gc.garbage) + + tf_fn_in_this_function = def_function.function(in_this_function) + with self.assertRaisesRegexp( + RuntimeError, r'Object was never used.*blah0:0'): + tf_fn_in_this_function() self.assertFalse(gc.garbage) def _testAddShouldUseWarningWhenUsed(self, fn, name): c = constant_op.constant(0, name=name) - with reroute_error() as (error, fatal): - h = tf_should_use._add_should_use_warning(c) + with reroute_error() as error: + h = tf_should_use._add_should_use_warning(c, warn_in_eager=True) fn(h) del h error.assert_not_called() - fatal.assert_not_called() - @test_util.run_deprecated_v1 def testAddShouldUseWarningWhenUsedWithAdd(self): def add(h): _ = h + 1 @@ -85,34 +90,32 @@ class TfShouldUseTest(test.TestCase): gc.collect() self.assertFalse(gc.garbage) - @test_util.run_deprecated_v1 - def testAddShouldUseWarningWhenUsedWithGetName(self): - def get_name(h): - _ = h.name - self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') + def testAddShouldUseWarningWhenUsedWithGetShape(self): + def get_shape(h): + _ = h.shape + self._testAddShouldUseWarningWhenUsed(get_shape, name='blah_get_name') gc.collect() self.assertFalse(gc.garbage) - @test_util.run_deprecated_v1 def testShouldUseResult(self): - @tf_should_use.should_use_result + @tf_should_use.should_use_result(warn_in_eager=True) def return_const(value): return constant_op.constant(value, name='blah2') - with reroute_error() as (error, _): + with reroute_error() as error: return_const(0.0) msg = '\n'.join(error.call_args[0]) self.assertIn('Object was never used', msg) - self.assertIn('blah2:0', msg) + if not context.executing_eagerly(): + self.assertIn('blah2:0', msg) self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) - @test_util.run_deprecated_v1 def testShouldUseResultWhenNotReallyUsed(self): - @tf_should_use.should_use_result + @tf_should_use.should_use_result(warn_in_eager=True) def return_const(value): return constant_op.constant(value, name='blah3') - with reroute_error() as (error, _): + with reroute_error() as error: with self.cached_session(): return_const(0.0) # Creating another op and executing it does not mark the @@ -121,14 +124,15 @@ class TfShouldUseTest(test.TestCase): self.evaluate(v) msg = '\n'.join(error.call_args[0]) self.assertIn('Object was never used', msg) - self.assertIn('blah3:0', msg) + if not context.executing_eagerly(): + self.assertIn('blah3:0', msg) self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) # Tests that mark_used is available in the API. def testMarkUsed(self): - @tf_should_use.should_use_result + @tf_should_use.should_use_result(warn_in_eager=True) def return_const(value): return constant_op.constant(value, name='blah3') diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 270a582783e..d1e43c92164 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -277,6 +277,16 @@ int IsTensorHelper(PyObject* o) { return check_cache->CachedLookup(o); } +// Returns 1 if `o` is an EagerTensor. +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsEagerTensorHelper(PyObject* o) { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + return IsInstanceOfRegisteredType(to_check, "EagerTensor"); + }); + return check_cache->CachedLookup(o); +} + // Returns 1 if `o` is a ResourceVariable. // Returns 0 otherwise. // Returns -1 if an error occurred. @@ -870,6 +880,7 @@ bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; } bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; } +bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; } bool IsResourceVariable(PyObject* o) { return IsResourceVariableHelper(o) == 1; } diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index fc5036a779c..7cd4b0cb495 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -131,6 +131,15 @@ bool IsAttrs(PyObject* o); // True if the object is a tensor. bool IsTensor(PyObject* o); +// Returns a true if its input is an eager.EagerTensor. +// +// Args: +// o: the input to be checked. +// +// Returns: +// True if the object is an eager tensor (or mimicking as one). +bool IsEagerTensorSlow(PyObject* o); + // Returns a true if its input is a ResourceVariable. // // Args: