dtypes.DType is now a wrapper around tensorflow::DataType
Note that some DType methods cannot yet be pushed to C++ and will be addressed in a followup. Quick benchmark results: >>> import tensorflow as tf >>> dtype = tf.int32 Before: >>> %timeit dtype.is_integer 1000000 loops, best of 3: 1.22 ?s per loop >>> %timeit dtype.is_floating 1000000 loops, best of 3: 1.2 ?s per loop After: >>> %timeit dtype.is_integer 10000000 loops, best of 3: 179 ns per loop >>> %timeit dtype.is_floating 10000000 loops, best of 3: 185 ns per loop PiperOrigin-RevId: 282730366 Change-Id: Ib1bc960de647cee0f97ab522215ace88c9050962
This commit is contained in:
parent
481366eab2
commit
6fa83d70b7
@ -1140,11 +1140,24 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_dtypes",
|
||||
srcs = ["framework/dtypes.cc"],
|
||||
module_name = "_dtypes",
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dtypes",
|
||||
srcs = ["framework/dtypes.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_dtypes",
|
||||
":pywrap_tensorflow",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
|
157
tensorflow/python/framework/dtypes.cc
Normal file
157
tensorflow/python/framework/dtypes.cc
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/detail/common.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
inline int DataTypeId(tensorflow::DataType dt) { return static_cast<int>(dt); }
|
||||
|
||||
// A variant of tensorflow::DataTypeString which uses fixed-width names
|
||||
// for floating point data types. This behavior is compatible with that of
|
||||
// existing pure Python DType.
|
||||
const std::string DataTypeStringCompat(tensorflow::DataType dt) {
|
||||
switch (dt) {
|
||||
case tensorflow::DataType::DT_HALF:
|
||||
return "float16";
|
||||
case tensorflow::DataType::DT_HALF_REF:
|
||||
return "float16_ref";
|
||||
case tensorflow::DataType::DT_FLOAT:
|
||||
return "float32";
|
||||
case tensorflow::DataType::DT_FLOAT_REF:
|
||||
return "float32_ref";
|
||||
case tensorflow::DataType::DT_DOUBLE:
|
||||
return "float64";
|
||||
case tensorflow::DataType::DT_DOUBLE_REF:
|
||||
return "float64_ref";
|
||||
default:
|
||||
return tensorflow::DataTypeString(dt);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr DataTypeSet kNumPyIncompatibleTypes =
|
||||
ToSet(DataType::DT_RESOURCE) | ToSet(DataType::DT_VARIANT);
|
||||
|
||||
inline bool DataTypeIsNumPyCompatible(DataType dt) {
|
||||
return !kNumPyIncompatibleTypes.Contains(dt);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_dtypes, m) {
|
||||
py::class_<tensorflow::DataType>(m, "DType")
|
||||
.def(py::init([](py::object obj) {
|
||||
auto id = static_cast<int>(py::int_(obj));
|
||||
if (tensorflow::DataType_IsValid(id) &&
|
||||
id != static_cast<int>(tensorflow::DT_INVALID)) {
|
||||
return static_cast<tensorflow::DataType>(id);
|
||||
}
|
||||
throw py::type_error(
|
||||
py::str("%d does not correspond to a valid tensorflow::DataType")
|
||||
.format(id));
|
||||
}))
|
||||
// For compatibility with pure-Python DType.
|
||||
.def_property_readonly("_type_enum", &DataTypeId)
|
||||
.def_property_readonly(
|
||||
"as_datatype_enum", &DataTypeId,
|
||||
"Returns a `types_pb2.DataType` enum value based on this data type.")
|
||||
|
||||
.def_property_readonly("name",
|
||||
[](tensorflow::DataType self) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return py::bytes(DataTypeStringCompat(self));
|
||||
#else
|
||||
return DataTypeStringCompat(self);
|
||||
#endif
|
||||
})
|
||||
.def_property_readonly(
|
||||
"size",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeSize(tensorflow::BaseType(self));
|
||||
})
|
||||
|
||||
.def("__repr__",
|
||||
[](tensorflow::DataType self) {
|
||||
return py::str("tf.{}").format(DataTypeStringCompat(self));
|
||||
})
|
||||
.def("__str__",
|
||||
[](tensorflow::DataType self) {
|
||||
return py::str("<dtype: {!r}>")
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
.format(py::bytes(DataTypeStringCompat(self)));
|
||||
#else
|
||||
.format(DataTypeStringCompat(self));
|
||||
#endif
|
||||
})
|
||||
.def("__hash__", &DataTypeId)
|
||||
|
||||
.def_property_readonly(
|
||||
"is_numpy_compatible",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsNumPyCompatible(
|
||||
tensorflow::BaseType(self));
|
||||
},
|
||||
"Returns whether this data type has a compatible NumPy data type.")
|
||||
|
||||
.def_property_readonly(
|
||||
"is_bool",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::BaseType(self) == tensorflow::DT_BOOL;
|
||||
},
|
||||
"Returns whether this is a boolean data type.")
|
||||
.def_property_readonly(
|
||||
"is_complex",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsComplex(tensorflow::BaseType(self));
|
||||
},
|
||||
"Returns whether this is a complex floating point type.")
|
||||
.def_property_readonly(
|
||||
"is_floating",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsFloating(tensorflow::BaseType(self));
|
||||
},
|
||||
"Returns whether this is a (non-quantized, real) floating point "
|
||||
"type.")
|
||||
.def_property_readonly(
|
||||
"is_integer",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsInteger(tensorflow::BaseType(self));
|
||||
},
|
||||
"Returns whether this is a (non-quantized) integer type.")
|
||||
.def_property_readonly(
|
||||
"is_quantized",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsQuantized(tensorflow::BaseType(self));
|
||||
},
|
||||
"Returns whether this is a quantized data type.")
|
||||
.def_property_readonly(
|
||||
"is_unsigned",
|
||||
[](tensorflow::DataType self) {
|
||||
return tensorflow::DataTypeIsUnsigned(tensorflow::BaseType(self));
|
||||
},
|
||||
R"doc(Returns whether this type is unsigned.
|
||||
|
||||
Non-numeric, unordered, and quantized types are not considered unsigned, and
|
||||
this function returns `False`.)doc");
|
||||
}
|
@ -21,14 +21,16 @@ import numpy as np
|
||||
from six.moves import builtins
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python import _dtypes
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
|
||||
|
||||
|
||||
# pylint: disable=slots-on-old-class
|
||||
@tf_export("dtypes.DType", "DType")
|
||||
class DType(object):
|
||||
class DType(_dtypes.DType):
|
||||
"""Represents the type of the elements in a `Tensor`.
|
||||
|
||||
The following `DType` objects are defined:
|
||||
@ -60,30 +62,7 @@ class DType(object):
|
||||
The `tf.as_dtype()` function converts numpy types and string type
|
||||
names to a `DType` object.
|
||||
"""
|
||||
__slots__ = ["_type_enum"]
|
||||
|
||||
def __init__(self, type_enum):
|
||||
"""Creates a new `DataType`.
|
||||
|
||||
NOTE(mrry): In normal circumstances, you should not need to
|
||||
construct a `DataType` object directly. Instead, use the
|
||||
`tf.as_dtype()` function.
|
||||
|
||||
Args:
|
||||
type_enum: A `types_pb2.DataType` enum value.
|
||||
|
||||
Raises:
|
||||
TypeError: If `type_enum` is not a value `types_pb2.DataType`.
|
||||
|
||||
"""
|
||||
# TODO(mrry): Make the necessary changes (using __new__) to ensure
|
||||
# that calling this returns one of the interned values.
|
||||
type_enum = int(type_enum)
|
||||
if (type_enum not in types_pb2.DataType.values() or
|
||||
type_enum == types_pb2.DT_INVALID):
|
||||
raise TypeError("type_enum is not a valid types_pb2.DataType: %s" %
|
||||
type_enum)
|
||||
self._type_enum = type_enum
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def _is_ref_dtype(self):
|
||||
@ -117,63 +96,11 @@ class DType(object):
|
||||
else:
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_numpy_compatible(self):
|
||||
return self._type_enum not in _NUMPY_INCOMPATIBLE
|
||||
|
||||
@property
|
||||
def as_numpy_dtype(self):
|
||||
"""Returns a `numpy.dtype` based on this `DType`."""
|
||||
return _TF_TO_NP[self._type_enum]
|
||||
|
||||
@property
|
||||
def as_datatype_enum(self):
|
||||
"""Returns a `types_pb2.DataType` enum value based on this `DType`."""
|
||||
return self._type_enum
|
||||
|
||||
@property
|
||||
def is_bool(self):
|
||||
"""Returns whether this is a boolean data type."""
|
||||
return self.base_dtype == bool
|
||||
|
||||
@property
|
||||
def is_integer(self):
|
||||
"""Returns whether this is a (non-quantized) integer type."""
|
||||
return (self.is_numpy_compatible and not self.is_quantized and
|
||||
np.issubdtype(self.as_numpy_dtype, np.integer))
|
||||
|
||||
@property
|
||||
def is_floating(self):
|
||||
"""Returns whether this is a (non-quantized, real) floating point type."""
|
||||
return ((self.is_numpy_compatible and
|
||||
np.issubdtype(self.as_numpy_dtype, np.floating)) or
|
||||
self.base_dtype == bfloat16)
|
||||
|
||||
@property
|
||||
def is_complex(self):
|
||||
"""Returns whether this is a complex floating point type."""
|
||||
return self.base_dtype in (complex64, complex128)
|
||||
|
||||
@property
|
||||
def is_quantized(self):
|
||||
"""Returns whether this is a quantized data type."""
|
||||
return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
|
||||
|
||||
@property
|
||||
def is_unsigned(self):
|
||||
"""Returns whether this type is unsigned.
|
||||
|
||||
Non-numeric, unordered, and quantized types are not considered unsigned, and
|
||||
this function returns `False`.
|
||||
|
||||
Returns:
|
||||
Whether a `DType` is unsigned.
|
||||
"""
|
||||
try:
|
||||
return self.min == 0
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
"""Returns the minimum representable value in this data type.
|
||||
@ -275,29 +202,15 @@ class DType(object):
|
||||
"""Returns True iff self != other."""
|
||||
return not self.__eq__(other)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Returns the string name for this `DType`."""
|
||||
return _TYPE_TO_STRING[self._type_enum]
|
||||
|
||||
def __str__(self):
|
||||
return "<dtype: %r>" % self.name
|
||||
|
||||
def __repr__(self):
|
||||
return "tf." + self.name
|
||||
|
||||
def __hash__(self):
|
||||
return self._type_enum
|
||||
# "If a class that overrides __eq__() needs to retain the implementation
|
||||
# of __hash__() from a parent class, the interpreter must be told this
|
||||
# explicitly by setting __hash__ = <ParentClass>.__hash__."
|
||||
# TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++.
|
||||
__hash__ = _dtypes.DType.__hash__
|
||||
|
||||
def __reduce__(self):
|
||||
return as_dtype, (self.name,)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
if (self._type_enum == types_pb2.DT_VARIANT or
|
||||
self._type_enum == types_pb2.DT_RESOURCE):
|
||||
return 1
|
||||
return np.dtype(self.as_numpy_dtype).itemsize
|
||||
# pylint: enable=slots-on-old-class
|
||||
|
||||
|
||||
# Define data type range of numpy dtype
|
||||
@ -395,11 +308,6 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF)
|
||||
qint32_ref = DType(types_pb2.DT_QINT32_REF)
|
||||
bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
|
||||
|
||||
_NUMPY_INCOMPATIBLE = frozenset([
|
||||
types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
|
||||
types_pb2.DT_RESOURCE_REF
|
||||
])
|
||||
|
||||
# Maintain an intern table so that we don't have to create a large
|
||||
# number of small objects.
|
||||
_INTERN_TABLE = {
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.DType"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member {
|
||||
name: "as_datatype_enum"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -68,7 +69,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.dtypes.DType"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member {
|
||||
name: "as_datatype_enum"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -68,7 +69,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.dtypes"
|
||||
tf_module {
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "QUANTIZED_DTYPES"
|
||||
|
@ -38,7 +38,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DeviceSpec"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.DType"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member {
|
||||
name: "as_datatype_enum"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -68,7 +69,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.dtypes.DType"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._dtypes.DType\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member {
|
||||
name: "as_datatype_enum"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -68,7 +69,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.dtypes"
|
||||
tf_module {
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "QUANTIZED_DTYPES"
|
||||
|
@ -10,7 +10,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "DType"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DeviceSpec"
|
||||
|
@ -79,8 +79,16 @@ if sys.version_info.major == 3:
|
||||
return (member == 'with_traceback' or member in ('name', 'value') and
|
||||
isinstance(cls, type) and issubclass(cls, enum.Enum))
|
||||
else:
|
||||
_NORMALIZE_TYPE = {"<class 'abc.ABCMeta'>": "<type 'type'>"}
|
||||
_NORMALIZE_ISINSTANCE = {}
|
||||
_NORMALIZE_TYPE = {
|
||||
"<class 'abc.ABCMeta'>":
|
||||
"<type 'type'>",
|
||||
"<class 'pybind11_type'>":
|
||||
"<class 'pybind11_builtins.pybind11_type'>",
|
||||
}
|
||||
_NORMALIZE_ISINSTANCE = {
|
||||
"<class 'pybind11_object'>":
|
||||
"<class 'pybind11_builtins.pybind11_object'>",
|
||||
}
|
||||
|
||||
def _SkipMember(cls, member): # pylint: disable=unused-argument
|
||||
return False
|
||||
|
@ -56,7 +56,15 @@ tensorflow::EventsWriter::Close
|
||||
[py_func_lib] # py_func
|
||||
tensorflow::InitializePyTrampoline
|
||||
|
||||
[framework_internal_impl] # op_def_registry
|
||||
[framework_internal_impl] # op_def_registry, dtypes
|
||||
tensorflow::BaseType
|
||||
tensorflow::DataTypeString
|
||||
tensorflow::DataTypeIsComplex
|
||||
tensorflow::DataTypeIsFloating
|
||||
tensorflow::DataTypeIsInteger
|
||||
tensorflow::DataTypeIsQuantized
|
||||
tensorflow::DataTypeIsUnsigned
|
||||
tensorflow::DataTypeSize
|
||||
tensorflow::OpRegistry::Global
|
||||
tensorflow::OpRegistry::LookUpOpDef
|
||||
tensorflow::RemoveNonDeprecationDescriptionsFromOpDef
|
||||
@ -72,7 +80,8 @@ tensorflow::DeviceFactory::AddDevices
|
||||
tensorflow::SessionOptions::SessionOptions
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef
|
||||
|
||||
[protos_all] # device_lib
|
||||
[protos_all] # device_lib, dtypes
|
||||
tensorflow::DataType_IsValid
|
||||
tensorflow::ConfigProto::ConfigProto
|
||||
tensorflow::ConfigProto::ParseFromString
|
||||
tensorflow::DeviceAttributes::SerializeToString
|
||||
|
Loading…
Reference in New Issue
Block a user