Internal change

PiperOrigin-RevId: 164916465
This commit is contained in:
A. Unique TensorFlower 2017-08-10 16:03:18 -07:00 committed by TensorFlower Gardener
parent b8d13d218f
commit d9ca2d86de
4 changed files with 297 additions and 233 deletions

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import threading
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace from tensorflow.python.eager import memory_trace
@ -28,17 +26,6 @@ from tensorflow.python.framework import errors
# Trace of execution and memory usage. # Trace of execution and memory usage.
_active_trace = None _active_trace = None
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
def _status_to_exception(code, message): def _status_to_exception(code, message):
try: try:

View File

@ -28,7 +28,6 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor from tensorflow.python.eager import tensor
@ -40,7 +39,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import nest from tensorflow.python.util import nest
# Thread-local storage for tfe Tensors which are referenced while evaluating a # Thread-local storage for tfe Tensors which are referenced while evaluating a
# graph-mode function. # graph-mode function.
_scoped_captures = threading.local() _scoped_captures = threading.local()
@ -86,9 +84,8 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
"To build a graph use tfe.defun or tfe.func_to_object.") "To build a graph use tfe.defun or tfe.func_to_object.")
captured_value = tensor_map.get(tape.tensor_id(value), None) captured_value = tensor_map.get(tape.tensor_id(value), None)
if captured_value is None: if captured_value is None:
captured_value = graph_placeholder(dtype=dtype or value.dtype, captured_value = graph_placeholder(
shape=value.shape, dtype=dtype or value.dtype, shape=value.shape, name=name)
name=name)
if captured_value.dtype == dtypes.resource: if captured_value.dtype == dtypes.resource:
captured_value._handle_data = value._handle_data # pylint: disable=protected-access captured_value._handle_data = value._handle_data # pylint: disable=protected-access
tensor_map[tape.tensor_id(value)] = (value, captured_value) tensor_map[tape.tensor_id(value)] = (value, captured_value)
@ -98,8 +95,10 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
# TODO(apassos): it'd be really nice if we could scope this registration. # TODO(apassos): it'd be really nice if we could scope this registration.
ops.register_tensor_conversion_function(tensor.Tensor, # Note that we register this at a higher priority than ops.Tensor since we want
_convert_to_graph_constant) # to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function(
tensor.Tensor, _convert_to_graph_constant, priority=-1)
class _CapturingContext(object): class _CapturingContext(object):
@ -133,17 +132,17 @@ class _CapturingContext(object):
def _forward_name(n): def _forward_name(n):
"""The name of a generated forward defun named n.""" """The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, core.uid()) return "__forward_%s_%s" % (n, ops.uid())
def _backward_name(n): def _backward_name(n):
"""The name of a generated backward defun named n.""" """The name of a generated backward defun named n."""
return "__backward_%s_%s" % (n, core.uid()) return "__backward_%s_%s" % (n, ops.uid())
def _inference_name(n): def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n.""" """The name of a forward-but-no-gradient defun named n."""
return "__inference_%s_%s" % (n, core.uid()) return "__inference_%s_%s" % (n, ops.uid())
class _DefinedFunction(object): class _DefinedFunction(object):
@ -184,15 +183,8 @@ class _GraphModeFunction(object):
internal function. internal function.
""" """
def __init__(self, def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
input_placeholders, func_outputs, func_outputs_to_fdef_outputs, output_shapes):
extra_inputs,
fdef,
graph,
operations,
func_outputs,
func_outputs_to_fdef_outputs,
output_shapes):
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % ( assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
len(input_placeholders), len(fdef.signature.input_arg)) len(input_placeholders), len(fdef.signature.input_arg))
self._input_placeholders = input_placeholders self._input_placeholders = input_placeholders
@ -204,8 +196,8 @@ class _GraphModeFunction(object):
self._num_outputs = len(fdef.signature.output_arg) self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations self._ops = operations
self._func_outputs = func_outputs self._func_outputs = func_outputs
if (isinstance(func_outputs, (ops.Tensor, type(None))) if (isinstance(func_outputs, (ops.Tensor, type(None))) or
or ag_core.isnode(func_outputs)): ag_core.isnode(func_outputs)):
self._returns = [func_outputs] self._returns = [func_outputs]
else: else:
self._returns = list(func_outputs) self._returns = list(func_outputs)
@ -218,11 +210,11 @@ class _GraphModeFunction(object):
with self._graph.as_default(), context.graph_mode(): with self._graph.as_default(), context.graph_mode():
c = _CapturingContext() c = _CapturingContext()
with c: with c:
filtered_outputs = [ag_core.getval(x) filtered_outputs = [
for x in self._returns if x is not None] ag_core.getval(x) for x in self._returns if x is not None
]
self._out_grad_placeholders = [ self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
for x in filtered_outputs
] ]
in_gradients = gradients_impl.gradients( in_gradients = gradients_impl.gradients(
filtered_outputs, filtered_outputs,
@ -231,20 +223,16 @@ class _GraphModeFunction(object):
shapes = [x.shape for x in in_gradients if x is not None] shapes = [x.shape for x in in_gradients if x is not None]
captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_function_def = graph_to_function_def.graph_to_function_def( forward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._ops, self._graph, self._ops, self._input_placeholders,
self._input_placeholders,
filtered_outputs + captures) filtered_outputs + captures)
self._forward_fdef = _DefinedFunction(forward_function_def) self._forward_fdef = _DefinedFunction(forward_function_def)
_register_with_name(_forward_name(self._func_name), _register_with_name(_forward_name(self._func_name), forward_function_def)
forward_function_def)
backward_outputs = [x for x in in_gradients if x is not None] backward_outputs = [x for x in in_gradients if x is not None]
all_inputs = self._out_grad_placeholders + captures all_inputs = self._out_grad_placeholders + captures
backward_function_def = graph_to_function_def.graph_to_function_def( backward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._graph, [x.op for x in self._out_grad_placeholders
[x.op for x in self._out_grad_placeholders] + ] + list(sorted(c.known_ops, key=lambda x: x.name)),
list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs)
all_inputs,
backward_outputs)
_register_with_name(_backward_name(self._func_name), backward_function_def) _register_with_name(_backward_name(self._func_name), backward_function_def)
self._backward_function = _GraphModeFunction( self._backward_function = _GraphModeFunction(
all_inputs, [], backward_function_def, self._graph, c.known_ops, all_inputs, [], backward_function_def, self._graph, c.known_ops,
@ -258,12 +246,12 @@ class _GraphModeFunction(object):
g = ops.get_default_graph() g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access g._add_function(self._forward_fdef) # pylint: disable=protected-access
unwrapped_args = [ag_core.getval(x) for x in all_args] unwrapped_args = [ag_core.getval(x) for x in all_args]
op = g.create_op(signature.name, op = g.create_op(
[ops.convert_to_tensor(x) for x in unwrapped_args], signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
[dtypes.DType(x.type) for x in signature.output_arg], [dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature, op_def=signature,
name="FunctionCall", name="FunctionCall",
compute_shapes=False) compute_shapes=False)
outputs = op.outputs outputs = op.outputs
outputs = [outputs] if isinstance( outputs = [outputs] if isinstance(
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
@ -288,17 +276,17 @@ class _GraphModeFunction(object):
watched_extra_inputs.append(t) watched_extra_inputs.append(t)
real_outputs = tape.record_operation(real_outputs, real_outputs = tape.record_operation(real_outputs,
(args + watched_extra_inputs), (args + watched_extra_inputs),
side_outputs, side_outputs, self._backward_function)
self._backward_function)
return self._build_call_outputs(self._returns, real_outputs) return self._build_call_outputs(self._returns, real_outputs)
def __call__(self, *args): def __call__(self, *args):
"""Executes the passed function in eager mode.""" """Executes the passed function in eager mode."""
tensor_inputs = [x for x in nest.flatten(args) tensor_inputs = [
if isinstance(x, (tensor.Tensor, ops.Tensor, x for x in nest.flatten(args)
tensor.LazyZero)) if isinstance(x, (tensor.Tensor, ops.Tensor,
or ag_core.isnode(x)] tensor.LazyZero)) or ag_core.isnode(x)
]
if tape.should_record(tensor_inputs) or any( if tape.should_record(tensor_inputs) or any(
tape.any_tape_has(t) for t in self._extra_inputs): tape.any_tape_has(t) for t in self._extra_inputs):
if not self._has_backprop: if not self._has_backprop:
@ -310,18 +298,20 @@ class _GraphModeFunction(object):
g._add_function(self._fdef) # pylint: disable=protected-access g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(signature.name, op = g.create_op(
[ops.convert_to_tensor(x) for x in args], signature.name, [ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg], [dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature, op_def=signature,
name="FunctionCall", name="FunctionCall",
compute_shapes=False) compute_shapes=False)
result = op.outputs result = op.outputs
for i, s in enumerate(self._output_shapes): for i, s in enumerate(self._output_shapes):
result[i].set_shape(s) result[i].set_shape(s)
else: else:
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x tensor_inputs = [
for x in tensor_inputs] x.tensor() if isinstance(x, tensor.LazyZero) else x
for x in tensor_inputs
]
result = execute.execute( result = execute.execute(
self._func_name, self._func_name,
num_outputs=self._num_outputs, num_outputs=self._num_outputs,
@ -383,22 +373,21 @@ def _defun_internal(name, func, args, kwds):
func_outputs = func(*func_inputs, **kwds) func_outputs = func(*func_inputs, **kwds)
ids = list(sorted(captures.keys())) ids = list(sorted(captures.keys()))
if ids: if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
else: else:
extra_inputs = [] extra_inputs = []
extra_placeholders = [] extra_placeholders = []
outputs_list = nest.flatten(func_outputs) outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None] output_shapes = [x.shape for x in outputs_list if x is not None]
flat_inputs = [x for x in nest.flatten(func_inputs) flat_inputs = [
if isinstance(x, ops.Tensor)] x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
]
all_inputs = flat_inputs + list(extra_placeholders) all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None] func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def( inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(), tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
all_inputs,
func_def_outputs)
# Register any other functions defined in the graph # Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty. # TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access for f in tmp_graph._functions.values(): # pylint: disable=protected-access
@ -407,14 +396,9 @@ def _defun_internal(name, func, args, kwds):
_register_with_name(_inference_name(name), inference_function_def) _register_with_name(_inference_name(name), inference_function_def)
return _GraphModeFunction( return _GraphModeFunction(
all_inputs, all_inputs, extra_inputs, inference_function_def, tmp_graph,
extra_inputs, tmp_graph.get_operations(), func_outputs,
inference_function_def, _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
tmp_graph,
tmp_graph.get_operations(),
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes)
# Defun uses this instead of Tensor as a cache key. Using dtype because # Defun uses this instead of Tensor as a cache key. Using dtype because

View File

@ -27,11 +27,13 @@ from tensorflow.python.eager import core
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
class Tensor(object): # TODO(agarwal): rename to TensorHandle.
"""A TensorFlow Tensor.""" class Tensor(tf_ops.Tensor):
"""A TensorFlow Eager Tensor."""
def __init__(self, value, dtype=None): def __init__(self, value, dtype=None):
"""Creates a Tensor object from a Python object or numpy array. """Creates a Tensor object from a Python object or numpy array.
@ -48,7 +50,7 @@ class Tensor(object):
# TODO(ashankar): Evaluate if we can and perhaps share code with # TODO(ashankar): Evaluate if we can and perhaps share code with
# tf.constant defined in # tf.constant defined in
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py # https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
self._id = core.uid() self._id = tf_ops.uid()
if not isinstance(value, np.ndarray): if not isinstance(value, np.ndarray):
npt = None if dtype is None else dtype.as_numpy_dtype npt = None if dtype is None else dtype.as_numpy_dtype
value = np.array(value, dtype=npt) value = np.array(value, dtype=npt)
@ -111,7 +113,7 @@ class Tensor(object):
if core.active_trace() is not None: if core.active_trace() is not None:
core.active_trace().record_tensor("MANUAL", core.active_trace().record_tensor("MANUAL",
tape.tensor_id(self), tape.tensor_id(self),
self._device_name(), self.device,
self.shape.num_elements()) self.shape.num_elements())
def __del__(self): def __del__(self):
@ -184,12 +186,13 @@ class Tensor(object):
if core.active_trace() is not None: if core.active_trace() is not None:
core.active_trace().record_tensor("COPY", core.active_trace().record_tensor("COPY",
tape.tensor_id(new_tensor), tape.tensor_id(new_tensor),
new_tensor._device_name(), new_tensor.device,
new_tensor.shape.num_elements()) new_tensor.shape.num_elements())
return new_tensor return new_tensor
# pylint: enable=protected-access # pylint: enable=protected-access
def _device_name(self): @property
def device(self):
return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle) return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
@property @property
@ -237,6 +240,10 @@ class Tensor(object):
pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x) pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
for x in range(n)) for x in range(n))
def _shape_as_list(self):
"""The shape of the tensor as a list."""
return list(self._shape_tuple())
def as_cpu_tensor(self): def as_cpu_tensor(self):
"""A copy of this Tensor with contents backed by host memory.""" """A copy of this Tensor with contents backed by host memory."""
return self._copy(context.get_default_context(), "CPU:0") return self._copy(context.get_default_context(), "CPU:0")
@ -266,6 +273,42 @@ class Tensor(object):
def __nonzero__(self): def __nonzero__(self):
return self.__bool__() return self.__bool__()
# Methods not supported / implemented for Eager Tensors.
@property
def op(self):
raise NotImplementedError("op not supported for Eager Tensors.")
@property
def graph(self):
raise NotImplementedError("graph not supported for Eager Tensors.")
@property
def name(self):
raise NotImplementedError("name not supported for Eager Tensors.")
def set_shape(self, shape):
raise NotImplementedError("set_shape not supported for Eager Tensors.")
@property
def value_index(self):
raise NotImplementedError("value_index not supported for Eager Tensors.")
def consumers(self):
raise NotImplementedError("consumers not supported for Eager Tensors.")
def _add_consumer(self, consumer):
raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
def _as_node_def_input(self):
raise NotImplementedError(
"_as_node_def_input not supported for Eager Tensors.")
def _as_tf_output(self):
raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
def eval(self, feed_dict=None, session=None):
raise NotImplementedError("eval not supported for Eager Tensors.")
class IndexedSlices(object): class IndexedSlices(object):
"""A sparse representation of a set of tensor slices at given indices. """A sparse representation of a set of tensor slices at given indices.
@ -374,7 +417,7 @@ def _tensor_from_handle(handle):
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
t = Tensor.__new__(Tensor) t = Tensor.__new__(Tensor)
t._id = core.uid() t._id = tf_ops.uid()
t._handle = handle t._handle = handle
t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle)) t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
t._handle_data = None t._handle_data = None

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classes and functions used to construct graphs.""" """Classes and functions used to construct graphs."""
# pylint: disable=g-bad-name # pylint: disable=g-bad-name
from __future__ import absolute_import from __future__ import absolute_import
@ -47,7 +46,6 @@ from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
# Temporary global switch determining if we should enable the work-in-progress # Temporary global switch determining if we should enable the work-in-progress
# calls to the C API. Currently disabled by default but can be manually enabled # calls to the C API. Currently disabled by default but can be manually enabled
# e.g. in tests. This will be removed once all functionality is supported and # e.g. in tests. This will be removed once all functionality is supported and
@ -153,6 +151,18 @@ def register_dense_tensor_like_type(tensor_type):
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type]) _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
class _TensorLike(object): class _TensorLike(object):
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
@ -261,6 +271,7 @@ class Tensor(_TensorLike):
# Attributes used for C++ shape inference. Not inspected, only forwarded. # Attributes used for C++ shape inference. Not inspected, only forwarded.
# If set, will be a HandleData object from cpp_shape_inference.proto. # If set, will be a HandleData object from cpp_shape_inference.proto.
self._handle_data = None self._handle_data = None
self._id = uid()
@property @property
def op(self): def op(self):
@ -284,11 +295,6 @@ class Tensor(_TensorLike):
raise ValueError("Operation was not named: %s" % self._op) raise ValueError("Operation was not named: %s" % self._op)
return "%s:%d" % (self._op.name, self._value_index) return "%s:%d" % (self._op.name, self._value_index)
@property
def _id(self):
"""An alias for the string name of this tensor."""
return self.name
@property @property
def device(self): def device(self):
"""The name of the device on which this tensor will be produced, or None.""" """The name of the device on which this tensor will be produced, or None."""
@ -437,15 +443,15 @@ class Tensor(_TensorLike):
def __str__(self): def __str__(self):
return "Tensor(\"%s\"%s%s%s)" % ( return "Tensor(\"%s\"%s%s%s)" % (
self.name, self.name, (", shape=%s" % self.get_shape())
(", shape=%s" % self.get_shape())
if self.get_shape().ndims is not None else "", if self.get_shape().ndims is not None else "",
(", dtype=%s" % self._dtype.name) if self._dtype else "", (", dtype=%s" % self._dtype.name)
(", device=%s" % self.device) if self.device else "") if self._dtype else "", (", device=%s" % self.device)
if self.device else "")
def __repr__(self): def __repr__(self):
return "<tf.Tensor '%s' shape=%s dtype=%s>" % ( return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
self.name, self.get_shape(), self._dtype.name) self._dtype.name)
def __hash__(self): def __hash__(self):
# Necessary to support Python's collection membership operators # Necessary to support Python's collection membership operators
@ -551,20 +557,18 @@ def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
_ = name, as_ref _ = name, as_ref
if dtype and not dtype.is_compatible_with(t.dtype): if dtype and not dtype.is_compatible_with(t.dtype):
raise ValueError( raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
% (dtype.name, t.dtype.name, str(t))) (dtype.name, t.dtype.name, str(t)))
return t return t
_tensor_conversion_func_registry = { _tensor_conversion_func_registry = {
0: [(Tensor, _TensorTensorConversionFunction)]} 0: [(Tensor, _TensorTensorConversionFunction)]
}
register_dense_tensor_like_type(Tensor) register_dense_tensor_like_type(Tensor)
def convert_to_tensor(value, def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
dtype=None,
name=None,
preferred_dtype=None):
"""Converts the given `value` to a `Tensor`. """Converts the given `value` to a `Tensor`.
This function converts Python objects of various types to `Tensor` This function converts Python objects of various types to `Tensor`
@ -686,18 +690,18 @@ def internal_convert_to_tensor(value,
if not isinstance(ret, Tensor): if not isinstance(ret, Tensor):
raise RuntimeError( raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r" "%sConversion function %r for type %s returned non-Tensor: %r" %
% (error_prefix, conversion_func, base_type, ret)) (error_prefix, conversion_func, base_type, ret))
if dtype and not dtype.is_compatible_with(ret.dtype): if dtype and not dtype.is_compatible_with(ret.dtype):
raise RuntimeError( raise RuntimeError(
"%sConversion function %r for type %s returned incompatible " "%sConversion function %r for type %s returned incompatible "
"dtype: requested = %s, actual = %s" "dtype: requested = %s, actual = %s" %
% (error_prefix, conversion_func, base_type, (error_prefix, conversion_func, base_type, dtype.name,
dtype.name, ret.dtype.name)) ret.dtype.name))
return ret return ret
raise TypeError("%sCannot convert %r with type %s to Tensor: " raise TypeError("%sCannot convert %r with type %s to Tensor: "
"no conversion function registered." "no conversion function registered." % (error_prefix, value,
% (error_prefix, value, type(value))) type(value)))
def internal_convert_n_to_tensor(values, def internal_convert_n_to_tensor(values,
@ -744,10 +748,7 @@ def internal_convert_n_to_tensor(values,
return ret return ret
def convert_n_to_tensor(values, def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
dtype=None,
name=None,
preferred_dtype=None):
"""Converts `values` to a list of `Tensor` objects. """Converts `values` to a list of `Tensor` objects.
Args: Args:
@ -771,11 +772,12 @@ def convert_n_to_tensor(values,
RuntimeError: If a registered conversion function returns an invalid RuntimeError: If a registered conversion function returns an invalid
value. value.
""" """
return internal_convert_n_to_tensor(values=values, return internal_convert_n_to_tensor(
dtype=dtype, values=values,
name=name, dtype=dtype,
preferred_dtype=preferred_dtype, name=name,
as_ref=False) preferred_dtype=preferred_dtype,
as_ref=False)
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
@ -802,7 +804,9 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
value=value, dtype=dtype, name=name, as_ref=False) value=value, dtype=dtype, name=name, as_ref=False)
def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, def internal_convert_to_tensor_or_indexed_slices(value,
dtype=None,
name=None,
as_ref=False): as_ref=False):
"""Converts the given object to an `Tensor` or an `IndexedSlices`. """Converts the given object to an `Tensor` or an `IndexedSlices`.
@ -827,18 +831,18 @@ def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
if isinstance(value, _TensorLike): if isinstance(value, _TensorLike):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError( raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
% (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
return value return value
else: else:
return internal_convert_to_tensor(value, return internal_convert_to_tensor(
dtype=dtype, value, dtype=dtype, name=name, as_ref=as_ref)
name=name,
as_ref=as_ref)
def internal_convert_n_to_tensor_or_indexed_slices(values, dtype=None, def internal_convert_n_to_tensor_or_indexed_slices(values,
name=None, as_ref=False): dtype=None,
name=None,
as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects. """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
@ -905,7 +909,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
values=values, dtype=dtype, name=name, as_ref=False) values=values, dtype=dtype, name=name, as_ref=False)
def register_tensor_conversion_function(base_type, conversion_func, def register_tensor_conversion_function(base_type,
conversion_func,
priority=100): priority=100):
"""Registers a function for converting objects of `base_type` to `Tensor`. """Registers a function for converting objects of `base_type` to `Tensor`.
@ -947,8 +952,8 @@ def register_tensor_conversion_function(base_type, conversion_func,
""" """
if not (isinstance(base_type, type) or if not (isinstance(base_type, type) or
(isinstance(base_type, tuple) (isinstance(base_type, tuple) and
and all(isinstance(x, type) for x in base_type))): all(isinstance(x, type) for x in base_type))):
raise TypeError("base_type must be a type or a tuple of types.") raise TypeError("base_type must be a type or a tuple of types.")
if not callable(conversion_func): if not callable(conversion_func):
raise TypeError("conversion_func must be callable.") raise TypeError("conversion_func must be callable.")
@ -1038,8 +1043,7 @@ class IndexedSlices(_TensorLike):
def __str__(self): def __str__(self):
return "IndexedSlices(indices=%s, values=%s%s)" % ( return "IndexedSlices(indices=%s, values=%s%s)" % (
self._indices, self._values, self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
(", dense_shape=%s" % self._dense_shape)
if self._dense_shape is not None else "") if self._dense_shape is not None else "")
def __neg__(self): def __neg__(self):
@ -1112,8 +1116,14 @@ class Operation(object):
`op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
""" """
def __init__(self, node_def, g, inputs=None, output_types=None, def __init__(self,
control_inputs=None, input_types=None, original_op=None, node_def,
g,
inputs=None,
output_types=None,
control_inputs=None,
input_types=None,
original_op=None,
op_def=None): op_def=None):
r"""Creates an `Operation`. r"""Creates an `Operation`.
@ -1177,18 +1187,20 @@ class Operation(object):
if output_types is None: if output_types is None:
output_types = [] output_types = []
self._output_types_val = output_types self._output_types_val = output_types
self._outputs = [Tensor(self, i, output_type) self._outputs = [
for i, output_type in enumerate(output_types)] Tensor(self, i, output_type)
for i, output_type in enumerate(output_types)
]
if input_types is None: if input_types is None:
input_types = [i.dtype.base_dtype for i in self._inputs] input_types = [i.dtype.base_dtype for i in self._inputs]
else: else:
if not all(x.is_compatible_with(i.dtype) if not all(
for i, x in zip(self._inputs, input_types)): x.is_compatible_with(i.dtype)
for i, x in zip(self._inputs, input_types)):
raise TypeError("In op '%s', input types (%s) are not compatible " raise TypeError("In op '%s', input types (%s) are not compatible "
"with expected types (%s)" % ( "with expected types (%s)" %
self.node_def.name, (self.node_def.name, [i.dtype for i in self._inputs],
[i.dtype for i in self._inputs], input_types))
input_types))
self._input_types_val = input_types self._input_types_val = input_types
# Build the list of control inputs. # Build the list of control inputs.
@ -1251,7 +1263,8 @@ class Operation(object):
A wrapped TF_Operation*. A wrapped TF_Operation*.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op), op_desc = c_api.TF_NewOperation(graph._c_graph,
compat.as_str(node_def.op),
compat.as_str(node_def.name)) compat.as_str(node_def.name))
# Add inputs # Add inputs
for op_input in inputs: for op_input in inputs:
@ -1271,8 +1284,8 @@ class Operation(object):
# TODO(skyewm): this creates and deletes a new TF_Status for every attr. # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status. # It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized, c_api.TF_SetAttrValueProto(op_desc,
status) compat.as_str(name), serialized, status)
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
c_op = c_api.TF_FinishOperation(op_desc, status) c_op = c_api.TF_FinishOperation(op_desc, status)
@ -1316,16 +1329,18 @@ class Operation(object):
def colocation_groups(self): def colocation_groups(self):
"""Returns the list of colocation groups of the op.""" """Returns the list of colocation groups of the op."""
default_colocation_group = [compat.as_bytes("loc:@%s" % default_colocation_group = [
self._node_def.name)] compat.as_bytes("loc:@%s" % self._node_def.name)
]
if "_class" not in self._node_def.attr: if "_class" not in self._node_def.attr:
# This op has no explicit colocation group, so it is itself its # This op has no explicit colocation group, so it is itself its
# own root of a colocation group. # own root of a colocation group.
return default_colocation_group return default_colocation_group
attr_groups = [class_name attr_groups = [
for class_name in self.get_attr("_class") class_name for class_name in self.get_attr("_class")
if class_name.startswith(b"loc:@")] if class_name.startswith(b"loc:@")
]
# If there are no colocation groups in the explicit _class field, # If there are no colocation groups in the explicit _class field,
# return the default colocation group. # return the default colocation group.
@ -1397,8 +1412,10 @@ class Operation(object):
""" """
if self._graph._c_graph: # pylint: disable=protected-access if self._graph._c_graph: # pylint: disable=protected-access
num_outputs = c_api.TF_OperationNumOutputs(self._c_op) num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
output_types = [c_api.TF_OperationOutputType(self._tf_output(i)) for output_types = [
i in xrange(num_outputs)] c_api.TF_OperationOutputType(self._tf_output(i))
for i in xrange(num_outputs)
]
# TODO(iga): Remove this assert after converting to C API by default. # TODO(iga): Remove this assert after converting to C API by default.
# Just being a bit paranoid here. # Just being a bit paranoid here.
assert self._output_types_val == output_types assert self._output_types_val == output_types
@ -1433,8 +1450,8 @@ class Operation(object):
device: string or device.. The device to set. device: string or device.. The device to set.
""" """
if _USE_C_API: if _USE_C_API:
c_api.SetRequestedDevice( c_api.SetRequestedDevice(self._graph._c_graph, self._c_op, # pylint: disable=protected-access
self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access _device_string(device))
# TODO(nolivia): remove this line when switch to C api # TODO(nolivia): remove this line when switch to C api
self._node_def.device = _device_string(device) self._node_def.device = _device_string(device)
@ -1462,8 +1479,8 @@ class Operation(object):
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype): if not dtype.is_compatible_with(tensor.dtype):
raise TypeError( raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" "Cannot convert a tensor of type %s to an input of type %s" %
% (tensor.dtype.name, dtype.name)) (tensor.dtype.name, dtype.name))
self._inputs.append(tensor) self._inputs.append(tensor)
self._input_types_val.append(dtype) self._input_types_val.append(dtype)
tensor._add_consumer(self) # pylint: disable=protected-access tensor._add_consumer(self) # pylint: disable=protected-access
@ -1496,8 +1513,8 @@ class Operation(object):
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype): if not dtype.is_compatible_with(tensor.dtype):
raise TypeError( raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" "Cannot convert a tensor of type %s to an input of type %s" %
% (tensor.dtype.name, dtype.name)) (tensor.dtype.name, dtype.name))
self._inputs[index].consumers().remove(self) self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor self._inputs[index] = tensor
@ -1547,8 +1564,8 @@ class Operation(object):
self._node_def.input.extend([t._as_node_def_input() for t in self._inputs]) self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
# pylint: enable=protected-access # pylint: enable=protected-access
if self._control_inputs: if self._control_inputs:
self._node_def.input.extend(["^%s" % op.name for op in self._node_def.input.extend(
self._control_inputs]) ["^%s" % op.name for op in self._control_inputs])
def __str__(self): def __str__(self):
return str(self._node_def) return str(self._node_def)
@ -1562,6 +1579,7 @@ class Operation(object):
return self._outputs return self._outputs
# pylint: disable=protected-access # pylint: disable=protected-access
class _InputList(object): class _InputList(object):
"""Immutable input list wrapper.""" """Immutable input list wrapper."""
@ -1582,6 +1600,7 @@ class Operation(object):
def __getitem__(self, i): def __getitem__(self, i):
return self._op._inputs[i] return self._op._inputs[i]
# pylint: enable=protected-access # pylint: enable=protected-access
@property @property
@ -1597,9 +1616,10 @@ class Operation(object):
def _input_types(self): def _input_types(self):
if self._graph._c_graph: # pylint: disable=protected-access if self._graph._c_graph: # pylint: disable=protected-access
num_inputs = c_api.TF_OperationNumInputs(self._c_op) num_inputs = c_api.TF_OperationNumInputs(self._c_op)
input_types = [dtypes.as_dtype( input_types = [
c_api.TF_OperationInputType(self._tf_input(i))) dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
for i in xrange(num_inputs)] for i in xrange(num_inputs)
]
# TODO(iga): Remove this assert after converting to C API by default. # TODO(iga): Remove this assert after converting to C API by default.
# Just being a bit paranoid here. # Just being a bit paranoid here.
assert self._input_types_val == input_types assert self._input_types_val == input_types
@ -1624,8 +1644,10 @@ class Operation(object):
if self._graph._c_graph: # pylint: disable=protected-access if self._graph._c_graph: # pylint: disable=protected-access
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
# pylint: disable=protected-access # pylint: disable=protected-access
return [self.graph._get_operation_by_name_unsafe( return [
c_api.TF_OperationName(c_op)) for c_op in control_c_ops] self.graph._get_operation_by_name_unsafe(
c_api.TF_OperationName(c_op)) for c_op in control_c_ops
]
# pylint: enable=protected-access # pylint: enable=protected-access
else: else:
return self._control_inputs return self._control_inputs
@ -1691,7 +1713,8 @@ class Operation(object):
A list of 5-tuples (filename, lineno, name, code, func_start_lineno). A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
""" """
return self._graph._convert_stack( # pylint: disable=protected-access return self._graph._convert_stack( # pylint: disable=protected-access
self._traceback, include_func_start_lineno=True) self._traceback,
include_func_start_lineno=True)
def get_attr(self, name): def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`. """Returns the value of the attr of this op with the given `name`.
@ -1707,8 +1730,7 @@ class Operation(object):
""" """
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
if name not in self._node_def.attr: if name not in self._node_def.attr:
raise ValueError("No attr named '" + name + "' in " + raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
str(self._node_def))
x = self._node_def.attr[name] x = self._node_def.attr[name]
# Treat an empty oneof value as an empty list. # Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"): if not x.WhichOneof("value"):
@ -1749,7 +1771,6 @@ class Operation(object):
""" """
_run_using_default_session(self, feed_dict, self.graph, session) _run_using_default_session(self, feed_dict, self.graph, session)
_gradient_registry = registry.Registry("gradient") _gradient_registry = registry.Registry("gradient")
@ -1834,7 +1855,8 @@ NoGradient = NotDifferentiable
def get_gradient_function(op): def get_gradient_function(op):
"""Returns the function that computes gradients for "op".""" """Returns the function that computes gradients for "op"."""
if not op.inputs: return None if not op.inputs:
return None
try: try:
op_type = op.get_attr("_gradient_op_type") op_type = op.get_attr("_gradient_op_type")
except ValueError: except ValueError:
@ -1982,6 +2004,7 @@ class OpStats(object):
self._value += other.value self._value += other.value
return self return self
_stats_registry = registry.Registry("statistical functions") _stats_registry = registry.Registry("statistical functions")
@ -2433,8 +2456,8 @@ class Graph(object):
graph.node.extend([op.node_def]) graph.node.extend([op.node_def])
if op.outputs and add_shapes: if op.outputs and add_shapes:
assert "_output_shapes" not in graph.node[-1].attr assert "_output_shapes" not in graph.node[-1].attr
graph.node[-1].attr["_output_shapes"].list.shape.extend([ graph.node[-1].attr["_output_shapes"].list.shape.extend(
output.get_shape().as_proto() for output in op.outputs]) [output.get_shape().as_proto() for output in op.outputs])
bytesize += op.node_def.ByteSize() bytesize += op.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0: if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.") raise ValueError("GraphDef cannot be larger than 2GB.")
@ -2469,7 +2492,8 @@ class Graph(object):
node with the inferred shapes of each of its outputs. node with the inferred shapes of each of its outputs.
Returns: Returns:
A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) A
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer. protocol buffer.
Raises: Raises:
@ -2518,8 +2542,8 @@ class Graph(object):
if previous: if previous:
raise ValueError("Another function is already defined with that name") raise ValueError("Another function is already defined with that name")
# Sanity checks on gradient definition. # Sanity checks on gradient definition.
if (function.grad_func_name is not None) and ( if (function.grad_func_name is not None) and (function.python_grad_func is
function.python_grad_func is not None): not None):
raise ValueError("Gradient defined twice for function %s" % name) raise ValueError("Gradient defined twice for function %s" % name)
# Need a new-enough consumer to support the functions we add to the graph. # Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12: if self._graph_def_versions.min_consumer < 12:
@ -2532,9 +2556,17 @@ class Graph(object):
return self._building_function return self._building_function
# Helper functions to create operations. # Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes, # pylint: disable=redefined-outer-name def create_op(
input_types=None, name=None, attrs=None, op_def=None, self,
compute_shapes=True, compute_device=True): op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
"""Creates an `Operation` in this graph. """Creates an `Operation` in this graph.
This is a low-level interface for creating an `Operation`. Most This is a low-level interface for creating an `Operation`. Most
@ -2597,8 +2629,8 @@ class Graph(object):
if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
raise TypeError( raise TypeError(
"Callable for scope map key '%s' must return either None or " "Callable for scope map key '%s' must return either None or "
"an AttrValue protocol buffer; but it returned: %s" % "an AttrValue protocol buffer; but it returned: %s" % (key,
(key, value)) value))
node_def.attr[key].CopyFrom(value) node_def.attr[key].CopyFrom(value)
# Apply a kernel label if one has been specified for this op_type. # Apply a kernel label if one has been specified for this op_type.
@ -2619,9 +2651,15 @@ class Graph(object):
pass pass
control_inputs = self._control_dependencies_for_inputs(inputs) control_inputs = self._control_dependencies_for_inputs(inputs)
ret = Operation(node_def, self, inputs=inputs, output_types=dtypes, ret = Operation(
control_inputs=control_inputs, input_types=input_types, node_def,
original_op=self._default_original_op, op_def=op_def) self,
inputs=inputs,
output_types=dtypes,
control_inputs=control_inputs,
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
if compute_shapes: if compute_shapes:
set_shapes_for_outputs(ret) set_shapes_for_outputs(ret)
self._add_op(ret) self._add_op(ret)
@ -2638,28 +2676,27 @@ class Graph(object):
# Make this device match the device of the colocated op, to # Make this device match the device of the colocated op, to
# provide consistency between the device and the colocation # provide consistency between the device and the colocation
# property. # property.
if (ret.device and if (ret.device and pydev.canonical_name(ret.device) !=
pydev.canonical_name(ret.device) !=
pydev.canonical_name(colocation_op.device)): pydev.canonical_name(colocation_op.device)):
logging.warning("Tried to colocate %s with an op %s that had " logging.warning("Tried to colocate %s with an op %s that had "
"a different device: %s vs %s. " "a different device: %s vs %s. "
"Ignoring colocation property.", "Ignoring colocation property.", name,
name, colocation_op.name, colocation_op.name, ret.device,
ret.device, colocation_op.device) colocation_op.device)
else: else:
ret._set_device(colocation_op.device) # pylint: disable=protected-access ret._set_device(colocation_op.device) # pylint: disable=protected-access
all_colocation_groups = sorted(set(all_colocation_groups)) all_colocation_groups = sorted(set(all_colocation_groups))
ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( ret.node_def.attr["_class"].CopyFrom(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
s=all_colocation_groups)))
# Sets "container" attribute if # Sets "container" attribute if
# (1) self._container is not None # (1) self._container is not None
# (2) "is_stateful" is set in OpDef # (2) "is_stateful" is set in OpDef
# (3) "container" attribute is in OpDef # (3) "container" attribute is in OpDef
# (4) "container" attribute is None # (4) "container" attribute is None
if (self._container and if (self._container and op_type in self._registered_ops and
op_type in self._registered_ops and
self._registered_ops[op_type].is_stateful and self._registered_ops[op_type].is_stateful and
"container" in ret.node_def.attr and "container" in ret.node_def.attr and
not ret.node_def.attr["container"].s): not ret.node_def.attr["container"].s):
@ -2749,13 +2786,13 @@ class Graph(object):
except: except:
raise KeyError("The name %s refers to a Tensor which does not " raise KeyError("The name %s refers to a Tensor which does not "
"exist. The operation, %s, exists but only has " "exist. The operation, %s, exists but only has "
"%s outputs." "%s outputs." % (repr(name), repr(op_name),
% (repr(name), repr(op_name), len(op.outputs))) len(op.outputs)))
elif ":" in name and not allow_tensor: elif ":" in name and not allow_tensor:
# Looks like a Tensor name but can't be a Tensor. # Looks like a Tensor name but can't be a Tensor.
raise ValueError("Name %s appears to refer to a Tensor, not a %s." raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
% (repr(name), types_str)) (repr(name), types_str))
elif ":" not in name and allow_operation: elif ":" not in name and allow_operation:
# Looks like an Operation name and can be an Operation. # Looks like an Operation name and can be an Operation.
@ -2768,8 +2805,8 @@ class Graph(object):
# Looks like an Operation name but can't be an Operation. # Looks like an Operation name but can't be an Operation.
if name in self._nodes_by_name: if name in self._nodes_by_name:
# Yep, it's an Operation name # Yep, it's an Operation name
err_msg = ("The name %s refers to an Operation, not a %s." err_msg = ("The name %s refers to an Operation, not a %s." %
% (repr(name), types_str)) (repr(name), types_str))
else: else:
err_msg = ("The name %s looks like an (invalid) Operation name, " err_msg = ("The name %s looks like an (invalid) Operation name, "
"not a %s." % (repr(name), types_str)) "not a %s." % (repr(name), types_str))
@ -2789,8 +2826,8 @@ class Graph(object):
return obj return obj
else: else:
# We give up! # We give up!
raise TypeError("Can not convert a %s into a %s." raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
% (type(obj).__name__, types_str)) types_str))
def get_operations(self): def get_operations(self):
"""Return the list of operations in the graph. """Return the list of operations in the graph.
@ -2827,8 +2864,8 @@ class Graph(object):
""" """
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
raise TypeError("Operation names are strings (or similar), not %s." raise TypeError("Operation names are strings (or similar), not %s." %
% type(name).__name__) type(name).__name__)
return self.as_graph_element(name, allow_tensor=False, allow_operation=True) return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
def _get_operation_by_name_unsafe(self, name): def _get_operation_by_name_unsafe(self, name):
@ -2871,8 +2908,8 @@ class Graph(object):
""" """
# Names should be strings. # Names should be strings.
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
raise TypeError("Tensor names are strings (or similar), not %s." raise TypeError("Tensor names are strings (or similar), not %s." %
% type(name).__name__) type(name).__name__)
return self.as_graph_element(name, allow_tensor=True, allow_operation=False) return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
def _next_id(self): def _next_id(self):
@ -3178,6 +3215,7 @@ class Graph(object):
yield "" if new_stack is None else new_stack + "/" yield "" if new_stack is None else new_stack + "/"
finally: finally:
self._name_stack = old_stack self._name_stack = old_stack
# pylint: enable=g-doc-return-or-yield,line-too-long # pylint: enable=g-doc-return-or-yield,line-too-long
def unique_name(self, name, mark_as_used=True): def unique_name(self, name, mark_as_used=True):
@ -3280,9 +3318,8 @@ class Graph(object):
""" """
if op is None and not ignore_existing: if op is None and not ignore_existing:
raise ValueError( raise ValueError("Trying to reset colocation (op is None) but "
"Trying to reset colocation (op is None) but " "ignore_existing is not True")
"ignore_existing is not True")
if op is not None and not isinstance(op, Operation): if op is not None and not isinstance(op, Operation):
# We always want to colocate with the reference op. # We always want to colocate with the reference op.
@ -3377,8 +3414,8 @@ class Graph(object):
""" """
# pylint: enable=line-too-long # pylint: enable=line-too-long
if (device_name_or_function is not None if (device_name_or_function is not None and
and not callable(device_name_or_function)): not callable(device_name_or_function)):
device_function = pydev.merge_device(device_name_or_function) device_function = pydev.merge_device(device_name_or_function)
else: else:
device_function = device_name_or_function device_function = device_name_or_function
@ -3452,6 +3489,7 @@ class Graph(object):
yield self._container yield self._container
finally: finally:
self._container = original_container self._container = original_container
# pylint: enable=g-doc-return-or-yield # pylint: enable=g-doc-return-or-yield
class _ControlDependenciesController(object): class _ControlDependenciesController(object):
@ -3491,6 +3529,7 @@ class Graph(object):
self._old_control_flow_context = None self._old_control_flow_context = None
# pylint: disable=protected-access # pylint: disable=protected-access
def __enter__(self): def __enter__(self):
if self._new_stack: if self._new_stack:
# Clear the control_dependencies graph. # Clear the control_dependencies graph.
@ -3506,6 +3545,7 @@ class Graph(object):
if self._new_stack: if self._new_stack:
self._graph._control_dependencies_stack = self._old_stack self._graph._control_dependencies_stack = self._old_stack
self._graph._set_control_flow_context(self._old_control_flow_context) self._graph._set_control_flow_context(self._old_control_flow_context)
# pylint: enable=protected-access # pylint: enable=protected-access
@property @property
@ -3733,6 +3773,7 @@ class Graph(object):
self._attr_scope_map[name] = saved_attrs[name] self._attr_scope_map[name] = saved_attrs[name]
except KeyError: except KeyError:
del self._attr_scope_map[name] del self._attr_scope_map[name]
# pylint: enable=g-doc-return-or-yield # pylint: enable=g-doc-return-or-yield
# pylint: disable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield
@ -3777,8 +3818,8 @@ class Graph(object):
saved_labels = {} saved_labels = {}
# Install the given label # Install the given label
for op_type, label in op_to_kernel_label_map.items(): for op_type, label in op_to_kernel_label_map.items():
if not (isinstance(op_type, six.string_types) if not (isinstance(op_type, six.string_types) and
and isinstance(label, six.string_types)): isinstance(label, six.string_types)):
raise TypeError("op_to_kernel_label_map must be a dictionary mapping " raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
"strings to strings") "strings to strings")
try: try:
@ -3795,6 +3836,7 @@ class Graph(object):
self._op_to_kernel_label_map[op_type] = saved_labels[op_type] self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
except KeyError: except KeyError:
del self._op_to_kernel_label_map[op_type] del self._op_to_kernel_label_map[op_type]
# pylint: enable=g-doc-return-or-yield # pylint: enable=g-doc-return-or-yield
# pylint: disable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield
@ -3840,8 +3882,8 @@ class Graph(object):
saved_mappings = {} saved_mappings = {}
# Install the given label # Install the given label
for op_type, mapped_op_type in op_type_map.items(): for op_type, mapped_op_type in op_type_map.items():
if not (isinstance(op_type, six.string_types) if not (isinstance(op_type, six.string_types) and
and isinstance(mapped_op_type, six.string_types)): isinstance(mapped_op_type, six.string_types)):
raise TypeError("op_type_map must be a dictionary mapping " raise TypeError("op_type_map must be a dictionary mapping "
"strings to strings") "strings to strings")
try: try:
@ -3858,6 +3900,7 @@ class Graph(object):
self._gradient_override_map[op_type] = saved_mappings[op_type] self._gradient_override_map[op_type] = saved_mappings[op_type]
except KeyError: except KeyError:
del self._gradient_override_map[op_type] del self._gradient_override_map[op_type]
# pylint: enable=g-doc-return-or-yield # pylint: enable=g-doc-return-or-yield
def prevent_feeding(self, tensor): def prevent_feeding(self, tensor):
@ -3969,12 +4012,13 @@ class _DefaultStack(threading.local):
if self._enforce_nesting: if self._enforce_nesting:
if self.stack[-1] is not default: if self.stack[-1] is not default:
raise AssertionError( raise AssertionError(
"Nesting violated for default stack of %s objects" "Nesting violated for default stack of %s objects" %
% type(default)) type(default))
self.stack.pop() self.stack.pop()
else: else:
self.stack.remove(default) self.stack.remove(default)
_default_session_stack = _DefaultStack() _default_session_stack = _DefaultStack()
@ -4145,6 +4189,7 @@ class _DefaultGraphStack(_DefaultStack):
super(_DefaultGraphStack, self).reset() super(_DefaultGraphStack, self).reset()
self._global_default_graph = None self._global_default_graph = None
_default_graph_stack = _DefaultGraphStack() _default_graph_stack = _DefaultGraphStack()
@ -4195,8 +4240,8 @@ def _assert_same_graph(original_item, item):
ValueError: if graphs do not match. ValueError: if graphs do not match.
""" """
if original_item.graph is not item.graph: if original_item.graph is not item.graph:
raise ValueError( raise ValueError("%s must be from the same graph as %s." % (item,
"%s must be from the same graph as %s." % (item, original_item)) original_item))
def _get_graph_from_inputs(op_input_list, graph=None): def _get_graph_from_inputs(op_input_list, graph=None):
@ -4246,8 +4291,11 @@ def _get_graph_from_inputs(op_input_list, graph=None):
original_graph_element = None original_graph_element = None
for op_input in op_input_list: for op_input in op_input_list:
# Determine if this is a valid graph_element. # Determine if this is a valid graph_element.
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
# up.
graph_element = None graph_element = None
if isinstance(op_input, (Operation, _TensorLike)): if (isinstance(op_input, (Operation, _TensorLike)) and
((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck
graph_element = op_input graph_element = op_input
else: else:
graph_element = _as_graph_element(op_input) graph_element = _as_graph_element(op_input)
@ -4259,8 +4307,7 @@ def _get_graph_from_inputs(op_input_list, graph=None):
elif original_graph_element is not None: elif original_graph_element is not None:
_assert_same_graph(original_graph_element, graph_element) _assert_same_graph(original_graph_element, graph_element)
elif graph_element.graph is not graph: elif graph_element.graph is not graph:
raise ValueError( raise ValueError("%s is not from the passed-in graph." % graph_element)
"%s is not from the passed-in graph." % graph_element)
# 2. If all else fails, we use the default graph, which is always there. # 2. If all else fails, we use the default graph, which is always there.
return graph or get_default_graph() return graph or get_default_graph()
@ -4512,13 +4559,15 @@ def name_scope(name, default_name=None, values=None):
# tf.name_scope(None) (values=None then) is sometimes used as an idiom # tf.name_scope(None) (values=None then) is sometimes used as an idiom
# to reset to top scope. # to reset to top scope.
raise ValueError( raise ValueError(
"At least one of name (%s) and default_name (%s) must be provided." % ( "At least one of name (%s) and default_name (%s) must be provided." %
name, default_name)) (name, default_name))
if values is None: if values is None:
values = [] values = []
g = _get_graph_from_inputs(values) g = _get_graph_from_inputs(values)
with g.as_default(), g.name_scope(n) as scope: with g.as_default(), g.name_scope(n) as scope:
yield scope yield scope
# pylint: enable=g-doc-return-or-yield # pylint: enable=g-doc-return-or-yield
@ -4585,7 +4634,9 @@ def op_scope(values, name, default_name=None):
_proto_function_registry = registry.Registry("proto functions") _proto_function_registry = registry.Registry("proto functions")
def register_proto_function(collection_name, proto_type=None, to_proto=None, def register_proto_function(collection_name,
proto_type=None,
to_proto=None,
from_proto=None): from_proto=None):
"""Registers `to_proto` and `from_proto` functions for collection_name. """Registers `to_proto` and `from_proto` functions for collection_name.
@ -4637,10 +4688,9 @@ def get_from_proto_function(collection_name):
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
"""Produce a nice error if someone converts an Operation to a Tensor.""" """Produce a nice error if someone converts an Operation to a Tensor."""
raise TypeError( raise TypeError(("Can't convert Operation '%s' to Tensor "
("Can't convert Operation '%s' to Tensor " "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
"(target dtype=%r, name=%r, as_ref=%r)") % name, as_ref))
(op.name, dtype, name, as_ref))
register_tensor_conversion_function(Operation, _operation_conversion_error) register_tensor_conversion_function(Operation, _operation_conversion_error)