dtypes.DType is now a wrapper around tensorflow::DataType

Note that some DType methods cannot yet be pushed to C++ and will be
addressed in a followup.

Quick benchmark results:

>>> import tensorflow as tf
>>> dtype = tf.int32

Before:

>>> %timeit dtype.is_integer
1000000 loops, best of 3: 1.22 ?s per loop
>>> %timeit dtype.is_floating
1000000 loops, best of 3: 1.2 ?s per loop

After:

>>> %timeit dtype.is_integer
10000000 loops, best of 3: 179 ns per loop
>>> %timeit dtype.is_floating
10000000 loops, best of 3: 185 ns per loop

PiperOrigin-RevId: 282730366
Change-Id: Ib1bc960de647cee0f97ab522215ace88c9050962
This commit is contained in:
Sergei Lebedev 2019-11-27 02:32:22 -08:00 committed by TensorFlower Gardener
parent 481366eab2
commit 6fa83d70b7
13 changed files with 213 additions and 118 deletions

View File

@ -1140,11 +1140,24 @@ py_library(
srcs_version = "PY2AND3",
)
tf_python_pybind_extension(
name = "_dtypes",
srcs = ["framework/dtypes.cc"],
module_name = "_dtypes",
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
"@pybind11",
],
)
py_library(
name = "dtypes",
srcs = ["framework/dtypes.py"],
srcs_version = "PY2AND3",
deps = [
":_dtypes",
":pywrap_tensorflow",
"//tensorflow/core:protos_all_py",
],

View File

@ -0,0 +1,157 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/detail/common.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
namespace {
inline int DataTypeId(tensorflow::DataType dt) { return static_cast<int>(dt); }
// A variant of tensorflow::DataTypeString which uses fixed-width names
// for floating point data types. This behavior is compatible with that of
// existing pure Python DType.
const std::string DataTypeStringCompat(tensorflow::DataType dt) {
switch (dt) {
case tensorflow::DataType::DT_HALF:
return "float16";
case tensorflow::DataType::DT_HALF_REF:
return "float16_ref";
case tensorflow::DataType::DT_FLOAT:
return "float32";
case tensorflow::DataType::DT_FLOAT_REF:
return "float32_ref";
case tensorflow::DataType::DT_DOUBLE:
return "float64";
case tensorflow::DataType::DT_DOUBLE_REF:
return "float64_ref";
default:
return tensorflow::DataTypeString(dt);
}
}
} // namespace
namespace tensorflow {
constexpr DataTypeSet kNumPyIncompatibleTypes =
ToSet(DataType::DT_RESOURCE) | ToSet(DataType::DT_VARIANT);
inline bool DataTypeIsNumPyCompatible(DataType dt) {
return !kNumPyIncompatibleTypes.Contains(dt);
}
} // namespace tensorflow
namespace py = pybind11;
PYBIND11_MODULE(_dtypes, m) {
py::class_<tensorflow::DataType>(m, "DType")
.def(py::init([](py::object obj) {
auto id = static_cast<int>(py::int_(obj));
if (tensorflow::DataType_IsValid(id) &&
id != static_cast<int>(tensorflow::DT_INVALID)) {
return static_cast<tensorflow::DataType>(id);
}
throw py::type_error(
py::str("%d does not correspond to a valid tensorflow::DataType")
.format(id));
}))
// For compatibility with pure-Python DType.
.def_property_readonly("_type_enum", &DataTypeId)
.def_property_readonly(
"as_datatype_enum", &DataTypeId,
"Returns a `types_pb2.DataType` enum value based on this data type.")
.def_property_readonly("name",
[](tensorflow::DataType self) {
#if PY_MAJOR_VERSION < 3
return py::bytes(DataTypeStringCompat(self));
#else
return DataTypeStringCompat(self);
#endif
})
.def_property_readonly(
"size",
[](tensorflow::DataType self) {
return tensorflow::DataTypeSize(tensorflow::BaseType(self));
})
.def("__repr__",
[](tensorflow::DataType self) {
return py::str("tf.{}").format(DataTypeStringCompat(self));
})
.def("__str__",
[](tensorflow::DataType self) {
return py::str("<dtype: {!r}>")
#if PY_MAJOR_VERSION < 3
.format(py::bytes(DataTypeStringCompat(self)));
#else
.format(DataTypeStringCompat(self));
#endif
})
.def("__hash__", &DataTypeId)
.def_property_readonly(
"is_numpy_compatible",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsNumPyCompatible(
tensorflow::BaseType(self));
},
"Returns whether this data type has a compatible NumPy data type.")
.def_property_readonly(
"is_bool",
[](tensorflow::DataType self) {
return tensorflow::BaseType(self) == tensorflow::DT_BOOL;
},
"Returns whether this is a boolean data type.")
.def_property_readonly(
"is_complex",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsComplex(tensorflow::BaseType(self));
},
"Returns whether this is a complex floating point type.")
.def_property_readonly(
"is_floating",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsFloating(tensorflow::BaseType(self));
},
"Returns whether this is a (non-quantized, real) floating point "
"type.")
.def_property_readonly(
"is_integer",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsInteger(tensorflow::BaseType(self));
},
"Returns whether this is a (non-quantized) integer type.")
.def_property_readonly(
"is_quantized",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsQuantized(tensorflow::BaseType(self));
},
"Returns whether this is a quantized data type.")
.def_property_readonly(
"is_unsigned",
[](tensorflow::DataType self) {
return tensorflow::DataTypeIsUnsigned(tensorflow::BaseType(self));
},
R"doc(Returns whether this type is unsigned.
Non-numeric, unordered, and quantized types are not considered unsigned, and
this function returns `False`.)doc");
}

