Add sign, broadcast_arrays to tf numpy.

Enable tests for geomspace

PiperOrigin-RevId: 315772740
Change-Id: I779a6a728093190f9790786bab1c9df0e78c080d
This commit is contained in:
Akshay Modi 2020-06-10 14:46:48 -07:00 committed by TensorFlower Gardener
parent cdfacf6551
commit 3ba9124877
4 changed files with 84 additions and 46 deletions

View File

@ -84,8 +84,8 @@ def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name
Returns:
An ndarray.
"""
if dtype:
dtype = np_utils.result_type(dtype)
dtype = (
np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
if isinstance(shape, np_arrays.ndarray):
shape = shape.data
return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype))
@ -380,28 +380,6 @@ def arange(start, stop=None, step=1, dtype=None):
math_ops.cast(math_ops.range(start, limit=stop, delta=step), dtype=dtype))
@np_utils.np_doc(np.geomspace)
def geomspace(start, stop, num=50, endpoint=True, dtype=float): # pylint: disable=missing-docstring
if dtype:
dtype = np_utils.result_type(dtype)
if num < 0:
raise ValueError('Number of samples {} must be non-negative.'.format(num))
if not num:
return empty([0])
step = 1.
if endpoint:
if num > 1:
step = math_ops.pow((stop / start), 1 / (num - 1))
else:
step = math_ops.pow((stop / start), 1 / num)
result = math_ops.cast(math_ops.range(num), step.dtype)
result = math_ops.pow(step, result)
result = math_ops.multiply(result, start)
if dtype:
result = math_ops.cast(result, dtype=dtype)
return np_arrays.tensor_to_ndarray(result)
# Building matrices.
@np_utils.np_doc(np.diag)
def diag(v, k=0): # pylint: disable=missing-docstring
@ -1636,3 +1614,35 @@ def ix_(*args): # pylint: disable=missing-docstring
'Only integer and bool dtypes are supported, got {}'.format(dtype))
return output
@np_utils.np_doc(np.broadcast_arrays)
def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring
subok = kwargs.pop('subok', False)
if subok:
raise ValueError('subok=True is not supported.')
if kwargs:
raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
args = [asarray(arg).data for arg in args]
args = np_utils.tf_broadcast(*args)
return [np_utils.tensor_to_ndarray(arg) for arg in args]
@np_utils.np_doc_only(np.sign)
def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name
if out:
raise ValueError('tf.numpy doesnt support setting out.')
if where:
raise ValueError('tf.numpy doesnt support setting where.')
if kwargs:
raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
x = asarray(x)
dtype = x.dtype
if np.issubdtype(dtype, np.complex):
result = math_ops.cast(math_ops.sign(math_ops.real(x.data)), dtype)
else:
result = math_ops.sign(x.data)
return np_utils.tensor_to_ndarray(result)

View File

@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
@ -400,27 +399,6 @@ class ArrayCreationTest(test.TestCase):
np.arange(start, stop, step, dtype=dtype),
msg=msg)
def testGeomSpace(self):
def run_test(start, stop, **kwargs):
arg1 = start
arg2 = stop
if test_util.is_gpu_available():
decimal = 3
else:
decimal = 4
self.match(
np_array_ops.geomspace(arg1, arg2, **kwargs),
np.geomspace(arg1, arg2, **kwargs),
msg='geomspace({}, {})'.format(arg1, arg2),
almost=True,
decimal=decimal)
run_test(1, 1000, num=5)
run_test(1, 1000, num=5, endpoint=False)
run_test(-1, -1000, num=5)
run_test(-1, -1000, num=5, endpoint=False)
def testDiag(self):
array_transforms = [
lambda x: x, # Identity,
@ -1081,6 +1059,21 @@ class ArrayMethodsTest(test.TestCase):
y = np_array_ops.split(x, [3, 5, 6, 10])
self.assertListEqual([([0, 1, 2]), ([3, 4]), ([5]), ([6, 7]), ([])], y)
def testSign(self):
state = np.random.RandomState(0)
test_types = [np.float16, np.float32, np.float64, np.int32, np.int64,
np.complex64, np.complex128]
test_shapes = [(), (1,), (2, 3, 4), (2, 3, 0, 4)]
for dtype in test_types:
for shape in test_shapes:
if np.issubdtype(dtype, np.complex):
arr = (np.asarray(state.randn(*shape) * 100, dtype=dtype) +
1j * np.asarray(state.randn(*shape) * 100, dtype=dtype))
else:
arr = np.asarray(state.randn(*shape) * 100, dtype=dtype)
self.match(np_array_ops.sign(arr), np.sign(arr))
class ArrayManipulationTest(test.TestCase):

View File

@ -1043,6 +1043,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
return np_arrays.tensor_to_ndarray(result)
@np_utils.np_doc(np.geomspace)
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring
dtype = dtype or np_utils.result_type(start, stop, float(num),
np_array_ops.zeros((), dtype))
computation_dtype = np.promote_types(dtype, np.float32)
start = np_array_ops.asarray(start, dtype=computation_dtype)
stop = np_array_ops.asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
signflip = 1 - start_sign * stop_sign // 2
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,
dtype=computation_dtype, axis=0)
if axis != 0:
res = np_array_ops.moveaxis(res, 0, axis)
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
@np_utils.np_doc(np.ptp)
def ptp(a, axis=None, keepdims=None):
return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) -

View File

@ -326,6 +326,21 @@ class MathTest(test.TestCase, parameterized.TestCase):
run_test(0, -5, endpoint=False)
run_test(0, -5, base=2.0)
def testGeomSpace(self):
def run_test(start, stop, **kwargs):
arg1 = start
arg2 = stop
self.match(
np_math_ops.geomspace(arg1, arg2, **kwargs),
np.geomspace(arg1, arg2, **kwargs),
msg='geomspace({}, {})'.format(arg1, arg2))
run_test(1, 1000, num=5)
run_test(1, 1000, num=5, endpoint=False)
run_test(-1, -1000, num=5)
run_test(-1, -1000, num=5, endpoint=False)
if __name__ == '__main__':
ops.enable_eager_execution()