From d9ca2d86de0d480db5a6959285f2e1b6493e2ae4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Aug 2017 16:03:18 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 164916465 --- tensorflow/python/eager/core.py | 13 -- tensorflow/python/eager/function.py | 120 +++++----- tensorflow/python/eager/tensor.py | 57 ++++- tensorflow/python/framework/ops.py | 340 ++++++++++++++++------------ 4 files changed, 297 insertions(+), 233 deletions(-) diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index 36c89e39638..64c615fb63b 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import memory_trace @@ -28,17 +26,6 @@ from tensorflow.python.framework import errors # Trace of execution and memory usage. _active_trace = None -_uid_counter = 0 -_uid_lock = threading.Lock() - - -def uid(): - """A unique (within this program execution) integer.""" - with _uid_lock: - global _uid_counter - _uid_counter += 1 - return _uid_counter - def _status_to_exception(code, message): try: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 540d19bb3d8..e4866b61056 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -28,7 +28,6 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context -from tensorflow.python.eager import core from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager import tensor @@ -40,7 +39,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import nest - # Thread-local storage for tfe Tensors which are referenced while evaluating a # graph-mode function. _scoped_captures = threading.local() @@ -86,9 +84,8 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): "To build a graph use tfe.defun or tfe.func_to_object.") captured_value = tensor_map.get(tape.tensor_id(value), None) if captured_value is None: - captured_value = graph_placeholder(dtype=dtype or value.dtype, - shape=value.shape, - name=name) + captured_value = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes.resource: captured_value._handle_data = value._handle_data # pylint: disable=protected-access tensor_map[tape.tensor_id(value)] = (value, captured_value) @@ -98,8 +95,10 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): # TODO(apassos): it'd be really nice if we could scope this registration. -ops.register_tensor_conversion_function(tensor.Tensor, - _convert_to_graph_constant) +# Note that we register this at a higher priority than ops.Tensor since we want +# to handle subclass specific conversion before a superclass conversion. +ops.register_tensor_conversion_function( + tensor.Tensor, _convert_to_graph_constant, priority=-1) class _CapturingContext(object): @@ -133,17 +132,17 @@ class _CapturingContext(object): def _forward_name(n): """The name of a generated forward defun named n.""" - return "__forward_%s_%s" % (n, core.uid()) + return "__forward_%s_%s" % (n, ops.uid()) def _backward_name(n): """The name of a generated backward defun named n.""" - return "__backward_%s_%s" % (n, core.uid()) + return "__backward_%s_%s" % (n, ops.uid()) def _inference_name(n): """The name of a forward-but-no-gradient defun named n.""" - return "__inference_%s_%s" % (n, core.uid()) + return "__inference_%s_%s" % (n, ops.uid()) class _DefinedFunction(object): @@ -184,15 +183,8 @@ class _GraphModeFunction(object): internal function. """ - def __init__(self, - input_placeholders, - extra_inputs, - fdef, - graph, - operations, - func_outputs, - func_outputs_to_fdef_outputs, - output_shapes): + def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations, + func_outputs, func_outputs_to_fdef_outputs, output_shapes): assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % ( len(input_placeholders), len(fdef.signature.input_arg)) self._input_placeholders = input_placeholders @@ -204,8 +196,8 @@ class _GraphModeFunction(object): self._num_outputs = len(fdef.signature.output_arg) self._ops = operations self._func_outputs = func_outputs - if (isinstance(func_outputs, (ops.Tensor, type(None))) - or ag_core.isnode(func_outputs)): + if (isinstance(func_outputs, (ops.Tensor, type(None))) or + ag_core.isnode(func_outputs)): self._returns = [func_outputs] else: self._returns = list(func_outputs) @@ -218,11 +210,11 @@ class _GraphModeFunction(object): with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: - filtered_outputs = [ag_core.getval(x) - for x in self._returns if x is not None] + filtered_outputs = [ + ag_core.getval(x) for x in self._returns if x is not None + ] self._out_grad_placeholders = [ - graph_placeholder(x.dtype, x.shape) - for x in filtered_outputs + graph_placeholder(x.dtype, x.shape) for x in filtered_outputs ] in_gradients = gradients_impl.gradients( filtered_outputs, @@ -231,20 +223,16 @@ class _GraphModeFunction(object): shapes = [x.shape for x in in_gradients if x is not None] captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_function_def = graph_to_function_def.graph_to_function_def( - self._graph, self._ops, - self._input_placeholders, + self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) self._forward_fdef = _DefinedFunction(forward_function_def) - _register_with_name(_forward_name(self._func_name), - forward_function_def) + _register_with_name(_forward_name(self._func_name), forward_function_def) backward_outputs = [x for x in in_gradients if x is not None] all_inputs = self._out_grad_placeholders + captures backward_function_def = graph_to_function_def.graph_to_function_def( - self._graph, - [x.op for x in self._out_grad_placeholders] + - list(sorted(c.known_ops, key=lambda x: x.name)), - all_inputs, - backward_outputs) + self._graph, [x.op for x in self._out_grad_placeholders + ] + list(sorted(c.known_ops, key=lambda x: x.name)), + all_inputs, backward_outputs) _register_with_name(_backward_name(self._func_name), backward_function_def) self._backward_function = _GraphModeFunction( all_inputs, [], backward_function_def, self._graph, c.known_ops, @@ -258,12 +246,12 @@ class _GraphModeFunction(object): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access unwrapped_args = [ag_core.getval(x) for x in all_args] - op = g.create_op(signature.name, - [ops.convert_to_tensor(x) for x in unwrapped_args], - [dtypes.DType(x.type) for x in signature.output_arg], - op_def=signature, - name="FunctionCall", - compute_shapes=False) + op = g.create_op( + signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], + [dtypes.DType(x.type) for x in signature.output_arg], + op_def=signature, + name="FunctionCall", + compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance( outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) @@ -288,17 +276,17 @@ class _GraphModeFunction(object): watched_extra_inputs.append(t) real_outputs = tape.record_operation(real_outputs, (args + watched_extra_inputs), - side_outputs, - self._backward_function) + side_outputs, self._backward_function) return self._build_call_outputs(self._returns, real_outputs) def __call__(self, *args): """Executes the passed function in eager mode.""" - tensor_inputs = [x for x in nest.flatten(args) - if isinstance(x, (tensor.Tensor, ops.Tensor, - tensor.LazyZero)) - or ag_core.isnode(x)] + tensor_inputs = [ + x for x in nest.flatten(args) + if isinstance(x, (tensor.Tensor, ops.Tensor, + tensor.LazyZero)) or ag_core.isnode(x) + ] if tape.should_record(tensor_inputs) or any( tape.any_tape_has(t) for t in self._extra_inputs): if not self._has_backprop: @@ -310,18 +298,20 @@ class _GraphModeFunction(object): g._add_function(self._fdef) # pylint: disable=protected-access signature = self._fdef.definition.signature args = list(tensor_inputs) + self._extra_inputs - op = g.create_op(signature.name, - [ops.convert_to_tensor(x) for x in args], - [dtypes.DType(x.type) for x in signature.output_arg], - op_def=signature, - name="FunctionCall", - compute_shapes=False) + op = g.create_op( + signature.name, [ops.convert_to_tensor(x) for x in args], + [dtypes.DType(x.type) for x in signature.output_arg], + op_def=signature, + name="FunctionCall", + compute_shapes=False) result = op.outputs for i, s in enumerate(self._output_shapes): result[i].set_shape(s) else: - tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x - for x in tensor_inputs] + tensor_inputs = [ + x.tensor() if isinstance(x, tensor.LazyZero) else x + for x in tensor_inputs + ] result = execute.execute( self._func_name, num_outputs=self._num_outputs, @@ -383,22 +373,21 @@ def _defun_internal(name, func, args, kwds): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) if ids: - extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) + extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, ops.Tensor)] + flat_inputs = [ + x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) + ] all_inputs = flat_inputs + list(extra_placeholders) func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None] inference_function_def = graph_to_function_def.graph_to_function_def( - tmp_graph, tmp_graph.get_operations(), - all_inputs, - func_def_outputs) + tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access @@ -407,14 +396,9 @@ def _defun_internal(name, func, args, kwds): _register_with_name(_inference_name(name), inference_function_def) return _GraphModeFunction( - all_inputs, - extra_inputs, - inference_function_def, - tmp_graph, - tmp_graph.get_operations(), - func_outputs, - _map_sequence_obj_to_idx(func_def_outputs), - output_shapes) + all_inputs, extra_inputs, inference_function_def, tmp_graph, + tmp_graph.get_operations(), func_outputs, + _map_sequence_obj_to_idx(func_def_outputs), output_shapes) # Defun uses this instead of Tensor as a cache key. Using dtype because diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py index 64441717308..86ac243ae37 100644 --- a/tensorflow/python/eager/tensor.py +++ b/tensorflow/python/eager/tensor.py @@ -27,11 +27,13 @@ from tensorflow.python.eager import core from tensorflow.python.eager import tape from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import tensor_shape -class Tensor(object): - """A TensorFlow Tensor.""" +# TODO(agarwal): rename to TensorHandle. +class Tensor(tf_ops.Tensor): + """A TensorFlow Eager Tensor.""" def __init__(self, value, dtype=None): """Creates a Tensor object from a Python object or numpy array. @@ -48,7 +50,7 @@ class Tensor(object): # TODO(ashankar): Evaluate if we can and perhaps share code with # tf.constant defined in # https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py - self._id = core.uid() + self._id = tf_ops.uid() if not isinstance(value, np.ndarray): npt = None if dtype is None else dtype.as_numpy_dtype value = np.array(value, dtype=npt) @@ -111,7 +113,7 @@ class Tensor(object): if core.active_trace() is not None: core.active_trace().record_tensor("MANUAL", tape.tensor_id(self), - self._device_name(), + self.device, self.shape.num_elements()) def __del__(self): @@ -184,12 +186,13 @@ class Tensor(object): if core.active_trace() is not None: core.active_trace().record_tensor("COPY", tape.tensor_id(new_tensor), - new_tensor._device_name(), + new_tensor.device, new_tensor.shape.num_elements()) return new_tensor # pylint: enable=protected-access - def _device_name(self): + @property + def device(self): return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle) @property @@ -237,6 +240,10 @@ class Tensor(object): pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x) for x in range(n)) + def _shape_as_list(self): + """The shape of the tensor as a list.""" + return list(self._shape_tuple()) + def as_cpu_tensor(self): """A copy of this Tensor with contents backed by host memory.""" return self._copy(context.get_default_context(), "CPU:0") @@ -266,6 +273,42 @@ class Tensor(object): def __nonzero__(self): return self.__bool__() + # Methods not supported / implemented for Eager Tensors. + @property + def op(self): + raise NotImplementedError("op not supported for Eager Tensors.") + + @property + def graph(self): + raise NotImplementedError("graph not supported for Eager Tensors.") + + @property + def name(self): + raise NotImplementedError("name not supported for Eager Tensors.") + + def set_shape(self, shape): + raise NotImplementedError("set_shape not supported for Eager Tensors.") + + @property + def value_index(self): + raise NotImplementedError("value_index not supported for Eager Tensors.") + + def consumers(self): + raise NotImplementedError("consumers not supported for Eager Tensors.") + + def _add_consumer(self, consumer): + raise NotImplementedError("_add_consumer not supported for Eager Tensors.") + + def _as_node_def_input(self): + raise NotImplementedError( + "_as_node_def_input not supported for Eager Tensors.") + + def _as_tf_output(self): + raise NotImplementedError("_as_tf_output not supported for Eager Tensors.") + + def eval(self, feed_dict=None, session=None): + raise NotImplementedError("eval not supported for Eager Tensors.") + class IndexedSlices(object): """A sparse representation of a set of tensor slices at given indices. @@ -374,7 +417,7 @@ def _tensor_from_handle(handle): """ # pylint: disable=protected-access t = Tensor.__new__(Tensor) - t._id = core.uid() + t._id = tf_ops.uid() t._handle = handle t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle)) t._handle_data = None diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 5d3ac45020a..35acd053391 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Classes and functions used to construct graphs.""" # pylint: disable=g-bad-name from __future__ import absolute_import @@ -47,7 +46,6 @@ from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils from tensorflow.python.util import tf_contextlib - # Temporary global switch determining if we should enable the work-in-progress # calls to the C API. Currently disabled by default but can be manually enabled # e.g. in tests. This will be removed once all functionality is supported and @@ -153,6 +151,18 @@ def register_dense_tensor_like_type(tensor_type): _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type]) +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def uid(): + """A unique (within this program execution) integer.""" + with _uid_lock: + global _uid_counter + _uid_counter += 1 + return _uid_counter + + # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. class _TensorLike(object): """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" @@ -261,6 +271,7 @@ class Tensor(_TensorLike): # Attributes used for C++ shape inference. Not inspected, only forwarded. # If set, will be a HandleData object from cpp_shape_inference.proto. self._handle_data = None + self._id = uid() @property def op(self): @@ -284,11 +295,6 @@ class Tensor(_TensorLike): raise ValueError("Operation was not named: %s" % self._op) return "%s:%d" % (self._op.name, self._value_index) - @property - def _id(self): - """An alias for the string name of this tensor.""" - return self.name - @property def device(self): """The name of the device on which this tensor will be produced, or None.""" @@ -437,15 +443,15 @@ class Tensor(_TensorLike): def __str__(self): return "Tensor(\"%s\"%s%s%s)" % ( - self.name, - (", shape=%s" % self.get_shape()) + self.name, (", shape=%s" % self.get_shape()) if self.get_shape().ndims is not None else "", - (", dtype=%s" % self._dtype.name) if self._dtype else "", - (", device=%s" % self.device) if self.device else "") + (", dtype=%s" % self._dtype.name) + if self._dtype else "", (", device=%s" % self.device) + if self.device else "") def __repr__(self): - return "" % ( - self.name, self.get_shape(), self._dtype.name) + return "" % (self.name, self.get_shape(), + self._dtype.name) def __hash__(self): # Necessary to support Python's collection membership operators @@ -551,20 +557,18 @@ def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): _ = name, as_ref if dtype and not dtype.is_compatible_with(t.dtype): raise ValueError( - "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" - % (dtype.name, t.dtype.name, str(t))) + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % + (dtype.name, t.dtype.name, str(t))) return t _tensor_conversion_func_registry = { - 0: [(Tensor, _TensorTensorConversionFunction)]} + 0: [(Tensor, _TensorTensorConversionFunction)] +} register_dense_tensor_like_type(Tensor) -def convert_to_tensor(value, - dtype=None, - name=None, - preferred_dtype=None): +def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -686,18 +690,18 @@ def internal_convert_to_tensor(value, if not isinstance(ret, Tensor): raise RuntimeError( - "%sConversion function %r for type %s returned non-Tensor: %r" - % (error_prefix, conversion_func, base_type, ret)) + "%sConversion function %r for type %s returned non-Tensor: %r" % + (error_prefix, conversion_func, base_type, ret)) if dtype and not dtype.is_compatible_with(ret.dtype): raise RuntimeError( "%sConversion function %r for type %s returned incompatible " - "dtype: requested = %s, actual = %s" - % (error_prefix, conversion_func, base_type, - dtype.name, ret.dtype.name)) + "dtype: requested = %s, actual = %s" % + (error_prefix, conversion_func, base_type, dtype.name, + ret.dtype.name)) return ret raise TypeError("%sCannot convert %r with type %s to Tensor: " - "no conversion function registered." - % (error_prefix, value, type(value))) + "no conversion function registered." % (error_prefix, value, + type(value))) def internal_convert_n_to_tensor(values, @@ -744,10 +748,7 @@ def internal_convert_n_to_tensor(values, return ret -def convert_n_to_tensor(values, - dtype=None, - name=None, - preferred_dtype=None): +def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): """Converts `values` to a list of `Tensor` objects. Args: @@ -771,11 +772,12 @@ def convert_n_to_tensor(values, RuntimeError: If a registered conversion function returns an invalid value. """ - return internal_convert_n_to_tensor(values=values, - dtype=dtype, - name=name, - preferred_dtype=preferred_dtype, - as_ref=False) + return internal_convert_n_to_tensor( + values=values, + dtype=dtype, + name=name, + preferred_dtype=preferred_dtype, + as_ref=False) def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): @@ -802,7 +804,9 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): value=value, dtype=dtype, name=name, as_ref=False) -def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, +def internal_convert_to_tensor_or_indexed_slices(value, + dtype=None, + name=None, as_ref=False): """Converts the given object to an `Tensor` or an `IndexedSlices`. @@ -827,18 +831,18 @@ def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, if isinstance(value, _TensorLike): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): raise ValueError( - "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" - % (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % + (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) return value else: - return internal_convert_to_tensor(value, - dtype=dtype, - name=name, - as_ref=as_ref) + return internal_convert_to_tensor( + value, dtype=dtype, name=name, as_ref=as_ref) -def internal_convert_n_to_tensor_or_indexed_slices(values, dtype=None, - name=None, as_ref=False): +def internal_convert_n_to_tensor_or_indexed_slices(values, + dtype=None, + name=None, + as_ref=False): """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. Any `IndexedSlices` or `SparseTensor` objects in `values` are returned @@ -905,7 +909,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): values=values, dtype=dtype, name=name, as_ref=False) -def register_tensor_conversion_function(base_type, conversion_func, +def register_tensor_conversion_function(base_type, + conversion_func, priority=100): """Registers a function for converting objects of `base_type` to `Tensor`. @@ -947,8 +952,8 @@ def register_tensor_conversion_function(base_type, conversion_func, """ if not (isinstance(base_type, type) or - (isinstance(base_type, tuple) - and all(isinstance(x, type) for x in base_type))): + (isinstance(base_type, tuple) and + all(isinstance(x, type) for x in base_type))): raise TypeError("base_type must be a type or a tuple of types.") if not callable(conversion_func): raise TypeError("conversion_func must be callable.") @@ -1038,8 +1043,7 @@ class IndexedSlices(_TensorLike): def __str__(self): return "IndexedSlices(indices=%s, values=%s%s)" % ( - self._indices, self._values, - (", dense_shape=%s" % self._dense_shape) + self._indices, self._values, (", dense_shape=%s" % self._dense_shape) if self._dense_shape is not None else "") def __neg__(self): @@ -1112,8 +1116,14 @@ class Operation(object): `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. """ - def __init__(self, node_def, g, inputs=None, output_types=None, - control_inputs=None, input_types=None, original_op=None, + def __init__(self, + node_def, + g, + inputs=None, + output_types=None, + control_inputs=None, + input_types=None, + original_op=None, op_def=None): r"""Creates an `Operation`. @@ -1177,18 +1187,20 @@ class Operation(object): if output_types is None: output_types = [] self._output_types_val = output_types - self._outputs = [Tensor(self, i, output_type) - for i, output_type in enumerate(output_types)] + self._outputs = [ + Tensor(self, i, output_type) + for i, output_type in enumerate(output_types) + ] if input_types is None: input_types = [i.dtype.base_dtype for i in self._inputs] else: - if not all(x.is_compatible_with(i.dtype) - for i, x in zip(self._inputs, input_types)): + if not all( + x.is_compatible_with(i.dtype) + for i, x in zip(self._inputs, input_types)): raise TypeError("In op '%s', input types (%s) are not compatible " - "with expected types (%s)" % ( - self.node_def.name, - [i.dtype for i in self._inputs], - input_types)) + "with expected types (%s)" % + (self.node_def.name, [i.dtype for i in self._inputs], + input_types)) self._input_types_val = input_types # Build the list of control inputs. @@ -1251,7 +1263,8 @@ class Operation(object): A wrapped TF_Operation*. """ # pylint: disable=protected-access - op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op), + op_desc = c_api.TF_NewOperation(graph._c_graph, + compat.as_str(node_def.op), compat.as_str(node_def.name)) # Add inputs for op_input in inputs: @@ -1271,8 +1284,8 @@ class Operation(object): # TODO(skyewm): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use the same status. with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized, - status) + c_api.TF_SetAttrValueProto(op_desc, + compat.as_str(name), serialized, status) with errors.raise_exception_on_not_ok_status() as status: c_op = c_api.TF_FinishOperation(op_desc, status) @@ -1316,16 +1329,18 @@ class Operation(object): def colocation_groups(self): """Returns the list of colocation groups of the op.""" - default_colocation_group = [compat.as_bytes("loc:@%s" % - self._node_def.name)] + default_colocation_group = [ + compat.as_bytes("loc:@%s" % self._node_def.name) + ] if "_class" not in self._node_def.attr: # This op has no explicit colocation group, so it is itself its # own root of a colocation group. return default_colocation_group - attr_groups = [class_name - for class_name in self.get_attr("_class") - if class_name.startswith(b"loc:@")] + attr_groups = [ + class_name for class_name in self.get_attr("_class") + if class_name.startswith(b"loc:@") + ] # If there are no colocation groups in the explicit _class field, # return the default colocation group. @@ -1397,8 +1412,10 @@ class Operation(object): """ if self._graph._c_graph: # pylint: disable=protected-access num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [c_api.TF_OperationOutputType(self._tf_output(i)) for - i in xrange(num_outputs)] + output_types = [ + c_api.TF_OperationOutputType(self._tf_output(i)) + for i in xrange(num_outputs) + ] # TODO(iga): Remove this assert after converting to C API by default. # Just being a bit paranoid here. assert self._output_types_val == output_types @@ -1433,8 +1450,8 @@ class Operation(object): device: string or device.. The device to set. """ if _USE_C_API: - c_api.SetRequestedDevice( - self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access + c_api.SetRequestedDevice(self._graph._c_graph, self._c_op, # pylint: disable=protected-access + _device_string(device)) # TODO(nolivia): remove this line when switch to C api self._node_def.device = _device_string(device) @@ -1462,8 +1479,8 @@ class Operation(object): dtype = dtypes.as_dtype(dtype) if not dtype.is_compatible_with(tensor.dtype): raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" - % (tensor.dtype.name, dtype.name)) + "Cannot convert a tensor of type %s to an input of type %s" % + (tensor.dtype.name, dtype.name)) self._inputs.append(tensor) self._input_types_val.append(dtype) tensor._add_consumer(self) # pylint: disable=protected-access @@ -1496,8 +1513,8 @@ class Operation(object): dtype = dtypes.as_dtype(dtype) if not dtype.is_compatible_with(tensor.dtype): raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" - % (tensor.dtype.name, dtype.name)) + "Cannot convert a tensor of type %s to an input of type %s" % + (tensor.dtype.name, dtype.name)) self._inputs[index].consumers().remove(self) self._inputs[index] = tensor @@ -1547,8 +1564,8 @@ class Operation(object): self._node_def.input.extend([t._as_node_def_input() for t in self._inputs]) # pylint: enable=protected-access if self._control_inputs: - self._node_def.input.extend(["^%s" % op.name for op in - self._control_inputs]) + self._node_def.input.extend( + ["^%s" % op.name for op in self._control_inputs]) def __str__(self): return str(self._node_def) @@ -1562,6 +1579,7 @@ class Operation(object): return self._outputs # pylint: disable=protected-access + class _InputList(object): """Immutable input list wrapper.""" @@ -1582,6 +1600,7 @@ class Operation(object): def __getitem__(self, i): return self._op._inputs[i] + # pylint: enable=protected-access @property @@ -1597,9 +1616,10 @@ class Operation(object): def _input_types(self): if self._graph._c_graph: # pylint: disable=protected-access num_inputs = c_api.TF_OperationNumInputs(self._c_op) - input_types = [dtypes.as_dtype( - c_api.TF_OperationInputType(self._tf_input(i))) - for i in xrange(num_inputs)] + input_types = [ + dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) + for i in xrange(num_inputs) + ] # TODO(iga): Remove this assert after converting to C API by default. # Just being a bit paranoid here. assert self._input_types_val == input_types @@ -1624,8 +1644,10 @@ class Operation(object): if self._graph._c_graph: # pylint: disable=protected-access control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) # pylint: disable=protected-access - return [self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops] + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] # pylint: enable=protected-access else: return self._control_inputs @@ -1691,7 +1713,8 @@ class Operation(object): A list of 5-tuples (filename, lineno, name, code, func_start_lineno). """ return self._graph._convert_stack( # pylint: disable=protected-access - self._traceback, include_func_start_lineno=True) + self._traceback, + include_func_start_lineno=True) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -1707,8 +1730,7 @@ class Operation(object): """ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] if name not in self._node_def.attr: - raise ValueError("No attr named '" + name + "' in " + - str(self._node_def)) + raise ValueError("No attr named '" + name + "' in " + str(self._node_def)) x = self._node_def.attr[name] # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): @@ -1749,7 +1771,6 @@ class Operation(object): """ _run_using_default_session(self, feed_dict, self.graph, session) - _gradient_registry = registry.Registry("gradient") @@ -1834,7 +1855,8 @@ NoGradient = NotDifferentiable def get_gradient_function(op): """Returns the function that computes gradients for "op".""" - if not op.inputs: return None + if not op.inputs: + return None try: op_type = op.get_attr("_gradient_op_type") except ValueError: @@ -1982,6 +2004,7 @@ class OpStats(object): self._value += other.value return self + _stats_registry = registry.Registry("statistical functions") @@ -2433,8 +2456,8 @@ class Graph(object): graph.node.extend([op.node_def]) if op.outputs and add_shapes: assert "_output_shapes" not in graph.node[-1].attr - graph.node[-1].attr["_output_shapes"].list.shape.extend([ - output.get_shape().as_proto() for output in op.outputs]) + graph.node[-1].attr["_output_shapes"].list.shape.extend( + [output.get_shape().as_proto() for output in op.outputs]) bytesize += op.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") @@ -2469,7 +2492,8 @@ class Graph(object): node with the inferred shapes of each of its outputs. Returns: - A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) + A + [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer. Raises: @@ -2518,8 +2542,8 @@ class Graph(object): if previous: raise ValueError("Another function is already defined with that name") # Sanity checks on gradient definition. - if (function.grad_func_name is not None) and ( - function.python_grad_func is not None): + if (function.grad_func_name is not None) and (function.python_grad_func is + not None): raise ValueError("Gradient defined twice for function %s" % name) # Need a new-enough consumer to support the functions we add to the graph. if self._graph_def_versions.min_consumer < 12: @@ -2532,9 +2556,17 @@ class Graph(object): return self._building_function # Helper functions to create operations. - def create_op(self, op_type, inputs, dtypes, # pylint: disable=redefined-outer-name - input_types=None, name=None, attrs=None, op_def=None, - compute_shapes=True, compute_device=True): + def create_op( + self, + op_type, + inputs, + dtypes, # pylint: disable=redefined-outer-name + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_shapes=True, + compute_device=True): """Creates an `Operation` in this graph. This is a low-level interface for creating an `Operation`. Most @@ -2597,8 +2629,8 @@ class Graph(object): if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): raise TypeError( "Callable for scope map key '%s' must return either None or " - "an AttrValue protocol buffer; but it returned: %s" % - (key, value)) + "an AttrValue protocol buffer; but it returned: %s" % (key, + value)) node_def.attr[key].CopyFrom(value) # Apply a kernel label if one has been specified for this op_type. @@ -2619,9 +2651,15 @@ class Graph(object): pass control_inputs = self._control_dependencies_for_inputs(inputs) - ret = Operation(node_def, self, inputs=inputs, output_types=dtypes, - control_inputs=control_inputs, input_types=input_types, - original_op=self._default_original_op, op_def=op_def) + ret = Operation( + node_def, + self, + inputs=inputs, + output_types=dtypes, + control_inputs=control_inputs, + input_types=input_types, + original_op=self._default_original_op, + op_def=op_def) if compute_shapes: set_shapes_for_outputs(ret) self._add_op(ret) @@ -2638,28 +2676,27 @@ class Graph(object): # Make this device match the device of the colocated op, to # provide consistency between the device and the colocation # property. - if (ret.device and - pydev.canonical_name(ret.device) != + if (ret.device and pydev.canonical_name(ret.device) != pydev.canonical_name(colocation_op.device)): logging.warning("Tried to colocate %s with an op %s that had " "a different device: %s vs %s. " - "Ignoring colocation property.", - name, colocation_op.name, - ret.device, colocation_op.device) + "Ignoring colocation property.", name, + colocation_op.name, ret.device, + colocation_op.device) else: ret._set_device(colocation_op.device) # pylint: disable=protected-access all_colocation_groups = sorted(set(all_colocation_groups)) - ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( - list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) + ret.node_def.attr["_class"].CopyFrom( + attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue( + s=all_colocation_groups))) # Sets "container" attribute if # (1) self._container is not None # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None - if (self._container and - op_type in self._registered_ops and + if (self._container and op_type in self._registered_ops and self._registered_ops[op_type].is_stateful and "container" in ret.node_def.attr and not ret.node_def.attr["container"].s): @@ -2749,13 +2786,13 @@ class Graph(object): except: raise KeyError("The name %s refers to a Tensor which does not " "exist. The operation, %s, exists but only has " - "%s outputs." - % (repr(name), repr(op_name), len(op.outputs))) + "%s outputs." % (repr(name), repr(op_name), + len(op.outputs))) elif ":" in name and not allow_tensor: # Looks like a Tensor name but can't be a Tensor. - raise ValueError("Name %s appears to refer to a Tensor, not a %s." - % (repr(name), types_str)) + raise ValueError("Name %s appears to refer to a Tensor, not a %s." % + (repr(name), types_str)) elif ":" not in name and allow_operation: # Looks like an Operation name and can be an Operation. @@ -2768,8 +2805,8 @@ class Graph(object): # Looks like an Operation name but can't be an Operation. if name in self._nodes_by_name: # Yep, it's an Operation name - err_msg = ("The name %s refers to an Operation, not a %s." - % (repr(name), types_str)) + err_msg = ("The name %s refers to an Operation, not a %s." % + (repr(name), types_str)) else: err_msg = ("The name %s looks like an (invalid) Operation name, " "not a %s." % (repr(name), types_str)) @@ -2789,8 +2826,8 @@ class Graph(object): return obj else: # We give up! - raise TypeError("Can not convert a %s into a %s." - % (type(obj).__name__, types_str)) + raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__, + types_str)) def get_operations(self): """Return the list of operations in the graph. @@ -2827,8 +2864,8 @@ class Graph(object): """ if not isinstance(name, six.string_types): - raise TypeError("Operation names are strings (or similar), not %s." - % type(name).__name__) + raise TypeError("Operation names are strings (or similar), not %s." % + type(name).__name__) return self.as_graph_element(name, allow_tensor=False, allow_operation=True) def _get_operation_by_name_unsafe(self, name): @@ -2871,8 +2908,8 @@ class Graph(object): """ # Names should be strings. if not isinstance(name, six.string_types): - raise TypeError("Tensor names are strings (or similar), not %s." - % type(name).__name__) + raise TypeError("Tensor names are strings (or similar), not %s." % + type(name).__name__) return self.as_graph_element(name, allow_tensor=True, allow_operation=False) def _next_id(self): @@ -3178,6 +3215,7 @@ class Graph(object): yield "" if new_stack is None else new_stack + "/" finally: self._name_stack = old_stack + # pylint: enable=g-doc-return-or-yield,line-too-long def unique_name(self, name, mark_as_used=True): @@ -3280,9 +3318,8 @@ class Graph(object): """ if op is None and not ignore_existing: - raise ValueError( - "Trying to reset colocation (op is None) but " - "ignore_existing is not True") + raise ValueError("Trying to reset colocation (op is None) but " + "ignore_existing is not True") if op is not None and not isinstance(op, Operation): # We always want to colocate with the reference op. @@ -3377,8 +3414,8 @@ class Graph(object): """ # pylint: enable=line-too-long - if (device_name_or_function is not None - and not callable(device_name_or_function)): + if (device_name_or_function is not None and + not callable(device_name_or_function)): device_function = pydev.merge_device(device_name_or_function) else: device_function = device_name_or_function @@ -3452,6 +3489,7 @@ class Graph(object): yield self._container finally: self._container = original_container + # pylint: enable=g-doc-return-or-yield class _ControlDependenciesController(object): @@ -3491,6 +3529,7 @@ class Graph(object): self._old_control_flow_context = None # pylint: disable=protected-access + def __enter__(self): if self._new_stack: # Clear the control_dependencies graph. @@ -3506,6 +3545,7 @@ class Graph(object): if self._new_stack: self._graph._control_dependencies_stack = self._old_stack self._graph._set_control_flow_context(self._old_control_flow_context) + # pylint: enable=protected-access @property @@ -3733,6 +3773,7 @@ class Graph(object): self._attr_scope_map[name] = saved_attrs[name] except KeyError: del self._attr_scope_map[name] + # pylint: enable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield @@ -3777,8 +3818,8 @@ class Graph(object): saved_labels = {} # Install the given label for op_type, label in op_to_kernel_label_map.items(): - if not (isinstance(op_type, six.string_types) - and isinstance(label, six.string_types)): + if not (isinstance(op_type, six.string_types) and + isinstance(label, six.string_types)): raise TypeError("op_to_kernel_label_map must be a dictionary mapping " "strings to strings") try: @@ -3795,6 +3836,7 @@ class Graph(object): self._op_to_kernel_label_map[op_type] = saved_labels[op_type] except KeyError: del self._op_to_kernel_label_map[op_type] + # pylint: enable=g-doc-return-or-yield # pylint: disable=g-doc-return-or-yield @@ -3840,8 +3882,8 @@ class Graph(object): saved_mappings = {} # Install the given label for op_type, mapped_op_type in op_type_map.items(): - if not (isinstance(op_type, six.string_types) - and isinstance(mapped_op_type, six.string_types)): + if not (isinstance(op_type, six.string_types) and + isinstance(mapped_op_type, six.string_types)): raise TypeError("op_type_map must be a dictionary mapping " "strings to strings") try: @@ -3858,6 +3900,7 @@ class Graph(object): self._gradient_override_map[op_type] = saved_mappings[op_type] except KeyError: del self._gradient_override_map[op_type] + # pylint: enable=g-doc-return-or-yield def prevent_feeding(self, tensor): @@ -3969,12 +4012,13 @@ class _DefaultStack(threading.local): if self._enforce_nesting: if self.stack[-1] is not default: raise AssertionError( - "Nesting violated for default stack of %s objects" - % type(default)) + "Nesting violated for default stack of %s objects" % + type(default)) self.stack.pop() else: self.stack.remove(default) + _default_session_stack = _DefaultStack() @@ -4145,6 +4189,7 @@ class _DefaultGraphStack(_DefaultStack): super(_DefaultGraphStack, self).reset() self._global_default_graph = None + _default_graph_stack = _DefaultGraphStack() @@ -4195,8 +4240,8 @@ def _assert_same_graph(original_item, item): ValueError: if graphs do not match. """ if original_item.graph is not item.graph: - raise ValueError( - "%s must be from the same graph as %s." % (item, original_item)) + raise ValueError("%s must be from the same graph as %s." % (item, + original_item)) def _get_graph_from_inputs(op_input_list, graph=None): @@ -4246,8 +4291,11 @@ def _get_graph_from_inputs(op_input_list, graph=None): original_graph_element = None for op_input in op_input_list: # Determine if this is a valid graph_element. + # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this + # up. graph_element = None - if isinstance(op_input, (Operation, _TensorLike)): + if (isinstance(op_input, (Operation, _TensorLike)) and + ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck graph_element = op_input else: graph_element = _as_graph_element(op_input) @@ -4259,8 +4307,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): elif original_graph_element is not None: _assert_same_graph(original_graph_element, graph_element) elif graph_element.graph is not graph: - raise ValueError( - "%s is not from the passed-in graph." % graph_element) + raise ValueError("%s is not from the passed-in graph." % graph_element) # 2. If all else fails, we use the default graph, which is always there. return graph or get_default_graph() @@ -4512,13 +4559,15 @@ def name_scope(name, default_name=None, values=None): # tf.name_scope(None) (values=None then) is sometimes used as an idiom # to reset to top scope. raise ValueError( - "At least one of name (%s) and default_name (%s) must be provided." % ( - name, default_name)) + "At least one of name (%s) and default_name (%s) must be provided." % + (name, default_name)) if values is None: values = [] g = _get_graph_from_inputs(values) with g.as_default(), g.name_scope(n) as scope: yield scope + + # pylint: enable=g-doc-return-or-yield @@ -4585,7 +4634,9 @@ def op_scope(values, name, default_name=None): _proto_function_registry = registry.Registry("proto functions") -def register_proto_function(collection_name, proto_type=None, to_proto=None, +def register_proto_function(collection_name, + proto_type=None, + to_proto=None, from_proto=None): """Registers `to_proto` and `from_proto` functions for collection_name. @@ -4637,10 +4688,9 @@ def get_from_proto_function(collection_name): def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): """Produce a nice error if someone converts an Operation to a Tensor.""" - raise TypeError( - ("Can't convert Operation '%s' to Tensor " - "(target dtype=%r, name=%r, as_ref=%r)") % - (op.name, dtype, name, as_ref)) + raise TypeError(("Can't convert Operation '%s' to Tensor " + "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype, + name, as_ref)) register_tensor_conversion_function(Operation, _operation_conversion_error)