View File

@ -21,14 +21,16 @@ import numpy as np
from six.moves import builtins
from tensorflow.core.framework import types_pb2
from tensorflow.python import _dtypes
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util.tf_export import tf_export
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
# pylint: disable=slots-on-old-class
@tf_export("dtypes.DType", "DType")
class DType(object):
class DType(_dtypes.DType):
"""Represents the type of the elements in a `Tensor`.
The following `DType` objects are defined:
@ -60,30 +62,7 @@ class DType(object):
The `tf.as_dtype()` function converts numpy types and string type
names to a `DType` object.
"""
__slots__ = ["_type_enum"]
def __init__(self, type_enum):
"""Creates a new `DataType`.
NOTE(mrry): In normal circumstances, you should not need to
construct a `DataType` object directly. Instead, use the
`tf.as_dtype()` function.
Args:
type_enum: A `types_pb2.DataType` enum value.
Raises:
TypeError: If `type_enum` is not a value `types_pb2.DataType`.
"""
# TODO(mrry): Make the necessary changes (using __new__) to ensure
# that calling this returns one of the interned values.
type_enum = int(type_enum)
if (type_enum not in types_pb2.DataType.values() or
type_enum == types_pb2.DT_INVALID):
raise TypeError("type_enum is not a valid types_pb2.DataType: %s" %
type_enum)
self._type_enum = type_enum
__slots__ = ()
@property
def _is_ref_dtype(self):
@ -117,63 +96,11 @@ class DType(object):
else:
return self
@property
def is_numpy_compatible(self):
return self._type_enum not in _NUMPY_INCOMPATIBLE
@property
def as_numpy_dtype(self):
"""Returns a `numpy.dtype` based on this `DType`."""
return _TF_TO_NP[self._type_enum]
@property
def as_datatype_enum(self):
"""Returns a `types_pb2.DataType` enum value based on this `DType`."""
return self._type_enum
@property
def is_bool(self):
"""Returns whether this is a boolean data type."""
return self.base_dtype == bool
@property
def is_integer(self):
"""Returns whether this is a (non-quantized) integer type."""
return (self.is_numpy_compatible and not self.is_quantized and
np.issubdtype(self.as_numpy_dtype, np.integer))
@property
def is_floating(self):
"""Returns whether this is a (non-quantized, real) floating point type."""
return ((self.is_numpy_compatible and
np.issubdtype(self.as_numpy_dtype, np.floating)) or
self.base_dtype == bfloat16)
@property
def is_complex(self):
"""Returns whether this is a complex floating point type."""
return self.base_dtype in (complex64, complex128)
@property
def is_quantized(self):
"""Returns whether this is a quantized data type."""
return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
@property
def is_unsigned(self):
"""Returns whether this type is unsigned.
Non-numeric, unordered, and quantized types are not considered unsigned, and
this function returns `False`.
Returns:
Whether a `DType` is unsigned.
"""
try:
return self.min == 0
except TypeError:
return False
@property
def min(self):
"""Returns the minimum representable value in this data type.
@ -275,29 +202,15 @@ class DType(object):
"""Returns True iff self != other."""
return not self.__eq__(other)
@property
def name(self):
"""Returns the string name for this `DType`."""
return _TYPE_TO_STRING[self._type_enum]
def __str__(self):
return "<dtype: %r>" % self.name
def __repr__(self):
return "tf." + self.name
def __hash__(self):
return self._type_enum
# "If a class that overrides __eq__() needs to retain the implementation
# of __hash__() from a parent class, the interpreter must be told this
# explicitly by setting __hash__ = <ParentClass>.__hash__."
# TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++.
__hash__ = _dtypes.DType.__hash__
def __reduce__(self):
return as_dtype, (self.name,)
@property
def size(self):
if (self._type_enum == types_pb2.DT_VARIANT or
self._type_enum == types_pb2.DT_RESOURCE):
return 1
return np.dtype(self.as_numpy_dtype).itemsize
# pylint: enable=slots-on-old-class
# Define data type range of numpy dtype
@ -395,11 +308,6 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF)
qint32_ref = DType(types_pb2.DT_QINT32_REF)
bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
_NUMPY_INCOMPATIBLE = frozenset([
types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
types_pb2.DT_RESOURCE_REF
])
# Maintain an intern table so that we don't have to create a large
# number of small objects.
_INTERN_TABLE = {

View File

@ -1,7 +1,8 @@
path: "tensorflow.DType"
tf_class {
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "as_datatype_enum"
mtype: "<type \'property\'>"
@ -68,7 +69,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"

View File

@ -1,7 +1,8 @@
path: "tensorflow.dtypes.DType"
tf_class {
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "as_datatype_enum"
mtype: "<type \'property\'>"
@ -68,7 +69,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"

View File

@ -2,7 +2,7 @@ path: "tensorflow.dtypes"
tf_module {
member {
name: "DType"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "QUANTIZED_DTYPES"

View File

@ -38,7 +38,7 @@ tf_module {
}
member {
name: "DType"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "DeviceSpec"

View File

@ -1,7 +1,8 @@
path: "tensorflow.DType"
tf_class {
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "as_datatype_enum"
mtype: "<type \'property\'>"
@ -68,7 +69,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"

View File

@ -1,7 +1,8 @@
path: "tensorflow.dtypes.DType"
tf_class {
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
member {
name: "as_datatype_enum"
mtype: "<type \'property\'>"
@ -68,7 +69,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"

View File

@ -2,7 +2,7 @@ path: "tensorflow.dtypes"
tf_module {
member {
name: "DType"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "QUANTIZED_DTYPES"

View File

@ -10,7 +10,7 @@ tf_module {
}
member {
name: "DType"
mtype: "<type \'type\'>"
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
}
member {
name: "DeviceSpec"

View File

@ -79,8 +79,16 @@ if sys.version_info.major == 3:
return (member == 'with_traceback' or member in ('name', 'value') and
isinstance(cls, type) and issubclass(cls, enum.Enum))
else:
_NORMALIZE_TYPE = {"<class 'abc.ABCMeta'>": "<type 'type'>"}
_NORMALIZE_ISINSTANCE = {}
_NORMALIZE_TYPE = {
"<class 'abc.ABCMeta'>":
"<type 'type'>",
"<class 'pybind11_type'>":
"<class 'pybind11_builtins.pybind11_type'>",
}
_NORMALIZE_ISINSTANCE = {
"<class 'pybind11_object'>":
"<class 'pybind11_builtins.pybind11_object'>",
}
def _SkipMember(cls, member): # pylint: disable=unused-argument
return False

View File

@ -56,7 +56,15 @@ tensorflow::EventsWriter::Close
[py_func_lib] # py_func
tensorflow::InitializePyTrampoline
[framework_internal_impl] # op_def_registry
[framework_internal_impl] # op_def_registry, dtypes
tensorflow::BaseType
tensorflow::DataTypeString
tensorflow::DataTypeIsComplex
tensorflow::DataTypeIsFloating
tensorflow::DataTypeIsInteger
tensorflow::DataTypeIsQuantized
tensorflow::DataTypeIsUnsigned
tensorflow::DataTypeSize
tensorflow::OpRegistry::Global
tensorflow::OpRegistry::LookUpOpDef
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
@ -72,7 +80,8 @@ tensorflow::DeviceFactory::AddDevices
tensorflow::SessionOptions::SessionOptions
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
[protos_all] # device_lib
[protos_all] # device_lib, dtypes
tensorflow::DataType_IsValid
tensorflow::ConfigProto::ConfigProto
tensorflow::ConfigProto::ParseFromString
tensorflow::DeviceAttributes::SerializeToString