tf.numpy: Numpy API on TF.

PiperOrigin-RevId: 314792337
Change-Id: I92dc879492af15c6160e3f671e185ea4f87e8ed4
This commit is contained in:
A. Unique TensorFlower 2020-06-04 13:26:01 -07:00 committed by TensorFlower Gardener
parent bab81497b0
commit 84c796966b
16 changed files with 5701 additions and 2 deletions

View File

@ -223,6 +223,7 @@ py_library(
"//tensorflow/python/ops/linalg",
"//tensorflow/python/ops/linalg/sparse",
"//tensorflow/python/ops/losses",
"//tensorflow/python/ops/numpy_ops:numpy",
"//tensorflow/python/ops/parallel_for",
"//tensorflow/python/ops/ragged",
"//tensorflow/python/ops/signal",

View File

@ -1,4 +1,4 @@
# TF numpy API
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
default_visibility = [
@ -8,9 +8,109 @@ package(
)
py_library(
name = "numpy_ops",
name = "numpy",
srcs = [
"__init__.py",
"np_array_ops.py",
"np_arrays.py",
"np_dtypes.py",
"np_math_ops.py",
"np_random.py",
"np_utils.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:indexed_slices",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sort_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "np_arrays_test",
srcs = ["np_arrays_test.py"],
deps = [
":numpy",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "np_array_ops_test",
srcs = ["np_array_ops_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
cuda_py_test(
name = "np_backprop_test",
srcs = ["np_backprop_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "np_logic_test",
srcs = ["np_logic_test.py"],
deps = [
":numpy",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "np_math_ops_test",
srcs = ["np_math_ops_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "np_random_test",
srcs = ["np_random_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
cuda_py_test(
name = "np_utils_test",
srcs = ["np_utils_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -17,3 +17,21 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.numpy_ops import np_random as random
# pylint: disable=wildcard-import
from tensorflow.python.ops.numpy_ops.np_array_ops import *
from tensorflow.python.ops.numpy_ops.np_arrays import ndarray
from tensorflow.python.ops.numpy_ops.np_dtypes import *
from tensorflow.python.ops.numpy_ops.np_math_ops import *
from tensorflow.python.ops.numpy_ops.np_utils import finfo
from tensorflow.python.ops.numpy_ops.np_utils import promote_types
from tensorflow.python.ops.numpy_ops.np_utils import result_type
# pylint: enable=wildcard-import
# pylint: disable=redefined-builtin,undefined-variable
max = amax
min = amin
round = around
# pylint: enable=redefined-builtin,undefined-variable

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,292 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ndarray class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_dtypes
def convert_to_tensor(value, dtype=None):
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
# TODO(wangpeng): Remove this function once the bug is fixed.
if (dtype is None and isinstance(value, six.integer_types) and
value >= 2**63):
dtype = dtypes.uint64
elif (dtype is None and isinstance(value, float)):
dtype = np_dtypes.default_float_type()
return ops.convert_to_tensor(value, dtype=dtype)
class ndarray(object): # pylint: disable=invalid-name
"""Equivalent of numpy.ndarray backed by TensorFlow tensors.
This does not support all features of NumPy ndarrays e.g. strides and
memory order since, unlike NumPy, the backing storage is not a raw memory
buffer.
TODO(srbs): Clearly specify which attributes and methods are not supported
or if there are any differences in behavior.
"""
def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin
"""Initializes an ndarray.
This is a low level interface for building ndarrays and should be avoided.
Users should instead use methods in array_creation.py.
This class provides a numpy.ndarray like interface for a TF Tensor with a
fully-defined shape. Note that, unlike the backing buffer of np.ndarray,
Tensors are immutable. So, operations like `__setitem__` are performed by
replacing the Tensor. This restricts the ability to implement NumPy `view`
semantics.
Compared to numpy.ndarray, this does not support `offset`, `strides`
and `order` arguments.
Args:
shape: The shape of the array. Must be a scalar, an iterable of integers
or a `TensorShape` object.
dtype: Optional. The dtype of the array. Must be a python type, a numpy
type or a tensorflow `DType` object.
buffer: Optional. The backing buffer of the array. Must have shape
`shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`.
Raises:
ValueError: If `buffer` is specified and its shape does not match
`shape`.
"""
if dtype and not isinstance(dtype, dtypes.DType):
dtype = dtypes.as_dtype(np.dtype(dtype))
if buffer is None:
buffer = array_ops.zeros(shape, dtype=dtype)
else:
if isinstance(buffer, ndarray):
buffer = buffer.data
elif isinstance(buffer, np.ndarray):
# If `buffer` is a np.ndarray, the Tensor will share the underlying
# storage of the array.
buffer = convert_to_tensor(value=buffer, dtype=dtype)
elif not isinstance(buffer, ops.Tensor):
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
' Tensor or np.ndarray.'.format(type(buffer)))
if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access
# TODO(srbs): NumPy allows this. Investigate if/how to support this.
raise ValueError('shape arg must match buffer.shape.')
assert isinstance(buffer, ops.Tensor)
if dtype and dtype != buffer.dtype:
buffer = array_ops.bitcast(buffer, dtype)
self._data = buffer
self.base = None
@property
def data(self):
"""Tensor object containing the array data.
This has a few key differences from the Python buffer object used in
NumPy arrays.
1. Tensors are immutable. So operations requiring in-place edit, e.g.
__setitem__, are performed by replacing the underlying buffer with a new
one.
2. Tensors do not provide access to their raw buffer.
Returns:
A Tensor.
"""
return self._data
@property
def shape(self):
"""Returns a tuple of array dimensions."""
return self.data._shape_tuple() # pylint: disable=protected-access
@property
def dtype(self):
return np.dtype(self.data.dtype.as_numpy_dtype)
@property
def ndim(self):
return self.data.shape.ndims
@property
def size(self):
"""Returns the number of elements in the array."""
return np.prod(self.shape)
@property
def T(self): # pylint: disable=invalid-name
return self.transpose()
def __len__(self):
if self.shape:
return self.shape[0]
else:
raise TypeError('len() of unsized object.')
def astype(self, dtype):
if self.dtype == dtype:
return self
else:
return tensor_to_ndarray(math_ops.cast(self.data, dtype))
# Unary operations
def __neg__(self):
return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type
def __pos__(self):
return self
__hash__ = None
def __int__(self):
return int(self.data)
def __float__(self):
return float(self.data)
def __nonzero__(self):
return bool(self.data)
def __bool__(self):
return self.__nonzero__()
def __getitem__(self, slice_spec):
# TODO(srbs): Need to support better indexing.
result_t = self.data.__getitem__(slice_spec)
return tensor_to_ndarray(result_t)
def __iter__(self):
for i in range(self.shape[0]):
result_t = self.data[i]
yield tensor_to_ndarray(result_t)
return
def __array__(self, dtype=None):
"""Returns a NumPy ndarray.
This allows instances of this class to be directly used in NumPy routines.
However, doing that may force a copy to CPU.
Args:
dtype: A NumPy compatible type.
Returns:
A NumPy ndarray.
"""
return np.asarray(self.data, dtype)
__array_priority__ = 110
def __index__(self):
"""Returns a python scalar.
This allows using an instance of this class as an array index.
Note that only arrays of integer types with size 1 can be used as array
indices.
Returns:
A Python scalar.
Raises:
TypeError: If the array is not of an integer type.
ValueError: If the array does not have size 1.
"""
# TODO(wangpeng): Handle graph mode
return np.asscalar(self.data.numpy())
def tolist(self):
return self.data.numpy().tolist()
def __str__(self):
return 'ndarray<{}>'.format(self.data.__str__())
def __repr__(self):
return 'ndarray<{}>'.format(self.data.__repr__())
def tensor_to_ndarray(tensor):
return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
if as_ref:
raise ValueError('as_ref is not supported.')
if dtype and dtypes.as_dtype(arr.dtype) != dtype:
return math_ops.cast(arr.data, dtype)
result_t = arr.data
if name:
result_t = array_ops.identity(result_t, name=name)
return result_t
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
# Don't use a namedtuple since nest considers that a tuple and unflattens and
# flattens it.
class ShardedNdArray(object):
"""Wrapper over ndarray that can contain tensors on multiple devices.
This is returned by extensions.pmap, and contains the individual tensors on
different devices.
"""
def __init__(self, tensors):
"""Initializes the ShardedNdArray.
Note that the tensors should be ordered in the way the pmap producing these
tensors is run.
Args:
tensors: list or tuple of eager tensors, one for each device.
"""
if not isinstance(tensors, (list, tuple)) or not tensors:
raise ValueError(
'Unable to create a ShardedNdArray without a list of tensors.')
self.tensors = tensors
self.n_devices = len(tensors)
def __getitem__(self, i):
return self.tensors[i]
@property
def shape(self):
return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access
@property
def dtype(self):
return self.tensors[0].dtype
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
del args, kwargs
# TODO(nareshmodi): Consider a collective op to gather the tensors from the
# various devices for performance reasons.
return array_ops.stack(value.tensors)
ops.register_tensor_conversion_function(ShardedNdArray,
convert_sharded_tensor_to_eager_tensor)

View File

@ -0,0 +1,189 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ndarray."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
# Required for operator overloads
from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unused-import
from tensorflow.python.platform import test
t2a = np_arrays.tensor_to_ndarray
class ArrayTest(test.TestCase):
def testDtype(self):
a = t2a(array_ops.zeros(shape=[1, 2], dtype=dtypes.int64))
self.assertIs(a.dtype.type, np.int64)
self.assertAllEqual(0, a.dtype.type(0))
def testAstype(self):
a = t2a(ops.convert_to_tensor(value=1.1,
dtype=dtypes.float32)).astype(np.int32)
self.assertIs(a.dtype.type, np.int32)
self.assertAllEqual(1, a)
a = t2a(ops.convert_to_tensor(value=[0.0, 1.1],
dtype=dtypes.float32)).astype(np.bool_)
self.assertIs(a.dtype.type, np.bool_)
self.assertAllEqual([False, True], a)
def testNeg(self):
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
self.assertAllEqual([-1.0, -2.0], -a)
def _testBinOp(self, a, b, out, f, types=None):
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
if not isinstance(out, np_arrays.ndarray):
out = t2a(ops.convert_to_tensor(value=out, dtype=np.int32))
if types is None:
types = [[np.int32, np.int32, np.int32], [np.int64, np.int32, np.int64],
[np.int32, np.int64, np.int64],
[np.float32, np.int32, np.float64],
[np.int32, np.float32, np.float64],
[np.float32, np.float32, np.float32],
[np.float64, np.float32, np.float64],
[np.float32, np.float64, np.float64]]
for a_type, b_type, out_type in types:
o = f(a.astype(a_type), b.astype(b_type))
self.assertIs(o.dtype.type, out_type)
out = out.astype(out_type)
if np.issubdtype(out_type, np.inexact):
self.assertAllClose(out, o)
else:
self.assertAllEqual(out, o)
def testAdd(self):
self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: a.__add__(b))
def testRadd(self):
self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: b.__radd__(a))
def testSub(self):
self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: a.__sub__(b))
def testRsub(self):
self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: b.__rsub__(a))
def testMul(self):
self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: a.__mul__(b))
def testRmul(self):
self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: b.__rmul__(a))
def testPow(self):
self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: a.__pow__(b))
def testRpow(self):
self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: b.__rpow__(a))
_truediv_types = [[np.int32, np.int32, np.float64],
[np.int64, np.int32, np.float64],
[np.int32, np.int64, np.float64],
[np.float32, np.int32, np.float64],
[np.int32, np.float32, np.float64],
[np.float32, np.float32, np.float32],
[np.float64, np.float32, np.float64],
[np.float32, np.float64, np.float64]]
def testTruediv(self):
self._testBinOp([3, 5], [2, 4],
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
lambda a, b: a.__truediv__(b),
types=self._truediv_types)
def testRtruediv(self):
self._testBinOp([3, 5], [2, 4],
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
lambda a, b: b.__rtruediv__(a),
types=self._truediv_types)
def _testCmp(self, a, b, out, f):
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
types = [[np.int32, np.int32], [np.int64, np.int32], [np.int32, np.int64],
[np.float32, np.int32], [np.int32, np.float32],
[np.float32, np.float32], [np.float64, np.float32],
[np.float32, np.float64]]
for a_type, b_type in types:
o = f(a.astype(a_type), b.astype(b_type))
self.assertAllEqual(out, o)
def testLt(self):
self._testCmp([1, 2, 3], [3, 2, 1], [True, False, False],
lambda a, b: a.__lt__(b))
def testLe(self):
self._testCmp([1, 2, 3], [3, 2, 1], [True, True, False],
lambda a, b: a.__le__(b))
def testGt(self):
self._testCmp([1, 2, 3], [3, 2, 1], [False, False, True],
lambda a, b: a.__gt__(b))
def testGe(self):
self._testCmp([1, 2, 3], [3, 2, 1], [False, True, True],
lambda a, b: a.__ge__(b))
def testEq(self):
self._testCmp([1, 2, 3], [3, 2, 1], [False, True, False],
lambda a, b: a.__eq__(b))
def testNe(self):
self._testCmp([1, 2, 3], [3, 2, 1], [True, False, True],
lambda a, b: a.__ne__(b))
def testInt(self):
v = 10
u = int(t2a(ops.convert_to_tensor(value=v)))
self.assertIsInstance(u, int)
self.assertAllEqual(v, u)
def testFloat(self):
v = 21.32
u = float(t2a(ops.convert_to_tensor(value=v)))
self.assertIsInstance(u, float)
self.assertAllClose(v, u)
def testBool(self):
b = bool(t2a(ops.convert_to_tensor(value=10)))
self.assertIsInstance(b, bool)
self.assertTrue(b)
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0))))
self.assertTrue(bool(t2a(ops.convert_to_tensor(value=0.1))))
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0.0))))
def testHash(self):
a = t2a(ops.convert_to_tensor(value=10))
self.assertNotIsInstance(a, collections.Hashable)
with self.assertRaisesWithPredicateMatch(TypeError, r'unhashable type'):
hash(a)
if __name__ == '__main__':
# TODO(wangpeng): Test in graph mode as well.
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,69 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for backpropgration on tf-numpy functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.framework import ops
from tensorflow.python.ops.numpy_ops import np_array_ops
# Required for operator overloads
from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unused-import
from tensorflow.python.platform import test
class BackpropTest(test.TestCase):
def test_setitem(self):
# Single integer index.
a = np_array_ops.array([1., 2., 3.])
b = np_array_ops.array(5.)
c = np_array_ops.array(10.)
tensors = [arr.data for arr in [a, b, c]]
with backprop.GradientTape() as g:
g.watch(tensors)
a[1] = b + c
loss = np_array_ops.sum(a)
gradients = g.gradient(loss.data, tensors)
self.assertSequenceEqual(
np_array_ops.array(gradients[0]).tolist(), [1., 0., 1.])
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), 1.)
self.assertEqual(np_array_ops.array(gradients[2]).tolist(), 1.)
# Tuple index.
a = np_array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.],
[7., 8.]]]) # 2x2x2 array.
b = np_array_ops.array([10., 11.])
tensors = [arr.data for arr in [a, b]]
with backprop.GradientTape() as g:
g.watch(tensors)
a[(1, 0)] = b
loss = np_array_ops.sum(a)
gradients = g.gradient(loss.data, tensors)
self.assertSequenceEqual(
np_array_ops.array(gradients[0]).tolist(),
[[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]])
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), [1., 1.])
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,96 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dtypes and dtype utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
# We use numpy's dtypes instead of TF's, because the user expects to use them
# with numpy facilities such as `np.dtype(np.int64)` and
# `if x.dtype.type is np.int64`.
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from numpy import bool_
from numpy import int_
from numpy import int16
from numpy import int32
from numpy import int64
from numpy import int8
from numpy import uint16
from numpy import uint32
from numpy import uint64
from numpy import uint8
from numpy import float_
from numpy import float16
from numpy import float32
from numpy import float64
from numpy import complex_
from numpy import complex64
from numpy import complex128
from numpy import inexact
from numpy import iinfo
from numpy import issubdtype
from numpy import inf
# TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's
from tensorflow.python.framework.dtypes import bfloat16
# pylint: enable=g-bad-import-order
# pylint: enable=unused-import
_to_float32 = {
np.dtype('float64'): np.dtype('float32'),
np.dtype('complex128'): np.dtype('complex64'),
}
_allow_float64 = True
def is_allow_float64():
return _allow_float64
def set_allow_float64(b):
global _allow_float64
_allow_float64 = b
def canonicalize_dtype(dtype):
if not is_allow_float64():
return _to_float32.get(dtype, dtype)
else:
return dtype
def _result_type(*arrays_and_dtypes):
dtype = np.result_type(*arrays_and_dtypes)
return canonicalize_dtype(dtype)
def default_float_type():
"""Gets the default float type.
Returns:
If `is_allow_float64()` is true, returns float64; otherwise returns float32.
"""
if is_allow_float64():
return float64
else:
return float32

