diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3234aa9a64d..2f0241146b8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1140,11 +1140,24 @@ py_library( srcs_version = "PY2AND3", ) +tf_python_pybind_extension( + name = "_dtypes", + srcs = ["framework/dtypes.cc"], + module_name = "_dtypes", + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", + "@pybind11", + ], +) + py_library( name = "dtypes", srcs = ["framework/dtypes.py"], srcs_version = "PY2AND3", deps = [ + ":_dtypes", ":pywrap_tensorflow", "//tensorflow/core:protos_all_py", ], diff --git a/tensorflow/python/framework/dtypes.cc b/tensorflow/python/framework/dtypes.cc new file mode 100644 index 00000000000..7c8521bd2d0 --- /dev/null +++ b/tensorflow/python/framework/dtypes.cc @@ -0,0 +1,157 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/detail/common.h" +#include "include/pybind11/pybind11.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace { + +inline int DataTypeId(tensorflow::DataType dt) { return static_cast<int>(dt); } + +// A variant of tensorflow::DataTypeString which uses fixed-width names +// for floating point data types. This behavior is compatible with that of +// existing pure Python DType. +const std::string DataTypeStringCompat(tensorflow::DataType dt) { + switch (dt) { + case tensorflow::DataType::DT_HALF: + return "float16"; + case tensorflow::DataType::DT_HALF_REF: + return "float16_ref"; + case tensorflow::DataType::DT_FLOAT: + return "float32"; + case tensorflow::DataType::DT_FLOAT_REF: + return "float32_ref"; + case tensorflow::DataType::DT_DOUBLE: + return "float64"; + case tensorflow::DataType::DT_DOUBLE_REF: + return "float64_ref"; + default: + return tensorflow::DataTypeString(dt); + } +} + +} // namespace + +namespace tensorflow { + +constexpr DataTypeSet kNumPyIncompatibleTypes = + ToSet(DataType::DT_RESOURCE) | ToSet(DataType::DT_VARIANT); + +inline bool DataTypeIsNumPyCompatible(DataType dt) { + return !kNumPyIncompatibleTypes.Contains(dt); +} + +} // namespace tensorflow + +namespace py = pybind11; + +PYBIND11_MODULE(_dtypes, m) { + py::class_<tensorflow::DataType>(m, "DType") + .def(py::init([](py::object obj) { + auto id = static_cast<int>(py::int_(obj)); + if (tensorflow::DataType_IsValid(id) && + id != static_cast<int>(tensorflow::DT_INVALID)) { + return static_cast<tensorflow::DataType>(id); + } + throw py::type_error( + py::str("%d does not correspond to a valid tensorflow::DataType") + .format(id)); + })) + // For compatibility with pure-Python DType. + .def_property_readonly("_type_enum", &DataTypeId) + .def_property_readonly( + "as_datatype_enum", &DataTypeId, + "Returns a `types_pb2.DataType` enum value based on this data type.") + + .def_property_readonly("name", + [](tensorflow::DataType self) { +#if PY_MAJOR_VERSION < 3 + return py::bytes(DataTypeStringCompat(self)); +#else + return DataTypeStringCompat(self); +#endif + }) + .def_property_readonly( + "size", + [](tensorflow::DataType self) { + return tensorflow::DataTypeSize(tensorflow::BaseType(self)); + }) + + .def("__repr__", + [](tensorflow::DataType self) { + return py::str("tf.{}").format(DataTypeStringCompat(self)); + }) + .def("__str__", + [](tensorflow::DataType self) { + return py::str("<dtype: {!r}>") +#if PY_MAJOR_VERSION < 3 + .format(py::bytes(DataTypeStringCompat(self))); +#else + .format(DataTypeStringCompat(self)); +#endif + }) + .def("__hash__", &DataTypeId) + + .def_property_readonly( + "is_numpy_compatible", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsNumPyCompatible( + tensorflow::BaseType(self)); + }, + "Returns whether this data type has a compatible NumPy data type.") + + .def_property_readonly( + "is_bool", + [](tensorflow::DataType self) { + return tensorflow::BaseType(self) == tensorflow::DT_BOOL; + }, + "Returns whether this is a boolean data type.") + .def_property_readonly( + "is_complex", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsComplex(tensorflow::BaseType(self)); + }, + "Returns whether this is a complex floating point type.") + .def_property_readonly( + "is_floating", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsFloating(tensorflow::BaseType(self)); + }, + "Returns whether this is a (non-quantized, real) floating point " + "type.") + .def_property_readonly( + "is_integer", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsInteger(tensorflow::BaseType(self)); + }, + "Returns whether this is a (non-quantized) integer type.") + .def_property_readonly( + "is_quantized", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsQuantized(tensorflow::BaseType(self)); + }, + "Returns whether this is a quantized data type.") + .def_property_readonly( + "is_unsigned", + [](tensorflow::DataType self) { + return tensorflow::DataTypeIsUnsigned(tensorflow::BaseType(self)); + }, + R"doc(Returns whether this type is unsigned. + +Non-numeric, unordered, and quantized types are not considered unsigned, and +this function returns `False`.)doc"); +} diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 423cd14c803..a9a8ac0518a 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -21,14 +21,16 @@ import numpy as np from six.moves import builtins from tensorflow.core.framework import types_pb2 +from tensorflow.python import _dtypes from tensorflow.python import pywrap_tensorflow from tensorflow.python.util.tf_export import tf_export _np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type() +# pylint: disable=slots-on-old-class @tf_export("dtypes.DType", "DType") -class DType(object): +class DType(_dtypes.DType): """Represents the type of the elements in a `Tensor`. The following `DType` objects are defined: @@ -60,30 +62,7 @@ class DType(object): The `tf.as_dtype()` function converts numpy types and string type names to a `DType` object. """ - __slots__ = ["_type_enum"] - - def __init__(self, type_enum): - """Creates a new `DataType`. - - NOTE(mrry): In normal circumstances, you should not need to - construct a `DataType` object directly. Instead, use the - `tf.as_dtype()` function. - - Args: - type_enum: A `types_pb2.DataType` enum value. - - Raises: - TypeError: If `type_enum` is not a value `types_pb2.DataType`. - - """ - # TODO(mrry): Make the necessary changes (using __new__) to ensure - # that calling this returns one of the interned values. - type_enum = int(type_enum) - if (type_enum not in types_pb2.DataType.values() or - type_enum == types_pb2.DT_INVALID): - raise TypeError("type_enum is not a valid types_pb2.DataType: %s" % - type_enum) - self._type_enum = type_enum + __slots__ = () @property def _is_ref_dtype(self): @@ -117,63 +96,11 @@ class DType(object): else: return self - @property - def is_numpy_compatible(self): - return self._type_enum not in _NUMPY_INCOMPATIBLE - @property def as_numpy_dtype(self): """Returns a `numpy.dtype` based on this `DType`.""" return _TF_TO_NP[self._type_enum] - @property - def as_datatype_enum(self): - """Returns a `types_pb2.DataType` enum value based on this `DType`.""" - return self._type_enum - - @property - def is_bool(self): - """Returns whether this is a boolean data type.""" - return self.base_dtype == bool - - @property - def is_integer(self): - """Returns whether this is a (non-quantized) integer type.""" - return (self.is_numpy_compatible and not self.is_quantized and - np.issubdtype(self.as_numpy_dtype, np.integer)) - - @property - def is_floating(self): - """Returns whether this is a (non-quantized, real) floating point type.""" - return ((self.is_numpy_compatible and - np.issubdtype(self.as_numpy_dtype, np.floating)) or - self.base_dtype == bfloat16) - - @property - def is_complex(self): - """Returns whether this is a complex floating point type.""" - return self.base_dtype in (complex64, complex128) - - @property - def is_quantized(self): - """Returns whether this is a quantized data type.""" - return self.base_dtype in _QUANTIZED_DTYPES_NO_REF - - @property - def is_unsigned(self): - """Returns whether this type is unsigned. - - Non-numeric, unordered, and quantized types are not considered unsigned, and - this function returns `False`. - - Returns: - Whether a `DType` is unsigned. - """ - try: - return self.min == 0 - except TypeError: - return False - @property def min(self): """Returns the minimum representable value in this data type. @@ -275,29 +202,15 @@ class DType(object): """Returns True iff self != other.""" return not self.__eq__(other) - @property - def name(self): - """Returns the string name for this `DType`.""" - return _TYPE_TO_STRING[self._type_enum] - - def __str__(self): - return "<dtype: %r>" % self.name - - def __repr__(self): - return "tf." + self.name - - def __hash__(self): - return self._type_enum + # "If a class that overrides __eq__() needs to retain the implementation + # of __hash__() from a parent class, the interpreter must be told this + # explicitly by setting __hash__ = <ParentClass>.__hash__." + # TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++. + __hash__ = _dtypes.DType.__hash__ def __reduce__(self): return as_dtype, (self.name,) - - @property - def size(self): - if (self._type_enum == types_pb2.DT_VARIANT or - self._type_enum == types_pb2.DT_RESOURCE): - return 1 - return np.dtype(self.as_numpy_dtype).itemsize +# pylint: enable=slots-on-old-class # Define data type range of numpy dtype @@ -395,11 +308,6 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF) qint32_ref = DType(types_pb2.DT_QINT32_REF) bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) -_NUMPY_INCOMPATIBLE = frozenset([ - types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, - types_pb2.DT_RESOURCE_REF -]) - # Maintain an intern table so that we don't have to create a large # number of small objects. _INTERN_TABLE = { diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt index 0b5b88bba80..ed2c44bf772 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.DType" tf_class { is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>" - is_instance: "<type \'object\'>" + is_instance: "<class \'tensorflow.python._dtypes.DType\'>" + is_instance: "<class \'pybind11_builtins.pybind11_object\'>" member { name: "as_datatype_enum" mtype: "<type \'property\'>" @@ -68,7 +69,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" } member_method { name: "is_compatible_with" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt index 423eca32a2c..2bdef02f1a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.dtypes.DType" tf_class { is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>" - is_instance: "<type \'object\'>" + is_instance: "<class \'tensorflow.python._dtypes.DType\'>" + is_instance: "<class \'pybind11_builtins.pybind11_object\'>" member { name: "as_datatype_enum" mtype: "<type \'property\'>" @@ -68,7 +69,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" } member_method { name: "is_compatible_with" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt index 01b870a8163..a7ba1d4b946 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.dtypes" tf_module { member { name: "DType" - mtype: "<type \'type\'>" + mtype: "<class \'pybind11_builtins.pybind11_type\'>" } member { name: "QUANTIZED_DTYPES" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 75fcf744dd9..9abecf88b18 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -38,7 +38,7 @@ tf_module { } member { name: "DType" - mtype: "<type \'type\'>" + mtype: "<class \'pybind11_builtins.pybind11_type\'>" } member { name: "DeviceSpec" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt index 0b5b88bba80..ed2c44bf772 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.DType" tf_class { is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>" - is_instance: "<type \'object\'>" + is_instance: "<class \'tensorflow.python._dtypes.DType\'>" + is_instance: "<class \'pybind11_builtins.pybind11_object\'>" member { name: "as_datatype_enum" mtype: "<type \'property\'>" @@ -68,7 +69,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" } member_method { name: "is_compatible_with" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt index 423eca32a2c..2bdef02f1a2 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt @@ -1,7 +1,8 @@ path: "tensorflow.dtypes.DType" tf_class { is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>" - is_instance: "<type \'object\'>" + is_instance: "<class \'tensorflow.python._dtypes.DType\'>" + is_instance: "<class \'pybind11_builtins.pybind11_object\'>" member { name: "as_datatype_enum" mtype: "<type \'property\'>" @@ -68,7 +69,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" } member_method { name: "is_compatible_with" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt index 956e4d93e57..7501c36e0a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.dtypes" tf_module { member { name: "DType" - mtype: "<type \'type\'>" + mtype: "<class \'pybind11_builtins.pybind11_type\'>" } member { name: "QUANTIZED_DTYPES" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 7cf14d69e49..514addea995 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -10,7 +10,7 @@ tf_module { } member { name: "DType" - mtype: "<type \'type\'>" + mtype: "<class \'pybind11_builtins.pybind11_type\'>" } member { name: "DeviceSpec" diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 283f53882c3..e1b2902332f 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -79,8 +79,16 @@ if sys.version_info.major == 3: return (member == 'with_traceback' or member in ('name', 'value') and isinstance(cls, type) and issubclass(cls, enum.Enum)) else: - _NORMALIZE_TYPE = {"<class 'abc.ABCMeta'>": "<type 'type'>"} - _NORMALIZE_ISINSTANCE = {} + _NORMALIZE_TYPE = { + "<class 'abc.ABCMeta'>": + "<type 'type'>", + "<class 'pybind11_type'>": + "<class 'pybind11_builtins.pybind11_type'>", + } + _NORMALIZE_ISINSTANCE = { + "<class 'pybind11_object'>": + "<class 'pybind11_builtins.pybind11_object'>", + } def _SkipMember(cls, member): # pylint: disable=unused-argument return False diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 30a844d93eb..052e757441f 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -56,7 +56,15 @@ tensorflow::EventsWriter::Close [py_func_lib] # py_func tensorflow::InitializePyTrampoline -[framework_internal_impl] # op_def_registry +[framework_internal_impl] # op_def_registry, dtypes +tensorflow::BaseType +tensorflow::DataTypeString +tensorflow::DataTypeIsComplex +tensorflow::DataTypeIsFloating +tensorflow::DataTypeIsInteger +tensorflow::DataTypeIsQuantized +tensorflow::DataTypeIsUnsigned +tensorflow::DataTypeSize tensorflow::OpRegistry::Global tensorflow::OpRegistry::LookUpOpDef tensorflow::RemoveNonDeprecationDescriptionsFromOpDef @@ -72,7 +80,8 @@ tensorflow::DeviceFactory::AddDevices tensorflow::SessionOptions::SessionOptions tensorflow::DoQuantizeTrainingOnSerializedGraphDef -[protos_all] # device_lib +[protos_all] # device_lib, dtypes +tensorflow::DataType_IsValid tensorflow::ConfigProto::ConfigProto tensorflow::ConfigProto::ParseFromString tensorflow::DeviceAttributes::SerializeToString