Internal change

PiperOrigin-RevId: 164916465
This commit is contained in:
A. Unique TensorFlower 2017-08-10 16:03:18 -07:00 committed by TensorFlower Gardener
parent b8d13d218f
commit d9ca2d86de
4 changed files with 297 additions and 233 deletions
tensorflow/python

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import memory_trace
@ -28,17 +26,6 @@ from tensorflow.python.framework import errors
# Trace of execution and memory usage.
_active_trace = None
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
def _status_to_exception(code, message):
try:

View File

@ -28,7 +28,6 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
@ -40,7 +39,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import nest
# Thread-local storage for tfe Tensors which are referenced while evaluating a
# graph-mode function.
_scoped_captures = threading.local()
@ -86,9 +84,8 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
"To build a graph use tfe.defun or tfe.func_to_object.")
captured_value = tensor_map.get(tape.tensor_id(value), None)
if captured_value is None:
captured_value = graph_placeholder(dtype=dtype or value.dtype,
shape=value.shape,
name=name)
captured_value = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
if captured_value.dtype == dtypes.resource:
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
tensor_map[tape.tensor_id(value)] = (value, captured_value)
@ -98,8 +95,10 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
# TODO(apassos): it'd be really nice if we could scope this registration.
ops.register_tensor_conversion_function(tensor.Tensor,
_convert_to_graph_constant)
# Note that we register this at a higher priority than ops.Tensor since we want
# to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function(
tensor.Tensor, _convert_to_graph_constant, priority=-1)
class _CapturingContext(object):
@ -133,17 +132,17 @@ class _CapturingContext(object):
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "__forward_%s_%s" % (n, core.uid())
return "__forward_%s_%s" % (n, ops.uid())
def _backward_name(n):
"""The name of a generated backward defun named n."""
return "__backward_%s_%s" % (n, core.uid())
return "__backward_%s_%s" % (n, ops.uid())
def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n."""
return "__inference_%s_%s" % (n, core.uid())
return "__inference_%s_%s" % (n, ops.uid())
class _DefinedFunction(object):
@ -184,15 +183,8 @@ class _GraphModeFunction(object):
internal function.
"""
def __init__(self,
input_placeholders,
extra_inputs,
fdef,
graph,
operations,
func_outputs,
func_outputs_to_fdef_outputs,
output_shapes):
def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
func_outputs, func_outputs_to_fdef_outputs, output_shapes):
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
len(input_placeholders), len(fdef.signature.input_arg))
self._input_placeholders = input_placeholders
@ -204,8 +196,8 @@ class _GraphModeFunction(object):
self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
if (isinstance(func_outputs, (ops.Tensor, type(None)))
or ag_core.isnode(func_outputs)):
if (isinstance(func_outputs, (ops.Tensor, type(None))) or
ag_core.isnode(func_outputs)):
self._returns = [func_outputs]
else:
self._returns = list(func_outputs)
@ -218,11 +210,11 @@ class _GraphModeFunction(object):
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
filtered_outputs = [ag_core.getval(x)
for x in self._returns if x is not None]
filtered_outputs = [
ag_core.getval(x) for x in self._returns if x is not None
]
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape)
for x in filtered_outputs
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
]
in_gradients = gradients_impl.gradients(
filtered_outputs,
@ -231,20 +223,16 @@ class _GraphModeFunction(object):
shapes = [x.shape for x in in_gradients if x is not None]
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
forward_function_def = graph_to_function_def.graph_to_function_def(
self._graph, self._ops,
self._input_placeholders,
self._graph, self._ops, self._input_placeholders,
filtered_outputs + captures)
self._forward_fdef = _DefinedFunction(forward_function_def)
_register_with_name(_forward_name(self._func_name),
forward_function_def)
_register_with_name(_forward_name(self._func_name), forward_function_def)
backward_outputs = [x for x in in_gradients if x is not None]
all_inputs = self._out_grad_placeholders + captures
backward_function_def = graph_to_function_def.graph_to_function_def(
self._graph,
[x.op for x in self._out_grad_placeholders] +
list(sorted(c.known_ops, key=lambda x: x.name)),
all_inputs,
backward_outputs)
self._graph, [x.op for x in self._out_grad_placeholders
] + list(sorted(c.known_ops, key=lambda x: x.name)),
all_inputs, backward_outputs)
_register_with_name(_backward_name(self._func_name), backward_function_def)
self._backward_function = _GraphModeFunction(
all_inputs, [], backward_function_def, self._graph, c.known_ops,
@ -258,12 +246,12 @@ class _GraphModeFunction(object):
g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access
unwrapped_args = [ag_core.getval(x) for x in all_args]
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in unwrapped_args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
outputs = op.outputs
outputs = [outputs] if isinstance(
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
@ -288,17 +276,17 @@ class _GraphModeFunction(object):
watched_extra_inputs.append(t)
real_outputs = tape.record_operation(real_outputs,
(args + watched_extra_inputs),
side_outputs,
self._backward_function)
side_outputs, self._backward_function)
return self._build_call_outputs(self._returns, real_outputs)
def __call__(self, *args):
"""Executes the passed function in eager mode."""
tensor_inputs = [x for x in nest.flatten(args)
if isinstance(x, (tensor.Tensor, ops.Tensor,
tensor.LazyZero))
or ag_core.isnode(x)]
tensor_inputs = [
x for x in nest.flatten(args)
if isinstance(x, (tensor.Tensor, ops.Tensor,
tensor.LazyZero)) or ag_core.isnode(x)
]
if tape.should_record(tensor_inputs) or any(
tape.any_tape_has(t) for t in self._extra_inputs):
if not self._has_backprop:
@ -310,18 +298,20 @@ class _GraphModeFunction(object):
g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(signature.name,
[ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x
for x in tensor_inputs]
tensor_inputs = [
x.tensor() if isinstance(x, tensor.LazyZero) else x
for x in tensor_inputs
]
result = execute.execute(
self._func_name,
num_outputs=self._num_outputs,
@ -383,22 +373,21 @@ def _defun_internal(name, func, args, kwds):
func_outputs = func(*func_inputs, **kwds)
ids = list(sorted(captures.keys()))
if ids:
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
else:
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
output_shapes = [x.shape for x in outputs_list if x is not None]
flat_inputs = [x for x in nest.flatten(func_inputs)
if isinstance(x, ops.Tensor)]
flat_inputs = [
x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
]
all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(),
all_inputs,
func_def_outputs)
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
@ -407,14 +396,9 @@ def _defun_internal(name, func, args, kwds):
_register_with_name(_inference_name(name), inference_function_def)
return _GraphModeFunction(
all_inputs,
extra_inputs,
inference_function_def,
tmp_graph,
tmp_graph.get_operations(),
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes)
all_inputs, extra_inputs, inference_function_def, tmp_graph,
tmp_graph.get_operations(), func_outputs,
_map_sequence_obj_to_idx(func_def_outputs), output_shapes)
# Defun uses this instead of Tensor as a cache key. Using dtype because

View File

@ -27,11 +27,13 @@ from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.framework import tensor_shape
class Tensor(object):
"""A TensorFlow Tensor."""
# TODO(agarwal): rename to TensorHandle.
class Tensor(tf_ops.Tensor):
"""A TensorFlow Eager Tensor."""
def __init__(self, value, dtype=None):
"""Creates a Tensor object from a Python object or numpy array.
@ -48,7 +50,7 @@ class Tensor(object):
# TODO(ashankar): Evaluate if we can and perhaps share code with
# tf.constant defined in
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
self._id = core.uid()
self._id = tf_ops.uid()
if not isinstance(value, np.ndarray):
npt = None if dtype is None else dtype.as_numpy_dtype
value = np.array(value, dtype=npt)
@ -111,7 +113,7 @@ class Tensor(object):
if core.active_trace() is not None:
core.active_trace().record_tensor("MANUAL",
tape.tensor_id(self),
self._device_name(),
self.device,
self.shape.num_elements())
def __del__(self):
@ -184,12 +186,13 @@ class Tensor(object):
if core.active_trace() is not None:
core.active_trace().record_tensor("COPY",
tape.tensor_id(new_tensor),
new_tensor._device_name(),
new_tensor.device,
new_tensor.shape.num_elements())
return new_tensor
# pylint: enable=protected-access
def _device_name(self):
@property
def device(self):
return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
@property
@ -237,6 +240,10 @@ class Tensor(object):
pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
for x in range(n))
def _shape_as_list(self):
"""The shape of the tensor as a list."""
return list(self._shape_tuple())
def as_cpu_tensor(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.get_default_context(), "CPU:0")
@ -266,6 +273,42 @@ class Tensor(object):
def __nonzero__(self):
return self.__bool__()
# Methods not supported / implemented for Eager Tensors.
@property
def op(self):
raise NotImplementedError("op not supported for Eager Tensors.")
@property
def graph(self):
raise NotImplementedError("graph not supported for Eager Tensors.")
@property
def name(self):
raise NotImplementedError("name not supported for Eager Tensors.")
def set_shape(self, shape):
raise NotImplementedError("set_shape not supported for Eager Tensors.")
@property
def value_index(self):
raise NotImplementedError("value_index not supported for Eager Tensors.")
def consumers(self):
raise NotImplementedError("consumers not supported for Eager Tensors.")
def _add_consumer(self, consumer):
raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
def _as_node_def_input(self):
raise NotImplementedError(
"_as_node_def_input not supported for Eager Tensors.")
def _as_tf_output(self):
raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
def eval(self, feed_dict=None, session=None):
raise NotImplementedError("eval not supported for Eager Tensors.")
class IndexedSlices(object):
"""A sparse representation of a set of tensor slices at given indices.
@ -374,7 +417,7 @@ def _tensor_from_handle(handle):
"""
# pylint: disable=protected-access
t = Tensor.__new__(Tensor)
t._id = core.uid()
t._id = tf_ops.uid()
t._handle = handle
t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
t._handle_data = None

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions used to construct graphs."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@ -47,7 +46,6 @@ from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib
# Temporary global switch determining if we should enable the work-in-progress
# calls to the C API. Currently disabled by default but can be manually enabled
# e.g. in tests. This will be removed once all functionality is supported and
@ -153,6 +151,18 @@ def register_dense_tensor_like_type(tensor_type):
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
_uid_counter = 0
_uid_lock = threading.Lock()
def uid():
"""A unique (within this program execution) integer."""
with _uid_lock:
global _uid_counter
_uid_counter += 1
return _uid_counter
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
class _TensorLike(object):
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
@ -261,6 +271,7 @@ class Tensor(_TensorLike):
# Attributes used for C++ shape inference. Not inspected, only forwarded.
# If set, will be a HandleData object from cpp_shape_inference.proto.
self._handle_data = None
self._id = uid()
@property
def op(self):
@ -284,11 +295,6 @@ class Tensor(_TensorLike):
raise ValueError("Operation was not named: %s" % self._op)
return "%s:%d" % (self._op.name, self._value_index)
@property
def _id(self):
"""An alias for the string name of this tensor."""
return self.name
@property
def device(self):
"""The name of the device on which this tensor will be produced, or None."""
@ -437,15 +443,15 @@ class Tensor(_TensorLike):
def __str__(self):
return "Tensor(\"%s\"%s%s%s)" % (
self.name,
(", shape=%s" % self.get_shape())
self.name, (", shape=%s" % self.get_shape())
if self.get_shape().ndims is not None else "",
(", dtype=%s" % self._dtype.name) if self._dtype else "",
(", device=%s" % self.device) if self.device else "")
(", dtype=%s" % self._dtype.name)
if self._dtype else "", (", device=%s" % self.device)
if self.device else "")
def __repr__(self):
return "<tf.Tensor '%s' shape=%s dtype=%s>" % (
self.name, self.get_shape(), self._dtype.name)
return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
self._dtype.name)
def __hash__(self):
# Necessary to support Python's collection membership operators
@ -551,20 +557,18 @@ def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
_ = name, as_ref
if dtype and not dtype.is_compatible_with(t.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
% (dtype.name, t.dtype.name, str(t)))
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
(dtype.name, t.dtype.name, str(t)))
return t
_tensor_conversion_func_registry = {
0: [(Tensor, _TensorTensorConversionFunction)]}
0: [(Tensor, _TensorTensorConversionFunction)]
}
register_dense_tensor_like_type(Tensor)
def convert_to_tensor(value,
dtype=None,
name=None,
preferred_dtype=None):
def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
"""Converts the given `value` to a `Tensor`.
This function converts Python objects of various types to `Tensor`
@ -686,18 +690,18 @@ def internal_convert_to_tensor(value,
if not isinstance(ret, Tensor):
raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r"
% (error_prefix, conversion_func, base_type, ret))
"%sConversion function %r for type %s returned non-Tensor: %r" %
(error_prefix, conversion_func, base_type, ret))
if dtype and not dtype.is_compatible_with(ret.dtype):
raise RuntimeError(
"%sConversion function %r for type %s returned incompatible "
"dtype: requested = %s, actual = %s"
% (error_prefix, conversion_func, base_type,
dtype.name, ret.dtype.name))
"dtype: requested = %s, actual = %s" %
(error_prefix, conversion_func, base_type, dtype.name,
ret.dtype.name))
return ret
raise TypeError("%sCannot convert %r with type %s to Tensor: "
"no conversion function registered."
% (error_prefix, value, type(value)))
"no conversion function registered." % (error_prefix, value,
type(value)))
def internal_convert_n_to_tensor(values,
@ -744,10 +748,7 @@ def internal_convert_n_to_tensor(values,
return ret
def convert_n_to_tensor(values,
dtype=None,
name=None,
preferred_dtype=None):
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
"""Converts `values` to a list of `Tensor` objects.
Args:
@ -771,11 +772,12 @@ def convert_n_to_tensor(values,
RuntimeError: If a registered conversion function returns an invalid
value.
"""
return internal_convert_n_to_tensor(values=values,
dtype=dtype,
name=name,
preferred_dtype=preferred_dtype,
as_ref=False)
return internal_convert_n_to_tensor(
values=values,
dtype=dtype,
name=name,
preferred_dtype=preferred_dtype,
as_ref=False)
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
@ -802,7 +804,9 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
value=value, dtype=dtype, name=name, as_ref=False)
def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
def internal_convert_to_tensor_or_indexed_slices(value,
dtype=None,
name=None,
as_ref=False):
"""Converts the given object to an `Tensor` or an `IndexedSlices`.
@ -827,18 +831,18 @@ def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
if isinstance(value, _TensorLike):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
% (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
(dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
return value
else:
return internal_convert_to_tensor(value,
dtype=dtype,
name=name,
as_ref=as_ref)
return internal_convert_to_tensor(
value, dtype=dtype, name=name, as_ref=as_ref)
def internal_convert_n_to_tensor_or_indexed_slices(values, dtype=None,
name=None, as_ref=False):
def internal_convert_n_to_tensor_or_indexed_slices(values,
dtype=None,
name=None,
as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
@ -905,7 +909,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
values=values, dtype=dtype, name=name, as_ref=False)
def register_tensor_conversion_function(base_type, conversion_func,
def register_tensor_conversion_function(base_type,
conversion_func,
priority=100):
"""Registers a function for converting objects of `base_type` to `Tensor`.
@ -947,8 +952,8 @@ def register_tensor_conversion_function(base_type, conversion_func,
"""
if not (isinstance(base_type, type) or
(isinstance(base_type, tuple)
and all(isinstance(x, type) for x in base_type))):
(isinstance(base_type, tuple) and
all(isinstance(x, type) for x in base_type))):
raise TypeError("base_type must be a type or a tuple of types.")
if not callable(conversion_func):
raise TypeError("conversion_func must be callable.")
@ -1038,8 +1043,7 @@ class IndexedSlices(_TensorLike):
def __str__(self):
return "IndexedSlices(indices=%s, values=%s%s)" % (
self._indices, self._values,
(", dense_shape=%s" % self._dense_shape)
self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
if self._dense_shape is not None else "")
def __neg__(self):
@ -1112,8 +1116,14 @@ class Operation(object):
`op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
"""
def __init__(self, node_def, g, inputs=None, output_types=None,
control_inputs=None, input_types=None, original_op=None,
def __init__(self,
node_def,
g,
inputs=None,
output_types=None,
control_inputs=None,
input_types=None,
original_op=None,
op_def=None):
r"""Creates an `Operation`.
@ -1177,18 +1187,20 @@ class Operation(object):
if output_types is None:
output_types = []
self._output_types_val = output_types
self._outputs = [Tensor(self, i, output_type)
for i, output_type in enumerate(output_types)]
self._outputs = [
Tensor(self, i, output_type)
for i, output_type in enumerate(output_types)
]
if input_types is None:
input_types = [i.dtype.base_dtype for i in self._inputs]
else:
if not all(x.is_compatible_with(i.dtype)
for i, x in zip(self._inputs, input_types)):
if not all(
x.is_compatible_with(i.dtype)
for i, x in zip(self._inputs, input_types)):
raise TypeError("In op '%s', input types (%s) are not compatible "
"with expected types (%s)" % (
self.node_def.name,
[i.dtype for i in self._inputs],
input_types))
"with expected types (%s)" %
(self.node_def.name, [i.dtype for i in self._inputs],
input_types))
self._input_types_val = input_types
# Build the list of control inputs.
@ -1251,7 +1263,8 @@ class Operation(object):
A wrapped TF_Operation*.
"""
# pylint: disable=protected-access
op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op),
op_desc = c_api.TF_NewOperation(graph._c_graph,
compat.as_str(node_def.op),
compat.as_str(node_def.name))
# Add inputs
for op_input in inputs:
@ -1271,8 +1284,8 @@ class Operation(object):
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized,
status)
c_api.TF_SetAttrValueProto(op_desc,
compat.as_str(name), serialized, status)
with errors.raise_exception_on_not_ok_status() as status:
c_op = c_api.TF_FinishOperation(op_desc, status)
@ -1316,16 +1329,18 @@ class Operation(object):
def colocation_groups(self):
"""Returns the list of colocation groups of the op."""
default_colocation_group = [compat.as_bytes("loc:@%s" %
self._node_def.name)]
default_colocation_group = [
compat.as_bytes("loc:@%s" % self._node_def.name)
]
if "_class" not in self._node_def.attr:
# This op has no explicit colocation group, so it is itself its
# own root of a colocation group.
return default_colocation_group
attr_groups = [class_name
for class_name in self.get_attr("_class")
if class_name.startswith(b"loc:@")]
attr_groups = [
class_name for class_name in self.get_attr("_class")
if class_name.startswith(b"loc:@")
]
# If there are no colocation groups in the explicit _class field,
# return the default colocation group.
@ -1397,8 +1412,10 @@ class Operation(object):
"""
if self._graph._c_graph: # pylint: disable=protected-access
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
output_types = [c_api.TF_OperationOutputType(self._tf_output(i)) for
i in xrange(num_outputs)]
output_types = [
c_api.TF_OperationOutputType(self._tf_output(i))
for i in xrange(num_outputs)
]
# TODO(iga): Remove this assert after converting to C API by default.
# Just being a bit paranoid here.
assert self._output_types_val == output_types
@ -1433,8 +1450,8 @@ class Operation(object):
device: string or device.. The device to set.
"""
if _USE_C_API:
c_api.SetRequestedDevice(
self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access
c_api.SetRequestedDevice(self._graph._c_graph, self._c_op, # pylint: disable=protected-access
_device_string(device))
# TODO(nolivia): remove this line when switch to C api
self._node_def.device = _device_string(device)
@ -1462,8 +1479,8 @@ class Operation(object):
dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s"
% (tensor.dtype.name, dtype.name))
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
self._inputs.append(tensor)
self._input_types_val.append(dtype)
tensor._add_consumer(self) # pylint: disable=protected-access
@ -1496,8 +1513,8 @@ class Operation(object):
dtype = dtypes.as_dtype(dtype)
if not dtype.is_compatible_with(tensor.dtype):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s"
% (tensor.dtype.name, dtype.name))
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
self._inputs[index].consumers().remove(self)
self._inputs[index] = tensor
@ -1547,8 +1564,8 @@ class Operation(object):
self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
# pylint: enable=protected-access
if self._control_inputs:
self._node_def.input.extend(["^%s" % op.name for op in
self._control_inputs])
self._node_def.input.extend(
["^%s" % op.name for op in self._control_inputs])
def __str__(self):
return str(self._node_def)
@ -1562,6 +1579,7 @@ class Operation(object):
return self._outputs
# pylint: disable=protected-access
class _InputList(object):
"""Immutable input list wrapper."""
@ -1582,6 +1600,7 @@ class Operation(object):
def __getitem__(self, i):
return self._op._inputs[i]
# pylint: enable=protected-access
@property
@ -1597,9 +1616,10 @@ class Operation(object):
def _input_types(self):
if self._graph._c_graph: # pylint: disable=protected-access
num_inputs = c_api.TF_OperationNumInputs(self._c_op)
input_types = [dtypes.as_dtype(
c_api.TF_OperationInputType(self._tf_input(i)))
for i in xrange(num_inputs)]
input_types = [
dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
for i in xrange(num_inputs)
]
# TODO(iga): Remove this assert after converting to C API by default.
# Just being a bit paranoid here.
assert self._input_types_val == input_types
@ -1624,8 +1644,10 @@ class Operation(object):
if self._graph._c_graph: # pylint: disable=protected-access
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
# pylint: disable=protected-access
return [self.graph._get_operation_by_name_unsafe(
c_api.TF_OperationName(c_op)) for c_op in control_c_ops]
return [
self.graph._get_operation_by_name_unsafe(
c_api.TF_OperationName(c_op)) for c_op in control_c_ops
]
# pylint: enable=protected-access
else:
return self._control_inputs
@ -1691,7 +1713,8 @@ class Operation(object):
A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
"""
return self._graph._convert_stack( # pylint: disable=protected-access
self._traceback, include_func_start_lineno=True)
self._traceback,
include_func_start_lineno=True)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@ -1707,8 +1730,7 @@ class Operation(object):
"""
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
if name not in self._node_def.attr:
raise ValueError("No attr named '" + name + "' in " +
str(self._node_def))
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
x = self._node_def.attr[name]
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
@ -1749,7 +1771,6 @@ class Operation(object):
"""
_run_using_default_session(self, feed_dict, self.graph, session)
_gradient_registry = registry.Registry("gradient")
@ -1834,7 +1855,8 @@ NoGradient = NotDifferentiable
def get_gradient_function(op):
"""Returns the function that computes gradients for "op"."""
if not op.inputs: return None
if not op.inputs:
return None
try:
op_type = op.get_attr("_gradient_op_type")
except ValueError:
@ -1982,6 +2004,7 @@ class OpStats(object):
self._value += other.value
return self
_stats_registry = registry.Registry("statistical functions")
@ -2433,8 +2456,8 @@ class Graph(object):
graph.node.extend([op.node_def])
if op.outputs and add_shapes:
assert "_output_shapes" not in graph.node[-1].attr
graph.node[-1].attr["_output_shapes"].list.shape.extend([
output.get_shape().as_proto() for output in op.outputs])
graph.node[-1].attr["_output_shapes"].list.shape.extend(
[output.get_shape().as_proto() for output in op.outputs])
bytesize += op.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
@ -2469,7 +2492,8 @@ class Graph(object):
node with the inferred shapes of each of its outputs.
Returns:
A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
A
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer.
Raises:
@ -2518,8 +2542,8 @@ class Graph(object):
if previous:
raise ValueError("Another function is already defined with that name")
# Sanity checks on gradient definition.
if (function.grad_func_name is not None) and (
function.python_grad_func is not None):
if (function.grad_func_name is not None) and (function.python_grad_func is
not None):
raise ValueError("Gradient defined twice for function %s" % name)
# Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12:
@ -2532,9 +2556,17 @@ class Graph(object):
return self._building_function
# Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes, # pylint: disable=redefined-outer-name
input_types=None, name=None, attrs=None, op_def=None,
compute_shapes=True, compute_device=True):
def create_op(
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
"""Creates an `Operation` in this graph.
This is a low-level interface for creating an `Operation`. Most
@ -2597,8 +2629,8 @@ class Graph(object):
if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
raise TypeError(
"Callable for scope map key '%s' must return either None or "
"an AttrValue protocol buffer; but it returned: %s" %
(key, value))
"an AttrValue protocol buffer; but it returned: %s" % (key,
value))
node_def.attr[key].CopyFrom(value)
# Apply a kernel label if one has been specified for this op_type.
@ -2619,9 +2651,15 @@ class Graph(object):
pass
control_inputs = self._control_dependencies_for_inputs(inputs)
ret = Operation(node_def, self, inputs=inputs, output_types=dtypes,
control_inputs=control_inputs, input_types=input_types,
original_op=self._default_original_op, op_def=op_def)
ret = Operation(
node_def,
self,
inputs=inputs,
output_types=dtypes,
control_inputs=control_inputs,
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
if compute_shapes:
set_shapes_for_outputs(ret)
self._add_op(ret)
@ -2638,28 +2676,27 @@ class Graph(object):
# Make this device match the device of the colocated op, to
# provide consistency between the device and the colocation
# property.
if (ret.device and
pydev.canonical_name(ret.device) !=
if (ret.device and pydev.canonical_name(ret.device) !=
pydev.canonical_name(colocation_op.device)):
logging.warning("Tried to colocate %s with an op %s that had "
"a different device: %s vs %s. "
"Ignoring colocation property.",
name, colocation_op.name,
ret.device, colocation_op.device)
"Ignoring colocation property.", name,
colocation_op.name, ret.device,
colocation_op.device)
else:
ret._set_device(colocation_op.device) # pylint: disable=protected-access
all_colocation_groups = sorted(set(all_colocation_groups))
ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
ret.node_def.attr["_class"].CopyFrom(
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
s=all_colocation_groups)))
# Sets "container" attribute if
# (1) self._container is not None
# (2) "is_stateful" is set in OpDef
# (3) "container" attribute is in OpDef
# (4) "container" attribute is None
if (self._container and
op_type in self._registered_ops and
if (self._container and op_type in self._registered_ops and
self._registered_ops[op_type].is_stateful and
"container" in ret.node_def.attr and
not ret.node_def.attr["container"].s):
@ -2749,13 +2786,13 @@ class Graph(object):
except:
raise KeyError("The name %s refers to a Tensor which does not "
"exist. The operation, %s, exists but only has "
"%s outputs."
% (repr(name), repr(op_name), len(op.outputs)))
"%s outputs." % (repr(name), repr(op_name),
len(op.outputs)))
elif ":" in name and not allow_tensor:
# Looks like a Tensor name but can't be a Tensor.
raise ValueError("Name %s appears to refer to a Tensor, not a %s."
% (repr(name), types_str))
raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
(repr(name), types_str))
elif ":" not in name and allow_operation:
# Looks like an Operation name and can be an Operation.
@ -2768,8 +2805,8 @@ class Graph(object):
# Looks like an Operation name but can't be an Operation.
if name in self._nodes_by_name:
# Yep, it's an Operation name
err_msg = ("The name %s refers to an Operation, not a %s."
% (repr(name), types_str))
err_msg = ("The name %s refers to an Operation, not a %s." %
(repr(name), types_str))
else:
err_msg = ("The name %s looks like an (invalid) Operation name, "
"not a %s." % (repr(name), types_str))
@ -2789,8 +2826,8 @@ class Graph(object):
return obj
else:
# We give up!
raise TypeError("Can not convert a %s into a %s."
% (type(obj).__name__, types_str))
raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
types_str))
def get_operations(self):
"""Return the list of operations in the graph.
@ -2827,8 +2864,8 @@ class Graph(object):
"""
if not isinstance(name, six.string_types):
raise TypeError("Operation names are strings (or similar), not %s."
% type(name).__name__)
raise TypeError("Operation names are strings (or similar), not %s." %
type(name).__name__)
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
def _get_operation_by_name_unsafe(self, name):
@ -2871,8 +2908,8 @@ class Graph(object):
"""
# Names should be strings.
if not isinstance(name, six.string_types):
raise TypeError("Tensor names are strings (or similar), not %s."
% type(name).__name__)
raise TypeError("Tensor names are strings (or similar), not %s." %
type(name).__name__)
return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
def _next_id(self):
@ -3178,6 +3215,7 @@ class Graph(object):
yield "" if new_stack is None else new_stack + "/"
finally:
self._name_stack = old_stack
# pylint: enable=g-doc-return-or-yield,line-too-long
def unique_name(self, name, mark_as_used=True):
@ -3280,9 +3318,8 @@ class Graph(object):
"""
if op is None and not ignore_existing:
raise ValueError(
"Trying to reset colocation (op is None) but "
"ignore_existing is not True")
raise ValueError("Trying to reset colocation (op is None) but "
"ignore_existing is not True")
if op is not None and not isinstance(op, Operation):
# We always want to colocate with the reference op.
@ -3377,8 +3414,8 @@ class Graph(object):
"""
# pylint: enable=line-too-long
if (device_name_or_function is not None
and not callable(device_name_or_function)):
if (device_name_or_function is not None and
not callable(device_name_or_function)):
device_function = pydev.merge_device(device_name_or_function)
else:
device_function = device_name_or_function
@ -3452,6 +3489,7 @@ class Graph(object):
yield self._container
finally:
self._container = original_container
# pylint: enable=g-doc-return-or-yield
class _ControlDependenciesController(object):
@ -3491,6 +3529,7 @@ class Graph(object):
self._old_control_flow_context = None
# pylint: disable=protected-access
def __enter__(self):
if self._new_stack:
# Clear the control_dependencies graph.
@ -3506,6 +3545,7 @@ class Graph(object):
if self._new_stack:
self._graph._control_dependencies_stack = self._old_stack
self._graph._set_control_flow_context(self._old_control_flow_context)
# pylint: enable=protected-access
@property
@ -3733,6 +3773,7 @@ class Graph(object):
self._attr_scope_map[name] = saved_attrs[name]
except KeyError:
del self._attr_scope_map[name]
# pylint: enable=g-doc-return-or-yield
# pylint: disable=g-doc-return-or-yield
@ -3777,8 +3818,8 @@ class Graph(object):
saved_labels = {}
# Install the given label
for op_type, label in op_to_kernel_label_map.items():
if not (isinstance(op_type, six.string_types)
and isinstance(label, six.string_types)):
if not (isinstance(op_type, six.string_types) and
isinstance(label, six.string_types)):
raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
"strings to strings")
try:
@ -3795,6 +3836,7 @@ class Graph(object):
self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
except KeyError:
del self._op_to_kernel_label_map[op_type]
# pylint: enable=g-doc-return-or-yield
# pylint: disable=g-doc-return-or-yield
@ -3840,8 +3882,8 @@ class Graph(object):
saved_mappings = {}
# Install the given label
for op_type, mapped_op_type in op_type_map.items():
if not (isinstance(op_type, six.string_types)
and isinstance(mapped_op_type, six.string_types)):
if not (isinstance(op_type, six.string_types) and
isinstance(mapped_op_type, six.string_types)):
raise TypeError("op_type_map must be a dictionary mapping "
"strings to strings")
try:
@ -3858,6 +3900,7 @@ class Graph(object):
self._gradient_override_map[op_type] = saved_mappings[op_type]
except KeyError:
del self._gradient_override_map[op_type]
# pylint: enable=g-doc-return-or-yield
def prevent_feeding(self, tensor):
@ -3969,12 +4012,13 @@ class _DefaultStack(threading.local):
if self._enforce_nesting:
if self.stack[-1] is not default:
raise AssertionError(
"Nesting violated for default stack of %s objects"
% type(default))
"Nesting violated for default stack of %s objects" %
type(default))
self.stack.pop()
else:
self.stack.remove(default)
_default_session_stack = _DefaultStack()
@ -4145,6 +4189,7 @@ class _DefaultGraphStack(_DefaultStack):
super(_DefaultGraphStack, self).reset()
self._global_default_graph = None
_default_graph_stack = _DefaultGraphStack()
@ -4195,8 +4240,8 @@ def _assert_same_graph(original_item, item):
ValueError: if graphs do not match.
"""
if original_item.graph is not item.graph:
raise ValueError(
"%s must be from the same graph as %s." % (item, original_item))
raise ValueError("%s must be from the same graph as %s." % (item,
original_item))
def _get_graph_from_inputs(op_input_list, graph=None):
@ -4246,8 +4291,11 @@ def _get_graph_from_inputs(op_input_list, graph=None):
original_graph_element = None
for op_input in op_input_list:
# Determine if this is a valid graph_element.
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
# up.
graph_element = None
if isinstance(op_input, (Operation, _TensorLike)):
if (isinstance(op_input, (Operation, _TensorLike)) and
((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck
graph_element = op_input
else:
graph_element = _as_graph_element(op_input)
@ -4259,8 +4307,7 @@ def _get_graph_from_inputs(op_input_list, graph=None):
elif original_graph_element is not None:
_assert_same_graph(original_graph_element, graph_element)
elif graph_element.graph is not graph:
raise ValueError(
"%s is not from the passed-in graph." % graph_element)
raise ValueError("%s is not from the passed-in graph." % graph_element)
# 2. If all else fails, we use the default graph, which is always there.
return graph or get_default_graph()
@ -4512,13 +4559,15 @@ def name_scope(name, default_name=None, values=None):
# tf.name_scope(None) (values=None then) is sometimes used as an idiom
# to reset to top scope.
raise ValueError(
"At least one of name (%s) and default_name (%s) must be provided." % (
name, default_name))
"At least one of name (%s) and default_name (%s) must be provided." %
(name, default_name))
if values is None:
values = []
g = _get_graph_from_inputs(values)
with g.as_default(), g.name_scope(n) as scope:
yield scope
# pylint: enable=g-doc-return-or-yield
@ -4585,7 +4634,9 @@ def op_scope(values, name, default_name=None):
_proto_function_registry = registry.Registry("proto functions")
def register_proto_function(collection_name, proto_type=None, to_proto=None,
def register_proto_function(collection_name,
proto_type=None,
to_proto=None,
from_proto=None):
"""Registers `to_proto` and `from_proto` functions for collection_name.
@ -4637,10 +4688,9 @@ def get_from_proto_function(collection_name):
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
"""Produce a nice error if someone converts an Operation to a Tensor."""
raise TypeError(
("Can't convert Operation '%s' to Tensor "
"(target dtype=%r, name=%r, as_ref=%r)") %
(op.name, dtype, name, as_ref))
raise TypeError(("Can't convert Operation '%s' to Tensor "
"(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
name, as_ref))
register_tensor_conversion_function(Operation, _operation_conversion_error)