View File

@ -0,0 +1,110 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf numpy random number methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_math_ops
from tensorflow.python.platform import test
class LogicTest(test.TestCase):
def setUp(self):
super(LogicTest, self).setUp()
self.array_transforms = [
lambda x: x, # Identity,
ops.convert_to_tensor,
np.array,
lambda x: np.array(x, dtype=np.int32),
lambda x: np.array(x, dtype=np.int64),
lambda x: np.array(x, dtype=np.float32),
lambda x: np.array(x, dtype=np.float64),
np_array_ops.array,
lambda x: np_array_ops.array(x, dtype=dtypes.int32),
lambda x: np_array_ops.array(x, dtype=dtypes.int64),
lambda x: np_array_ops.array(x, dtype=dtypes.float32),
lambda x: np_array_ops.array(x, dtype=dtypes.float64),
]
def testEqual(self):
def run_test(x1, x2=None):
if x2 is None:
x2 = x1
for fn1 in self.array_transforms:
for fn2 in self.array_transforms:
arg1 = fn1(x1)
arg2 = fn2(x2)
self.match(
np_math_ops.equal(arg1, arg2),
np.equal(
make_numpy_compatible(arg1), make_numpy_compatible(arg2)))
run_test(1)
run_test(1, 2)
run_test([1, 2])
run_test([1, 2, 3], [2])
run_test([[1, 2], [3, 4]], [1, 2])
run_test([[1, 2], [1, 4]], [1, 2])
run_test([1, 2], [[1, 2], [1, 4]])
run_test([[1, 2], [3, 4]], [[1, 2], [3, 4]])
run_test([[1, 2], [3, 4]], [[1, 3], [3, 4]])
def match_shape(self, actual, expected, msg=None):
if msg:
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
msg, expected.shape, actual.shape)
self.assertEqual(actual.shape, expected.shape, msg=msg)
if msg:
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
self.assertIsInstance(actual.shape, tuple, msg=msg)
def match_dtype(self, actual, expected, msg=None):
if msg:
msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format(
msg, expected.dtype, actual.dtype)
self.assertEqual(actual.dtype, expected.dtype, msg=msg)
def match(self, actual, expected, msg=None):
msg_ = 'Expected: {} Actual: {}'.format(expected, actual)
if msg:
msg = '{} {}'.format(msg_, msg)
else:
msg = msg_
self.assertIsInstance(actual, np_arrays.ndarray)
self.match_dtype(actual, expected, msg)
self.match_shape(actual, expected, msg)
if not actual.shape:
self.assertEqual(actual.tolist(), expected.tolist())
else:
self.assertSequenceEqual(actual.tolist(), expected.tolist())
def make_numpy_compatible(s):
return s if not isinstance(s, np_arrays.ndarray) else s.data.numpy()
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,332 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf numpy mathematical methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from six.moves import range
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_math_ops
from tensorflow.python.platform import test
class MathTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(MathTest, self).setUp()
self.array_transforms = [
lambda x: x, # Identity,
ops.convert_to_tensor,
np.array,
lambda x: np.array(x, dtype=np.float32),
lambda x: np.array(x, dtype=np.float64),
np_array_ops.array,
lambda x: np_array_ops.array(x, dtype=np.float32),
lambda x: np_array_ops.array(x, dtype=np.float64),
]
self.types = [np.int32, np.int64, np.float32, np.float64]
def _testBinaryOp(self,
math_fun,
np_fun,
name,
operands=None,
extra_operands=None,
check_promotion=True,
check_promotion_result_type=True):
def run_test(a, b):
for fn in self.array_transforms:
arg1 = fn(a)
arg2 = fn(b)
self.match(
math_fun(arg1, arg2),
np_fun(arg1, arg2),
msg='{}({}, {})'.format(name, arg1, arg2))
# Tests type promotion
for type_a in self.types:
for type_b in self.types:
if not check_promotion and type_a != type_b:
continue
arg1 = np_array_ops.array(a, dtype=type_a)
arg2 = np_array_ops.array(b, dtype=type_b)
self.match(
math_fun(arg1, arg2),
np_fun(arg1, arg2),
msg='{}({}, {})'.format(name, arg1, arg2),
check_dtype=check_promotion_result_type)
if operands is None:
operands = [(5, 2), (5, [2, 3]), (5, [[2, 3], [6, 7]]), ([1, 2, 3], 7),
([1, 2, 3], [5, 6, 7])]
for operand1, operand2 in operands:
run_test(operand1, operand2)
if extra_operands is not None:
for operand1, operand2 in extra_operands:
run_test(operand1, operand2)
def testDot(self):
extra_operands = [([1, 2], [[5, 6, 7], [8, 9, 10]]),
(np.arange(2 * 3 * 5).reshape([2, 3, 5]).tolist(),
np.arange(5 * 7 * 11).reshape([7, 5, 11]).tolist())]
return self._testBinaryOp(
np_math_ops.dot, np.dot, 'dot', extra_operands=extra_operands)
def testMinimum(self):
# The numpy version has strange result type when promotion happens,
# so set check_promotion_result_type to False.
return self._testBinaryOp(
np_math_ops.minimum,
np.minimum,
'minimum',
check_promotion_result_type=False)
def testMaximum(self):
# The numpy version has strange result type when promotion happens,
# so set check_promotion_result_type to False.
return self._testBinaryOp(
np_math_ops.maximum,
np.maximum,
'maximum',
check_promotion_result_type=False)
def testMatmul(self):
operands = [([[1, 2]], [[3, 4, 5], [6, 7, 8]])]
return self._testBinaryOp(
np_math_ops.matmul, np.matmul, 'matmul', operands=operands)
def testMatmulError(self):
with self.assertRaisesRegex(ValueError, r''):
np_math_ops.matmul(
np_array_ops.ones([], np.int32), np_array_ops.ones([2, 3], np.int32))
with self.assertRaisesRegex(ValueError, r''):
np_math_ops.matmul(
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
def _testUnaryOp(self, math_fun, np_fun, name):
def run_test(a):
for fn in self.array_transforms:
arg1 = fn(a)
self.match(
math_fun(arg1), np_fun(arg1), msg='{}({})'.format(name, arg1))
run_test(5)
run_test([2, 3])
run_test([[2, -3], [-6, 7]])
def testLog(self):
self._testUnaryOp(np_math_ops.log, np.log, 'log')
def testExp(self):
self._testUnaryOp(np_math_ops.exp, np.exp, 'exp')
def testTanh(self):
self._testUnaryOp(np_math_ops.tanh, np.tanh, 'tanh')
def testSqrt(self):
self._testUnaryOp(np_math_ops.sqrt, np.sqrt, 'sqrt')
def match(self, actual, expected, msg='', check_dtype=True):
self.assertIsInstance(actual, np_arrays.ndarray)
if check_dtype:
self.assertEqual(
actual.dtype, expected.dtype,
'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format(
actual.dtype, expected.dtype, msg))
self.assertEqual(
actual.shape, expected.shape,
'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format(
actual.shape, expected.shape, msg))
np.testing.assert_almost_equal(actual.tolist(), expected.tolist())
def testArgsort(self):
self._testUnaryOp(np_math_ops.argsort, np.argsort, 'argsort')
# Test stability
r = np.arange(100)
a = np.zeros(100)
np.testing.assert_equal(np_math_ops.argsort(a, kind='stable'), r)
def testArgMaxArgMin(self):
data = [
0,
5,
[1],
[1, 2, 3],
[[1, 2, 3]],
[[4, 6], [7, 8]],
[[[4, 6], [9, 10]], [[7, 8], [12, 34]]],
]
for fn, d in itertools.product(self.array_transforms, data):
arr = fn(d)
self.match(np_math_ops.argmax(arr), np.argmax(arr))
self.match(np_math_ops.argmin(arr), np.argmin(arr))
if hasattr(arr, 'shape'):
ndims = len(arr.shape)
else:
ndims = np_array_ops.array(arr, copy=False).ndim
if ndims == 0:
# Numpy flattens the scalar ndarray and treats it as a 1-d array of
# size 1.
ndims = 1
for axis in range(-ndims, ndims):
self.match(
np_math_ops.argmax(arr, axis=axis), np.argmax(arr, axis=axis))
self.match(
np_math_ops.argmin(arr, axis=axis), np.argmin(arr, axis=axis))
@parameterized.parameters([False, True])
def testIsCloseEqualNan(self, equal_nan):
a = np.asarray([1, 1, np.nan, 1, np.nan], np.float32)
b = np.asarray([1, 2, 1, np.nan, np.nan], np.float32)
self.match(
np_math_ops.isclose(a, b, equal_nan=equal_nan),
np.isclose(a, b, equal_nan=equal_nan))
def testAverageWrongShape(self):
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
np_math_ops.average(np.ones([2, 3]), weights=np.ones([2, 4]))
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([2, 4]))
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([]))
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([5]))
def testClip(self):
def run_test(arr, *args, **kwargs):
check_dtype = kwargs.pop('check_dtype', True)
for fn in self.array_transforms:
arr = fn(arr)
self.match(
np_math_ops.clip(arr, *args, **kwargs),
np.clip(arr, *args, **kwargs),
check_dtype=check_dtype)
# NumPy exhibits weird typing behavior when a/a_min/a_max are scalars v/s
# lists, e.g.,
#
# np.clip(np.array(0, dtype=np.int32), -5, 5).dtype == np.int64
# np.clip(np.array([0], dtype=np.int32), -5, 5).dtype == np.int32
# np.clip(np.array([0], dtype=np.int32), [-5], [5]).dtype == np.int64
#
# So we skip matching type. In tf-numpy the type of the output array is
# always the same as the input array.
run_test(0, -1, 5, check_dtype=False)
run_test(-1, -1, 5, check_dtype=False)
run_test(5, -1, 5, check_dtype=False)
run_test(-10, -1, 5, check_dtype=False)
run_test(10, -1, 5, check_dtype=False)
run_test(10, None, 5, check_dtype=False)
run_test(10, -1, None, check_dtype=False)
run_test([0, 20, -5, 4], -1, 5, check_dtype=False)
run_test([0, 20, -5, 4], None, 5, check_dtype=False)
run_test([0, 20, -5, 4], -1, None, check_dtype=False)
run_test([0.5, 20.2, -5.7, 4.4], -1.5, 5.1, check_dtype=False)
run_test([0, 20, -5, 4], [-5, 0, -5, 0], [0, 5, 0, 5], check_dtype=False)
run_test([[1, 2, 3], [4, 5, 6]], [2, 0, 2], 5, check_dtype=False)
run_test([[1, 2, 3], [4, 5, 6]], 0, [5, 3, 1], check_dtype=False)
def testPtp(self):
def run_test(arr, *args, **kwargs):
for fn in self.array_transforms:
arg = fn(arr)
self.match(
np_math_ops.ptp(arg, *args, **kwargs), np.ptp(arg, *args, **kwargs))
run_test([1, 2, 3])
run_test([1., 2., 3.])
run_test([[1, 2], [3, 4]], axis=1)
run_test([[1, 2], [3, 4]], axis=0)
run_test([[1, 2], [3, 4]], axis=-1)
run_test([[1, 2], [3, 4]], axis=-2)
def testLinSpace(self):
array_transforms = [
lambda x: x, # Identity,
ops.convert_to_tensor,
np.array,
lambda x: np.array(x, dtype=np.float32),
lambda x: np.array(x, dtype=np.float64),
np_array_ops.array,
lambda x: np_array_ops.array(x, dtype=np.float32),
lambda x: np_array_ops.array(x, dtype=np.float64)
]
def run_test(start, stop, **kwargs):
for fn1 in array_transforms:
for fn2 in array_transforms:
arg1 = fn1(start)
arg2 = fn2(stop)
self.match(
np_math_ops.linspace(arg1, arg2, **kwargs),
np.linspace(arg1, arg2, **kwargs),
msg='linspace({}, {})'.format(arg1, arg2))
run_test(0, 1)
run_test(0, 1, num=10)
run_test(0, 1, endpoint=False)
run_test(0, -1)
run_test(0, -1, num=10)
run_test(0, -1, endpoint=False)
def testLogSpace(self):
array_transforms = [
lambda x: x, # Identity,
ops.convert_to_tensor,
np.array,
lambda x: np.array(x, dtype=np.float32),
lambda x: np.array(x, dtype=np.float64),
np_array_ops.array,
lambda x: np_array_ops.array(x, dtype=np.float32),
lambda x: np_array_ops.array(x, dtype=np.float64)
]
def run_test(start, stop, **kwargs):
for fn1 in array_transforms:
for fn2 in array_transforms:
arg1 = fn1(start)
arg2 = fn2(stop)
self.match(
np_math_ops.logspace(arg1, arg2, **kwargs),
np.logspace(arg1, arg2, **kwargs),
msg='logspace({}, {})'.format(arg1, arg2))
run_test(0, 5)
run_test(0, 5, num=10)
run_test(0, 5, endpoint=False)
run_test(0, 5, base=2.0)
run_test(0, -5)
run_test(0, -5, num=10)
run_test(0, -5, endpoint=False)
run_test(0, -5, base=2.0)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,56 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Random functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.numpy_ops import np_utils
DEFAULT_RANDN_DTYPE = np.float32
def randn(*args):
"""Returns samples from a normal distribution.
Uses `tf.random_normal`.
Args:
*args: The shape of the output array.
Returns:
An ndarray with shape `args` and dtype `float64`.
"""
# TODO(wangpeng): Use new stateful RNG
if np_utils.isscalar(args):
args = (args,)
return np_utils.tensor_to_ndarray(
random_ops.random_normal(args, dtype=DEFAULT_RANDN_DTYPE))
def seed(s):
"""Sets the seed for the random number generator.
Uses `tf.set_random_seed`.
Args:
s: an integer.
"""
# TODO(wangpeng): make the signature the same as numpy
random_seed.set_seed(s)

