Automated rollback of commit 74a6cca5d8

PiperOrigin-RevId: 230873931
This commit is contained in:
Sergei Lebedev 2019-01-25 02:52:49 -08:00 committed by TensorFlower Gardener
parent da2c471248
commit bdca83b71c

View File

@ -535,29 +535,47 @@ _np_qint32 = np.dtype([("qint32", np.int32, 1)])
np_resource = np.dtype([("resource", np.ubyte, 1)]) np_resource = np.dtype([("resource", np.ubyte, 1)])
# Standard mappings between types_pb2.DataType values and numpy.dtypes. # Standard mappings between types_pb2.DataType values and numpy.dtypes.
_NP_TO_TF = frozenset([ _NP_TO_TF = {
(np.float16, float16), np.float16: float16,
(np.float32, float32), np.float32: float32,
(np.float64, float64), np.float64: float64,
(np.int32, int32), np.int32: int32,
(np.int64, int64), np.int64: int64,
(np.uint8, uint8), np.uint8: uint8,
(np.uint16, uint16), np.uint16: uint16,
(np.uint32, uint32), np.uint32: uint32,
(np.uint64, uint64), np.uint64: uint64,
(np.int16, int16), np.int16: int16,
(np.int8, int8), np.int8: int8,
(np.complex64, complex64), np.complex64: complex64,
(np.complex128, complex128), np.complex128: complex128,
(np.object_, string), np.object_: string,
(np.bool_, bool), np.string_: string,
(_np_qint8, qint8), np.unicode_: string,
(_np_quint8, quint8), np.bool_: bool,
(_np_qint16, qint16), _np_qint8: qint8,
(_np_quint16, quint16), _np_quint8: quint8,
(_np_qint32, qint32), _np_qint16: qint16,
(_np_bfloat16, bfloat16), _np_quint16: quint16,
]) _np_qint32: qint32,
_np_bfloat16: bfloat16,
}
# Map (some) NumPy platform dtypes to TF ones using their fixed-width
# synonyms. Note that platform dtypes are not always simples aliases,
# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
for pdt in [
np.intc,
np.uintc,
np.int_,
np.uint,
np.longlong,
np.ulonglong,
]:
if pdt not in _NP_TO_TF:
_NP_TO_TF[pdt] = next(
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
_TF_TO_NP = { _TF_TO_NP = {
types_pb2.DT_HALF: types_pb2.DT_HALF:
np.float16, np.float16,
@ -664,6 +682,20 @@ _PYTHON_TO_TF = {
builtins.object: string builtins.object: string
} }
_ANY_TO_TF = {}
_ANY_TO_TF.update(_INTERN_TABLE)
_ANY_TO_TF.update(_STRING_TO_TF)
_ANY_TO_TF.update(_PYTHON_TO_TF)
_ANY_TO_TF.update(_NP_TO_TF)
# Ensure no collisions.
assert len(_ANY_TO_TF) == sum(len(d) for d in [
_INTERN_TABLE,
_STRING_TO_TF,
_PYTHON_TO_TF,
_NP_TO_TF
])
@tf_export("dtypes.as_dtype", "as_dtype") @tf_export("dtypes.as_dtype", "as_dtype")
def as_dtype(type_value): def as_dtype(type_value):
@ -684,37 +716,16 @@ def as_dtype(type_value):
if isinstance(type_value, DType): if isinstance(type_value, DType):
return type_value return type_value
try:
return _INTERN_TABLE[type_value]
except KeyError:
pass
try:
return _STRING_TO_TF[type_value]
except KeyError:
pass
try:
return _PYTHON_TO_TF[type_value]
except KeyError:
pass
if isinstance(type_value, np.dtype): if isinstance(type_value, np.dtype):
# The numpy dtype for strings is variable length. We can not compare try:
# dtype with a single constant (np.string does not exist) to decide return _NP_TO_TF[type_value.type]
# dtype is a "string" type. We need to compare the dtype.type to be except KeyError:
# sure it's a string type. pass
if type_value.type == np.string_ or type_value.type == np.unicode_:
return string
if isinstance(type_value, (type, np.dtype)): try:
for key, val in _NP_TO_TF: return _ANY_TO_TF[type_value]
try: except KeyError:
if key == type_value: pass
return val
except TypeError as e:
raise TypeError("Cannot convert {} to a dtype. {}".format(
type_value, e))
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
raise TypeError(
"Cannot convert value %r to a TensorFlow DType." % type_value)