diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a5071d18da6..dcd99a0f3d6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -20,6 +20,7 @@ visibility = [ ] load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow:tensorflow.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow:tensorflow.bzl", "cuda_py_test") @@ -99,6 +100,7 @@ py_library( "//third_party/py/tensorflow_core:__subpackages__", ], deps = [ + ":_pywrap_utils", ":array_ops", ":audio_ops_gen", ":bitwise_ops", @@ -377,6 +379,22 @@ cc_library( ], ) +tf_python_pybind_extension( + name = "_pywrap_utils", + srcs = ["util/util_wrapper.cc"], + hdrs = ["util/util.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_utils", + deps = [ + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + cc_library( name = "cpp_python_util", srcs = ["util/util.cc"], @@ -685,6 +703,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":_pywrap_utils", ":common_shapes", ":composite_tensor", ":convert_to_constants", @@ -4987,7 +5006,6 @@ tf_py_wrap_cc( "util/tfprof.i", "util/traceme.i", "util/transform_graph.i", - "util/util.i", "//tensorflow/lite/toco/python:toco.i", ], # add win_def_file for pywrap_tensorflow @@ -5056,14 +5074,20 @@ tf_py_wrap_cc( # the dynamic libraries of custom ops can find it at runtime. genrule( name = "pywrap_tensorflow_filtered_def_file", - srcs = ["//tensorflow:tensorflow_def_file"], + srcs = [ + "//tensorflow:tensorflow_def_file", + "//tensorflow/tools/def_file_filter:symbols_pybind", + ":cpp_python_util", + ], outs = ["pywrap_tensorflow_filtered_def_file.def"], cmd = select({ "//tensorflow:windows": """ $(location @local_config_def_file_filter//:def_file_filter) \\ --input $(location //tensorflow:tensorflow_def_file) \\ --output $@ \\ - --target _pywrap_tensorflow_internal.pyd + --target _pywrap_tensorflow_internal.pyd \\ + --lib_paths $(execpath :cpp_python_util) \\ + --symbols $(location //tensorflow/tools/def_file_filter:symbols_pybind) """, "//conditions:default": "touch $@", # Just a placeholder for Unix platforms }), diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 4e5477d17b2..06216f47b85 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -47,6 +47,7 @@ import traceback import numpy as np from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils # Protocol buffers from tensorflow.core.framework.graph_pb2 import * diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 24cdc97d006..245f578826b 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -37,7 +37,7 @@ from __future__ import print_function import six as _six -from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.framework import sparse_tensor as _sparse_tensor from tensorflow.python.util.compat import collections_abc as _collections_abc @@ -95,10 +95,10 @@ def _yield_value(iterable): # See the swig file (../../util/util.i) for documentation. -is_sequence = _pywrap_tensorflow.IsSequenceForData +is_sequence = _pywrap_utils.IsSequenceForData # See the swig file (../../util/util.i) for documentation. -flatten = _pywrap_tensorflow.FlattenForData +flatten = _pywrap_utils.FlattenForData def assert_same_structure(nest1, nest2, check_types=True): @@ -120,7 +120,7 @@ def assert_same_structure(nest1, nest2, check_types=True): TypeError: If the two structures differ in the type of sequence in any of their substructures. Only possible if `check_types` is `True`. """ - _pywrap_tensorflow.AssertSameStructureForData(nest1, nest2, check_types) + _pywrap_utils.AssertSameStructureForData(nest1, nest2, check_types) def _packed_nest_with_indices(structure, flat, index): diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d047bc5c455..37632d183ec 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -24,6 +24,7 @@ import sys import six +from tensorflow.python import _pywrap_utils from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute @@ -844,8 +845,7 @@ class GradientTape(object): ValueError: if it encounters something that is not a tensor. """ for t in nest.flatten(tensor): - if not (pywrap_tensorflow.IsTensor(t) or - pywrap_tensorflow.IsVariable(t)): + if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): raise ValueError("Passed in object of type {}, not tf.Tensor".format( type(t))) if not t.dtype.is_floating: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 9db064a8030..524a4af289f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -32,6 +32,7 @@ import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 +from tensorflow.python import _pywrap_utils from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute @@ -1416,8 +1417,8 @@ class ConcreteFunction(object): return ret -pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) -pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) +_pywrap_utils.RegisterType("Tensor", ops.Tensor) +_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices) def _deterministic_dict_values(dictionary): @@ -1698,7 +1699,7 @@ def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): need_packing = False for index, (value, spec) in enumerate(zip(flatten_inputs, flat_input_signature)): - if not pywrap_tensorflow.IsTensor(value): + if not _pywrap_utils.IsTensor(value): try: flatten_inputs[index] = ops.convert_to_tensor( value, dtype_hint=spec.dtype) diff --git a/tensorflow/python/framework/composite_tensor.py b/tensorflow/python/framework/composite_tensor.py index e44e3a83d38..b475685b779 100644 --- a/tensorflow/python/framework/composite_tensor.py +++ b/tensorflow/python/framework/composite_tensor.py @@ -22,7 +22,7 @@ import abc import six -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.util import nest @@ -137,7 +137,7 @@ class CompositeTensor(object): return list(set(consumers)) -pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor) +_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor) def replace_composites_with_components(structure): diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index fe0c42ffde1..ec60b675226 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -21,7 +21,7 @@ from __future__ import print_function import collections import numpy as np -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python import tf2 from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op @@ -255,7 +255,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): SparseTensorValue = collections.namedtuple("SparseTensorValue", ["indices", "values", "dense_shape"]) tf_export(v1=["SparseTensorValue"])(SparseTensorValue) -pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue) +_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue) @tf_export("SparseTensorSpec") diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index 1e224e628c2..7240f288686 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -287,7 +287,7 @@ class BoundedTensorSpec(TensorSpec): return (self._shape, self._dtype, self._minimum, self._maximum, self._name) -pywrap_tensorflow.RegisterType("TensorSpec", TensorSpec) +_pywrap_utils.RegisterType("TensorSpec", TensorSpec) # Note: we do not include Tensor names when constructing TypeSpecs. diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index ffc93b06c67..c724f5d8100 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -22,7 +22,7 @@ import abc import numpy as np import six -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape @@ -546,4 +546,4 @@ def register_type_spec_from_value_converter(type_object, converter_fn, (type_object, converter_fn, allow_subclass)) -pywrap_tensorflow.RegisterType("TypeSpec", TypeSpec) +_pywrap_utils.RegisterType("TypeSpec", TypeSpec) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 608176b4200..dc6ebd0f64f 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -24,6 +24,7 @@ import functools from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 +from tensorflow.python import _pywrap_utils from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape @@ -1781,7 +1782,7 @@ class UninitializedVariable(BaseResourceVariable): synchronization=synchronization, aggregation=aggregation) -pywrap_tensorflow.RegisterType("ResourceVariable", ResourceVariable) +_pywrap_utils.RegisterType("ResourceVariable", ResourceVariable) math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 8805a719aee..7ff361b9db4 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -26,7 +26,7 @@ import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -1351,7 +1351,7 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)): Variable._OverloadAllOperators() # pylint: disable=protected-access -pywrap_tensorflow.RegisterType("Variable", Variable) +_pywrap_utils.RegisterType("Variable", Variable) @tf_export(v1=["Variable"]) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index ee75e04f8c9..2c76f50aa5d 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -172,6 +172,7 @@ limitations under the License. %{ #include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/util/util.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/eager/c_api_experimental.h" diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index fef77ce2432..c331601bebd 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -49,8 +49,6 @@ limitations under the License. %include "tensorflow/python/util/transform_graph.i" -%include "tensorflow/python/util/util.i" - %include "tensorflow/python/grappler/cluster.i" %include "tensorflow/python/grappler/item.i" %include "tensorflow/python/grappler/tf_optimizer.i" diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 97a587e734c..5cff541c5c6 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -38,7 +38,7 @@ import collections as _collections import six as _six -from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow +from tensorflow.python import _pywrap_utils from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.compat import collections_abc as _collections_abc @@ -104,15 +104,15 @@ def _is_namedtuple(instance, strict=False): Returns: True if `instance` is a `namedtuple`. """ - return _pywrap_tensorflow.IsNamedtuple(instance, strict) + return _pywrap_utils.IsNamedtuple(instance, strict) # See the swig file (util.i) for documentation. -_is_mapping = _pywrap_tensorflow.IsMapping -_is_mapping_view = _pywrap_tensorflow.IsMappingView -_is_attrs = _pywrap_tensorflow.IsAttrs -_is_composite_tensor = _pywrap_tensorflow.IsCompositeTensor -_is_type_spec = _pywrap_tensorflow.IsTypeSpec +_is_mapping = _pywrap_utils.IsMapping +_is_mapping_view = _pywrap_utils.IsMappingView +_is_attrs = _pywrap_utils.IsAttrs +_is_composite_tensor = _pywrap_utils.IsCompositeTensor +_is_type_spec = _pywrap_utils.IsTypeSpec def _sequence_like(instance, args): @@ -208,11 +208,11 @@ def _yield_sorted_items(iterable): # See the swig file (util.i) for documentation. -is_sequence = _pywrap_tensorflow.IsSequence +is_sequence = _pywrap_utils.IsSequence # See the swig file (util.i) for documentation. -is_sequence_or_composite = _pywrap_tensorflow.IsSequenceOrComposite +is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite @tf_export("nest.is_nested") @@ -260,11 +260,11 @@ def flatten(structure, expand_composites=False): Raises: TypeError: The nest is or contains a dict with non-sortable keys. """ - return _pywrap_tensorflow.Flatten(structure, expand_composites) + return _pywrap_utils.Flatten(structure, expand_composites) # See the swig file (util.i) for documentation. -_same_namedtuples = _pywrap_tensorflow.SameNamedtuples +_same_namedtuples = _pywrap_utils.SameNamedtuples class _DotString(object): @@ -315,8 +315,8 @@ def assert_same_structure(nest1, nest2, check_types=True, their substructures. Only possible if `check_types` is `True`. """ try: - _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types, - expand_composites) + _pywrap_utils.AssertSameStructure(nest1, nest2, check_types, + expand_composites) except (ValueError, TypeError) as e: str1 = str(map_structure(lambda _: _DOT, nest1)) str2 = str(map_structure(lambda _: _DOT, nest2)) @@ -1327,6 +1327,6 @@ def flatten_with_tuple_paths(structure, expand_composites=False): flatten(structure, expand_composites=expand_composites))) -_pywrap_tensorflow.RegisterType("Mapping", _collections_abc.Mapping) -_pywrap_tensorflow.RegisterType("Sequence", _collections_abc.Sequence) -_pywrap_tensorflow.RegisterType("MappingView", _collections_abc.MappingView) +_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) +_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) +_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView) diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i deleted file mode 100644 index f9a08cc3d23..00000000000 --- a/tensorflow/python/util/util.i +++ /dev/null @@ -1,212 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/python/util/util.h" -%} - -%ignoreall - -%unignore tensorflow; -%unignore tensorflow::swig; -// The %exception block defined in tf_session.i releases the Python GIL for -// the length of each wrapped method. This file is included in tensorflow.i -// after tf_session.i and inherits this definition. We disable this behavior -// for functions in this module because they use python methods that need GIL. -// TODO(iga): Find a way not to leak such definitions across files. - -%unignore tensorflow::swig::RegisterType; -%noexception tensorflow::swig::RegisterType; - -%unignore tensorflow::swig::IsTensor; -%noexception tensorflow::swig::IsTensor; - -%unignore tensorflow::swig::IsResourceVariable; -%noexception tensorflow::swig::IsResourceVariable; - -%unignore tensorflow::swig::IsVariable; -%noexception tensorflow::swig::IsVariable; - -%feature("docstring") tensorflow::swig::IsSequence -"""Returns true if its input is a collections.Sequence (except strings). - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a not a string and is a collections.Sequence or a - dict. -""" -%unignore tensorflow::swig::IsSequence; -%noexception tensorflow::swig::IsSequence; - -%feature("docstring") tensorflow::swig::IsSequenceOrComposite -"""Returns true if its input is a sequence or a `CompositeTensor`. - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a not a string and is a collections.Sequence or a - dict or a CompositeTensor or a TypeSpec (except string and TensorSpec). -""" -%unignore tensorflow::swig::IsSequenceOrComposite; -%noexception tensorflow::swig::IsSequenceOrComposite; - -%feature("docstring") tensorflow::swig::IsCompositeTensor -"""Returns true if its input is a `CompositeTensor`. - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a CompositeTensor. -""" -%unignore tensorflow::swig::IsCompositeTensor; -%noexception tensorflow::swig::IsCompositeTensor; - -%feature("docstring") tensorflow::swig::IsTypeSpec -"""Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`. - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a `TypeSpec`, but is not a `TensorSpec`. -""" -%unignore tensorflow::swig::IsTypeSpec; -%noexception tensorflow::swig::IsTypeSpec; - -%unignore tensorflow::swig::IsNamedtuple; -%noexception tensorflow::swig::IsNamedtuple; - -%feature("docstring") tensorflow::swig::IsMapping -"""Returns True iff `instance` is a `collections.Mapping`. - -Args: - instance: An instance of a Python object. - -Returns: - True if `instance` is a `collections.Mapping`. -""" -%unignore tensorflow::swig::IsMapping; -%noexception tensorflow::swig::IsMapping; - -%feature("docstring") tensorflow::swig::IsMappingView -"""Returns True iff `instance` is a `collections.MappingView`. - -Args: - instance: An instance of a Python object. - -Returns: - True if `instance` is a `collections.MappingView`. -""" -%unignore tensorflow::swig::IsMappingView; -%noexception tensorflow::swig::IsMappingView; - -%feature("docstring") tensorflow::swig::IsAttrs -"""Returns True iff `instance` is an instance of an `attr.s` decorated class. - -Args: - instance: An instance of a Python object. - -Returns: - True if `instance` is an instance of an `attr.s` decorated class. -""" -%unignore tensorflow::swig::IsAttrs; -%noexception tensorflow::swig::IsAttrs; - -%feature("docstring") tensorflow::swig::SameNamedtuples -"Returns True if the two namedtuples have the same name and fields." -%unignore tensorflow::swig::SameNamedtuples; -%noexception tensorflow::swig::SameNamedtuples; - -%unignore tensorflow::swig::AssertSameStructure; -%noexception tensorflow::swig::AssertSameStructure; - -%feature("docstring") tensorflow::swig::Flatten -"""Returns a flat list from a given nested structure. - -If `nest` is not a sequence, tuple, or dict, then returns a single-element -list: `[nest]`. - -In the case of dict instances, the sequence consists of the values, sorted by -key to ensure deterministic behavior. This is true also for `OrderedDict` -instances: their sequence order is ignored, the sorting order of keys is -used instead. The same convention is followed in `pack_sequence_as`. This -correctly repacks dicts and `OrderedDict`s after they have been flattened, -and also allows flattening an `OrderedDict` and then repacking it back using -a corresponding plain dict, or vice-versa. -Dictionaries with non-sortable keys cannot be flattened. - -Users must not modify any collections used in `nest` while this function is -running. - -Args: - nest: an arbitrarily nested structure or a scalar object. Note, numpy - arrays are considered scalars. - expand_composites: If true, then composite tensors such as `tf.SparseTensor` - and `tf.RaggedTensor` are expanded into their component tensors. - -Returns: - A Python list, the flattened version of the input. - -Raises: - TypeError: The nest is or contains a dict with non-sortable keys. -""" -%unignore tensorflow::swig::Flatten; -%noexception tensorflow::swig::Flatten; -%feature("kwargs") tensorflow::swig::Flatten; - -%feature("docstring") tensorflow::swig::IsSequenceForData -"""Returns a true if `seq` is a Sequence or dict (except strings/lists). - -NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, -which *does* treat a Python list as a sequence. For ergonomic -reasons, `tf.data` users would prefer to treat lists as -implicit `tf.Tensor` objects, and dicts as (nested) sequences. - -Args: - seq: an input sequence. - -Returns: - True if the sequence is a not a string or list and is a - collections.Sequence. -""" -%unignore tensorflow::swig::IsSequenceForData; -%noexception tensorflow::swig::IsSequenceForData; - -%feature("docstring") tensorflow::swig::FlattenForData -"""Returns a flat sequence from a given nested structure. - -If `nest` is not a sequence, this returns a single-element list: `[nest]`. - -Args: - nest: an arbitrarily nested structure or a scalar object. - Note, numpy arrays are considered scalars. - -Returns: - A Python list, the flattened version of the input. -""" -%unignore tensorflow::swig::FlattenForData; -%noexception tensorflow::swig::FlattenForData; - -%unignore tensorflow::swig::AssertSameStructureForData; -%noexception tensorflow::swig::AssertSameStructureForData; - -%include "tensorflow/python/util/util.h" - -%unignoreall diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc new file mode 100644 index 00000000000..835ba070b01 --- /dev/null +++ b/tensorflow/python/util/util_wrapper.cc @@ -0,0 +1,333 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "tensorflow/python/util/util.h" + +namespace py = pybind11; + +inline py::object pyo_or_throw(PyObject* ptr) { + if (PyErr_Occurred() || ptr == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(ptr); +} + +PYBIND11_MODULE(_pywrap_utils, m) { + m.doc() = R"pbdoc( + _pywrap_utils + ----- + )pbdoc"; + m.def("RegisterType", + [](const py::handle& type_name, const py::handle& type) { + return pyo_or_throw( + tensorflow::swig::RegisterType(type_name.ptr(), type.ptr())); + }); + m.def( + "IsTensor", + [](const py::handle& o) { + bool result = tensorflow::swig::IsTensor(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Check if an object is a Tensor. + )pbdoc"); + m.def( + "IsSequence", + [](const py::handle& o) { + bool result = tensorflow::swig::IsSequence(o.ptr()); + return result; + }, + R"pbdoc( + Returns true if its input is a collections.Sequence (except strings). + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a not a string and is a collections.Sequence or a + dict. + )pbdoc"); + m.def( + "IsSequenceOrComposite", + [](const py::handle& o) { + bool result = tensorflow::swig::IsSequenceOrComposite(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns true if its input is a sequence or a `CompositeTensor`. + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a not a string and is a collections.Sequence or a + dict or a CompositeTensor or a TypeSpec (except string and TensorSpec). + )pbdoc"); + m.def( + "IsCompositeTensor", + [](const py::handle& o) { + bool result = tensorflow::swig::IsCompositeTensor(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns true if its input is a `CompositeTensor`. + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a CompositeTensor. + )pbdoc"); + m.def( + "IsTypeSpec", + [](const py::handle& o) { + bool result = tensorflow::swig::IsTypeSpec(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns true if its input is a `TypeSpec`, but is not a `TensorSpec`. + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a `TypeSpec`, but is not a `TensorSpec`. + )pbdoc"); + m.def( + "IsNamedtuple", + [](const py::handle& o, bool strict) { + return pyo_or_throw(tensorflow::swig::IsNamedtuple(o.ptr(), strict)); + }, + R"pbdoc( + Check if an object is a NamedTuple. + )pbdoc"); + m.def( + "IsMapping", + [](const py::handle& o) { + bool result = tensorflow::swig::IsMapping(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns True if `instance` is a `collections.Mapping`. + + Args: + instance: An instance of a Python object. + + Returns: + True if `instance` is a `collections.Mapping`. + )pbdoc"); + m.def( + "IsMappingView", + [](const py::handle& o) { + bool result = tensorflow::swig::IsMappingView(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns True if considered a mapping view for the purposes of Flatten()`. + + Args: + instance: An instance of a Python object. + + Returns: + True if considered a mapping view for the purposes of Flatten(). + )pbdoc"); + m.def( + "IsAttrs", + [](const py::handle& o) { + bool result = tensorflow::swig::IsAttrs(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns True if `instance` is an instance of an `attr.s` decorated class. + + Args: + instance: An instance of a Python object. + + Returns: + True if `instance` is an instance of an `attr.s` decorated class. + )pbdoc"); + m.def( + "SameNamedtuples", + [](const py::handle& o1, const py::handle& o2) { + return pyo_or_throw( + tensorflow::swig::SameNamedtuples(o1.ptr(), o2.ptr())); + }, + R"pbdoc( + Returns True if the two namedtuples have the same name and fields. + )pbdoc"); + m.def( + "AssertSameStructure", + [](const py::handle& o1, const py::handle& o2, bool check_types, + bool expand_composites) { + bool result = tensorflow::swig::AssertSameStructure( + o1.ptr(), o2.ptr(), check_types, expand_composites); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns True if the two structures are nested in the same way. + )pbdoc"); + m.def( + "Flatten", + [](const py::handle& o, bool expand_composites) { + return pyo_or_throw( + tensorflow::swig::Flatten(o.ptr(), expand_composites)); + }, + R"pbdoc( + Returns a flat list from a given nested structure. + + If `nest` is not a sequence, tuple, or dict, then returns a single-element + list: `[nest]`. + + In the case of dict instances, the sequence consists of the values, sorted by + key to ensure deterministic behavior. This is true also for `OrderedDict` + instances: their sequence order is ignored, the sorting order of keys is + used instead. The same convention is followed in `pack_sequence_as`. This + correctly repacks dicts and `OrderedDict`s after they have been flattened, + and also allows flattening an `OrderedDict` and then repacking it back using + a corresponding plain dict, or vice-versa. + Dictionaries with non-sortable keys cannot be flattened. + + Users must not modify any collections used in `nest` while this function is + running. + + Args: + nest: an arbitrarily nested structure or a scalar object. Note, numpy + arrays are considered scalars. + expand_composites: If true, then composite tensors such as `tf.SparseTensor` + and `tf.RaggedTensor` are expanded into their component tensors. + + Returns: + A Python list, the flattened version of the input. + + Raises: + TypeError: The nest is or contains a dict with non-sortable keys. + )pbdoc"); + m.def( + "IsSequenceForData", + [](const py::handle& o) { + bool result = tensorflow::swig::IsSequenceForData(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns a true if `seq` is a Sequence or dict (except strings/lists). + + NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, + which *does* treat a Python list as a sequence. For ergonomic + reasons, `tf.data` users would prefer to treat lists as + implicit `tf.Tensor` objects, and dicts as (nested) sequences. + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a not a string or list and is a + collections.Sequence. + )pbdoc"); + m.def( + "FlattenForData", + [](const py::handle& o) { + return pyo_or_throw(tensorflow::swig::FlattenForData(o.ptr())); + }, + R"pbdoc( + Returns a flat sequence from a given nested structure. + + If `nest` is not a sequence, this returns a single-element list: `[nest]`. + + Args: + nest: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + + Returns: + A Python list, the flattened version of the input. + )pbdoc"); + m.def( + "AssertSameStructureForData", + [](const py::handle& o1, const py::handle& o2, bool check_types) { + bool result = tensorflow::swig::AssertSameStructureForData( + o1.ptr(), o2.ptr(), check_types); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns True if the two structures are nested in the same way in particular tf.data. + )pbdoc"); + m.def( + "IsResourceVariable", + [](const py::handle& o) { + bool result = tensorflow::swig::IsResourceVariable(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns 1 if `o` is a ResourceVariable. + + Args: + instance: An instance of a Python object. + + Returns: + True if `instance` is a `ResourceVariable`. + )pbdoc"); + m.def( + "IsVariable", + [](const py::handle& o) { + bool result = tensorflow::swig::IsVariable(o.ptr()); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return result; + }, + R"pbdoc( + Returns 1 if `o` is a Variable. + + Args: + instance: An instance of a Python object. + + Returns: + True if `instance` is a `Variable`. + )pbdoc"); +} diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 8b137891791..4dcc8abaa8d 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -1 +1,19 @@ - +[cpp_python_util] +tensorflow::swig::IsSequence +tensorflow::swig::IsSequenceOrComposite +tensorflow::swig::IsCompositeTensor +tensorflow::swig::IsTypeSpec +tensorflow::swig::IsNamedtuple +tensorflow::swig::IsMapping +tensorflow::swig::IsMappingView +tensorflow::swig::IsAttrs +tensorflow::swig::IsTensor +tensorflow::swig::IsResourceVariable +tensorflow::swig::IsVariable +tensorflow::swig::SameNamedtuples +tensorflow::swig::AssertSameStructure +tensorflow::swig::Flatten +tensorflow::swig::IsSequenceForData +tensorflow::swig::FlattenForData +tensorflow::swig::AssertSameStructureForData +tensorflow::swig::RegisterType