View File

@ -0,0 +1,91 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf numpy random number methods."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range
from tensorflow.python.framework import ops
# Needed for ndarray.reshape.
from tensorflow.python.ops.numpy_ops import np_array_ops # pylint: disable=unused-import
from tensorflow.python.ops.numpy_ops import np_random
from tensorflow.python.platform import test
class RandomTest(test.TestCase):
def assertNotAllClose(self, a, b, **kwargs):
try:
self.assertAllClose(a, b, **kwargs)
except AssertionError:
return
raise AssertionError('The two values are close at all %d elements' %
np.size(a))
def testRandN(self):
def run_test(*args):
num_samples = 1000
tol = 0.1 # High tolerance to keep the # of samples low else the test
# takes a long time to run.
np_random.seed(10)
outputs = [np_random.randn(*args) for _ in range(num_samples)]
# Test output shape.
for output in outputs:
self.assertEqual(output.shape, tuple(args))
self.assertEqual(output.dtype.type, np_random.DEFAULT_RANDN_DTYPE)
if np.prod(args): # Don't bother with empty arrays.
outputs = [output.tolist() for output in outputs]
# Test that the properties of normal distribution are satisfied.
mean = np.mean(outputs, axis=0)
stddev = np.std(outputs, axis=0)
self.assertAllClose(mean, np.zeros(args), atol=tol)
self.assertAllClose(stddev, np.ones(args), atol=tol)
# Test that outputs are different with different seeds.
np_random.seed(20)
diff_seed_outputs = [
np_random.randn(*args).tolist() for _ in range(num_samples)
]
self.assertNotAllClose(outputs, diff_seed_outputs)
# Test that outputs are the same with the same seed.
np_random.seed(10)
same_seed_outputs = [
np_random.randn(*args).tolist() for _ in range(num_samples)
]
self.assertAllClose(outputs, same_seed_outputs)
run_test()
run_test(0)
run_test(1)
run_test(5)
run_test(2, 3)
run_test(0, 2, 3)
run_test(2, 0, 3)
run_test(2, 3, 0)
run_test(2, 3, 5)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -0,0 +1,448 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for internal use."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.util import nest
tensor_to_ndarray = np_arrays.tensor_to_ndarray
def _supports_signature():
return hasattr(inspect, 'signature')
def _to_tf_type(dtype):
"""Converts a native python or numpy type to TF DType.
Args:
dtype: Could be a python type, a numpy type or a TF DType.
Returns:
A tensorflow `DType`.
"""
return dtypes.as_dtype(dtype)
def _to_numpy_type(dtype):
"""Converts a native python or TF DType to numpy type.
Args:
dtype: Could be a python type, a numpy type or a TF DType.
Returns:
A NumPy `dtype`.
"""
if isinstance(dtype, dtypes.DType):
return dtype.as_numpy_dtype
return np.dtype(dtype)
def finfo(dtype):
"""Returns properties of floating point types.
Note that currently it just forwards to the numpy namesake, while tensorflow
and numpy dtypes may have different properties.
Args:
dtype: Could be a python type, a numpy type or a TF DType.
Returns:
A class describing properties of `dtype`, as described by
https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html
"""
return np.finfo(_to_numpy_type(dtype))
def isscalar(val):
"""Returns whether `val` is a scalar value or scalar Tensor."""
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)):
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test
return np.isscalar(val)
# Can't use np_doc because np.result_type is a builtin function.
def result_type(*arrays_and_dtypes):
"""Returns the type resulting from applying NumPy type promotion to arguments.
Args:
*arrays_and_dtypes: A list of array_like objects or dtypes.
Returns:
A numpy dtype.
"""
def maybe_get_dtype(x):
# Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, (np_arrays.ndarray, np_arrays.ShardedNdArray, ops.Tensor,
indexed_slices.IndexedSlices)):
return _to_numpy_type(x.dtype)
elif isinstance(x, dtypes.DType):
return _to_numpy_type(x)
return x
arrays_and_dtypes = [
maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
]
if not arrays_and_dtypes:
# If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
arrays_and_dtypes = [np.asarray([])]
return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
def promote_types(type1, type2):
"""Returns the type resulting from applying NumPy type promotion.
Args:
type1: A numpy type.
type2: A numpy type.
Returns:
A numpy type.
"""
type1 = _to_numpy_type(type1)
type2 = _to_numpy_type(type2)
return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
def _has_docstring(f):
return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
f.__doc__)
def _add_blank_line(s):
if s.endswith('\n'):
return s + '\n'
else:
return s + '\n\n'
def _np_signature(f):
"""An enhanced inspect.signature that can handle numpy.ufunc."""
# TODO(wangpeng): consider migrating away from inspect.signature.
# inspect.signature is supported in Python 3.3.
if not hasattr(inspect, 'signature'):
return None
if f is None:
return None
if not isinstance(f, np.ufunc):
try:
return inspect.signature(f)
except ValueError:
return None
def names_from_num(prefix, n):
if n <= 0:
return []
elif n == 1:
return [prefix]
else:
return [prefix + str(i + 1) for i in range(n)]
input_names = names_from_num('x', f.nin)
output_names = names_from_num('out', f.nout)
keyword_only_params = [('where', True), ('casting', 'same_kind'),
('order', 'K'), ('dtype', None), ('subok', True),
('signature', None), ('extobj', None)]
params = []
params += [
inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
for name in input_names
]
if f.nout > 1:
params += [
inspect.Parameter(
name, inspect.Parameter.POSITIONAL_ONLY, default=None)
for name in output_names
]
params += [
inspect.Parameter(
'out',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None if f.nout == 1 else (None,) * f.nout)
]
params += [
inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
for name, default in keyword_only_params
]
return inspect.Signature(params)
# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
# allow positional-only argument. So we conflate positional-only, keyword-only
# and positional-or-keyword arguments here.
def _is_compatible_param_kind(a, b):
def relax(k):
if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
return inspect.Parameter.POSITIONAL_OR_KEYWORD
return k
return relax(a) == relax(b)
def np_doc(np_fun, np_fun_name=None):
"""Attachs numpy docstring to a function.
Args:
np_fun: the numpy function whose docstring will be used.
np_fun_name: optional name for the np_fun symbol. At least one of np_fun or
np_fun_name shoud be set.
Returns:
A function decorator that attaches the docstring from `np_fun` to the
decorated function.
"""
if np_fun is None:
assert np_fun_name is not None
try:
np_fun = getattr(np, str(np_fun_name))
except AttributeError:
np_fun = None
np_sig = _np_signature(np_fun)
if np_fun_name is None:
assert np_fun is not None
np_fun_name = np_fun.__name__
def decorator(f):
"""The decorator."""
unsupported_params = []
if hasattr(inspect, 'signature') and np_sig is not None:
try:
sig = inspect.signature(f)
except ValueError:
sig = None
# TODO(wangpeng): Enable this.
# Looks like this may not work with different versions of numpy.
# if sig is not None:
# for name, param in sig.parameters.items():
# np_param = np_sig.parameters.get(name)
# if np_param is None:
# raise TypeError('Cannot find parameter "%s" in the numpy
# function\'s ' 'signature' % name)
# if not _is_compatible_param_kind(param.kind, np_param.kind):
# raise TypeError(
# 'Parameter "%s" is of kind %s while in numpy it is of '
# 'kind %s' % (name, param.kind, np_param.kind))
# has_default = (param.default != inspect.Parameter.empty)
# np_has_default = (np_param.default != inspect.Parameter.empty)
# if has_default != np_has_default:
# raise TypeError('Parameter "%s" should%s have a default value' %
# (name, '' if np_has_default else ' not'))
# for name in np_sig.parameters:
# if name not in sig.parameters:
# unsupported_params.append(name)
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name,
unsupported_params=unsupported_params)
return f
return decorator
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
"""Helper to get docs."""
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f
# comment, here and below.
return np_f.__doc__.replace('>>>', '>')
assert np_f or np_fun_name
if not np_fun_name:
np_fun_name = np_f.__name__
doc = 'TensorFlow variant of `numpy.%s`.\n\n' % np_fun_name
if unsupported_params:
doc += 'Unsupported arguments: ' + ', '.join(
'`' + name + '`' for name in unsupported_params) + '.\n\n'
if _has_docstring(f):
doc += f.__doc__
doc = _add_blank_line(doc)
if _has_docstring(np_f):
doc += 'Documentation for `numpy.%s`:\n\n' % np_f.__name__
doc += np_f.__doc__.replace('>>>', '>')
return doc
def np_doc_only(np_f):
"""Attachs numpy docstring to a function.
This differs from np_doc in that it doesn't check for a match in signature.
Args:
np_f: the numpy function whose docstring will be used.
Returns:
A function decorator that attaches the docstring from `np_f` to the
decorated function.
"""
def decorator(f):
f.__doc__ = _np_doc_helper(f, np_f)
return f
return decorator
def tf_broadcast(*args):
"""Broadcast tensors.
Args:
*args: a list of tensors whose shapes are broadcastable against each other.
Returns:
Tensors broadcasted to the common shape.
"""
if len(args) <= 1:
return args
sh = array_ops.shape(args[0])
for arg in args[1:]:
sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
return [array_ops.broadcast_to(arg, sh) for arg in args]
# TODO(wangpeng): Move the following functions to a separate file and check for
# float dtypes in each of them.
def get_static_value(x):
"""A version of tf.get_static_value that returns None on float dtypes.
It returns None on float dtypes in order to avoid breaking gradients.
Args:
x: a tensor.
Returns:
Same as `tf.get_static_value`, except that it returns None when `x` has a
float dtype.
"""
if isinstance(x, ops.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
return None
return tensor_util.constant_value(x)
def _maybe_static(x):
value = get_static_value(x)
if value is None:
return x
else:
return value
# All the following functions exist becaues get_static_value can't handle
# their TF counterparts.
def cond(pred, true_fn, false_fn):
"""A version of tf.cond that tries to evaluate the condition."""
v = get_static_value(pred)
if v is None:
return control_flow_ops.cond(pred, true_fn, false_fn)
if v:
return true_fn()
else:
return false_fn()
def add(a, b):
"""A version of tf.add that eagerly evaluates if possible."""
return _maybe_static(a) + _maybe_static(b)
def subtract(a, b):
"""A version of tf.subtract that eagerly evaluates if possible."""
return _maybe_static(a) - _maybe_static(b)
def greater(a, b):
"""A version of tf.greater that eagerly evaluates if possible."""
return _maybe_static(a) > _maybe_static(b)
def greater_equal(a, b):
"""A version of tf.greater_equal that eagerly evaluates if possible."""
return _maybe_static(a) >= _maybe_static(b)
def less_equal(a, b):
"""A version of tf.less_equal that eagerly evaluates if possible."""
return _maybe_static(a) <= _maybe_static(b)
def logical_and(a, b):
"""A version of tf.logical_and that eagerly evaluates if possible."""
a_value = get_static_value(a)
if a_value is not None:
if np.isscalar(a_value):
if a_value:
return _maybe_static(b)
else:
return a_value
else:
return a_value & _maybe_static(b)
else:
return a & _maybe_static(b)
def logical_or(a, b):
"""A version of tf.logical_or that eagerly evaluates if possible."""
a_value = get_static_value(a)
if a_value is not None:
if np.isscalar(a_value):
if a_value:
return a_value
else:
return _maybe_static(b)
else:
return a_value | _maybe_static(b)
else:
return a | _maybe_static(b)
def getitem(a, slice_spec):
"""A version of __getitem__ that eagerly evaluates if possible."""
return _maybe_static(a)[slice_spec]
def reduce_all(input_tensor, axis=None, keepdims=False):
"""A version of tf.reduce_all that eagerly evaluates if possible."""
v = get_static_value(input_tensor)
if v is None:
return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
else:
return v.all(axis=axis, keepdims=keepdims)
def reduce_any(input_tensor, axis=None, keepdims=False):
"""A version of tf.reduce_any that eagerly evaluates if possible."""
v = get_static_value(input_tensor)
if v is None:
return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
else:
return v.any(axis=axis, keepdims=keepdims)

View File

@ -0,0 +1,92 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utils.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.numpy_ops import np_utils
from tensorflow.python.platform import test
class UtilsTest(test.TestCase):
# pylint: disable=unused-argument
def testNpDoc(self):
def np_fun(x):
"""np_fun docstring."""
return
@np_utils.np_doc(np_fun)
def f():
"""f docstring."""
return
expected = """TensorFlow variant of `numpy.np_fun`.
f docstring.
Documentation for `numpy.np_fun`:
np_fun docstring."""
self.assertEqual(expected, f.__doc__)
def testNpDocName(self):
@np_utils.np_doc(None, np_fun_name='foo')
def f():
"""f docstring."""
return
expected = """TensorFlow variant of `numpy.foo`.
f docstring.
"""
self.assertEqual(expected, f.__doc__)
def testNpDocErrors(self):
self.skipTest('Enable once np signature checking is done.')
# if not np_utils._supports_signature():
# self.skipTest("inspect.signature not supported")
def np_fun(x, y=1, **kwargs):
return
# pylint: disable=unused-variable
with self.assertRaisesRegexp(TypeError, 'Cannot find parameter'):
@np_utils.np_doc(np_fun)
def f1(a):
return
with self.assertRaisesRegexp(TypeError, 'is of kind'):
@np_utils.np_doc(np_fun)
def f2(x, kwargs):
return
with self.assertRaisesRegexp(TypeError,
'Parameter "y" should have a default value'):
@np_utils.np_doc(np_fun)
def f3(x, y):
return
if __name__ == '__main__':
test.main()