Automated rollback of commit 74a6cca5d8

PiperOrigin-RevId: 230873931
This commit is contained in:
Sergei Lebedev 2019-01-25 02:52:49 -08:00 committed by TensorFlower Gardener
parent da2c471248
commit bdca83b71c

View File

@ -535,29 +535,47 @@ _np_qint32 = np.dtype([("qint32", np.int32, 1)])
np_resource = np.dtype([("resource", np.ubyte, 1)])
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
_NP_TO_TF = frozenset([
(np.float16, float16),
(np.float32, float32),
(np.float64, float64),
(np.int32, int32),
(np.int64, int64),
(np.uint8, uint8),
(np.uint16, uint16),
(np.uint32, uint32),
(np.uint64, uint64),
(np.int16, int16),
(np.int8, int8),
(np.complex64, complex64),
(np.complex128, complex128),
(np.object_, string),
(np.bool_, bool),
(_np_qint8, qint8),
(_np_quint8, quint8),
(_np_qint16, qint16),
(_np_quint16, quint16),
(_np_qint32, qint32),
(_np_bfloat16, bfloat16),
])
_NP_TO_TF = {
np.float16: float16,
np.float32: float32,
np.float64: float64,
np.int32: int32,
np.int64: int64,
np.uint8: uint8,
np.uint16: uint16,
np.uint32: uint32,
np.uint64: uint64,
np.int16: int16,
np.int8: int8,
np.complex64: complex64,
np.complex128: complex128,
np.object_: string,
np.string_: string,
np.unicode_: string,
np.bool_: bool,
_np_qint8: qint8,
_np_quint8: quint8,
_np_qint16: qint16,
_np_quint16: quint16,
_np_qint32: qint32,
_np_bfloat16: bfloat16,
}
# Map (some) NumPy platform dtypes to TF ones using their fixed-width
# synonyms. Note that platform dtypes are not always simples aliases,
# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
for pdt in [
np.intc,
np.uintc,
np.int_,
np.uint,
np.longlong,
np.ulonglong,
]:
if pdt not in _NP_TO_TF:
_NP_TO_TF[pdt] = next(
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
_TF_TO_NP = {
types_pb2.DT_HALF:
np.float16,
@ -664,6 +682,20 @@ _PYTHON_TO_TF = {
builtins.object: string
}
_ANY_TO_TF = {}
_ANY_TO_TF.update(_INTERN_TABLE)
_ANY_TO_TF.update(_STRING_TO_TF)
_ANY_TO_TF.update(_PYTHON_TO_TF)
_ANY_TO_TF.update(_NP_TO_TF)
# Ensure no collisions.
assert len(_ANY_TO_TF) == sum(len(d) for d in [
_INTERN_TABLE,
_STRING_TO_TF,
_PYTHON_TO_TF,
_NP_TO_TF
])
@tf_export("dtypes.as_dtype", "as_dtype")
def as_dtype(type_value):
@ -684,37 +716,16 @@ def as_dtype(type_value):
if isinstance(type_value, DType):
return type_value
try:
return _INTERN_TABLE[type_value]
except KeyError:
pass
try:
return _STRING_TO_TF[type_value]
except KeyError:
pass
try:
return _PYTHON_TO_TF[type_value]
except KeyError:
pass
if isinstance(type_value, np.dtype):
# The numpy dtype for strings is variable length. We can not compare
# dtype with a single constant (np.string does not exist) to decide
# dtype is a "string" type. We need to compare the dtype.type to be
# sure it's a string type.
if type_value.type == np.string_ or type_value.type == np.unicode_:
return string
try:
return _NP_TO_TF[type_value.type]
except KeyError:
pass
if isinstance(type_value, (type, np.dtype)):
for key, val in _NP_TO_TF:
try:
if key == type_value:
return val
except TypeError as e:
raise TypeError("Cannot convert {} to a dtype. {}".format(
type_value, e))
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
try:
return _ANY_TO_TF[type_value]
except KeyError:
pass
raise TypeError(
"Cannot convert value %r to a TensorFlow DType." % type_value)