[TF-numpy] Adds @np_doc to finfo
, result_type
and promote_types
.
PiperOrigin-RevId: 318944304 Change-Id: I7c37f744725c8da1e5ca58fd79640354bee05302
This commit is contained in:
parent
8562a13cca
commit
aed5e3ea00
@ -82,22 +82,6 @@ def _to_numpy_type(dtype):
|
||||
return np.dtype(dtype)
|
||||
|
||||
|
||||
def finfo(dtype):
|
||||
"""Returns properties of floating point types.
|
||||
|
||||
Note that currently it just forwards to the numpy namesake, while tensorflow
|
||||
and numpy dtypes may have different properties.
|
||||
|
||||
Args:
|
||||
dtype: Could be a python type, a numpy type or a TF DType.
|
||||
|
||||
Returns:
|
||||
A class describing properties of `dtype`, as described by
|
||||
https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html
|
||||
"""
|
||||
return np.finfo(_to_numpy_type(dtype))
|
||||
|
||||
|
||||
def isscalar(val):
|
||||
"""Returns whether `val` is a scalar value or scalar Tensor."""
|
||||
if isinstance(val, np_arrays.ndarray):
|
||||
@ -112,51 +96,6 @@ def isscalar(val):
|
||||
return np.isscalar(val)
|
||||
|
||||
|
||||
# Can't use np_doc because np.result_type is a builtin function.
|
||||
def result_type(*arrays_and_dtypes):
|
||||
"""Returns the type resulting from applying NumPy type promotion to arguments.
|
||||
|
||||
Args:
|
||||
*arrays_and_dtypes: A list of array_like objects or dtypes.
|
||||
|
||||
Returns:
|
||||
A numpy dtype.
|
||||
"""
|
||||
|
||||
def maybe_get_dtype(x):
|
||||
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||
# value (not just dtype) of np.ndarray to decide the result type.
|
||||
if isinstance(
|
||||
x, (np_arrays.ndarray, core.Tensor, indexed_slices.IndexedSlices)):
|
||||
return _to_numpy_type(x.dtype)
|
||||
elif isinstance(x, dtypes.DType):
|
||||
return _to_numpy_type(x)
|
||||
return x
|
||||
|
||||
arrays_and_dtypes = [
|
||||
maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
|
||||
]
|
||||
if not arrays_and_dtypes:
|
||||
# If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
|
||||
arrays_and_dtypes = [np.asarray([])]
|
||||
return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def promote_types(type1, type2):
|
||||
"""Returns the type resulting from applying NumPy type promotion.
|
||||
|
||||
Args:
|
||||
type1: A numpy type.
|
||||
type2: A numpy type.
|
||||
|
||||
Returns:
|
||||
A numpy type.
|
||||
"""
|
||||
type1 = _to_numpy_type(type1)
|
||||
type2 = _to_numpy_type(type2)
|
||||
return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
|
||||
|
||||
|
||||
def _has_docstring(f):
|
||||
return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
|
||||
f.__doc__)
|
||||
@ -360,6 +299,44 @@ def np_doc_only(np_fun_name, np_fun=None):
|
||||
return decorator
|
||||
|
||||
|
||||
# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
|
||||
@np_doc('finfo')
|
||||
def finfo(dtype):
|
||||
"""Note that currently it just forwards to the numpy namesake, while
|
||||
tensorflow and numpy dtypes may have different properties."""
|
||||
return np.finfo(_to_numpy_type(dtype))
|
||||
# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
|
||||
|
||||
|
||||
# Can't use np_doc because np.result_type is a builtin function.
|
||||
@np_doc_only('result_type')
|
||||
def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring
|
||||
def maybe_get_dtype(x):
|
||||
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||
# value (not just dtype) of np.ndarray to decide the result type.
|
||||
if isinstance(
|
||||
x, (np_arrays.ndarray, core.Tensor, indexed_slices.IndexedSlices)):
|
||||
return _to_numpy_type(x.dtype)
|
||||
elif isinstance(x, dtypes.DType):
|
||||
return _to_numpy_type(x)
|
||||
return x
|
||||
|
||||
arrays_and_dtypes = [
|
||||
maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
|
||||
]
|
||||
if not arrays_and_dtypes:
|
||||
# If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
|
||||
arrays_and_dtypes = [np.asarray([])]
|
||||
return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@np_doc('promote_types')
|
||||
def promote_types(type1, type2): # pylint: disable=missing-function-docstring
|
||||
type1 = _to_numpy_type(type1)
|
||||
type2 = _to_numpy_type(type2)
|
||||
return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
|
||||
|
||||
|
||||
def tf_broadcast(*args):
|
||||
"""Broadcast tensors.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user