From bdca83b71c0f47628321bafc29056fb2467144db Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 25 Jan 2019 02:52:49 -0800 Subject: [PATCH] Automated rollback of commit 74a6cca5d867d37e79ec9d780f2c57b926f07a80 PiperOrigin-RevId: 230873931 --- tensorflow/python/framework/dtypes.py | 119 ++++++++++++++------------ 1 file changed, 65 insertions(+), 54 deletions(-) diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 574d82ea93e..9d643e041c6 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -535,29 +535,47 @@ _np_qint32 = np.dtype([("qint32", np.int32, 1)]) np_resource = np.dtype([("resource", np.ubyte, 1)]) # Standard mappings between types_pb2.DataType values and numpy.dtypes. -_NP_TO_TF = frozenset([ - (np.float16, float16), - (np.float32, float32), - (np.float64, float64), - (np.int32, int32), - (np.int64, int64), - (np.uint8, uint8), - (np.uint16, uint16), - (np.uint32, uint32), - (np.uint64, uint64), - (np.int16, int16), - (np.int8, int8), - (np.complex64, complex64), - (np.complex128, complex128), - (np.object_, string), - (np.bool_, bool), - (_np_qint8, qint8), - (_np_quint8, quint8), - (_np_qint16, qint16), - (_np_quint16, quint16), - (_np_qint32, qint32), - (_np_bfloat16, bfloat16), -]) +_NP_TO_TF = { + np.float16: float16, + np.float32: float32, + np.float64: float64, + np.int32: int32, + np.int64: int64, + np.uint8: uint8, + np.uint16: uint16, + np.uint32: uint32, + np.uint64: uint64, + np.int16: int16, + np.int8: int8, + np.complex64: complex64, + np.complex128: complex128, + np.object_: string, + np.string_: string, + np.unicode_: string, + np.bool_: bool, + _np_qint8: qint8, + _np_quint8: quint8, + _np_qint16: qint16, + _np_quint16: quint16, + _np_qint32: qint32, + _np_bfloat16: bfloat16, +} + +# Map (some) NumPy platform dtypes to TF ones using their fixed-width +# synonyms. Note that platform dtypes are not always simples aliases, +# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799. +for pdt in [ + np.intc, + np.uintc, + np.int_, + np.uint, + np.longlong, + np.ulonglong, +]: + if pdt not in _NP_TO_TF: + _NP_TO_TF[pdt] = next( + _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype) + _TF_TO_NP = { types_pb2.DT_HALF: np.float16, @@ -664,6 +682,20 @@ _PYTHON_TO_TF = { builtins.object: string } +_ANY_TO_TF = {} +_ANY_TO_TF.update(_INTERN_TABLE) +_ANY_TO_TF.update(_STRING_TO_TF) +_ANY_TO_TF.update(_PYTHON_TO_TF) +_ANY_TO_TF.update(_NP_TO_TF) + +# Ensure no collisions. +assert len(_ANY_TO_TF) == sum(len(d) for d in [ + _INTERN_TABLE, + _STRING_TO_TF, + _PYTHON_TO_TF, + _NP_TO_TF +]) + @tf_export("dtypes.as_dtype", "as_dtype") def as_dtype(type_value): @@ -684,37 +716,16 @@ def as_dtype(type_value): if isinstance(type_value, DType): return type_value - try: - return _INTERN_TABLE[type_value] - except KeyError: - pass - - try: - return _STRING_TO_TF[type_value] - except KeyError: - pass - - try: - return _PYTHON_TO_TF[type_value] - except KeyError: - pass - if isinstance(type_value, np.dtype): - # The numpy dtype for strings is variable length. We can not compare - # dtype with a single constant (np.string does not exist) to decide - # dtype is a "string" type. We need to compare the dtype.type to be - # sure it's a string type. - if type_value.type == np.string_ or type_value.type == np.unicode_: - return string + try: + return _NP_TO_TF[type_value.type] + except KeyError: + pass - if isinstance(type_value, (type, np.dtype)): - for key, val in _NP_TO_TF: - try: - if key == type_value: - return val - except TypeError as e: - raise TypeError("Cannot convert {} to a dtype. {}".format( - type_value, e)) - - raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value) + try: + return _ANY_TO_TF[type_value] + except KeyError: + pass + raise TypeError( + "Cannot convert value %r to a TensorFlow DType." % type_value)