Removed a linear scan in dtypes.as_dtype
PiperOrigin-RevId: 229152423
This commit is contained in:
parent
267263629f
commit
74a6cca5d8
@ -535,29 +535,31 @@ _np_qint32 = np.dtype([("qint32", np.int32, 1)])
|
||||
np_resource = np.dtype([("resource", np.ubyte, 1)])
|
||||
|
||||
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
||||
_NP_TO_TF = frozenset([
|
||||
(np.float16, float16),
|
||||
(np.float32, float32),
|
||||
(np.float64, float64),
|
||||
(np.int32, int32),
|
||||
(np.int64, int64),
|
||||
(np.uint8, uint8),
|
||||
(np.uint16, uint16),
|
||||
(np.uint32, uint32),
|
||||
(np.uint64, uint64),
|
||||
(np.int16, int16),
|
||||
(np.int8, int8),
|
||||
(np.complex64, complex64),
|
||||
(np.complex128, complex128),
|
||||
(np.object_, string),
|
||||
(np.bool_, bool),
|
||||
(_np_qint8, qint8),
|
||||
(_np_quint8, quint8),
|
||||
(_np_qint16, qint16),
|
||||
(_np_quint16, quint16),
|
||||
(_np_qint32, qint32),
|
||||
(_np_bfloat16, bfloat16),
|
||||
])
|
||||
_NP_TO_TF = {
|
||||
np.float16: float16,
|
||||
np.float32: float32,
|
||||
np.float64: float64,
|
||||
np.int32: int32,
|
||||
np.int64: int64,
|
||||
np.uint8: uint8,
|
||||
np.uint16: uint16,
|
||||
np.uint32: uint32,
|
||||
np.uint64: uint64,
|
||||
np.int16: int16,
|
||||
np.int8: int8,
|
||||
np.complex64: complex64,
|
||||
np.complex128: complex128,
|
||||
np.object_: string,
|
||||
np.string_: string,
|
||||
np.unicode_: string,
|
||||
np.bool_: bool,
|
||||
_np_qint8: qint8,
|
||||
_np_quint8: quint8,
|
||||
_np_qint16: qint16,
|
||||
_np_quint16: quint16,
|
||||
_np_qint32: qint32,
|
||||
_np_bfloat16: bfloat16,
|
||||
}
|
||||
_TF_TO_NP = {
|
||||
types_pb2.DT_HALF:
|
||||
np.float16,
|
||||
@ -664,6 +666,20 @@ _PYTHON_TO_TF = {
|
||||
builtins.object: string
|
||||
}
|
||||
|
||||
_ANY_TO_TF = {}
|
||||
_ANY_TO_TF.update(_INTERN_TABLE)
|
||||
_ANY_TO_TF.update(_STRING_TO_TF)
|
||||
_ANY_TO_TF.update(_PYTHON_TO_TF)
|
||||
_ANY_TO_TF.update(_NP_TO_TF)
|
||||
|
||||
# Ensure no collisions.
|
||||
assert len(_ANY_TO_TF) == sum(len(d) for d in [
|
||||
_INTERN_TABLE,
|
||||
_STRING_TO_TF,
|
||||
_PYTHON_TO_TF,
|
||||
_NP_TO_TF
|
||||
])
|
||||
|
||||
|
||||
@tf_export("dtypes.as_dtype", "as_dtype")
|
||||
def as_dtype(type_value):
|
||||
@ -684,36 +700,16 @@ def as_dtype(type_value):
|
||||
if isinstance(type_value, DType):
|
||||
return type_value
|
||||
|
||||
try:
|
||||
return _INTERN_TABLE[type_value]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return _STRING_TO_TF[type_value]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return _PYTHON_TO_TF[type_value]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if isinstance(type_value, np.dtype):
|
||||
# The numpy dtype for strings is variable length. We can not compare
|
||||
# dtype with a single constant (np.string does not exist) to decide
|
||||
# dtype is a "string" type. We need to compare the dtype.type to be
|
||||
# sure it's a string type.
|
||||
if type_value.type == np.string_ or type_value.type == np.unicode_:
|
||||
return string
|
||||
|
||||
if isinstance(type_value, (type, np.dtype)):
|
||||
for key, val in _NP_TO_TF:
|
||||
try:
|
||||
if key == type_value:
|
||||
return val
|
||||
except TypeError as e:
|
||||
raise TypeError("Cannot convert {} to a dtype. {}".format(
|
||||
type_value, e))
|
||||
return _NP_TO_TF[type_value.type]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
|
||||
try:
|
||||
return _ANY_TO_TF[type_value]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
raise TypeError(
|
||||
"Cannot convert value %r to a TensorFlow DType." % type_value)
|
||||
|
Loading…
Reference in New Issue
Block a user