Export the utils functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA and MLIR are using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 264935549
This commit is contained in:
parent
9a631a1103
commit
da3f7b14ff
@ -20,6 +20,7 @@ visibility = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
@ -99,6 +100,7 @@ py_library(
|
|||||||
"//third_party/py/tensorflow_core:__subpackages__",
|
"//third_party/py/tensorflow_core:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_utils",
|
||||||
":array_ops",
|
":array_ops",
|
||||||
":audio_ops_gen",
|
":audio_ops_gen",
|
||||||
":bitwise_ops",
|
":bitwise_ops",
|
||||||
@ -377,6 +379,22 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_utils",
|
||||||
|
srcs = ["util/util_wrapper.cc"],
|
||||||
|
hdrs = ["util/util.h"],
|
||||||
|
copts = [
|
||||||
|
"-fexceptions",
|
||||||
|
"-fno-strict-aliasing",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
module_name = "_pywrap_utils",
|
||||||
|
deps = [
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpp_python_util",
|
name = "cpp_python_util",
|
||||||
srcs = ["util/util.cc"],
|
srcs = ["util/util.cc"],
|
||||||
@ -685,6 +703,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_utils",
|
||||||
":common_shapes",
|
":common_shapes",
|
||||||
":composite_tensor",
|
":composite_tensor",
|
||||||
":convert_to_constants",
|
":convert_to_constants",
|
||||||
@ -4987,7 +5006,6 @@ tf_py_wrap_cc(
|
|||||||
"util/tfprof.i",
|
"util/tfprof.i",
|
||||||
"util/traceme.i",
|
"util/traceme.i",
|
||||||
"util/transform_graph.i",
|
"util/transform_graph.i",
|
||||||
"util/util.i",
|
|
||||||
"//tensorflow/lite/toco/python:toco.i",
|
"//tensorflow/lite/toco/python:toco.i",
|
||||||
],
|
],
|
||||||
# add win_def_file for pywrap_tensorflow
|
# add win_def_file for pywrap_tensorflow
|
||||||
@ -5056,14 +5074,20 @@ tf_py_wrap_cc(
|
|||||||
# the dynamic libraries of custom ops can find it at runtime.
|
# the dynamic libraries of custom ops can find it at runtime.
|
||||||
genrule(
|
genrule(
|
||||||
name = "pywrap_tensorflow_filtered_def_file",
|
name = "pywrap_tensorflow_filtered_def_file",
|
||||||
srcs = ["//tensorflow:tensorflow_def_file"],
|
srcs = [
|
||||||
|
"//tensorflow:tensorflow_def_file",
|
||||||
|
"//tensorflow/tools/def_file_filter:symbols_pybind",
|
||||||
|
":cpp_python_util",
|
||||||
|
],
|
||||||
outs = ["pywrap_tensorflow_filtered_def_file.def"],
|
outs = ["pywrap_tensorflow_filtered_def_file.def"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
"//tensorflow:windows": """
|
"//tensorflow:windows": """
|
||||||
$(location @local_config_def_file_filter//:def_file_filter) \\
|
$(location @local_config_def_file_filter//:def_file_filter) \\
|
||||||
--input $(location //tensorflow:tensorflow_def_file) \\
|
--input $(location //tensorflow:tensorflow_def_file) \\
|
||||||
--output $@ \\
|
--output $@ \\
|
||||||
--target _pywrap_tensorflow_internal.pyd
|
--target _pywrap_tensorflow_internal.pyd \\
|
||||||
|
--lib_paths $(execpath :cpp_python_util) \\
|
||||||
|
--symbols $(location //tensorflow/tools/def_file_filter:symbols_pybind)
|
||||||
""",
|
""",
|
||||||
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
|
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
|
||||||
}),
|
}),
|
||||||
|
@ -47,6 +47,7 @@ import traceback
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python import _pywrap_utils
|
||||||
|
|
||||||
# Protocol buffers
|
# Protocol buffers
|
||||||
from tensorflow.core.framework.graph_pb2 import *
|
from tensorflow.core.framework.graph_pb2 import *
|
||||||
|
@ -37,7 +37,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import six as _six
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.framework import sparse_tensor as _sparse_tensor
|
from tensorflow.python.framework import sparse_tensor as _sparse_tensor
|
||||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||||
|
|
||||||
@ -95,10 +95,10 @@ def _yield_value(iterable):
|
|||||||
|
|
||||||
|
|
||||||
# See the swig file (../../util/util.i) for documentation.
|
# See the swig file (../../util/util.i) for documentation.
|
||||||
is_sequence = _pywrap_tensorflow.IsSequenceForData
|
is_sequence = _pywrap_utils.IsSequenceForData
|
||||||
|
|
||||||
# See the swig file (../../util/util.i) for documentation.
|
# See the swig file (../../util/util.i) for documentation.
|
||||||
flatten = _pywrap_tensorflow.FlattenForData
|
flatten = _pywrap_utils.FlattenForData
|
||||||
|
|
||||||
|
|
||||||
def assert_same_structure(nest1, nest2, check_types=True):
|
def assert_same_structure(nest1, nest2, check_types=True):
|
||||||
@ -120,7 +120,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
|
|||||||
TypeError: If the two structures differ in the type of sequence in any of
|
TypeError: If the two structures differ in the type of sequence in any of
|
||||||
their substructures. Only possible if `check_types` is `True`.
|
their substructures. Only possible if `check_types` is `True`.
|
||||||
"""
|
"""
|
||||||
_pywrap_tensorflow.AssertSameStructureForData(nest1, nest2, check_types)
|
_pywrap_utils.AssertSameStructureForData(nest1, nest2, check_types)
|
||||||
|
|
||||||
|
|
||||||
def _packed_nest_with_indices(structure, flat, index):
|
def _packed_nest_with_indices(structure, flat, index):
|
||||||
|
@ -24,6 +24,7 @@ import sys
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
@ -844,8 +845,7 @@ class GradientTape(object):
|
|||||||
ValueError: if it encounters something that is not a tensor.
|
ValueError: if it encounters something that is not a tensor.
|
||||||
"""
|
"""
|
||||||
for t in nest.flatten(tensor):
|
for t in nest.flatten(tensor):
|
||||||
if not (pywrap_tensorflow.IsTensor(t) or
|
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
|
||||||
pywrap_tensorflow.IsVariable(t)):
|
|
||||||
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
|
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
|
||||||
type(t)))
|
type(t)))
|
||||||
if not t.dtype.is_floating:
|
if not t.dtype.is_floating:
|
||||||
|
@ -32,6 +32,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import function_pb2
|
from tensorflow.core.framework import function_pb2
|
||||||
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
@ -1416,8 +1417,8 @@ class ConcreteFunction(object):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
|
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||||
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
|
_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||||
|
|
||||||
|
|
||||||
def _deterministic_dict_values(dictionary):
|
def _deterministic_dict_values(dictionary):
|
||||||
@ -1698,7 +1699,7 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
|
|||||||
need_packing = False
|
need_packing = False
|
||||||
for index, (value, spec) in enumerate(zip(flatten_inputs,
|
for index, (value, spec) in enumerate(zip(flatten_inputs,
|
||||||
flat_input_signature)):
|
flat_input_signature)):
|
||||||
if not pywrap_tensorflow.IsTensor(value):
|
if not _pywrap_utils.IsTensor(value):
|
||||||
try:
|
try:
|
||||||
flatten_inputs[index] = ops.convert_to_tensor(
|
flatten_inputs[index] = ops.convert_to_tensor(
|
||||||
value, dtype_hint=spec.dtype)
|
value, dtype_hint=spec.dtype)
|
||||||
|
@ -22,7 +22,7 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ class CompositeTensor(object):
|
|||||||
return list(set(consumers))
|
return list(set(consumers))
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor)
|
_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor)
|
||||||
|
|
||||||
|
|
||||||
def replace_composites_with_components(structure):
|
def replace_composites_with_components(structure):
|
||||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -255,7 +255,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
SparseTensorValue = collections.namedtuple("SparseTensorValue",
|
SparseTensorValue = collections.namedtuple("SparseTensorValue",
|
||||||
["indices", "values", "dense_shape"])
|
["indices", "values", "dense_shape"])
|
||||||
tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
|
tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
|
||||||
pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
|
_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("SparseTensorSpec")
|
@tf_export("SparseTensorSpec")
|
||||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -287,7 +287,7 @@ class BoundedTensorSpec(TensorSpec):
|
|||||||
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
|
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("TensorSpec", TensorSpec)
|
_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
|
||||||
|
|
||||||
|
|
||||||
# Note: we do not include Tensor names when constructing TypeSpecs.
|
# Note: we do not include Tensor names when constructing TypeSpecs.
|
||||||
|
@ -22,7 +22,7 @@ import abc
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -546,4 +546,4 @@ def register_type_spec_from_value_converter(type_object, converter_fn,
|
|||||||
(type_object, converter_fn, allow_subclass))
|
(type_object, converter_fn, allow_subclass))
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("TypeSpec", TypeSpec)
|
_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
|
||||||
|
@ -24,6 +24,7 @@ import functools
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import variable_pb2
|
from tensorflow.core.framework import variable_pb2
|
||||||
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
@ -1781,7 +1782,7 @@ class UninitializedVariable(BaseResourceVariable):
|
|||||||
synchronization=synchronization, aggregation=aggregation)
|
synchronization=synchronization, aggregation=aggregation)
|
||||||
|
|
||||||
|
|
||||||
pywrap_tensorflow.RegisterType("ResourceVariable", ResourceVariable)
|
_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
|
||||||
math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access
|
math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import variable_pb2
|
from tensorflow.core.framework import variable_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -1351,7 +1351,7 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
|||||||
|
|
||||||
|
|
||||||
Variable._OverloadAllOperators() # pylint: disable=protected-access
|
Variable._OverloadAllOperators() # pylint: disable=protected-access
|
||||||
pywrap_tensorflow.RegisterType("Variable", Variable)
|
_pywrap_utils.RegisterType("Variable", Variable)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["Variable"])
|
@tf_export(v1=["Variable"])
|
||||||
|
@ -172,6 +172,7 @@ limitations under the License.
|
|||||||
|
|
||||||
%{
|
%{
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
#include "tensorflow/python/util/util.h"
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
@ -49,8 +49,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/util/transform_graph.i"
|
%include "tensorflow/python/util/transform_graph.i"
|
||||||
|
|
||||||
%include "tensorflow/python/util/util.i"
|
|
||||||
|
|
||||||
%include "tensorflow/python/grappler/cluster.i"
|
%include "tensorflow/python/grappler/cluster.i"
|
||||||
%include "tensorflow/python/grappler/item.i"
|
%include "tensorflow/python/grappler/item.i"
|
||||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||||
|
@ -38,7 +38,7 @@ import collections as _collections
|
|||||||
|
|
||||||
import six as _six
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||||
|
|
||||||
@ -104,15 +104,15 @@ def _is_namedtuple(instance, strict=False):
|
|||||||
Returns:
|
Returns:
|
||||||
True if `instance` is a `namedtuple`.
|
True if `instance` is a `namedtuple`.
|
||||||
"""
|
"""
|
||||||
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
|
return _pywrap_utils.IsNamedtuple(instance, strict)
|
||||||
|
|
||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
_is_mapping = _pywrap_tensorflow.IsMapping
|
_is_mapping = _pywrap_utils.IsMapping
|
||||||
_is_mapping_view = _pywrap_tensorflow.IsMappingView
|
_is_mapping_view = _pywrap_utils.IsMappingView
|
||||||
_is_attrs = _pywrap_tensorflow.IsAttrs
|
_is_attrs = _pywrap_utils.IsAttrs
|
||||||
_is_composite_tensor = _pywrap_tensorflow.IsCompositeTensor
|
_is_composite_tensor = _pywrap_utils.IsCompositeTensor
|
||||||
_is_type_spec = _pywrap_tensorflow.IsTypeSpec
|
_is_type_spec = _pywrap_utils.IsTypeSpec
|
||||||
|
|
||||||
|
|
||||||
def _sequence_like(instance, args):
|
def _sequence_like(instance, args):
|
||||||
@ -208,11 +208,11 @@ def _yield_sorted_items(iterable):
|
|||||||
|
|
||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
is_sequence = _pywrap_tensorflow.IsSequence
|
is_sequence = _pywrap_utils.IsSequence
|
||||||
|
|
||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
is_sequence_or_composite = _pywrap_tensorflow.IsSequenceOrComposite
|
is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite
|
||||||
|
|
||||||
|
|
||||||
@tf_export("nest.is_nested")
|
@tf_export("nest.is_nested")
|
||||||
@ -260,11 +260,11 @@ def flatten(structure, expand_composites=False):
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: The nest is or contains a dict with non-sortable keys.
|
TypeError: The nest is or contains a dict with non-sortable keys.
|
||||||
"""
|
"""
|
||||||
return _pywrap_tensorflow.Flatten(structure, expand_composites)
|
return _pywrap_utils.Flatten(structure, expand_composites)
|
||||||
|
|
||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
|
_same_namedtuples = _pywrap_utils.SameNamedtuples
|
||||||
|
|
||||||
|
|
||||||
class _DotString(object):
|
class _DotString(object):
|
||||||
@ -315,8 +315,8 @@ def assert_same_structure(nest1, nest2, check_types=True,
|
|||||||
their substructures. Only possible if `check_types` is `True`.
|
their substructures. Only possible if `check_types` is `True`.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types,
|
_pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
|
||||||
expand_composites)
|
expand_composites)
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
str1 = str(map_structure(lambda _: _DOT, nest1))
|
str1 = str(map_structure(lambda _: _DOT, nest1))
|
||||||
str2 = str(map_structure(lambda _: _DOT, nest2))
|
str2 = str(map_structure(lambda _: _DOT, nest2))
|
||||||
@ -1327,6 +1327,6 @@ def flatten_with_tuple_paths(structure, expand_composites=False):
|
|||||||
flatten(structure, expand_composites=expand_composites)))
|
flatten(structure, expand_composites=expand_composites)))
|
||||||
|
|
||||||
|
|
||||||
_pywrap_tensorflow.RegisterType("Mapping", _collections_abc.Mapping)
|
_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
|
||||||
_pywrap_tensorflow.RegisterType("Sequence", _collections_abc.Sequence)
|
_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
|
||||||
_pywrap_tensorflow.RegisterType("MappingView", _collections_abc.MappingView)
|
_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
|
||||||
|
@ -1,212 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/python/util/util.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
|
|
||||||
%unignore tensorflow;
|
|
||||||
%unignore tensorflow::swig;
|
|
||||||
// The %exception block defined in tf_session.i releases the Python GIL for
|
|
||||||
// the length of each wrapped method. This file is included in tensorflow.i
|
|
||||||
// after tf_session.i and inherits this definition. We disable this behavior
|
|
||||||
// for functions in this module because they use python methods that need GIL.
|
|
||||||
// TODO(iga): Find a way not to leak such definitions across files.
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::RegisterType;
|
|
||||||
%noexception tensorflow::swig::RegisterType;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::IsTensor;
|
|
||||||
%noexception tensorflow::swig::IsTensor;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::IsResourceVariable;
|
|
||||||
%noexception tensorflow::swig::IsResourceVariable;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::IsVariable;
|
|
||||||
%noexception tensorflow::swig::IsVariable;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsSequence
|
|
||||||
"""Returns true if its input is a collections.Sequence (except strings).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: an input sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the sequence is a not a string and is a collections.Sequence or a
|
|
||||||
dict.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsSequence;
|
|
||||||
%noexception tensorflow::swig::IsSequence;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsSequenceOrComposite
|
|
||||||
"""Returns true if its input is a sequence or a `CompositeTensor`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: an input sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the sequence is a not a string and is a collections.Sequence or a
|
|
||||||
dict or a CompositeTensor or a TypeSpec (except string and TensorSpec).
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsSequenceOrComposite;
|
|
||||||
%noexception tensorflow::swig::IsSequenceOrComposite;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsCompositeTensor
|
|
||||||
"""Returns true if its input is a `CompositeTensor`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: an input sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the sequence is a CompositeTensor.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsCompositeTensor;
|
|
||||||
%noexception tensorflow::swig::IsCompositeTensor;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsTypeSpec
|
|
||||||
"""Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: an input sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the sequence is a `TypeSpec`, but is not a `TensorSpec`.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsTypeSpec;
|
|
||||||
%noexception tensorflow::swig::IsTypeSpec;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::IsNamedtuple;
|
|
||||||
%noexception tensorflow::swig::IsNamedtuple;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsMapping
|
|
||||||
"""Returns True iff `instance` is a `collections.Mapping`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instance: An instance of a Python object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if `instance` is a `collections.Mapping`.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsMapping;
|
|
||||||
%noexception tensorflow::swig::IsMapping;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsMappingView
|
|
||||||
"""Returns True iff `instance` is a `collections.MappingView`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instance: An instance of a Python object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if `instance` is a `collections.MappingView`.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsMappingView;
|
|
||||||
%noexception tensorflow::swig::IsMappingView;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsAttrs
|
|
||||||
"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
instance: An instance of a Python object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if `instance` is an instance of an `attr.s` decorated class.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsAttrs;
|
|
||||||
%noexception tensorflow::swig::IsAttrs;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::SameNamedtuples
|
|
||||||
"Returns True if the two namedtuples have the same name and fields."
|
|
||||||
%unignore tensorflow::swig::SameNamedtuples;
|
|
||||||
%noexception tensorflow::swig::SameNamedtuples;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::AssertSameStructure;
|
|
||||||
%noexception tensorflow::swig::AssertSameStructure;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::Flatten
|
|
||||||
"""Returns a flat list from a given nested structure.
|
|
||||||
|
|
||||||
If `nest` is not a sequence, tuple, or dict, then returns a single-element
|
|
||||||
list: `[nest]`.
|
|
||||||
|
|
||||||
In the case of dict instances, the sequence consists of the values, sorted by
|
|
||||||
key to ensure deterministic behavior. This is true also for `OrderedDict`
|
|
||||||
instances: their sequence order is ignored, the sorting order of keys is
|
|
||||||
used instead. The same convention is followed in `pack_sequence_as`. This
|
|
||||||
correctly repacks dicts and `OrderedDict`s after they have been flattened,
|
|
||||||
and also allows flattening an `OrderedDict` and then repacking it back using
|
|
||||||
a corresponding plain dict, or vice-versa.
|
|
||||||
Dictionaries with non-sortable keys cannot be flattened.
|
|
||||||
|
|
||||||
Users must not modify any collections used in `nest` while this function is
|
|
||||||
running.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nest: an arbitrarily nested structure or a scalar object. Note, numpy
|
|
||||||
arrays are considered scalars.
|
|
||||||
expand_composites: If true, then composite tensors such as `tf.SparseTensor`
|
|
||||||
and `tf.RaggedTensor` are expanded into their component tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Python list, the flattened version of the input.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: The nest is or contains a dict with non-sortable keys.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::Flatten;
|
|
||||||
%noexception tensorflow::swig::Flatten;
|
|
||||||
%feature("kwargs") tensorflow::swig::Flatten;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::IsSequenceForData
|
|
||||||
"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
|
||||||
|
|
||||||
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
|
||||||
which *does* treat a Python list as a sequence. For ergonomic
|
|
||||||
reasons, `tf.data` users would prefer to treat lists as
|
|
||||||
implicit `tf.Tensor` objects, and dicts as (nested) sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: an input sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the sequence is a not a string or list and is a
|
|
||||||
collections.Sequence.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::IsSequenceForData;
|
|
||||||
%noexception tensorflow::swig::IsSequenceForData;
|
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::FlattenForData
|
|
||||||
"""Returns a flat sequence from a given nested structure.
|
|
||||||
|
|
||||||
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nest: an arbitrarily nested structure or a scalar object.
|
|
||||||
Note, numpy arrays are considered scalars.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Python list, the flattened version of the input.
|
|
||||||
"""
|
|
||||||
%unignore tensorflow::swig::FlattenForData;
|
|
||||||
%noexception tensorflow::swig::FlattenForData;
|
|
||||||
|
|
||||||
%unignore tensorflow::swig::AssertSameStructureForData;
|
|
||||||
%noexception tensorflow::swig::AssertSameStructureForData;
|
|
||||||
|
|
||||||
%include "tensorflow/python/util/util.h"
|
|
||||||
|
|
||||||
%unignoreall
|
|
333
tensorflow/python/util/util_wrapper.cc
Normal file
333
tensorflow/python/util/util_wrapper.cc
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/python/util/util.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
inline py::object pyo_or_throw(PyObject* ptr) {
|
||||||
|
if (PyErr_Occurred() || ptr == nullptr) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return py::reinterpret_steal<py::object>(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_utils, m) {
|
||||||
|
m.doc() = R"pbdoc(
|
||||||
|
_pywrap_utils
|
||||||
|
-----
|
||||||
|
)pbdoc";
|
||||||
|
m.def("RegisterType",
|
||||||
|
[](const py::handle& type_name, const py::handle& type) {
|
||||||
|
return pyo_or_throw(
|
||||||
|
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
|
||||||
|
});
|
||||||
|
m.def(
|
||||||
|
"IsTensor",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsTensor(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Check if an object is a Tensor.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsSequence",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsSequence(o.ptr());
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns true if its input is a collections.Sequence (except strings).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a not a string and is a collections.Sequence or a
|
||||||
|
dict.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsSequenceOrComposite",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsSequenceOrComposite(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns true if its input is a sequence or a `CompositeTensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a not a string and is a collections.Sequence or a
|
||||||
|
dict or a CompositeTensor or a TypeSpec (except string and TensorSpec).
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsCompositeTensor",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsCompositeTensor(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns true if its input is a `CompositeTensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a CompositeTensor.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsTypeSpec",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsTypeSpec(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a `TypeSpec`, but is not a `TensorSpec`.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsNamedtuple",
|
||||||
|
[](const py::handle& o, bool strict) {
|
||||||
|
return pyo_or_throw(tensorflow::swig::IsNamedtuple(o.ptr(), strict));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Check if an object is a NamedTuple.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsMapping",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsMapping(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if `instance` is a `collections.Mapping`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `instance` is a `collections.Mapping`.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsMappingView",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsMappingView(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if considered a mapping view for the purposes of Flatten()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if considered a mapping view for the purposes of Flatten().
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsAttrs",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsAttrs(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if `instance` is an instance of an `attr.s` decorated class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `instance` is an instance of an `attr.s` decorated class.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"SameNamedtuples",
|
||||||
|
[](const py::handle& o1, const py::handle& o2) {
|
||||||
|
return pyo_or_throw(
|
||||||
|
tensorflow::swig::SameNamedtuples(o1.ptr(), o2.ptr()));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if the two namedtuples have the same name and fields.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"AssertSameStructure",
|
||||||
|
[](const py::handle& o1, const py::handle& o2, bool check_types,
|
||||||
|
bool expand_composites) {
|
||||||
|
bool result = tensorflow::swig::AssertSameStructure(
|
||||||
|
o1.ptr(), o2.ptr(), check_types, expand_composites);
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if the two structures are nested in the same way.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"Flatten",
|
||||||
|
[](const py::handle& o, bool expand_composites) {
|
||||||
|
return pyo_or_throw(
|
||||||
|
tensorflow::swig::Flatten(o.ptr(), expand_composites));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a flat list from a given nested structure.
|
||||||
|
|
||||||
|
If `nest` is not a sequence, tuple, or dict, then returns a single-element
|
||||||
|
list: `[nest]`.
|
||||||
|
|
||||||
|
In the case of dict instances, the sequence consists of the values, sorted by
|
||||||
|
key to ensure deterministic behavior. This is true also for `OrderedDict`
|
||||||
|
instances: their sequence order is ignored, the sorting order of keys is
|
||||||
|
used instead. The same convention is followed in `pack_sequence_as`. This
|
||||||
|
correctly repacks dicts and `OrderedDict`s after they have been flattened,
|
||||||
|
and also allows flattening an `OrderedDict` and then repacking it back using
|
||||||
|
a corresponding plain dict, or vice-versa.
|
||||||
|
Dictionaries with non-sortable keys cannot be flattened.
|
||||||
|
|
||||||
|
Users must not modify any collections used in `nest` while this function is
|
||||||
|
running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nest: an arbitrarily nested structure or a scalar object. Note, numpy
|
||||||
|
arrays are considered scalars.
|
||||||
|
expand_composites: If true, then composite tensors such as `tf.SparseTensor`
|
||||||
|
and `tf.RaggedTensor` are expanded into their component tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python list, the flattened version of the input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: The nest is or contains a dict with non-sortable keys.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsSequenceForData",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsSequenceForData(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
||||||
|
|
||||||
|
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
||||||
|
which *does* treat a Python list as a sequence. For ergonomic
|
||||||
|
reasons, `tf.data` users would prefer to treat lists as
|
||||||
|
implicit `tf.Tensor` objects, and dicts as (nested) sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a not a string or list and is a
|
||||||
|
collections.Sequence.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"FlattenForData",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
return pyo_or_throw(tensorflow::swig::FlattenForData(o.ptr()));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns a flat sequence from a given nested structure.
|
||||||
|
|
||||||
|
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nest: an arbitrarily nested structure or a scalar object.
|
||||||
|
Note, numpy arrays are considered scalars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python list, the flattened version of the input.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"AssertSameStructureForData",
|
||||||
|
[](const py::handle& o1, const py::handle& o2, bool check_types) {
|
||||||
|
bool result = tensorflow::swig::AssertSameStructureForData(
|
||||||
|
o1.ptr(), o2.ptr(), check_types);
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns True if the two structures are nested in the same way in particular tf.data.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsResourceVariable",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsResourceVariable(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns 1 if `o` is a ResourceVariable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `instance` is a `ResourceVariable`.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"IsVariable",
|
||||||
|
[](const py::handle& o) {
|
||||||
|
bool result = tensorflow::swig::IsVariable(o.ptr());
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
Returns 1 if `o` is a Variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `instance` is a `Variable`.
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -1 +1,19 @@
|
|||||||
|
[cpp_python_util]
|
||||||
|
tensorflow::swig::IsSequence
|
||||||
|
tensorflow::swig::IsSequenceOrComposite
|
||||||
|
tensorflow::swig::IsCompositeTensor
|
||||||
|
tensorflow::swig::IsTypeSpec
|
||||||
|
tensorflow::swig::IsNamedtuple
|
||||||
|
tensorflow::swig::IsMapping
|
||||||
|
tensorflow::swig::IsMappingView
|
||||||
|
tensorflow::swig::IsAttrs
|
||||||
|
tensorflow::swig::IsTensor
|
||||||
|
tensorflow::swig::IsResourceVariable
|
||||||
|
tensorflow::swig::IsVariable
|
||||||
|
tensorflow::swig::SameNamedtuples
|
||||||
|
tensorflow::swig::AssertSameStructure
|
||||||
|
tensorflow::swig::Flatten
|
||||||
|
tensorflow::swig::IsSequenceForData
|
||||||
|
tensorflow::swig::FlattenForData
|
||||||
|
tensorflow::swig::AssertSameStructureForData
|
||||||
|
tensorflow::swig::RegisterType
|
||||||
|
Loading…
x
Reference in New Issue
Block a user