TensorFlow: upstream changes to git.
Change 109418220 Update WORKSPACE to use gmock.BUILD from google/protobuf instead of a duplicate. Update google/protobuf's commit hash to include damieng@'s commit. Change 109417314 TensorFlow: add .gitignore to ignore some in-tree modified files. Change 109400051 Optionally build full TensorFlow for Android. 1. --define ANDROID_TYPES=__ANDROID_TYPES_FULL__ to register ops for all types, not just float. Today this increases codesize by ~700K when compiled for ARM, though only for clients who request full type support. 2. Add more ops to android_extended_ops, sufficient to train on the linear regression baseball codelab. Change 109388118 Fix the option changed in templatize. Oops. Change 109382553 Allows setting a function name in an op's attr in the py frontend. Change 109380896 Remove assert_same_graph in favor of op_scope. Change the latter to handle tensor-like objects such as SparseTensor, IndexedSlices, and Variable. Base CL: 109418322
This commit is contained in:
parent
54a644f33f
commit
3dfd14421d
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
node_modules
|
||||||
|
/bazel-bin
|
||||||
|
/bazel-genfiles
|
||||||
|
/bazel-out
|
||||||
|
/bazel-tensorflow
|
||||||
|
/bazel-testlogs
|
||||||
|
/bazel-tf
|
||||||
|
/third_party/py/numpy/numpy_include
|
||||||
|
/tools/bazel.rc
|
||||||
|
/util/python/python_include
|
||||||
|
/util/python/python_lib
|
@ -16,7 +16,7 @@ new_http_archive(
|
|||||||
name = "gmock_archive",
|
name = "gmock_archive",
|
||||||
url = "https://googlemock.googlecode.com/files/gmock-1.7.0.zip",
|
url = "https://googlemock.googlecode.com/files/gmock-1.7.0.zip",
|
||||||
sha256 = "26fcbb5925b74ad5fc8c26b0495dfc96353f4d553492eb97e85a8a6d2f43095b",
|
sha256 = "26fcbb5925b74ad5fc8c26b0495dfc96353f4d553492eb97e85a8a6d2f43095b",
|
||||||
build_file = "gmock.BUILD",
|
build_file = "google/protobuf/gmock.BUILD",
|
||||||
)
|
)
|
||||||
|
|
||||||
bind(
|
bind(
|
||||||
|
@ -632,6 +632,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"//tensorflow/core:kernels/avgpooling_op.cc",
|
"//tensorflow/core:kernels/avgpooling_op.cc",
|
||||||
"//tensorflow/core:kernels/avgpooling_op.h",
|
"//tensorflow/core:kernels/avgpooling_op.h",
|
||||||
|
"//tensorflow/core:kernels/bcast_ops.cc",
|
||||||
"//tensorflow/core:kernels/control_flow_ops.cc",
|
"//tensorflow/core:kernels/control_flow_ops.cc",
|
||||||
"//tensorflow/core:kernels/control_flow_ops.h",
|
"//tensorflow/core:kernels/control_flow_ops.h",
|
||||||
"//tensorflow/core:kernels/conv_2d.h",
|
"//tensorflow/core:kernels/conv_2d.h",
|
||||||
@ -642,19 +643,23 @@ filegroup(
|
|||||||
"//tensorflow/core:kernels/cwise_op_less.cc",
|
"//tensorflow/core:kernels/cwise_op_less.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_log.cc",
|
"//tensorflow/core:kernels/cwise_op_log.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_mul.cc",
|
"//tensorflow/core:kernels/cwise_op_mul.cc",
|
||||||
|
"//tensorflow/core:kernels/cwise_op_neg.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_sigmoid.cc",
|
"//tensorflow/core:kernels/cwise_op_sigmoid.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_sqrt.cc",
|
"//tensorflow/core:kernels/cwise_op_sqrt.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_square.cc",
|
"//tensorflow/core:kernels/cwise_op_square.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_sub.cc",
|
"//tensorflow/core:kernels/cwise_op_sub.cc",
|
||||||
"//tensorflow/core:kernels/cwise_op_tanh.cc",
|
"//tensorflow/core:kernels/cwise_op_tanh.cc",
|
||||||
"//tensorflow/core:kernels/dynamic_partition_op.cc",
|
"//tensorflow/core:kernels/dynamic_partition_op.cc",
|
||||||
|
"//tensorflow/core:kernels/dynamic_stitch_op.cc",
|
||||||
"//tensorflow/core:kernels/lrn_op.cc",
|
"//tensorflow/core:kernels/lrn_op.cc",
|
||||||
"//tensorflow/core:kernels/maxpooling_op.cc",
|
"//tensorflow/core:kernels/maxpooling_op.cc",
|
||||||
"//tensorflow/core:kernels/maxpooling_op.h",
|
"//tensorflow/core:kernels/maxpooling_op.h",
|
||||||
"//tensorflow/core:kernels/reduction_ops.h",
|
"//tensorflow/core:kernels/reduction_ops.h",
|
||||||
"//tensorflow/core:kernels/reduction_ops_common.h",
|
"//tensorflow/core:kernels/reduction_ops_common.h",
|
||||||
"//tensorflow/core:kernels/reduction_ops_max.cc",
|
"//tensorflow/core:kernels/reduction_ops_max.cc",
|
||||||
|
"//tensorflow/core:kernels/reduction_ops_mean.cc",
|
||||||
"//tensorflow/core:kernels/reduction_ops_min.cc",
|
"//tensorflow/core:kernels/reduction_ops_min.cc",
|
||||||
|
"//tensorflow/core:kernels/reduction_ops_prod.cc",
|
||||||
"//tensorflow/core:kernels/reduction_ops_sum.cc",
|
"//tensorflow/core:kernels/reduction_ops_sum.cc",
|
||||||
"//tensorflow/core:kernels/relu_op.cc",
|
"//tensorflow/core:kernels/relu_op.cc",
|
||||||
"//tensorflow/core:kernels/relu_op.h",
|
"//tensorflow/core:kernels/relu_op.h",
|
||||||
@ -663,6 +668,8 @@ filegroup(
|
|||||||
"//tensorflow/core:kernels/softsign_op.cc",
|
"//tensorflow/core:kernels/softsign_op.cc",
|
||||||
"//tensorflow/core:kernels/softsign_op.h",
|
"//tensorflow/core:kernels/softsign_op.h",
|
||||||
"//tensorflow/core:kernels/stack_ops.cc",
|
"//tensorflow/core:kernels/stack_ops.cc",
|
||||||
|
"//tensorflow/core:kernels/tile_ops.cc",
|
||||||
|
"//tensorflow/core:kernels/tile_ops.h",
|
||||||
"//tensorflow/core:kernels/transpose_op.cc",
|
"//tensorflow/core:kernels/transpose_op.cc",
|
||||||
"//tensorflow/core:kernels/transpose_op.h",
|
"//tensorflow/core:kernels/transpose_op.h",
|
||||||
"//tensorflow/core:kernels/transpose_op_functor.h",
|
"//tensorflow/core:kernels/transpose_op_functor.h",
|
||||||
|
@ -367,11 +367,14 @@ struct SelectFunctor<CPUDevice, T> {
|
|||||||
OP<D##Device, F<T>>);
|
OP<D##Device, F<T>>);
|
||||||
|
|
||||||
// Macros to register kernels for multiple types (T0, T1, etc.) on
|
// Macros to register kernels for multiple types (T0, T1, etc.) on
|
||||||
// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using
|
// device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
|
||||||
// the functor "F" (e.g., functor:sqrt).
|
// the functor "F" (e.g., functor:sqrt).
|
||||||
|
|
||||||
#if defined(__ANDROID__)
|
#if defined(__ANDROID_TYPES_SLIM__)
|
||||||
// On Android, only register the first type (float)
|
// Normally Android TensorFlow is built with a reduced number of types (float).
|
||||||
|
// Override on the command-line "--define ANDROID_TYPES=__ANDROID_TYPES_FULL__"
|
||||||
|
// to generate a library with full type support with a consequent increase in
|
||||||
|
// code size.
|
||||||
#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
|
#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
|
||||||
#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
|
#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
|
||||||
#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
|
#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
|
||||||
@ -381,7 +384,7 @@ struct SelectFunctor<CPUDevice, T> {
|
|||||||
REGISTER(OP, D, N, F, T0)
|
REGISTER(OP, D, N, F, T0)
|
||||||
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
||||||
REGISTER(OP, D, N, F, T0)
|
REGISTER(OP, D, N, F, T0)
|
||||||
#else // !defined(__ANDROID__)
|
#else // !defined(__ANDROID_TYPES_SLIM__)
|
||||||
#define REGISTER2(OP, D, N, F, T0, T1) \
|
#define REGISTER2(OP, D, N, F, T0, T1) \
|
||||||
REGISTER(OP, D, N, F, T0) \
|
REGISTER(OP, D, N, F, T0) \
|
||||||
REGISTER(OP, D, N, F, T1)
|
REGISTER(OP, D, N, F, T1)
|
||||||
@ -403,7 +406,7 @@ struct SelectFunctor<CPUDevice, T> {
|
|||||||
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
||||||
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
|
REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
|
||||||
REGISTER4(OP, D, N, F, T4, T5, T6, T7)
|
REGISTER4(OP, D, N, F, T4, T5, T6, T7)
|
||||||
#endif // defined(__ANDROID__)
|
#endif // defined(__ANDROID_TYPES_SLIM__)
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import registry
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
from tensorflow.python.platform import logging
|
||||||
|
|
||||||
|
|
||||||
def _convert_stack(stack):
|
def _convert_stack(stack):
|
||||||
@ -95,6 +96,22 @@ def _extract_stack():
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _as_graph_element(obj):
|
||||||
|
"""Convert `obj` to a graph element if possible, otherwise return `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of `obj._as_graph_element()` if that method is available;
|
||||||
|
otherwise `None`.
|
||||||
|
"""
|
||||||
|
conv_fn = getattr(obj, "_as_graph_element", None)
|
||||||
|
if conv_fn and callable(conv_fn):
|
||||||
|
return conv_fn()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Tensor(object):
|
class Tensor(object):
|
||||||
"""Represents a value produced by an `Operation`.
|
"""Represents a value produced by an `Operation`.
|
||||||
|
|
||||||
@ -680,6 +697,7 @@ class IndexedSlices(object):
|
|||||||
|
|
||||||
def __init__(self, values, indices, dense_shape=None):
|
def __init__(self, values, indices, dense_shape=None):
|
||||||
"""Creates an `IndexedSlices`."""
|
"""Creates an `IndexedSlices`."""
|
||||||
|
_get_graph_from_inputs([values, indices, dense_shape])
|
||||||
self._values = values
|
self._values = values
|
||||||
self._indices = indices
|
self._indices = indices
|
||||||
self._dense_shape = dense_shape
|
self._dense_shape = dense_shape
|
||||||
@ -719,30 +737,15 @@ class IndexedSlices(object):
|
|||||||
"""The `DType` of elements in this tensor."""
|
"""The `DType` of elements in this tensor."""
|
||||||
return self.values.dtype
|
return self.values.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph(self):
|
||||||
|
"""The `Graph` that contains the values, indices, and shape tensors."""
|
||||||
|
return self._values.graph
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "IndexedSlices(indices=%s, values=%s)" % (
|
return "IndexedSlices(indices=%s, values=%s%s)" % (
|
||||||
self._indices, self._values)
|
self._indices, self._values,
|
||||||
|
(", dense_shape=%s" % self._dense_shape) if self._dense_shape else "")
|
||||||
|
|
||||||
def assert_same_graph(items, expected_graph=None):
|
|
||||||
"""Asserts all items are from the same graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: List of graph items (e.g., Variable, Tensor, SparseTensor,
|
|
||||||
Operation, or IndexedSlices).
|
|
||||||
expected_graph: Expected graph. If not specified, assert all tensors are
|
|
||||||
from the same graph.
|
|
||||||
Returns:
|
|
||||||
items, for chaining.
|
|
||||||
Raises:
|
|
||||||
ValueError: If any graphs do not match.
|
|
||||||
"""
|
|
||||||
for item in items:
|
|
||||||
if not expected_graph:
|
|
||||||
expected_graph = item.graph
|
|
||||||
elif expected_graph != item.graph:
|
|
||||||
raise ValueError("Items must be from the same graph.")
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
class SparseTensor(object):
|
class SparseTensor(object):
|
||||||
@ -1106,7 +1109,7 @@ class Operation(object):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(tensor, Tensor):
|
if not isinstance(tensor, Tensor):
|
||||||
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
||||||
assert_same_graph([self, tensor])
|
_assert_same_graph(self, tensor)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = tensor.dtype
|
dtype = tensor.dtype
|
||||||
else:
|
else:
|
||||||
@ -1138,7 +1141,7 @@ class Operation(object):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(tensor, Tensor):
|
if not isinstance(tensor, Tensor):
|
||||||
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
raise TypeError("tensor must be a Tensor: %s" % tensor)
|
||||||
assert_same_graph([self, tensor])
|
_assert_same_graph(self, tensor)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = tensor.dtype
|
dtype = tensor.dtype
|
||||||
else:
|
else:
|
||||||
@ -1166,7 +1169,7 @@ class Operation(object):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(op, Operation):
|
if not isinstance(op, Operation):
|
||||||
raise TypeError("op must be an Operation: %s" % op)
|
raise TypeError("op must be an Operation: %s" % op)
|
||||||
assert_same_graph([self, op])
|
_assert_same_graph(self, op)
|
||||||
self._control_inputs.append(op)
|
self._control_inputs.append(op)
|
||||||
self._recompute_node_def()
|
self._recompute_node_def()
|
||||||
|
|
||||||
@ -1887,9 +1890,7 @@ class Graph(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("allow_tensor and allow_operation can't both be False.")
|
raise ValueError("allow_tensor and allow_operation can't both be False.")
|
||||||
|
|
||||||
conv_fn = getattr(obj, "_as_graph_element", None)
|
obj = _as_graph_element(obj) or obj
|
||||||
if conv_fn and callable(conv_fn):
|
|
||||||
obj = conv_fn()
|
|
||||||
|
|
||||||
# If obj appears to be a name...
|
# If obj appears to be a name...
|
||||||
if isinstance(obj, compat.bytes_or_text_types):
|
if isinstance(obj, compat.bytes_or_text_types):
|
||||||
@ -2971,6 +2972,21 @@ def get_default_graph():
|
|||||||
return _default_graph_stack.get_default()
|
return _default_graph_stack.get_default()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_same_graph(original_item, item):
|
||||||
|
"""Fail if the 2 items are from different graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_item: Original item to check against.
|
||||||
|
item: Item to check.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if graphs do not match.
|
||||||
|
"""
|
||||||
|
if original_item.graph is not item.graph:
|
||||||
|
raise ValueError(
|
||||||
|
"%s must be from the same graph as %s." % (item, original_item))
|
||||||
|
|
||||||
|
|
||||||
def _get_graph_from_inputs(op_input_list, graph=None):
|
def _get_graph_from_inputs(op_input_list, graph=None):
|
||||||
"""Returns the appropriate graph to use for the given inputs.
|
"""Returns the appropriate graph to use for the given inputs.
|
||||||
|
|
||||||
@ -2986,8 +3002,8 @@ def _get_graph_from_inputs(op_input_list, graph=None):
|
|||||||
"op_input_list", we attempt to use the default graph.
|
"op_input_list", we attempt to use the default graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_input_list: A list of inputs to an operation, which may include Tensor
|
op_input_list: A list of inputs to an operation, which may include `Tensor`,
|
||||||
and Operation objects.
|
`Operation`, and other objects that may be converted to a graph element.
|
||||||
graph: (Optional) The explicit graph to use.
|
graph: (Optional) The explicit graph to use.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -3001,37 +3017,35 @@ def _get_graph_from_inputs(op_input_list, graph=None):
|
|||||||
The appropriate graph to use for the given inputs.
|
The appropriate graph to use for the given inputs.
|
||||||
"""
|
"""
|
||||||
op_input_list = tuple(op_input_list) # Handle generators correctly
|
op_input_list = tuple(op_input_list) # Handle generators correctly
|
||||||
|
if graph and not isinstance(graph, Graph):
|
||||||
|
raise TypeError("Input graph needs to be a Graph: %s" % graph)
|
||||||
|
|
||||||
# 1. If the graph is specified explicitly, we validate that all of the inputs
|
# 1. We validate that all of the inputs are from the same graph. This is
|
||||||
# are compatible with that graph.
|
# either the supplied graph parameter, or the first one selected from one
|
||||||
if graph is not None:
|
# the graph-element-valued inputs. In the latter case, we hold onto
|
||||||
if not isinstance(graph, Graph):
|
# that input in original_graph_element so we can provide a more
|
||||||
raise TypeError("Input graph needs to be a Graph: %s" % graph)
|
# informative error if a mismatch is found.
|
||||||
for op_input in op_input_list:
|
original_graph_element = None
|
||||||
if isinstance(op_input, Operation):
|
|
||||||
if op_input.graph is not graph:
|
|
||||||
raise ValueError("Operation %s is not from the passed-in graph"
|
|
||||||
% op_input)
|
|
||||||
elif isinstance(op_input, Tensor):
|
|
||||||
if op_input.graph is not graph:
|
|
||||||
raise ValueError("Tensor %s is not from the passed-in graph"
|
|
||||||
% op_input)
|
|
||||||
return graph
|
|
||||||
|
|
||||||
# 2. Otherwise, we attempt to select a graph from one of the Operation-
|
|
||||||
# or Tensor-valued inputs.
|
|
||||||
original_input = None
|
|
||||||
for op_input in op_input_list:
|
for op_input in op_input_list:
|
||||||
if isinstance(op_input, (Operation, Tensor)):
|
# Determine if this is a valid graph_element.
|
||||||
if original_input is None:
|
graph_element = None
|
||||||
original_input = op_input
|
if isinstance(op_input, (Operation, Tensor, SparseTensor, IndexedSlices)):
|
||||||
else:
|
graph_element = op_input
|
||||||
assert_same_graph([original_input, op_input])
|
else:
|
||||||
if original_input is not None:
|
graph_element = _as_graph_element(op_input)
|
||||||
return original_input.graph
|
|
||||||
|
|
||||||
# 3. If all else fails, we use the default graph, which is always there.
|
if graph_element:
|
||||||
return get_default_graph()
|
if not graph:
|
||||||
|
original_graph_element = graph_element
|
||||||
|
graph = graph_element.graph
|
||||||
|
elif original_graph_element:
|
||||||
|
_assert_same_graph(original_graph_element, graph_element)
|
||||||
|
elif graph_element.graph is not graph:
|
||||||
|
raise ValueError(
|
||||||
|
"%s is not from the passed-in graph." % graph_element)
|
||||||
|
|
||||||
|
# 2. If all else fails, we use the default graph, which is always there.
|
||||||
|
return graph or get_default_graph()
|
||||||
|
|
||||||
|
|
||||||
class GraphKeys(object):
|
class GraphKeys(object):
|
||||||
@ -3115,7 +3129,7 @@ def get_collection(key, scope=None):
|
|||||||
|
|
||||||
# pylint: disable=g-doc-return-or-yield
|
# pylint: disable=g-doc-return-or-yield
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def op_scope(values, name, default_name):
|
def op_scope(values, name, default_name=None):
|
||||||
"""Returns a context manager for use when defining a Python op.
|
"""Returns a context manager for use when defining a Python op.
|
||||||
|
|
||||||
This context manager validates that the given `values` are from the
|
This context manager validates that the given `values` are from the
|
||||||
@ -3140,10 +3154,17 @@ def op_scope(values, name, default_name):
|
|||||||
default_name: The default name to use if the `name` argument is `None`.
|
default_name: The default name to use if the `name` argument is `None`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A context manager for use in defining a Python op.
|
A context manager for use in defining Python ops. Yields the name scope.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if neither `name` nor `default_name` is provided.
|
||||||
"""
|
"""
|
||||||
g = _get_graph_from_inputs(values)
|
g = _get_graph_from_inputs(values)
|
||||||
n = default_name if name is None else name
|
n = default_name if name is None else name
|
||||||
|
if n is None:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of name (%s) and default_name (%s) must be provided." % (
|
||||||
|
name, default_name))
|
||||||
with g.as_default(), g.name_scope(n) as scope:
|
with g.as_default(), g.name_scope(n) as scope:
|
||||||
yield scope
|
yield scope
|
||||||
# pylint: enable=g-doc-return-or-yield
|
# pylint: enable=g-doc-return-or-yield
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import test_kernel_label_op
|
from tensorflow.python.framework import test_kernel_label_op
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import common_shapes
|
from tensorflow.python.ops import common_shapes
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
@ -356,19 +357,19 @@ class NameTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("my_op", op2.name)
|
self.assertEqual("my_op", op2.name)
|
||||||
self.assertEqual("my_op:0", op2.outputs[0].name)
|
self.assertEqual("my_op:0", op2.outputs[0].name)
|
||||||
|
|
||||||
def testname_scope(self):
|
def testNameScope(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
|
|
||||||
with g.name_scope("foo") as foo:
|
with g.name_scope("foo") as foo:
|
||||||
self.assertEqual(foo, "foo/")
|
self.assertEqual("foo/", foo)
|
||||||
with g.name_scope("foo2") as foo2:
|
with g.name_scope("foo2") as foo2:
|
||||||
self.assertEqual(foo2, "foo/foo2/")
|
self.assertEqual("foo/foo2/", foo2)
|
||||||
with g.name_scope(None) as empty1:
|
with g.name_scope(None) as empty1:
|
||||||
self.assertEqual(empty1, "")
|
self.assertEqual("", empty1)
|
||||||
with g.name_scope("foo3") as foo3:
|
with g.name_scope("foo3") as foo3:
|
||||||
self.assertEqual(foo3, "foo3/")
|
self.assertEqual("foo3/", foo3)
|
||||||
with g.name_scope("") as empty2:
|
with g.name_scope("") as empty2:
|
||||||
self.assertEqual(empty2, "")
|
self.assertEqual("", empty2)
|
||||||
|
|
||||||
self.assertEqual("const",
|
self.assertEqual("const",
|
||||||
g.create_op("const", [], [dtypes.float32]).name)
|
g.create_op("const", [], [dtypes.float32]).name)
|
||||||
@ -792,6 +793,80 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(b.op.control_inputs, [])
|
self.assertEqual(b.op.control_inputs, [])
|
||||||
|
|
||||||
|
|
||||||
|
class OpScopeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testNoScopeName(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
values = [
|
||||||
|
g0.create_op("a", [], [dtypes.float32]),
|
||||||
|
g0.create_op("b", [], [dtypes.float32])]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
with ops.op_scope(values, None):
|
||||||
|
pass
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
with ops.op_scope(values, None, None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testEmptyScopeName(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
a = g0.create_op("a", [], [dtypes.float32])
|
||||||
|
b = g0.create_op("b", [], [dtypes.float32])
|
||||||
|
with ops.op_scope([a, b], "") as scope:
|
||||||
|
self.assertEqual("", scope)
|
||||||
|
self.assertEqual(g0, ops.get_default_graph())
|
||||||
|
with ops.op_scope([a, b], "", "my_default_scope") as scope:
|
||||||
|
self.assertEqual("", scope)
|
||||||
|
self.assertEqual(g0, ops.get_default_graph())
|
||||||
|
|
||||||
|
def testDefaultScopeName(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
a = g0.create_op("a", [], [dtypes.float32])
|
||||||
|
b = g0.create_op("b", [], [dtypes.float32])
|
||||||
|
scope_name = "my_scope"
|
||||||
|
default_scope_name = "my_default_scope"
|
||||||
|
with ops.op_scope([a, b], scope_name, default_scope_name) as scope:
|
||||||
|
self.assertEqual("%s/" % scope_name, scope)
|
||||||
|
self.assertEqual(g0, ops.get_default_graph())
|
||||||
|
with ops.op_scope([a, b], None, default_scope_name) as scope:
|
||||||
|
self.assertEqual("%s/" % default_scope_name, scope)
|
||||||
|
self.assertEqual(g0, ops.get_default_graph())
|
||||||
|
|
||||||
|
def _testGraphElements(self, graph_elements):
|
||||||
|
scope_name = "my_scope"
|
||||||
|
with ops.op_scope(graph_elements, scope_name) as scope:
|
||||||
|
self.assertEqual("%s/" % scope_name, scope)
|
||||||
|
self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
|
||||||
|
g1 = ops.Graph()
|
||||||
|
c = g1.create_op("c", [], [dtypes.float32])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
with ops.op_scope(graph_elements + [c], scope_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testTensor(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
a = g0.create_op("a", [], [dtypes.float32])
|
||||||
|
b = g0.create_op("b", [], [dtypes.float32])
|
||||||
|
self._testGraphElements([a, b])
|
||||||
|
|
||||||
|
def testSparseTensor(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
a = g0.create_op("a", [], [dtypes.float32])
|
||||||
|
b = g0.create_op("b", [], [dtypes.float32])
|
||||||
|
sparse = ops.SparseTensor(
|
||||||
|
_apply_op(g0, "const", [], [dtypes.int64]),
|
||||||
|
_apply_op(g0, "const", [], [dtypes.float32]),
|
||||||
|
_apply_op(g0, "const", [], [dtypes.int64]))
|
||||||
|
self._testGraphElements([a, sparse, b])
|
||||||
|
|
||||||
|
def testVariable(self):
|
||||||
|
g0 = ops.Graph()
|
||||||
|
with g0.as_default():
|
||||||
|
variable = variables.Variable([1.0])
|
||||||
|
a = g0.create_op("a", [], [dtypes.float32])
|
||||||
|
b = g0.create_op("b", [], [dtypes.float32])
|
||||||
|
self._testGraphElements([a, variable, b])
|
||||||
|
|
||||||
|
|
||||||
class GraphTest(test_util.TensorFlowTestCase):
|
class GraphTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -835,27 +910,6 @@ class GraphTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
g.as_graph_element(NonConvertibleObj())
|
g.as_graph_element(NonConvertibleObj())
|
||||||
|
|
||||||
def testAssertSameGraph(self):
|
|
||||||
g0 = ops.Graph()
|
|
||||||
a = g0.create_op("a", [], [dtypes.float32])
|
|
||||||
b = g0.create_op("b", [], [dtypes.float32])
|
|
||||||
ops.assert_same_graph([a, b])
|
|
||||||
ops.assert_same_graph([a, b], g0)
|
|
||||||
g1 = ops.Graph()
|
|
||||||
c = g1.create_op("c", [], [dtypes.float32])
|
|
||||||
self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
|
|
||||||
self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
|
|
||||||
self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
|
|
||||||
|
|
||||||
sparse = ops.SparseTensor(
|
|
||||||
_apply_op(g0, "const", [], [dtypes.int64]),
|
|
||||||
_apply_op(g0, "const", [], [dtypes.float32]),
|
|
||||||
_apply_op(g0, "const", [], [dtypes.int64]))
|
|
||||||
ops.assert_same_graph([sparse, a, b])
|
|
||||||
ops.assert_same_graph([sparse, a, b], g0)
|
|
||||||
self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
|
|
||||||
self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
|
|
||||||
|
|
||||||
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
|
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -616,6 +616,10 @@ class OpDefLibrary(object):
|
|||||||
elif attr_def.type == "list(tensor)":
|
elif attr_def.type == "list(tensor)":
|
||||||
attr_value.list.tensor.extend(
|
attr_value.list.tensor.extend(
|
||||||
[_MakeTensor(x, key) for x in value])
|
[_MakeTensor(x, key) for x in value])
|
||||||
|
elif attr_def.type == "func":
|
||||||
|
if not isinstance(value, compat.bytes_or_text_types):
|
||||||
|
raise TypeError("Expects a string for the func name")
|
||||||
|
attr_value.func.name = value
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unrecognized Attr type " + attr_def.type)
|
raise TypeError("Unrecognized Attr type " + attr_def.type)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user