parent
da2c471248
commit
bdca83b71c
@ -535,29 +535,47 @@ _np_qint32 = np.dtype([("qint32", np.int32, 1)])
|
|||||||
np_resource = np.dtype([("resource", np.ubyte, 1)])
|
np_resource = np.dtype([("resource", np.ubyte, 1)])
|
||||||
|
|
||||||
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
||||||
_NP_TO_TF = frozenset([
|
_NP_TO_TF = {
|
||||||
(np.float16, float16),
|
np.float16: float16,
|
||||||
(np.float32, float32),
|
np.float32: float32,
|
||||||
(np.float64, float64),
|
np.float64: float64,
|
||||||
(np.int32, int32),
|
np.int32: int32,
|
||||||
(np.int64, int64),
|
np.int64: int64,
|
||||||
(np.uint8, uint8),
|
np.uint8: uint8,
|
||||||
(np.uint16, uint16),
|
np.uint16: uint16,
|
||||||
(np.uint32, uint32),
|
np.uint32: uint32,
|
||||||
(np.uint64, uint64),
|
np.uint64: uint64,
|
||||||
(np.int16, int16),
|
np.int16: int16,
|
||||||
(np.int8, int8),
|
np.int8: int8,
|
||||||
(np.complex64, complex64),
|
np.complex64: complex64,
|
||||||
(np.complex128, complex128),
|
np.complex128: complex128,
|
||||||
(np.object_, string),
|
np.object_: string,
|
||||||
(np.bool_, bool),
|
np.string_: string,
|
||||||
(_np_qint8, qint8),
|
np.unicode_: string,
|
||||||
(_np_quint8, quint8),
|
np.bool_: bool,
|
||||||
(_np_qint16, qint16),
|
_np_qint8: qint8,
|
||||||
(_np_quint16, quint16),
|
_np_quint8: quint8,
|
||||||
(_np_qint32, qint32),
|
_np_qint16: qint16,
|
||||||
(_np_bfloat16, bfloat16),
|
_np_quint16: quint16,
|
||||||
])
|
_np_qint32: qint32,
|
||||||
|
_np_bfloat16: bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Map (some) NumPy platform dtypes to TF ones using their fixed-width
|
||||||
|
# synonyms. Note that platform dtypes are not always simples aliases,
|
||||||
|
# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
|
||||||
|
for pdt in [
|
||||||
|
np.intc,
|
||||||
|
np.uintc,
|
||||||
|
np.int_,
|
||||||
|
np.uint,
|
||||||
|
np.longlong,
|
||||||
|
np.ulonglong,
|
||||||
|
]:
|
||||||
|
if pdt not in _NP_TO_TF:
|
||||||
|
_NP_TO_TF[pdt] = next(
|
||||||
|
_NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
|
||||||
|
|
||||||
_TF_TO_NP = {
|
_TF_TO_NP = {
|
||||||
types_pb2.DT_HALF:
|
types_pb2.DT_HALF:
|
||||||
np.float16,
|
np.float16,
|
||||||
@ -664,6 +682,20 @@ _PYTHON_TO_TF = {
|
|||||||
builtins.object: string
|
builtins.object: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_ANY_TO_TF = {}
|
||||||
|
_ANY_TO_TF.update(_INTERN_TABLE)
|
||||||
|
_ANY_TO_TF.update(_STRING_TO_TF)
|
||||||
|
_ANY_TO_TF.update(_PYTHON_TO_TF)
|
||||||
|
_ANY_TO_TF.update(_NP_TO_TF)
|
||||||
|
|
||||||
|
# Ensure no collisions.
|
||||||
|
assert len(_ANY_TO_TF) == sum(len(d) for d in [
|
||||||
|
_INTERN_TABLE,
|
||||||
|
_STRING_TO_TF,
|
||||||
|
_PYTHON_TO_TF,
|
||||||
|
_NP_TO_TF
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
@tf_export("dtypes.as_dtype", "as_dtype")
|
@tf_export("dtypes.as_dtype", "as_dtype")
|
||||||
def as_dtype(type_value):
|
def as_dtype(type_value):
|
||||||
@ -684,37 +716,16 @@ def as_dtype(type_value):
|
|||||||
if isinstance(type_value, DType):
|
if isinstance(type_value, DType):
|
||||||
return type_value
|
return type_value
|
||||||
|
|
||||||
try:
|
|
||||||
return _INTERN_TABLE[type_value]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
return _STRING_TO_TF[type_value]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
return _PYTHON_TO_TF[type_value]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if isinstance(type_value, np.dtype):
|
if isinstance(type_value, np.dtype):
|
||||||
# The numpy dtype for strings is variable length. We can not compare
|
try:
|
||||||
# dtype with a single constant (np.string does not exist) to decide
|
return _NP_TO_TF[type_value.type]
|
||||||
# dtype is a "string" type. We need to compare the dtype.type to be
|
except KeyError:
|
||||||
# sure it's a string type.
|
pass
|
||||||
if type_value.type == np.string_ or type_value.type == np.unicode_:
|
|
||||||
return string
|
|
||||||
|
|
||||||
if isinstance(type_value, (type, np.dtype)):
|
try:
|
||||||
for key, val in _NP_TO_TF:
|
return _ANY_TO_TF[type_value]
|
||||||
try:
|
except KeyError:
|
||||||
if key == type_value:
|
pass
|
||||||
return val
|
|
||||||
except TypeError as e:
|
|
||||||
raise TypeError("Cannot convert {} to a dtype. {}".format(
|
|
||||||
type_value, e))
|
|
||||||
|
|
||||||
raise TypeError("Cannot convert value %r to a TensorFlow DType." % type_value)
|
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
"Cannot convert value %r to a TensorFlow DType." % type_value)
|
||||||
|
Loading…
Reference in New Issue
Block a user