eager: Treat eager tensors as constants during graph construction.
Unless capturing is explicitly enabled. PiperOrigin-RevId: 168052675
This commit is contained in:
parent
6e402d0d2c
commit
a263ea626e
@ -31,6 +31,7 @@ from tensorflow.python.eager import execute
|
|||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor
|
from tensorflow.python.eager import tensor
|
||||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import graph_to_function_def
|
from tensorflow.python.framework import graph_to_function_def
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -57,10 +58,8 @@ def capture_tensors(captures):
|
|||||||
_scoped_captures.tensors = old
|
_scoped_captures.tensors = old
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
||||||
"""Captures a tfe Tensor while building a graph mode function.
|
"""Captures a Tensor while building a graph mode function.
|
||||||
|
|
||||||
Creates a placeholder to pass the tensor as an argument.
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
value: A tfe.Tensor object
|
value: A tfe.Tensor object
|
||||||
@ -69,19 +68,17 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
|||||||
as_ref: Ignored (required by register_tensor_conversion_function).
|
as_ref: Ignored (required by register_tensor_conversion_function).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A placeholder which will, at runtime, have the value of this tensor.
|
Returns a constant (the current value of the tensor) if capturing
|
||||||
|
is not enabled. A placeholder which will have the value of the
|
||||||
Raises:
|
tensor at runtime otherwise.
|
||||||
ValueError: if called outside a defun context.
|
|
||||||
"""
|
"""
|
||||||
if context.in_eager_mode():
|
if context.in_eager_mode():
|
||||||
return value
|
return value
|
||||||
_ = as_ref
|
_ = as_ref
|
||||||
tensor_map = _scoped_captures.tensors
|
tensor_map = _scoped_captures.tensors
|
||||||
if tensor_map is None:
|
if tensor_map is None:
|
||||||
raise ValueError(
|
# Capturing is not enabled.
|
||||||
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
|
return constant_op.constant(value.numpy())
|
||||||
"To build a graph use tfe.defun or tfe.make_template.")
|
|
||||||
captured_value = tensor_map.get(ops.tensor_id(value), None)
|
captured_value = tensor_map.get(ops.tensor_id(value), None)
|
||||||
if captured_value is None:
|
if captured_value is None:
|
||||||
captured_value = graph_placeholder(
|
captured_value = graph_placeholder(
|
||||||
@ -98,7 +95,7 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
|||||||
# Note that we register this at a higher priority than ops.Tensor since we want
|
# Note that we register this at a higher priority than ops.Tensor since we want
|
||||||
# to handle subclass specific conversion before a superclass conversion.
|
# to handle subclass specific conversion before a superclass conversion.
|
||||||
ops.register_tensor_conversion_function(
|
ops.register_tensor_conversion_function(
|
||||||
tensor.Tensor, _convert_to_graph_constant, priority=-1)
|
tensor.Tensor, _convert_to_graph_tensor, priority=-1)
|
||||||
|
|
||||||
|
|
||||||
class _CapturingContext(object):
|
class _CapturingContext(object):
|
||||||
|
Loading…
Reference in New Issue
Block a user