Export the utils functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA and MLIR are using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 264935549
This commit is contained in:
parent
9a631a1103
commit
da3f7b14ff
@ -20,6 +20,7 @@ visibility = [
|
||||
]
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
@ -99,6 +100,7 @@ py_library(
|
||||
"//third_party/py/tensorflow_core:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":_pywrap_utils",
|
||||
":array_ops",
|
||||
":audio_ops_gen",
|
||||
":bitwise_ops",
|
||||
@ -377,6 +379,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_utils",
|
||||
srcs = ["util/util_wrapper.cc"],
|
||||
hdrs = ["util/util.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_utils",
|
||||
deps = [
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpp_python_util",
|
||||
srcs = ["util/util.cc"],
|
||||
@ -685,6 +703,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_utils",
|
||||
":common_shapes",
|
||||
":composite_tensor",
|
||||
":convert_to_constants",
|
||||
@ -4987,7 +5006,6 @@ tf_py_wrap_cc(
|
||||
"util/tfprof.i",
|
||||
"util/traceme.i",
|
||||
"util/transform_graph.i",
|
||||
"util/util.i",
|
||||
"//tensorflow/lite/toco/python:toco.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
@ -5056,14 +5074,20 @@ tf_py_wrap_cc(
|
||||
# the dynamic libraries of custom ops can find it at runtime.
|
||||
genrule(
|
||||
name = "pywrap_tensorflow_filtered_def_file",
|
||||
srcs = ["//tensorflow:tensorflow_def_file"],
|
||||
srcs = [
|
||||
"//tensorflow:tensorflow_def_file",
|
||||
"//tensorflow/tools/def_file_filter:symbols_pybind",
|
||||
":cpp_python_util",
|
||||
],
|
||||
outs = ["pywrap_tensorflow_filtered_def_file.def"],
|
||||
cmd = select({
|
||||
"//tensorflow:windows": """
|
||||
$(location @local_config_def_file_filter//:def_file_filter) \\
|
||||
--input $(location //tensorflow:tensorflow_def_file) \\
|
||||
--output $@ \\
|
||||
--target _pywrap_tensorflow_internal.pyd
|
||||
--target _pywrap_tensorflow_internal.pyd \\
|
||||
--lib_paths $(execpath :cpp_python_util) \\
|
||||
--symbols $(location //tensorflow/tools/def_file_filter:symbols_pybind)
|
||||
""",
|
||||
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
|
||||
}),
|
||||
|
@ -47,6 +47,7 @@ import traceback
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
|
||||
# Protocol buffers
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
|
@ -37,7 +37,7 @@ from __future__ import print_function
|
||||
|
||||
import six as _six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.framework import sparse_tensor as _sparse_tensor
|
||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||
|
||||
@ -95,10 +95,10 @@ def _yield_value(iterable):
|
||||
|
||||
|
||||
# See the swig file (../../util/util.i) for documentation.
|
||||
is_sequence = _pywrap_tensorflow.IsSequenceForData
|
||||
is_sequence = _pywrap_utils.IsSequenceForData
|
||||
|
||||
# See the swig file (../../util/util.i) for documentation.
|
||||
flatten = _pywrap_tensorflow.FlattenForData
|
||||
flatten = _pywrap_utils.FlattenForData
|
||||
|
||||
|
||||
def assert_same_structure(nest1, nest2, check_types=True):
|
||||
@ -120,7 +120,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
|
||||
TypeError: If the two structures differ in the type of sequence in any of
|
||||
their substructures. Only possible if `check_types` is `True`.
|
||||
"""
|
||||
_pywrap_tensorflow.AssertSameStructureForData(nest1, nest2, check_types)
|
||||
_pywrap_utils.AssertSameStructureForData(nest1, nest2, check_types)
|
||||
|
||||
|
||||
def _packed_nest_with_indices(structure, flat, index):
|
||||
|
@ -24,6 +24,7 @@ import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
@ -844,8 +845,7 @@ class GradientTape(object):
|
||||
ValueError: if it encounters something that is not a tensor.
|
||||
"""
|
||||
for t in nest.flatten(tensor):
|
||||
if not (pywrap_tensorflow.IsTensor(t) or
|
||||
pywrap_tensorflow.IsVariable(t)):
|
||||
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
|
||||
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
|
||||
type(t)))
|
||||
if not t.dtype.is_floating:
|
||||
|
@ -32,6 +32,7 @@ import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
@ -1416,8 +1417,8 @@ class ConcreteFunction(object):
|
||||
return ret
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
|
||||
pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||
_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
|
||||
|
||||
|
||||
def _deterministic_dict_values(dictionary):
|
||||
@ -1698,7 +1699,7 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
|
||||
need_packing = False
|
||||
for index, (value, spec) in enumerate(zip(flatten_inputs,
|
||||
flat_input_signature)):
|
||||
if not pywrap_tensorflow.IsTensor(value):
|
||||
if not _pywrap_utils.IsTensor(value):
|
||||
try:
|
||||
flatten_inputs[index] = ops.convert_to_tensor(
|
||||
value, dtype_hint=spec.dtype)
|
||||
|
@ -22,7 +22,7 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -137,7 +137,7 @@ class CompositeTensor(object):
|
||||
return list(set(consumers))
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor)
|
||||
_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor)
|
||||
|
||||
|
||||
def replace_composites_with_components(structure):
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -255,7 +255,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
||||
SparseTensorValue = collections.namedtuple("SparseTensorValue",
|
||||
["indices", "values", "dense_shape"])
|
||||
tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
|
||||
pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
|
||||
_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
|
||||
|
||||
|
||||
@tf_export("SparseTensorSpec")
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -287,7 +287,7 @@ class BoundedTensorSpec(TensorSpec):
|
||||
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("TensorSpec", TensorSpec)
|
||||
_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
|
||||
|
||||
|
||||
# Note: we do not include Tensor names when constructing TypeSpecs.
|
||||
|
@ -22,7 +22,7 @@ import abc
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -546,4 +546,4 @@ def register_type_spec_from_value_converter(type_object, converter_fn,
|
||||
(type_object, converter_fn, allow_subclass))
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("TypeSpec", TypeSpec)
|
||||
_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
|
||||
|
@ -24,6 +24,7 @@ import functools
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape
|
||||
@ -1781,7 +1782,7 @@ class UninitializedVariable(BaseResourceVariable):
|
||||
synchronization=synchronization, aggregation=aggregation)
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("ResourceVariable", ResourceVariable)
|
||||
_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable)
|
||||
math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -1351,7 +1351,7 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||
|
||||
|
||||
Variable._OverloadAllOperators() # pylint: disable=protected-access
|
||||
pywrap_tensorflow.RegisterType("Variable", Variable)
|
||||
_pywrap_utils.RegisterType("Variable", Variable)
|
||||
|
||||
|
||||
@tf_export(v1=["Variable"])
|
||||
|
@ -172,6 +172,7 @@ limitations under the License.
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
@ -49,8 +49,6 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/util/transform_graph.i"
|
||||
|
||||
%include "tensorflow/python/util/util.i"
|
||||
|
||||
%include "tensorflow/python/grappler/cluster.i"
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
|
@ -38,7 +38,7 @@ import collections as _collections
|
||||
|
||||
import six as _six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||
|
||||
@ -104,15 +104,15 @@ def _is_namedtuple(instance, strict=False):
|
||||
Returns:
|
||||
True if `instance` is a `namedtuple`.
|
||||
"""
|
||||
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
|
||||
return _pywrap_utils.IsNamedtuple(instance, strict)
|
||||
|
||||
|
||||
# See the swig file (util.i) for documentation.
|
||||
_is_mapping = _pywrap_tensorflow.IsMapping
|
||||
_is_mapping_view = _pywrap_tensorflow.IsMappingView
|
||||
_is_attrs = _pywrap_tensorflow.IsAttrs
|
||||
_is_composite_tensor = _pywrap_tensorflow.IsCompositeTensor
|
||||
_is_type_spec = _pywrap_tensorflow.IsTypeSpec
|
||||
_is_mapping = _pywrap_utils.IsMapping
|
||||
_is_mapping_view = _pywrap_utils.IsMappingView
|
||||
_is_attrs = _pywrap_utils.IsAttrs
|
||||
_is_composite_tensor = _pywrap_utils.IsCompositeTensor
|
||||
_is_type_spec = _pywrap_utils.IsTypeSpec
|
||||
|
||||
|
||||
def _sequence_like(instance, args):
|
||||
@ -208,11 +208,11 @@ def _yield_sorted_items(iterable):
|
||||
|
||||
|
||||
# See the swig file (util.i) for documentation.
|
||||
is_sequence = _pywrap_tensorflow.IsSequence
|
||||
is_sequence = _pywrap_utils.IsSequence
|
||||
|
||||
|
||||
# See the swig file (util.i) for documentation.
|
||||
is_sequence_or_composite = _pywrap_tensorflow.IsSequenceOrComposite
|
||||
is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite
|
||||
|
||||
|
||||
@tf_export("nest.is_nested")
|
||||
@ -260,11 +260,11 @@ def flatten(structure, expand_composites=False):
|
||||
Raises:
|
||||
TypeError: The nest is or contains a dict with non-sortable keys.
|
||||
"""
|
||||
return _pywrap_tensorflow.Flatten(structure, expand_composites)
|
||||
return _pywrap_utils.Flatten(structure, expand_composites)
|
||||
|
||||
|
||||
# See the swig file (util.i) for documentation.
|
||||
_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
|
||||
_same_namedtuples = _pywrap_utils.SameNamedtuples
|
||||
|
||||
|
||||
class _DotString(object):
|
||||
@ -315,8 +315,8 @@ def assert_same_structure(nest1, nest2, check_types=True,
|
||||
their substructures. Only possible if `check_types` is `True`.
|
||||
"""
|
||||
try:
|
||||
_pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types,
|
||||
expand_composites)
|
||||
_pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
|
||||
expand_composites)
|
||||
except (ValueError, TypeError) as e:
|
||||
str1 = str(map_structure(lambda _: _DOT, nest1))
|
||||
str2 = str(map_structure(lambda _: _DOT, nest2))
|
||||
@ -1327,6 +1327,6 @@ def flatten_with_tuple_paths(structure, expand_composites=False):
|
||||
flatten(structure, expand_composites=expand_composites)))
|
||||
|
||||
|
||||
_pywrap_tensorflow.RegisterType("Mapping", _collections_abc.Mapping)
|
||||
_pywrap_tensorflow.RegisterType("Sequence", _collections_abc.Sequence)
|
||||
_pywrap_tensorflow.RegisterType("MappingView", _collections_abc.MappingView)
|
||||
_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
|
||||
_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
|
||||
_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
|
||||
|
@ -1,212 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/python/util/util.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::swig;
|
||||
// The %exception block defined in tf_session.i releases the Python GIL for
|
||||
// the length of each wrapped method. This file is included in tensorflow.i
|
||||
// after tf_session.i and inherits this definition. We disable this behavior
|
||||
// for functions in this module because they use python methods that need GIL.
|
||||
// TODO(iga): Find a way not to leak such definitions across files.
|
||||
|
||||
%unignore tensorflow::swig::RegisterType;
|
||||
%noexception tensorflow::swig::RegisterType;
|
||||
|
||||
%unignore tensorflow::swig::IsTensor;
|
||||
%noexception tensorflow::swig::IsTensor;
|
||||
|
||||
%unignore tensorflow::swig::IsResourceVariable;
|
||||
%noexception tensorflow::swig::IsResourceVariable;
|
||||
|
||||
%unignore tensorflow::swig::IsVariable;
|
||||
%noexception tensorflow::swig::IsVariable;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsSequence
|
||||
"""Returns true if its input is a collections.Sequence (except strings).
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string and is a collections.Sequence or a
|
||||
dict.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsSequence;
|
||||
%noexception tensorflow::swig::IsSequence;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsSequenceOrComposite
|
||||
"""Returns true if its input is a sequence or a `CompositeTensor`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string and is a collections.Sequence or a
|
||||
dict or a CompositeTensor or a TypeSpec (except string and TensorSpec).
|
||||
"""
|
||||
%unignore tensorflow::swig::IsSequenceOrComposite;
|
||||
%noexception tensorflow::swig::IsSequenceOrComposite;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsCompositeTensor
|
||||
"""Returns true if its input is a `CompositeTensor`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a CompositeTensor.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsCompositeTensor;
|
||||
%noexception tensorflow::swig::IsCompositeTensor;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsTypeSpec
|
||||
"""Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a `TypeSpec`, but is not a `TensorSpec`.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsTypeSpec;
|
||||
%noexception tensorflow::swig::IsTypeSpec;
|
||||
|
||||
%unignore tensorflow::swig::IsNamedtuple;
|
||||
%noexception tensorflow::swig::IsNamedtuple;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsMapping
|
||||
"""Returns True iff `instance` is a `collections.Mapping`.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is a `collections.Mapping`.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsMapping;
|
||||
%noexception tensorflow::swig::IsMapping;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsMappingView
|
||||
"""Returns True iff `instance` is a `collections.MappingView`.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is a `collections.MappingView`.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsMappingView;
|
||||
%noexception tensorflow::swig::IsMappingView;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsAttrs
|
||||
"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is an instance of an `attr.s` decorated class.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsAttrs;
|
||||
%noexception tensorflow::swig::IsAttrs;
|
||||
|
||||
%feature("docstring") tensorflow::swig::SameNamedtuples
|
||||
"Returns True if the two namedtuples have the same name and fields."
|
||||
%unignore tensorflow::swig::SameNamedtuples;
|
||||
%noexception tensorflow::swig::SameNamedtuples;
|
||||
|
||||
%unignore tensorflow::swig::AssertSameStructure;
|
||||
%noexception tensorflow::swig::AssertSameStructure;
|
||||
|
||||
%feature("docstring") tensorflow::swig::Flatten
|
||||
"""Returns a flat list from a given nested structure.
|
||||
|
||||
If `nest` is not a sequence, tuple, or dict, then returns a single-element
|
||||
list: `[nest]`.
|
||||
|
||||
In the case of dict instances, the sequence consists of the values, sorted by
|
||||
key to ensure deterministic behavior. This is true also for `OrderedDict`
|
||||
instances: their sequence order is ignored, the sorting order of keys is
|
||||
used instead. The same convention is followed in `pack_sequence_as`. This
|
||||
correctly repacks dicts and `OrderedDict`s after they have been flattened,
|
||||
and also allows flattening an `OrderedDict` and then repacking it back using
|
||||
a corresponding plain dict, or vice-versa.
|
||||
Dictionaries with non-sortable keys cannot be flattened.
|
||||
|
||||
Users must not modify any collections used in `nest` while this function is
|
||||
running.
|
||||
|
||||
Args:
|
||||
nest: an arbitrarily nested structure or a scalar object. Note, numpy
|
||||
arrays are considered scalars.
|
||||
expand_composites: If true, then composite tensors such as `tf.SparseTensor`
|
||||
and `tf.RaggedTensor` are expanded into their component tensors.
|
||||
|
||||
Returns:
|
||||
A Python list, the flattened version of the input.
|
||||
|
||||
Raises:
|
||||
TypeError: The nest is or contains a dict with non-sortable keys.
|
||||
"""
|
||||
%unignore tensorflow::swig::Flatten;
|
||||
%noexception tensorflow::swig::Flatten;
|
||||
%feature("kwargs") tensorflow::swig::Flatten;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsSequenceForData
|
||||
"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
||||
|
||||
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
||||
which *does* treat a Python list as a sequence. For ergonomic
|
||||
reasons, `tf.data` users would prefer to treat lists as
|
||||
implicit `tf.Tensor` objects, and dicts as (nested) sequences.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string or list and is a
|
||||
collections.Sequence.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsSequenceForData;
|
||||
%noexception tensorflow::swig::IsSequenceForData;
|
||||
|
||||
%feature("docstring") tensorflow::swig::FlattenForData
|
||||
"""Returns a flat sequence from a given nested structure.
|
||||
|
||||
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
||||
|
||||
Args:
|
||||
nest: an arbitrarily nested structure or a scalar object.
|
||||
Note, numpy arrays are considered scalars.
|
||||
|
||||
Returns:
|
||||
A Python list, the flattened version of the input.
|
||||
"""
|
||||
%unignore tensorflow::swig::FlattenForData;
|
||||
%noexception tensorflow::swig::FlattenForData;
|
||||
|
||||
%unignore tensorflow::swig::AssertSameStructureForData;
|
||||
%noexception tensorflow::swig::AssertSameStructureForData;
|
||||
|
||||
%include "tensorflow/python/util/util.h"
|
||||
|
||||
%unignoreall
|
333
tensorflow/python/util/util_wrapper.cc
Normal file
333
tensorflow/python/util/util_wrapper.cc
Normal file
@ -0,0 +1,333 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/pytypes.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
inline py::object pyo_or_throw(PyObject* ptr) {
|
||||
if (PyErr_Occurred() || ptr == nullptr) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return py::reinterpret_steal<py::object>(ptr);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pywrap_utils, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_utils
|
||||
-----
|
||||
)pbdoc";
|
||||
m.def("RegisterType",
|
||||
[](const py::handle& type_name, const py::handle& type) {
|
||||
return pyo_or_throw(
|
||||
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
|
||||
});
|
||||
m.def(
|
||||
"IsTensor",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsTensor(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Check if an object is a Tensor.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsSequence",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsSequence(o.ptr());
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns true if its input is a collections.Sequence (except strings).
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string and is a collections.Sequence or a
|
||||
dict.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsSequenceOrComposite",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsSequenceOrComposite(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns true if its input is a sequence or a `CompositeTensor`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string and is a collections.Sequence or a
|
||||
dict or a CompositeTensor or a TypeSpec (except string and TensorSpec).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsCompositeTensor",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsCompositeTensor(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns true if its input is a `CompositeTensor`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a CompositeTensor.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsTypeSpec",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsTypeSpec(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a `TypeSpec`, but is not a `TensorSpec`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsNamedtuple",
|
||||
[](const py::handle& o, bool strict) {
|
||||
return pyo_or_throw(tensorflow::swig::IsNamedtuple(o.ptr(), strict));
|
||||
},
|
||||
R"pbdoc(
|
||||
Check if an object is a NamedTuple.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsMapping",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsMapping(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if `instance` is a `collections.Mapping`.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is a `collections.Mapping`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsMappingView",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsMappingView(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if considered a mapping view for the purposes of Flatten()`.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if considered a mapping view for the purposes of Flatten().
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsAttrs",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsAttrs(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if `instance` is an instance of an `attr.s` decorated class.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is an instance of an `attr.s` decorated class.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"SameNamedtuples",
|
||||
[](const py::handle& o1, const py::handle& o2) {
|
||||
return pyo_or_throw(
|
||||
tensorflow::swig::SameNamedtuples(o1.ptr(), o2.ptr()));
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if the two namedtuples have the same name and fields.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"AssertSameStructure",
|
||||
[](const py::handle& o1, const py::handle& o2, bool check_types,
|
||||
bool expand_composites) {
|
||||
bool result = tensorflow::swig::AssertSameStructure(
|
||||
o1.ptr(), o2.ptr(), check_types, expand_composites);
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if the two structures are nested in the same way.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"Flatten",
|
||||
[](const py::handle& o, bool expand_composites) {
|
||||
return pyo_or_throw(
|
||||
tensorflow::swig::Flatten(o.ptr(), expand_composites));
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns a flat list from a given nested structure.
|
||||
|
||||
If `nest` is not a sequence, tuple, or dict, then returns a single-element
|
||||
list: `[nest]`.
|
||||
|
||||
In the case of dict instances, the sequence consists of the values, sorted by
|
||||
key to ensure deterministic behavior. This is true also for `OrderedDict`
|
||||
instances: their sequence order is ignored, the sorting order of keys is
|
||||
used instead. The same convention is followed in `pack_sequence_as`. This
|
||||
correctly repacks dicts and `OrderedDict`s after they have been flattened,
|
||||
and also allows flattening an `OrderedDict` and then repacking it back using
|
||||
a corresponding plain dict, or vice-versa.
|
||||
Dictionaries with non-sortable keys cannot be flattened.
|
||||
|
||||
Users must not modify any collections used in `nest` while this function is
|
||||
running.
|
||||
|
||||
Args:
|
||||
nest: an arbitrarily nested structure or a scalar object. Note, numpy
|
||||
arrays are considered scalars.
|
||||
expand_composites: If true, then composite tensors such as `tf.SparseTensor`
|
||||
and `tf.RaggedTensor` are expanded into their component tensors.
|
||||
|
||||
Returns:
|
||||
A Python list, the flattened version of the input.
|
||||
|
||||
Raises:
|
||||
TypeError: The nest is or contains a dict with non-sortable keys.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsSequenceForData",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsSequenceForData(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
||||
|
||||
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
||||
which *does* treat a Python list as a sequence. For ergonomic
|
||||
reasons, `tf.data` users would prefer to treat lists as
|
||||
implicit `tf.Tensor` objects, and dicts as (nested) sequences.
|
||||
|
||||
Args:
|
||||
seq: an input sequence.
|
||||
|
||||
Returns:
|
||||
True if the sequence is a not a string or list and is a
|
||||
collections.Sequence.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"FlattenForData",
|
||||
[](const py::handle& o) {
|
||||
return pyo_or_throw(tensorflow::swig::FlattenForData(o.ptr()));
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns a flat sequence from a given nested structure.
|
||||
|
||||
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
||||
|
||||
Args:
|
||||
nest: an arbitrarily nested structure or a scalar object.
|
||||
Note, numpy arrays are considered scalars.
|
||||
|
||||
Returns:
|
||||
A Python list, the flattened version of the input.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"AssertSameStructureForData",
|
||||
[](const py::handle& o1, const py::handle& o2, bool check_types) {
|
||||
bool result = tensorflow::swig::AssertSameStructureForData(
|
||||
o1.ptr(), o2.ptr(), check_types);
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns True if the two structures are nested in the same way in particular tf.data.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsResourceVariable",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsResourceVariable(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns 1 if `o` is a ResourceVariable.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is a `ResourceVariable`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"IsVariable",
|
||||
[](const py::handle& o) {
|
||||
bool result = tensorflow::swig::IsVariable(o.ptr());
|
||||
if (PyErr_Occurred()) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
R"pbdoc(
|
||||
Returns 1 if `o` is a Variable.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is a `Variable`.
|
||||
)pbdoc");
|
||||
}
|
@ -1 +1,19 @@
|
||||
|
||||
[cpp_python_util]
|
||||
tensorflow::swig::IsSequence
|
||||
tensorflow::swig::IsSequenceOrComposite
|
||||
tensorflow::swig::IsCompositeTensor
|
||||
tensorflow::swig::IsTypeSpec
|
||||
tensorflow::swig::IsNamedtuple
|
||||
tensorflow::swig::IsMapping
|
||||
tensorflow::swig::IsMappingView
|
||||
tensorflow::swig::IsAttrs
|
||||
tensorflow::swig::IsTensor
|
||||
tensorflow::swig::IsResourceVariable
|
||||
tensorflow::swig::IsVariable
|
||||
tensorflow::swig::SameNamedtuples
|
||||
tensorflow::swig::AssertSameStructure
|
||||
tensorflow::swig::Flatten
|
||||
tensorflow::swig::IsSequenceForData
|
||||
tensorflow::swig::FlattenForData
|
||||
tensorflow::swig::AssertSameStructureForData
|
||||
tensorflow::swig::RegisterType
|
||||
|
Loading…
x
Reference in New Issue
Block a user