eager: Treat eager tensors as constants during graph construction.
Unless capturing is explicitly enabled. PiperOrigin-RevId: 168052675
This commit is contained in:
parent
6e402d0d2c
commit
a263ea626e
@ -31,6 +31,7 @@ from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
@ -57,10 +58,8 @@ def capture_tensors(captures):
|
||||
_scoped_captures.tensors = old
|
||||
|
||||
|
||||
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
"""Captures a tfe Tensor while building a graph mode function.
|
||||
|
||||
Creates a placeholder to pass the tensor as an argument.
|
||||
def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
||||
"""Captures a Tensor while building a graph mode function.
|
||||
|
||||
Arguments:
|
||||
value: A tfe.Tensor object
|
||||
@ -69,19 +68,17 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
as_ref: Ignored (required by register_tensor_conversion_function).
|
||||
|
||||
Returns:
|
||||
A placeholder which will, at runtime, have the value of this tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if called outside a defun context.
|
||||
Returns a constant (the current value of the tensor) if capturing
|
||||
is not enabled. A placeholder which will have the value of the
|
||||
tensor at runtime otherwise.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
return value
|
||||
_ = as_ref
|
||||
tensor_map = _scoped_captures.tensors
|
||||
if tensor_map is None:
|
||||
raise ValueError(
|
||||
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
|
||||
"To build a graph use tfe.defun or tfe.make_template.")
|
||||
# Capturing is not enabled.
|
||||
return constant_op.constant(value.numpy())
|
||||
captured_value = tensor_map.get(ops.tensor_id(value), None)
|
||||
if captured_value is None:
|
||||
captured_value = graph_placeholder(
|
||||
@ -98,7 +95,7 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
# Note that we register this at a higher priority than ops.Tensor since we want
|
||||
# to handle subclass specific conversion before a superclass conversion.
|
||||
ops.register_tensor_conversion_function(
|
||||
tensor.Tensor, _convert_to_graph_constant, priority=-1)
|
||||
tensor.Tensor, _convert_to_graph_tensor, priority=-1)
|
||||
|
||||
|
||||
class _CapturingContext(object):
|
||||
|
Loading…
Reference in New Issue
Block a user