Internal change
PiperOrigin-RevId: 164916465
This commit is contained in:
parent
b8d13d218f
commit
d9ca2d86de
tensorflow/python
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import memory_trace
|
||||
@ -28,17 +26,6 @@ from tensorflow.python.framework import errors
|
||||
# Trace of execution and memory usage.
|
||||
_active_trace = None
|
||||
|
||||
_uid_counter = 0
|
||||
_uid_lock = threading.Lock()
|
||||
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
with _uid_lock:
|
||||
global _uid_counter
|
||||
_uid_counter += 1
|
||||
return _uid_counter
|
||||
|
||||
|
||||
def _status_to_exception(code, message):
|
||||
try:
|
||||
|
@ -28,7 +28,6 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor
|
||||
@ -40,7 +39,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# Thread-local storage for tfe Tensors which are referenced while evaluating a
|
||||
# graph-mode function.
|
||||
_scoped_captures = threading.local()
|
||||
@ -86,9 +84,8 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
"To build a graph use tfe.defun or tfe.func_to_object.")
|
||||
captured_value = tensor_map.get(tape.tensor_id(value), None)
|
||||
if captured_value is None:
|
||||
captured_value = graph_placeholder(dtype=dtype or value.dtype,
|
||||
shape=value.shape,
|
||||
name=name)
|
||||
captured_value = graph_placeholder(
|
||||
dtype=dtype or value.dtype, shape=value.shape, name=name)
|
||||
if captured_value.dtype == dtypes.resource:
|
||||
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
|
||||
tensor_map[tape.tensor_id(value)] = (value, captured_value)
|
||||
@ -98,8 +95,10 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
|
||||
|
||||
|
||||
# TODO(apassos): it'd be really nice if we could scope this registration.
|
||||
ops.register_tensor_conversion_function(tensor.Tensor,
|
||||
_convert_to_graph_constant)
|
||||
# Note that we register this at a higher priority than ops.Tensor since we want
|
||||
# to handle subclass specific conversion before a superclass conversion.
|
||||
ops.register_tensor_conversion_function(
|
||||
tensor.Tensor, _convert_to_graph_constant, priority=-1)
|
||||
|
||||
|
||||
class _CapturingContext(object):
|
||||
@ -133,17 +132,17 @@ class _CapturingContext(object):
|
||||
|
||||
def _forward_name(n):
|
||||
"""The name of a generated forward defun named n."""
|
||||
return "__forward_%s_%s" % (n, core.uid())
|
||||
return "__forward_%s_%s" % (n, ops.uid())
|
||||
|
||||
|
||||
def _backward_name(n):
|
||||
"""The name of a generated backward defun named n."""
|
||||
return "__backward_%s_%s" % (n, core.uid())
|
||||
return "__backward_%s_%s" % (n, ops.uid())
|
||||
|
||||
|
||||
def _inference_name(n):
|
||||
"""The name of a forward-but-no-gradient defun named n."""
|
||||
return "__inference_%s_%s" % (n, core.uid())
|
||||
return "__inference_%s_%s" % (n, ops.uid())
|
||||
|
||||
|
||||
class _DefinedFunction(object):
|
||||
@ -184,15 +183,8 @@ class _GraphModeFunction(object):
|
||||
internal function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_placeholders,
|
||||
extra_inputs,
|
||||
fdef,
|
||||
graph,
|
||||
operations,
|
||||
func_outputs,
|
||||
func_outputs_to_fdef_outputs,
|
||||
output_shapes):
|
||||
def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations,
|
||||
func_outputs, func_outputs_to_fdef_outputs, output_shapes):
|
||||
assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % (
|
||||
len(input_placeholders), len(fdef.signature.input_arg))
|
||||
self._input_placeholders = input_placeholders
|
||||
@ -204,8 +196,8 @@ class _GraphModeFunction(object):
|
||||
self._num_outputs = len(fdef.signature.output_arg)
|
||||
self._ops = operations
|
||||
self._func_outputs = func_outputs
|
||||
if (isinstance(func_outputs, (ops.Tensor, type(None)))
|
||||
or ag_core.isnode(func_outputs)):
|
||||
if (isinstance(func_outputs, (ops.Tensor, type(None))) or
|
||||
ag_core.isnode(func_outputs)):
|
||||
self._returns = [func_outputs]
|
||||
else:
|
||||
self._returns = list(func_outputs)
|
||||
@ -218,11 +210,11 @@ class _GraphModeFunction(object):
|
||||
with self._graph.as_default(), context.graph_mode():
|
||||
c = _CapturingContext()
|
||||
with c:
|
||||
filtered_outputs = [ag_core.getval(x)
|
||||
for x in self._returns if x is not None]
|
||||
filtered_outputs = [
|
||||
ag_core.getval(x) for x in self._returns if x is not None
|
||||
]
|
||||
self._out_grad_placeholders = [
|
||||
graph_placeholder(x.dtype, x.shape)
|
||||
for x in filtered_outputs
|
||||
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
|
||||
]
|
||||
in_gradients = gradients_impl.gradients(
|
||||
filtered_outputs,
|
||||
@ -231,20 +223,16 @@ class _GraphModeFunction(object):
|
||||
shapes = [x.shape for x in in_gradients if x is not None]
|
||||
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
|
||||
forward_function_def = graph_to_function_def.graph_to_function_def(
|
||||
self._graph, self._ops,
|
||||
self._input_placeholders,
|
||||
self._graph, self._ops, self._input_placeholders,
|
||||
filtered_outputs + captures)
|
||||
self._forward_fdef = _DefinedFunction(forward_function_def)
|
||||
_register_with_name(_forward_name(self._func_name),
|
||||
forward_function_def)
|
||||
_register_with_name(_forward_name(self._func_name), forward_function_def)
|
||||
backward_outputs = [x for x in in_gradients if x is not None]
|
||||
all_inputs = self._out_grad_placeholders + captures
|
||||
backward_function_def = graph_to_function_def.graph_to_function_def(
|
||||
self._graph,
|
||||
[x.op for x in self._out_grad_placeholders] +
|
||||
list(sorted(c.known_ops, key=lambda x: x.name)),
|
||||
all_inputs,
|
||||
backward_outputs)
|
||||
self._graph, [x.op for x in self._out_grad_placeholders
|
||||
] + list(sorted(c.known_ops, key=lambda x: x.name)),
|
||||
all_inputs, backward_outputs)
|
||||
_register_with_name(_backward_name(self._func_name), backward_function_def)
|
||||
self._backward_function = _GraphModeFunction(
|
||||
all_inputs, [], backward_function_def, self._graph, c.known_ops,
|
||||
@ -258,12 +246,12 @@ class _GraphModeFunction(object):
|
||||
g = ops.get_default_graph()
|
||||
g._add_function(self._forward_fdef) # pylint: disable=protected-access
|
||||
unwrapped_args = [ag_core.getval(x) for x in all_args]
|
||||
op = g.create_op(signature.name,
|
||||
[ops.convert_to_tensor(x) for x in unwrapped_args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
op = g.create_op(
|
||||
signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
outputs = op.outputs
|
||||
outputs = [outputs] if isinstance(
|
||||
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
|
||||
@ -288,17 +276,17 @@ class _GraphModeFunction(object):
|
||||
watched_extra_inputs.append(t)
|
||||
real_outputs = tape.record_operation(real_outputs,
|
||||
(args + watched_extra_inputs),
|
||||
side_outputs,
|
||||
self._backward_function)
|
||||
side_outputs, self._backward_function)
|
||||
|
||||
return self._build_call_outputs(self._returns, real_outputs)
|
||||
|
||||
def __call__(self, *args):
|
||||
"""Executes the passed function in eager mode."""
|
||||
tensor_inputs = [x for x in nest.flatten(args)
|
||||
if isinstance(x, (tensor.Tensor, ops.Tensor,
|
||||
tensor.LazyZero))
|
||||
or ag_core.isnode(x)]
|
||||
tensor_inputs = [
|
||||
x for x in nest.flatten(args)
|
||||
if isinstance(x, (tensor.Tensor, ops.Tensor,
|
||||
tensor.LazyZero)) or ag_core.isnode(x)
|
||||
]
|
||||
if tape.should_record(tensor_inputs) or any(
|
||||
tape.any_tape_has(t) for t in self._extra_inputs):
|
||||
if not self._has_backprop:
|
||||
@ -310,18 +298,20 @@ class _GraphModeFunction(object):
|
||||
g._add_function(self._fdef) # pylint: disable=protected-access
|
||||
signature = self._fdef.definition.signature
|
||||
args = list(tensor_inputs) + self._extra_inputs
|
||||
op = g.create_op(signature.name,
|
||||
[ops.convert_to_tensor(x) for x in args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
op = g.create_op(
|
||||
signature.name, [ops.convert_to_tensor(x) for x in args],
|
||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||
op_def=signature,
|
||||
name="FunctionCall",
|
||||
compute_shapes=False)
|
||||
result = op.outputs
|
||||
for i, s in enumerate(self._output_shapes):
|
||||
result[i].set_shape(s)
|
||||
else:
|
||||
tensor_inputs = [x.tensor() if isinstance(x, tensor.LazyZero) else x
|
||||
for x in tensor_inputs]
|
||||
tensor_inputs = [
|
||||
x.tensor() if isinstance(x, tensor.LazyZero) else x
|
||||
for x in tensor_inputs
|
||||
]
|
||||
result = execute.execute(
|
||||
self._func_name,
|
||||
num_outputs=self._num_outputs,
|
||||
@ -383,22 +373,21 @@ def _defun_internal(name, func, args, kwds):
|
||||
func_outputs = func(*func_inputs, **kwds)
|
||||
ids = list(sorted(captures.keys()))
|
||||
if ids:
|
||||
extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
|
||||
extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
|
||||
else:
|
||||
extra_inputs = []
|
||||
extra_placeholders = []
|
||||
outputs_list = nest.flatten(func_outputs)
|
||||
output_shapes = [x.shape for x in outputs_list if x is not None]
|
||||
|
||||
flat_inputs = [x for x in nest.flatten(func_inputs)
|
||||
if isinstance(x, ops.Tensor)]
|
||||
flat_inputs = [
|
||||
x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
|
||||
]
|
||||
all_inputs = flat_inputs + list(extra_placeholders)
|
||||
|
||||
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
|
||||
inference_function_def = graph_to_function_def.graph_to_function_def(
|
||||
tmp_graph, tmp_graph.get_operations(),
|
||||
all_inputs,
|
||||
func_def_outputs)
|
||||
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
|
||||
# Register any other functions defined in the graph
|
||||
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
|
||||
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
|
||||
@ -407,14 +396,9 @@ def _defun_internal(name, func, args, kwds):
|
||||
_register_with_name(_inference_name(name), inference_function_def)
|
||||
|
||||
return _GraphModeFunction(
|
||||
all_inputs,
|
||||
extra_inputs,
|
||||
inference_function_def,
|
||||
tmp_graph,
|
||||
tmp_graph.get_operations(),
|
||||
func_outputs,
|
||||
_map_sequence_obj_to_idx(func_def_outputs),
|
||||
output_shapes)
|
||||
all_inputs, extra_inputs, inference_function_def, tmp_graph,
|
||||
tmp_graph.get_operations(), func_outputs,
|
||||
_map_sequence_obj_to_idx(func_def_outputs), output_shapes)
|
||||
|
||||
|
||||
# Defun uses this instead of Tensor as a cache key. Using dtype because
|
||||
|
@ -27,11 +27,13 @@ from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
|
||||
class Tensor(object):
|
||||
"""A TensorFlow Tensor."""
|
||||
# TODO(agarwal): rename to TensorHandle.
|
||||
class Tensor(tf_ops.Tensor):
|
||||
"""A TensorFlow Eager Tensor."""
|
||||
|
||||
def __init__(self, value, dtype=None):
|
||||
"""Creates a Tensor object from a Python object or numpy array.
|
||||
@ -48,7 +50,7 @@ class Tensor(object):
|
||||
# TODO(ashankar): Evaluate if we can and perhaps share code with
|
||||
# tf.constant defined in
|
||||
# https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
|
||||
self._id = core.uid()
|
||||
self._id = tf_ops.uid()
|
||||
if not isinstance(value, np.ndarray):
|
||||
npt = None if dtype is None else dtype.as_numpy_dtype
|
||||
value = np.array(value, dtype=npt)
|
||||
@ -111,7 +113,7 @@ class Tensor(object):
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("MANUAL",
|
||||
tape.tensor_id(self),
|
||||
self._device_name(),
|
||||
self.device,
|
||||
self.shape.num_elements())
|
||||
|
||||
def __del__(self):
|
||||
@ -184,12 +186,13 @@ class Tensor(object):
|
||||
if core.active_trace() is not None:
|
||||
core.active_trace().record_tensor("COPY",
|
||||
tape.tensor_id(new_tensor),
|
||||
new_tensor._device_name(),
|
||||
new_tensor.device,
|
||||
new_tensor.shape.num_elements())
|
||||
return new_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _device_name(self):
|
||||
@property
|
||||
def device(self):
|
||||
return pywrap_tensorflow.TFE_TensorHandleDeviceName(self._handle)
|
||||
|
||||
@property
|
||||
@ -237,6 +240,10 @@ class Tensor(object):
|
||||
pywrap_tensorflow.TFE_TensorHandleDim(self._handle, x)
|
||||
for x in range(n))
|
||||
|
||||
def _shape_as_list(self):
|
||||
"""The shape of the tensor as a list."""
|
||||
return list(self._shape_tuple())
|
||||
|
||||
def as_cpu_tensor(self):
|
||||
"""A copy of this Tensor with contents backed by host memory."""
|
||||
return self._copy(context.get_default_context(), "CPU:0")
|
||||
@ -266,6 +273,42 @@ class Tensor(object):
|
||||
def __nonzero__(self):
|
||||
return self.__bool__()
|
||||
|
||||
# Methods not supported / implemented for Eager Tensors.
|
||||
@property
|
||||
def op(self):
|
||||
raise NotImplementedError("op not supported for Eager Tensors.")
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
raise NotImplementedError("graph not supported for Eager Tensors.")
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
raise NotImplementedError("name not supported for Eager Tensors.")
|
||||
|
||||
def set_shape(self, shape):
|
||||
raise NotImplementedError("set_shape not supported for Eager Tensors.")
|
||||
|
||||
@property
|
||||
def value_index(self):
|
||||
raise NotImplementedError("value_index not supported for Eager Tensors.")
|
||||
|
||||
def consumers(self):
|
||||
raise NotImplementedError("consumers not supported for Eager Tensors.")
|
||||
|
||||
def _add_consumer(self, consumer):
|
||||
raise NotImplementedError("_add_consumer not supported for Eager Tensors.")
|
||||
|
||||
def _as_node_def_input(self):
|
||||
raise NotImplementedError(
|
||||
"_as_node_def_input not supported for Eager Tensors.")
|
||||
|
||||
def _as_tf_output(self):
|
||||
raise NotImplementedError("_as_tf_output not supported for Eager Tensors.")
|
||||
|
||||
def eval(self, feed_dict=None, session=None):
|
||||
raise NotImplementedError("eval not supported for Eager Tensors.")
|
||||
|
||||
|
||||
class IndexedSlices(object):
|
||||
"""A sparse representation of a set of tensor slices at given indices.
|
||||
@ -374,7 +417,7 @@ def _tensor_from_handle(handle):
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
t = Tensor.__new__(Tensor)
|
||||
t._id = core.uid()
|
||||
t._id = tf_ops.uid()
|
||||
t._handle = handle
|
||||
t._dtype = dtypes.as_dtype(pywrap_tensorflow.TFE_TensorHandleDataType(handle))
|
||||
t._handle_data = None
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Classes and functions used to construct graphs."""
|
||||
# pylint: disable=g-bad-name
|
||||
from __future__ import absolute_import
|
||||
@ -47,7 +46,6 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import decorator_utils
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
# Temporary global switch determining if we should enable the work-in-progress
|
||||
# calls to the C API. Currently disabled by default but can be manually enabled
|
||||
# e.g. in tests. This will be removed once all functionality is supported and
|
||||
@ -153,6 +151,18 @@ def register_dense_tensor_like_type(tensor_type):
|
||||
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
|
||||
|
||||
|
||||
_uid_counter = 0
|
||||
_uid_lock = threading.Lock()
|
||||
|
||||
|
||||
def uid():
|
||||
"""A unique (within this program execution) integer."""
|
||||
with _uid_lock:
|
||||
global _uid_counter
|
||||
_uid_counter += 1
|
||||
return _uid_counter
|
||||
|
||||
|
||||
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
|
||||
class _TensorLike(object):
|
||||
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
|
||||
@ -261,6 +271,7 @@ class Tensor(_TensorLike):
|
||||
# Attributes used for C++ shape inference. Not inspected, only forwarded.
|
||||
# If set, will be a HandleData object from cpp_shape_inference.proto.
|
||||
self._handle_data = None
|
||||
self._id = uid()
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
@ -284,11 +295,6 @@ class Tensor(_TensorLike):
|
||||
raise ValueError("Operation was not named: %s" % self._op)
|
||||
return "%s:%d" % (self._op.name, self._value_index)
|
||||
|
||||
@property
|
||||
def _id(self):
|
||||
"""An alias for the string name of this tensor."""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""The name of the device on which this tensor will be produced, or None."""
|
||||
@ -437,15 +443,15 @@ class Tensor(_TensorLike):
|
||||
|
||||
def __str__(self):
|
||||
return "Tensor(\"%s\"%s%s%s)" % (
|
||||
self.name,
|
||||
(", shape=%s" % self.get_shape())
|
||||
self.name, (", shape=%s" % self.get_shape())
|
||||
if self.get_shape().ndims is not None else "",
|
||||
(", dtype=%s" % self._dtype.name) if self._dtype else "",
|
||||
(", device=%s" % self.device) if self.device else "")
|
||||
(", dtype=%s" % self._dtype.name)
|
||||
if self._dtype else "", (", device=%s" % self.device)
|
||||
if self.device else "")
|
||||
|
||||
def __repr__(self):
|
||||
return "<tf.Tensor '%s' shape=%s dtype=%s>" % (
|
||||
self.name, self.get_shape(), self._dtype.name)
|
||||
return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
|
||||
self._dtype.name)
|
||||
|
||||
def __hash__(self):
|
||||
# Necessary to support Python's collection membership operators
|
||||
@ -551,20 +557,18 @@ def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
||||
_ = name, as_ref
|
||||
if dtype and not dtype.is_compatible_with(t.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
|
||||
% (dtype.name, t.dtype.name, str(t)))
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||
(dtype.name, t.dtype.name, str(t)))
|
||||
return t
|
||||
|
||||
|
||||
_tensor_conversion_func_registry = {
|
||||
0: [(Tensor, _TensorTensorConversionFunction)]}
|
||||
0: [(Tensor, _TensorTensorConversionFunction)]
|
||||
}
|
||||
register_dense_tensor_like_type(Tensor)
|
||||
|
||||
|
||||
def convert_to_tensor(value,
|
||||
dtype=None,
|
||||
name=None,
|
||||
preferred_dtype=None):
|
||||
def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
|
||||
"""Converts the given `value` to a `Tensor`.
|
||||
|
||||
This function converts Python objects of various types to `Tensor`
|
||||
@ -686,18 +690,18 @@ def internal_convert_to_tensor(value,
|
||||
|
||||
if not isinstance(ret, Tensor):
|
||||
raise RuntimeError(
|
||||
"%sConversion function %r for type %s returned non-Tensor: %r"
|
||||
% (error_prefix, conversion_func, base_type, ret))
|
||||
"%sConversion function %r for type %s returned non-Tensor: %r" %
|
||||
(error_prefix, conversion_func, base_type, ret))
|
||||
if dtype and not dtype.is_compatible_with(ret.dtype):
|
||||
raise RuntimeError(
|
||||
"%sConversion function %r for type %s returned incompatible "
|
||||
"dtype: requested = %s, actual = %s"
|
||||
% (error_prefix, conversion_func, base_type,
|
||||
dtype.name, ret.dtype.name))
|
||||
"dtype: requested = %s, actual = %s" %
|
||||
(error_prefix, conversion_func, base_type, dtype.name,
|
||||
ret.dtype.name))
|
||||
return ret
|
||||
raise TypeError("%sCannot convert %r with type %s to Tensor: "
|
||||
"no conversion function registered."
|
||||
% (error_prefix, value, type(value)))
|
||||
"no conversion function registered." % (error_prefix, value,
|
||||
type(value)))
|
||||
|
||||
|
||||
def internal_convert_n_to_tensor(values,
|
||||
@ -744,10 +748,7 @@ def internal_convert_n_to_tensor(values,
|
||||
return ret
|
||||
|
||||
|
||||
def convert_n_to_tensor(values,
|
||||
dtype=None,
|
||||
name=None,
|
||||
preferred_dtype=None):
|
||||
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
|
||||
"""Converts `values` to a list of `Tensor` objects.
|
||||
|
||||
Args:
|
||||
@ -771,11 +772,12 @@ def convert_n_to_tensor(values,
|
||||
RuntimeError: If a registered conversion function returns an invalid
|
||||
value.
|
||||
"""
|
||||
return internal_convert_n_to_tensor(values=values,
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
preferred_dtype=preferred_dtype,
|
||||
as_ref=False)
|
||||
return internal_convert_n_to_tensor(
|
||||
values=values,
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
preferred_dtype=preferred_dtype,
|
||||
as_ref=False)
|
||||
|
||||
|
||||
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
|
||||
@ -802,7 +804,9 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
|
||||
value=value, dtype=dtype, name=name, as_ref=False)
|
||||
|
||||
|
||||
def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
|
||||
def internal_convert_to_tensor_or_indexed_slices(value,
|
||||
dtype=None,
|
||||
name=None,
|
||||
as_ref=False):
|
||||
"""Converts the given object to an `Tensor` or an `IndexedSlices`.
|
||||
|
||||
@ -827,18 +831,18 @@ def internal_convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
|
||||
if isinstance(value, _TensorLike):
|
||||
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
|
||||
% (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||
(dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
|
||||
return value
|
||||
else:
|
||||
return internal_convert_to_tensor(value,
|
||||
dtype=dtype,
|
||||
name=name,
|
||||
as_ref=as_ref)
|
||||
return internal_convert_to_tensor(
|
||||
value, dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
|
||||
def internal_convert_n_to_tensor_or_indexed_slices(values, dtype=None,
|
||||
name=None, as_ref=False):
|
||||
def internal_convert_n_to_tensor_or_indexed_slices(values,
|
||||
dtype=None,
|
||||
name=None,
|
||||
as_ref=False):
|
||||
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
|
||||
|
||||
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
|
||||
@ -905,7 +909,8 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
|
||||
values=values, dtype=dtype, name=name, as_ref=False)
|
||||
|
||||
|
||||
def register_tensor_conversion_function(base_type, conversion_func,
|
||||
def register_tensor_conversion_function(base_type,
|
||||
conversion_func,
|
||||
priority=100):
|
||||
"""Registers a function for converting objects of `base_type` to `Tensor`.
|
||||
|
||||
@ -947,8 +952,8 @@ def register_tensor_conversion_function(base_type, conversion_func,
|
||||
|
||||
"""
|
||||
if not (isinstance(base_type, type) or
|
||||
(isinstance(base_type, tuple)
|
||||
and all(isinstance(x, type) for x in base_type))):
|
||||
(isinstance(base_type, tuple) and
|
||||
all(isinstance(x, type) for x in base_type))):
|
||||
raise TypeError("base_type must be a type or a tuple of types.")
|
||||
if not callable(conversion_func):
|
||||
raise TypeError("conversion_func must be callable.")
|
||||
@ -1038,8 +1043,7 @@ class IndexedSlices(_TensorLike):
|
||||
|
||||
def __str__(self):
|
||||
return "IndexedSlices(indices=%s, values=%s%s)" % (
|
||||
self._indices, self._values,
|
||||
(", dense_shape=%s" % self._dense_shape)
|
||||
self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
|
||||
if self._dense_shape is not None else "")
|
||||
|
||||
def __neg__(self):
|
||||
@ -1112,8 +1116,14 @@ class Operation(object):
|
||||
`op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
|
||||
"""
|
||||
|
||||
def __init__(self, node_def, g, inputs=None, output_types=None,
|
||||
control_inputs=None, input_types=None, original_op=None,
|
||||
def __init__(self,
|
||||
node_def,
|
||||
g,
|
||||
inputs=None,
|
||||
output_types=None,
|
||||
control_inputs=None,
|
||||
input_types=None,
|
||||
original_op=None,
|
||||
op_def=None):
|
||||
r"""Creates an `Operation`.
|
||||
|
||||
@ -1177,18 +1187,20 @@ class Operation(object):
|
||||
if output_types is None:
|
||||
output_types = []
|
||||
self._output_types_val = output_types
|
||||
self._outputs = [Tensor(self, i, output_type)
|
||||
for i, output_type in enumerate(output_types)]
|
||||
self._outputs = [
|
||||
Tensor(self, i, output_type)
|
||||
for i, output_type in enumerate(output_types)
|
||||
]
|
||||
if input_types is None:
|
||||
input_types = [i.dtype.base_dtype for i in self._inputs]
|
||||
else:
|
||||
if not all(x.is_compatible_with(i.dtype)
|
||||
for i, x in zip(self._inputs, input_types)):
|
||||
if not all(
|
||||
x.is_compatible_with(i.dtype)
|
||||
for i, x in zip(self._inputs, input_types)):
|
||||
raise TypeError("In op '%s', input types (%s) are not compatible "
|
||||
"with expected types (%s)" % (
|
||||
self.node_def.name,
|
||||
[i.dtype for i in self._inputs],
|
||||
input_types))
|
||||
"with expected types (%s)" %
|
||||
(self.node_def.name, [i.dtype for i in self._inputs],
|
||||
input_types))
|
||||
self._input_types_val = input_types
|
||||
|
||||
# Build the list of control inputs.
|
||||
@ -1251,7 +1263,8 @@ class Operation(object):
|
||||
A wrapped TF_Operation*.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
op_desc = c_api.TF_NewOperation(graph._c_graph, compat.as_str(node_def.op),
|
||||
op_desc = c_api.TF_NewOperation(graph._c_graph,
|
||||
compat.as_str(node_def.op),
|
||||
compat.as_str(node_def.name))
|
||||
# Add inputs
|
||||
for op_input in inputs:
|
||||
@ -1271,8 +1284,8 @@ class Operation(object):
|
||||
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
|
||||
# It might be worth creating a convenient way to re-use the same status.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized,
|
||||
status)
|
||||
c_api.TF_SetAttrValueProto(op_desc,
|
||||
compat.as_str(name), serialized, status)
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
c_op = c_api.TF_FinishOperation(op_desc, status)
|
||||
@ -1316,16 +1329,18 @@ class Operation(object):
|
||||
|
||||
def colocation_groups(self):
|
||||
"""Returns the list of colocation groups of the op."""
|
||||
default_colocation_group = [compat.as_bytes("loc:@%s" %
|
||||
self._node_def.name)]
|
||||
default_colocation_group = [
|
||||
compat.as_bytes("loc:@%s" % self._node_def.name)
|
||||
]
|
||||
if "_class" not in self._node_def.attr:
|
||||
# This op has no explicit colocation group, so it is itself its
|
||||
# own root of a colocation group.
|
||||
return default_colocation_group
|
||||
|
||||
attr_groups = [class_name
|
||||
for class_name in self.get_attr("_class")
|
||||
if class_name.startswith(b"loc:@")]
|
||||
attr_groups = [
|
||||
class_name for class_name in self.get_attr("_class")
|
||||
if class_name.startswith(b"loc:@")
|
||||
]
|
||||
|
||||
# If there are no colocation groups in the explicit _class field,
|
||||
# return the default colocation group.
|
||||
@ -1397,8 +1412,10 @@ class Operation(object):
|
||||
"""
|
||||
if self._graph._c_graph: # pylint: disable=protected-access
|
||||
num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
|
||||
output_types = [c_api.TF_OperationOutputType(self._tf_output(i)) for
|
||||
i in xrange(num_outputs)]
|
||||
output_types = [
|
||||
c_api.TF_OperationOutputType(self._tf_output(i))
|
||||
for i in xrange(num_outputs)
|
||||
]
|
||||
# TODO(iga): Remove this assert after converting to C API by default.
|
||||
# Just being a bit paranoid here.
|
||||
assert self._output_types_val == output_types
|
||||
@ -1433,8 +1450,8 @@ class Operation(object):
|
||||
device: string or device.. The device to set.
|
||||
"""
|
||||
if _USE_C_API:
|
||||
c_api.SetRequestedDevice(
|
||||
self._graph._c_graph, self._c_op, _device_string(device)) # pylint: disable=protected-access
|
||||
c_api.SetRequestedDevice(self._graph._c_graph, self._c_op, # pylint: disable=protected-access
|
||||
_device_string(device))
|
||||
# TODO(nolivia): remove this line when switch to C api
|
||||
self._node_def.device = _device_string(device)
|
||||
|
||||
@ -1462,8 +1479,8 @@ class Operation(object):
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_compatible_with(tensor.dtype):
|
||||
raise TypeError(
|
||||
"Cannot convert a tensor of type %s to an input of type %s"
|
||||
% (tensor.dtype.name, dtype.name))
|
||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
||||
(tensor.dtype.name, dtype.name))
|
||||
self._inputs.append(tensor)
|
||||
self._input_types_val.append(dtype)
|
||||
tensor._add_consumer(self) # pylint: disable=protected-access
|
||||
@ -1496,8 +1513,8 @@ class Operation(object):
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_compatible_with(tensor.dtype):
|
||||
raise TypeError(
|
||||
"Cannot convert a tensor of type %s to an input of type %s"
|
||||
% (tensor.dtype.name, dtype.name))
|
||||
"Cannot convert a tensor of type %s to an input of type %s" %
|
||||
(tensor.dtype.name, dtype.name))
|
||||
|
||||
self._inputs[index].consumers().remove(self)
|
||||
self._inputs[index] = tensor
|
||||
@ -1547,8 +1564,8 @@ class Operation(object):
|
||||
self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
|
||||
# pylint: enable=protected-access
|
||||
if self._control_inputs:
|
||||
self._node_def.input.extend(["^%s" % op.name for op in
|
||||
self._control_inputs])
|
||||
self._node_def.input.extend(
|
||||
["^%s" % op.name for op in self._control_inputs])
|
||||
|
||||
def __str__(self):
|
||||
return str(self._node_def)
|
||||
@ -1562,6 +1579,7 @@ class Operation(object):
|
||||
return self._outputs
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
class _InputList(object):
|
||||
"""Immutable input list wrapper."""
|
||||
|
||||
@ -1582,6 +1600,7 @@ class Operation(object):
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self._op._inputs[i]
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
@ -1597,9 +1616,10 @@ class Operation(object):
|
||||
def _input_types(self):
|
||||
if self._graph._c_graph: # pylint: disable=protected-access
|
||||
num_inputs = c_api.TF_OperationNumInputs(self._c_op)
|
||||
input_types = [dtypes.as_dtype(
|
||||
c_api.TF_OperationInputType(self._tf_input(i)))
|
||||
for i in xrange(num_inputs)]
|
||||
input_types = [
|
||||
dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
|
||||
for i in xrange(num_inputs)
|
||||
]
|
||||
# TODO(iga): Remove this assert after converting to C API by default.
|
||||
# Just being a bit paranoid here.
|
||||
assert self._input_types_val == input_types
|
||||
@ -1624,8 +1644,10 @@ class Operation(object):
|
||||
if self._graph._c_graph: # pylint: disable=protected-access
|
||||
control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
|
||||
# pylint: disable=protected-access
|
||||
return [self.graph._get_operation_by_name_unsafe(
|
||||
c_api.TF_OperationName(c_op)) for c_op in control_c_ops]
|
||||
return [
|
||||
self.graph._get_operation_by_name_unsafe(
|
||||
c_api.TF_OperationName(c_op)) for c_op in control_c_ops
|
||||
]
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
return self._control_inputs
|
||||
@ -1691,7 +1713,8 @@ class Operation(object):
|
||||
A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
|
||||
"""
|
||||
return self._graph._convert_stack( # pylint: disable=protected-access
|
||||
self._traceback, include_func_start_lineno=True)
|
||||
self._traceback,
|
||||
include_func_start_lineno=True)
|
||||
|
||||
def get_attr(self, name):
|
||||
"""Returns the value of the attr of this op with the given `name`.
|
||||
@ -1707,8 +1730,7 @@ class Operation(object):
|
||||
"""
|
||||
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
|
||||
if name not in self._node_def.attr:
|
||||
raise ValueError("No attr named '" + name + "' in " +
|
||||
str(self._node_def))
|
||||
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
|
||||
x = self._node_def.attr[name]
|
||||
# Treat an empty oneof value as an empty list.
|
||||
if not x.WhichOneof("value"):
|
||||
@ -1749,7 +1771,6 @@ class Operation(object):
|
||||
"""
|
||||
_run_using_default_session(self, feed_dict, self.graph, session)
|
||||
|
||||
|
||||
_gradient_registry = registry.Registry("gradient")
|
||||
|
||||
|
||||
@ -1834,7 +1855,8 @@ NoGradient = NotDifferentiable
|
||||
|
||||
def get_gradient_function(op):
|
||||
"""Returns the function that computes gradients for "op"."""
|
||||
if not op.inputs: return None
|
||||
if not op.inputs:
|
||||
return None
|
||||
try:
|
||||
op_type = op.get_attr("_gradient_op_type")
|
||||
except ValueError:
|
||||
@ -1982,6 +2004,7 @@ class OpStats(object):
|
||||
self._value += other.value
|
||||
return self
|
||||
|
||||
|
||||
_stats_registry = registry.Registry("statistical functions")
|
||||
|
||||
|
||||
@ -2433,8 +2456,8 @@ class Graph(object):
|
||||
graph.node.extend([op.node_def])
|
||||
if op.outputs and add_shapes:
|
||||
assert "_output_shapes" not in graph.node[-1].attr
|
||||
graph.node[-1].attr["_output_shapes"].list.shape.extend([
|
||||
output.get_shape().as_proto() for output in op.outputs])
|
||||
graph.node[-1].attr["_output_shapes"].list.shape.extend(
|
||||
[output.get_shape().as_proto() for output in op.outputs])
|
||||
bytesize += op.node_def.ByteSize()
|
||||
if bytesize >= (1 << 31) or bytesize < 0:
|
||||
raise ValueError("GraphDef cannot be larger than 2GB.")
|
||||
@ -2469,7 +2492,8 @@ class Graph(object):
|
||||
node with the inferred shapes of each of its outputs.
|
||||
|
||||
Returns:
|
||||
A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
||||
A
|
||||
[`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
|
||||
protocol buffer.
|
||||
|
||||
Raises:
|
||||
@ -2518,8 +2542,8 @@ class Graph(object):
|
||||
if previous:
|
||||
raise ValueError("Another function is already defined with that name")
|
||||
# Sanity checks on gradient definition.
|
||||
if (function.grad_func_name is not None) and (
|
||||
function.python_grad_func is not None):
|
||||
if (function.grad_func_name is not None) and (function.python_grad_func is
|
||||
not None):
|
||||
raise ValueError("Gradient defined twice for function %s" % name)
|
||||
# Need a new-enough consumer to support the functions we add to the graph.
|
||||
if self._graph_def_versions.min_consumer < 12:
|
||||
@ -2532,9 +2556,17 @@ class Graph(object):
|
||||
return self._building_function
|
||||
|
||||
# Helper functions to create operations.
|
||||
def create_op(self, op_type, inputs, dtypes, # pylint: disable=redefined-outer-name
|
||||
input_types=None, name=None, attrs=None, op_def=None,
|
||||
compute_shapes=True, compute_device=True):
|
||||
def create_op(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_shapes=True,
|
||||
compute_device=True):
|
||||
"""Creates an `Operation` in this graph.
|
||||
|
||||
This is a low-level interface for creating an `Operation`. Most
|
||||
@ -2597,8 +2629,8 @@ class Graph(object):
|
||||
if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
|
||||
raise TypeError(
|
||||
"Callable for scope map key '%s' must return either None or "
|
||||
"an AttrValue protocol buffer; but it returned: %s" %
|
||||
(key, value))
|
||||
"an AttrValue protocol buffer; but it returned: %s" % (key,
|
||||
value))
|
||||
node_def.attr[key].CopyFrom(value)
|
||||
|
||||
# Apply a kernel label if one has been specified for this op_type.
|
||||
@ -2619,9 +2651,15 @@ class Graph(object):
|
||||
pass
|
||||
|
||||
control_inputs = self._control_dependencies_for_inputs(inputs)
|
||||
ret = Operation(node_def, self, inputs=inputs, output_types=dtypes,
|
||||
control_inputs=control_inputs, input_types=input_types,
|
||||
original_op=self._default_original_op, op_def=op_def)
|
||||
ret = Operation(
|
||||
node_def,
|
||||
self,
|
||||
inputs=inputs,
|
||||
output_types=dtypes,
|
||||
control_inputs=control_inputs,
|
||||
input_types=input_types,
|
||||
original_op=self._default_original_op,
|
||||
op_def=op_def)
|
||||
if compute_shapes:
|
||||
set_shapes_for_outputs(ret)
|
||||
self._add_op(ret)
|
||||
@ -2638,28 +2676,27 @@ class Graph(object):
|
||||
# Make this device match the device of the colocated op, to
|
||||
# provide consistency between the device and the colocation
|
||||
# property.
|
||||
if (ret.device and
|
||||
pydev.canonical_name(ret.device) !=
|
||||
if (ret.device and pydev.canonical_name(ret.device) !=
|
||||
pydev.canonical_name(colocation_op.device)):
|
||||
logging.warning("Tried to colocate %s with an op %s that had "
|
||||
"a different device: %s vs %s. "
|
||||
"Ignoring colocation property.",
|
||||
name, colocation_op.name,
|
||||
ret.device, colocation_op.device)
|
||||
"Ignoring colocation property.", name,
|
||||
colocation_op.name, ret.device,
|
||||
colocation_op.device)
|
||||
else:
|
||||
ret._set_device(colocation_op.device) # pylint: disable=protected-access
|
||||
|
||||
all_colocation_groups = sorted(set(all_colocation_groups))
|
||||
ret.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue(
|
||||
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
|
||||
ret.node_def.attr["_class"].CopyFrom(
|
||||
attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
|
||||
s=all_colocation_groups)))
|
||||
|
||||
# Sets "container" attribute if
|
||||
# (1) self._container is not None
|
||||
# (2) "is_stateful" is set in OpDef
|
||||
# (3) "container" attribute is in OpDef
|
||||
# (4) "container" attribute is None
|
||||
if (self._container and
|
||||
op_type in self._registered_ops and
|
||||
if (self._container and op_type in self._registered_ops and
|
||||
self._registered_ops[op_type].is_stateful and
|
||||
"container" in ret.node_def.attr and
|
||||
not ret.node_def.attr["container"].s):
|
||||
@ -2749,13 +2786,13 @@ class Graph(object):
|
||||
except:
|
||||
raise KeyError("The name %s refers to a Tensor which does not "
|
||||
"exist. The operation, %s, exists but only has "
|
||||
"%s outputs."
|
||||
% (repr(name), repr(op_name), len(op.outputs)))
|
||||
"%s outputs." % (repr(name), repr(op_name),
|
||||
len(op.outputs)))
|
||||
|
||||
elif ":" in name and not allow_tensor:
|
||||
# Looks like a Tensor name but can't be a Tensor.
|
||||
raise ValueError("Name %s appears to refer to a Tensor, not a %s."
|
||||
% (repr(name), types_str))
|
||||
raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
|
||||
(repr(name), types_str))
|
||||
|
||||
elif ":" not in name and allow_operation:
|
||||
# Looks like an Operation name and can be an Operation.
|
||||
@ -2768,8 +2805,8 @@ class Graph(object):
|
||||
# Looks like an Operation name but can't be an Operation.
|
||||
if name in self._nodes_by_name:
|
||||
# Yep, it's an Operation name
|
||||
err_msg = ("The name %s refers to an Operation, not a %s."
|
||||
% (repr(name), types_str))
|
||||
err_msg = ("The name %s refers to an Operation, not a %s." %
|
||||
(repr(name), types_str))
|
||||
else:
|
||||
err_msg = ("The name %s looks like an (invalid) Operation name, "
|
||||
"not a %s." % (repr(name), types_str))
|
||||
@ -2789,8 +2826,8 @@ class Graph(object):
|
||||
return obj
|
||||
else:
|
||||
# We give up!
|
||||
raise TypeError("Can not convert a %s into a %s."
|
||||
% (type(obj).__name__, types_str))
|
||||
raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
|
||||
types_str))
|
||||
|
||||
def get_operations(self):
|
||||
"""Return the list of operations in the graph.
|
||||
@ -2827,8 +2864,8 @@ class Graph(object):
|
||||
"""
|
||||
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError("Operation names are strings (or similar), not %s."
|
||||
% type(name).__name__)
|
||||
raise TypeError("Operation names are strings (or similar), not %s." %
|
||||
type(name).__name__)
|
||||
return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
|
||||
|
||||
def _get_operation_by_name_unsafe(self, name):
|
||||
@ -2871,8 +2908,8 @@ class Graph(object):
|
||||
"""
|
||||
# Names should be strings.
|
||||
if not isinstance(name, six.string_types):
|
||||
raise TypeError("Tensor names are strings (or similar), not %s."
|
||||
% type(name).__name__)
|
||||
raise TypeError("Tensor names are strings (or similar), not %s." %
|
||||
type(name).__name__)
|
||||
return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
|
||||
|
||||
def _next_id(self):
|
||||
@ -3178,6 +3215,7 @@ class Graph(object):
|
||||
yield "" if new_stack is None else new_stack + "/"
|
||||
finally:
|
||||
self._name_stack = old_stack
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield,line-too-long
|
||||
|
||||
def unique_name(self, name, mark_as_used=True):
|
||||
@ -3280,9 +3318,8 @@ class Graph(object):
|
||||
|
||||
"""
|
||||
if op is None and not ignore_existing:
|
||||
raise ValueError(
|
||||
"Trying to reset colocation (op is None) but "
|
||||
"ignore_existing is not True")
|
||||
raise ValueError("Trying to reset colocation (op is None) but "
|
||||
"ignore_existing is not True")
|
||||
|
||||
if op is not None and not isinstance(op, Operation):
|
||||
# We always want to colocate with the reference op.
|
||||
@ -3377,8 +3414,8 @@ class Graph(object):
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if (device_name_or_function is not None
|
||||
and not callable(device_name_or_function)):
|
||||
if (device_name_or_function is not None and
|
||||
not callable(device_name_or_function)):
|
||||
device_function = pydev.merge_device(device_name_or_function)
|
||||
else:
|
||||
device_function = device_name_or_function
|
||||
@ -3452,6 +3489,7 @@ class Graph(object):
|
||||
yield self._container
|
||||
finally:
|
||||
self._container = original_container
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
class _ControlDependenciesController(object):
|
||||
@ -3491,6 +3529,7 @@ class Graph(object):
|
||||
self._old_control_flow_context = None
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
def __enter__(self):
|
||||
if self._new_stack:
|
||||
# Clear the control_dependencies graph.
|
||||
@ -3506,6 +3545,7 @@ class Graph(object):
|
||||
if self._new_stack:
|
||||
self._graph._control_dependencies_stack = self._old_stack
|
||||
self._graph._set_control_flow_context(self._old_control_flow_context)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@property
|
||||
@ -3733,6 +3773,7 @@ class Graph(object):
|
||||
self._attr_scope_map[name] = saved_attrs[name]
|
||||
except KeyError:
|
||||
del self._attr_scope_map[name]
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
@ -3777,8 +3818,8 @@ class Graph(object):
|
||||
saved_labels = {}
|
||||
# Install the given label
|
||||
for op_type, label in op_to_kernel_label_map.items():
|
||||
if not (isinstance(op_type, six.string_types)
|
||||
and isinstance(label, six.string_types)):
|
||||
if not (isinstance(op_type, six.string_types) and
|
||||
isinstance(label, six.string_types)):
|
||||
raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
|
||||
"strings to strings")
|
||||
try:
|
||||
@ -3795,6 +3836,7 @@ class Graph(object):
|
||||
self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
|
||||
except KeyError:
|
||||
del self._op_to_kernel_label_map[op_type]
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
@ -3840,8 +3882,8 @@ class Graph(object):
|
||||
saved_mappings = {}
|
||||
# Install the given label
|
||||
for op_type, mapped_op_type in op_type_map.items():
|
||||
if not (isinstance(op_type, six.string_types)
|
||||
and isinstance(mapped_op_type, six.string_types)):
|
||||
if not (isinstance(op_type, six.string_types) and
|
||||
isinstance(mapped_op_type, six.string_types)):
|
||||
raise TypeError("op_type_map must be a dictionary mapping "
|
||||
"strings to strings")
|
||||
try:
|
||||
@ -3858,6 +3900,7 @@ class Graph(object):
|
||||
self._gradient_override_map[op_type] = saved_mappings[op_type]
|
||||
except KeyError:
|
||||
del self._gradient_override_map[op_type]
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
def prevent_feeding(self, tensor):
|
||||
@ -3969,12 +4012,13 @@ class _DefaultStack(threading.local):
|
||||
if self._enforce_nesting:
|
||||
if self.stack[-1] is not default:
|
||||
raise AssertionError(
|
||||
"Nesting violated for default stack of %s objects"
|
||||
% type(default))
|
||||
"Nesting violated for default stack of %s objects" %
|
||||
type(default))
|
||||
self.stack.pop()
|
||||
else:
|
||||
self.stack.remove(default)
|
||||
|
||||
|
||||
_default_session_stack = _DefaultStack()
|
||||
|
||||
|
||||
@ -4145,6 +4189,7 @@ class _DefaultGraphStack(_DefaultStack):
|
||||
super(_DefaultGraphStack, self).reset()
|
||||
self._global_default_graph = None
|
||||
|
||||
|
||||
_default_graph_stack = _DefaultGraphStack()
|
||||
|
||||
|
||||
@ -4195,8 +4240,8 @@ def _assert_same_graph(original_item, item):
|
||||
ValueError: if graphs do not match.
|
||||
"""
|
||||
if original_item.graph is not item.graph:
|
||||
raise ValueError(
|
||||
"%s must be from the same graph as %s." % (item, original_item))
|
||||
raise ValueError("%s must be from the same graph as %s." % (item,
|
||||
original_item))
|
||||
|
||||
|
||||
def _get_graph_from_inputs(op_input_list, graph=None):
|
||||
@ -4246,8 +4291,11 @@ def _get_graph_from_inputs(op_input_list, graph=None):
|
||||
original_graph_element = None
|
||||
for op_input in op_input_list:
|
||||
# Determine if this is a valid graph_element.
|
||||
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
|
||||
# up.
|
||||
graph_element = None
|
||||
if isinstance(op_input, (Operation, _TensorLike)):
|
||||
if (isinstance(op_input, (Operation, _TensorLike)) and
|
||||
((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck
|
||||
graph_element = op_input
|
||||
else:
|
||||
graph_element = _as_graph_element(op_input)
|
||||
@ -4259,8 +4307,7 @@ def _get_graph_from_inputs(op_input_list, graph=None):
|
||||
elif original_graph_element is not None:
|
||||
_assert_same_graph(original_graph_element, graph_element)
|
||||
elif graph_element.graph is not graph:
|
||||
raise ValueError(
|
||||
"%s is not from the passed-in graph." % graph_element)
|
||||
raise ValueError("%s is not from the passed-in graph." % graph_element)
|
||||
|
||||
# 2. If all else fails, we use the default graph, which is always there.
|
||||
return graph or get_default_graph()
|
||||
@ -4512,13 +4559,15 @@ def name_scope(name, default_name=None, values=None):
|
||||
# tf.name_scope(None) (values=None then) is sometimes used as an idiom
|
||||
# to reset to top scope.
|
||||
raise ValueError(
|
||||
"At least one of name (%s) and default_name (%s) must be provided." % (
|
||||
name, default_name))
|
||||
"At least one of name (%s) and default_name (%s) must be provided." %
|
||||
(name, default_name))
|
||||
if values is None:
|
||||
values = []
|
||||
g = _get_graph_from_inputs(values)
|
||||
with g.as_default(), g.name_scope(n) as scope:
|
||||
yield scope
|
||||
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
|
||||
@ -4585,7 +4634,9 @@ def op_scope(values, name, default_name=None):
|
||||
_proto_function_registry = registry.Registry("proto functions")
|
||||
|
||||
|
||||
def register_proto_function(collection_name, proto_type=None, to_proto=None,
|
||||
def register_proto_function(collection_name,
|
||||
proto_type=None,
|
||||
to_proto=None,
|
||||
from_proto=None):
|
||||
"""Registers `to_proto` and `from_proto` functions for collection_name.
|
||||
|
||||
@ -4637,10 +4688,9 @@ def get_from_proto_function(collection_name):
|
||||
|
||||
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
|
||||
"""Produce a nice error if someone converts an Operation to a Tensor."""
|
||||
raise TypeError(
|
||||
("Can't convert Operation '%s' to Tensor "
|
||||
"(target dtype=%r, name=%r, as_ref=%r)") %
|
||||
(op.name, dtype, name, as_ref))
|
||||
raise TypeError(("Can't convert Operation '%s' to Tensor "
|
||||
"(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
|
||||
name, as_ref))
|
||||
|
||||
|
||||
register_tensor_conversion_function(Operation, _operation_conversion_error)
|
||||
|
Loading…
Reference in New Issue
Block a user