eager: Treat eager tensors as constants during graph construction.

Unless capturing is explicitly enabled.

PiperOrigin-RevId: 168052675
This commit is contained in:
Asim Shankar 2017-09-08 15:12:30 -07:00 committed by TensorFlower Gardener
parent 6e402d0d2c
commit a263ea626e

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
@ -57,10 +58,8 @@ def capture_tensors(captures):
_scoped_captures.tensors = old
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
"""Captures a tfe Tensor while building a graph mode function.
Creates a placeholder to pass the tensor as an argument.
def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
"""Captures a Tensor while building a graph mode function.
Arguments:
value: A tfe.Tensor object
@ -69,19 +68,17 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
as_ref: Ignored (required by register_tensor_conversion_function).
Returns:
A placeholder which will, at runtime, have the value of this tensor.
Raises:
ValueError: if called outside a defun context.
Returns a constant (the current value of the tensor) if capturing
is not enabled. A placeholder which will have the value of the
tensor at runtime otherwise.
"""
if context.in_eager_mode():
return value
_ = as_ref
tensor_map = _scoped_captures.tensors
if tensor_map is None:
raise ValueError(
"Trying to use tfe.Tensor objects in a graph outside graph mode. "
"To build a graph use tfe.defun or tfe.make_template.")
# Capturing is not enabled.
return constant_op.constant(value.numpy())
captured_value = tensor_map.get(ops.tensor_id(value), None)
if captured_value is None:
captured_value = graph_placeholder(
@ -98,7 +95,7 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
# Note that we register this at a higher priority than ops.Tensor since we want
# to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function(
tensor.Tensor, _convert_to_graph_constant, priority=-1)
tensor.Tensor, _convert_to_graph_tensor, priority=-1)
class _CapturingContext(object):