Fix bug inside _WeakrefSelf that was effectively suppressing autograph for methods.
PiperOrigin-RevId: 225412615
This commit is contained in:
parent
bd1aaf9f95
commit
d501a62aae
@ -1223,19 +1223,18 @@ def validate_signature(signature):
|
||||
def defun(func=None, input_signature=None, autograph=True):
|
||||
"""Compiles a Python function into a callable TensorFlow graph.
|
||||
|
||||
`defun` (short for "define function") trace-compiles a Python function
|
||||
`defun` (short for "define function") compiles a Python function
|
||||
composed of TensorFlow operations into a callable that executes a `tf.Graph`
|
||||
containing those operations. The callable produced by `defun` contains only
|
||||
the subgraph of TensorFlow operations that were executed when the Python
|
||||
function was called with a particular input signature, defined as a list
|
||||
of the shapes and dtypes of the Python function's Tensor-valued arguments and
|
||||
the values of its non-Tensor Python objects. In particular, `defun` is _not_ a
|
||||
compiler for arbitrary Python code.
|
||||
the values of its non-Tensor Python objects.
|
||||
|
||||
When eager execution is enabled, the ability to create graphs from Python
|
||||
functions makes it possible to incrementally trade off debugability and
|
||||
interactivity for performance. Functions compiled with `defun` cannot be
|
||||
inspected with `pdb` and `print` statements; however, executing a graph
|
||||
inspected with `pdb`; however, executing a graph
|
||||
generated by `defun` sometimes takes less time and memory than eagerly
|
||||
executing the corresponding Python function, since specifying computations as
|
||||
graphs allows for optimizations like automatic buffer reuse and
|
||||
@ -1326,6 +1325,7 @@ def defun(func=None, input_signature=None, autograph=True):
|
||||
outer graph otherwise.
|
||||
|
||||
_Input Signatures_
|
||||
|
||||
By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
|
||||
for every unique sequence of the shapes and dtypes of Tensor arguments and
|
||||
the values of Python objects it is invoked with. For example, calling
|
||||
@ -1384,6 +1384,7 @@ def defun(func=None, input_signature=None, autograph=True):
|
||||
Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
|
||||
|
||||
_Tracing_
|
||||
|
||||
Be aware that because `F` only logs TensorFlow operations, all the other
|
||||
Python code that `f` executes will only shape the _construction_ of the graphs
|
||||
that `F` executes: the Python code won't be executed when the graphs
|
||||
@ -1409,6 +1410,7 @@ def defun(func=None, input_signature=None, autograph=True):
|
||||
replace the call to `np.random.randn` with `tf.random_normal((5, 5))`.
|
||||
|
||||
_Python Side-Effects_
|
||||
|
||||
A corollary of the previous discussion on tracing is the following: If a
|
||||
Python function `f` has Python side-effects, then executing `f` multiple times
|
||||
will not necessarily be semantically equivalent to executing `F =
|
||||
@ -1416,7 +1418,8 @@ def defun(func=None, input_signature=None, autograph=True):
|
||||
that `defun` only captures the subgraph of TensorFlow operations that is
|
||||
constructed when `f` is called in a graph-building context.
|
||||
|
||||
_Python Control Flow_.
|
||||
_Python Control Flow_
|
||||
|
||||
The structure of many machine learning computations depend upon whether one is
|
||||
training or validating, and it is common to nest specialized logic under `if
|
||||
training:` blocks. By mapping each input signature to a unique graph, `defun`
|
||||
@ -1445,27 +1448,26 @@ def defun(func=None, input_signature=None, autograph=True):
|
||||
exact_outputs = lossy_matmul(W, x, training=False)
|
||||
```
|
||||
|
||||
On the other hand, because `defun` generates graphs by tracing and not by
|
||||
source code analysis, it fully unrolls Python `for` and `while` loops,
|
||||
potentially creating large graphs. If your Python function has native loops
|
||||
that run for many iterations, consider replacing them with `tf.while_loop`
|
||||
operations.
|
||||
_TensorFlow Control Flow_
|
||||
|
||||
When constructing graphs, `tf.Tensor` objects cannot be used as Python
|
||||
`bool` objects. This means, for example, that you should replace code in `f`
|
||||
resembling
|
||||
When `autograph` is `True`, data-dependent control flow is allowed as well.
|
||||
Control flow statements that depend on `Tensor` values are staged into
|
||||
corresponding TensorFlow ops. For example, the following code will work as
|
||||
expected:
|
||||
|
||||
```python
|
||||
|
||||
if tensor < 10:
|
||||
true_fn()
|
||||
else:
|
||||
false_fn()
|
||||
@tf.contrib.eager.defun
|
||||
def dynamic_rnn_loop(cell, seq):
|
||||
state, output = cell.zero_state()
|
||||
for input in seq:
|
||||
state, output = cell(input, state)
|
||||
return output
|
||||
```
|
||||
|
||||
with `tf.cond(tensor < 10, true_fn, false_fn)`.
|
||||
For more information see `tf.autograph`.
|
||||
|
||||
_Variables_
|
||||
|
||||
TensorFlow operations related to variable creation and initialization are
|
||||
automatically lifted out of the graphs generated by `defun`. In practice, this
|
||||
implies that variable creation and initialization only happen the first time
|
||||
@ -1638,12 +1640,19 @@ def class_method_to_instance_method(original_function, instance):
|
||||
assert hasattr(original_function, "python_function")
|
||||
|
||||
def bound_method_wrapper(*args, **kwargs):
|
||||
"""Wraps either a dummy MethodType or a converted AutoGraph function."""
|
||||
# __wrapped__ allows AutoGraph to swap in a converted function.
|
||||
wrapped_fn = bound_method_wrapper.__wrapped__
|
||||
# If __wrapped__ was not replaced, then call original_function.
|
||||
# TODO(b/119246461): This needs to be simplified.
|
||||
if tf_inspect.ismethod(wrapped_fn):
|
||||
|
||||
if wrapped_fn is bound_method_wrapper.__original_wrapped__:
|
||||
# If __wrapped__ was not replaced, then call original_function.
|
||||
wrapped_fn = original_function.python_function
|
||||
if tf_inspect.ismethod(wrapped_fn):
|
||||
wrapped_fn = six.get_unbound_function(wrapped_fn)
|
||||
return wrapped_fn(weak_instance(), *args, **kwargs)
|
||||
|
||||
# If __wrapped__ was replaced, then it is always an unbound function
|
||||
# that takes self as first argument.
|
||||
return wrapped_fn(weak_instance(), *args, **kwargs)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user