eager: Treat eager tensors as constants during graph construction.

Unless capturing is explicitly enabled.

PiperOrigin-RevId: 168052675
This commit is contained in:
Asim Shankar 2017-09-08 15:12:30 -07:00 committed by TensorFlower Gardener
parent 6e402d0d2c
commit a263ea626e

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -57,10 +58,8 @@ def capture_tensors(captures):
_scoped_captures.tensors = old _scoped_captures.tensors = old
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
"""Captures a tfe Tensor while building a graph mode function. """Captures a Tensor while building a graph mode function.
Creates a placeholder to pass the tensor as an argument.
Arguments: Arguments:
value: A tfe.Tensor object value: A tfe.Tensor object
@ -69,19 +68,17 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
as_ref: Ignored (required by register_tensor_conversion_function). as_ref: Ignored (required by register_tensor_conversion_function).
Returns: Returns:
A placeholder which will, at runtime, have the value of this tensor. Returns a constant (the current value of the tensor) if capturing
is not enabled. A placeholder which will have the value of the
Raises: tensor at runtime otherwise.
ValueError: if called outside a defun context.
""" """
if context.in_eager_mode(): if context.in_eager_mode():
return value return value
_ = as_ref _ = as_ref
tensor_map = _scoped_captures.tensors tensor_map = _scoped_captures.tensors
if tensor_map is None: if tensor_map is None:
raise ValueError( # Capturing is not enabled.
"Trying to use tfe.Tensor objects in a graph outside graph mode. " return constant_op.constant(value.numpy())
"To build a graph use tfe.defun or tfe.make_template.")
captured_value = tensor_map.get(ops.tensor_id(value), None) captured_value = tensor_map.get(ops.tensor_id(value), None)
if captured_value is None: if captured_value is None:
captured_value = graph_placeholder( captured_value = graph_placeholder(
@ -98,7 +95,7 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
# Note that we register this at a higher priority than ops.Tensor since we want # Note that we register this at a higher priority than ops.Tensor since we want
# to handle subclass specific conversion before a superclass conversion. # to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function( ops.register_tensor_conversion_function(
tensor.Tensor, _convert_to_graph_constant, priority=-1) tensor.Tensor, _convert_to_graph_tensor, priority=-1)
class _CapturingContext(object): class _CapturingContext(object):