Add sign, broadcast_arrays to tf numpy.
Enable tests for geomspace PiperOrigin-RevId: 315772740 Change-Id: I779a6a728093190f9790786bab1c9df0e78c080d
This commit is contained in:
parent
cdfacf6551
commit
3ba9124877
@ -84,8 +84,8 @@ def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name
|
||||
Returns:
|
||||
An ndarray.
|
||||
"""
|
||||
if dtype:
|
||||
dtype = np_utils.result_type(dtype)
|
||||
dtype = (
|
||||
np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
|
||||
if isinstance(shape, np_arrays.ndarray):
|
||||
shape = shape.data
|
||||
return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype))
|
||||
@ -380,28 +380,6 @@ def arange(start, stop=None, step=1, dtype=None):
|
||||
math_ops.cast(math_ops.range(start, limit=stop, delta=step), dtype=dtype))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.geomspace)
|
||||
def geomspace(start, stop, num=50, endpoint=True, dtype=float): # pylint: disable=missing-docstring
|
||||
if dtype:
|
||||
dtype = np_utils.result_type(dtype)
|
||||
if num < 0:
|
||||
raise ValueError('Number of samples {} must be non-negative.'.format(num))
|
||||
if not num:
|
||||
return empty([0])
|
||||
step = 1.
|
||||
if endpoint:
|
||||
if num > 1:
|
||||
step = math_ops.pow((stop / start), 1 / (num - 1))
|
||||
else:
|
||||
step = math_ops.pow((stop / start), 1 / num)
|
||||
result = math_ops.cast(math_ops.range(num), step.dtype)
|
||||
result = math_ops.pow(step, result)
|
||||
result = math_ops.multiply(result, start)
|
||||
if dtype:
|
||||
result = math_ops.cast(result, dtype=dtype)
|
||||
return np_arrays.tensor_to_ndarray(result)
|
||||
|
||||
|
||||
# Building matrices.
|
||||
@np_utils.np_doc(np.diag)
|
||||
def diag(v, k=0): # pylint: disable=missing-docstring
|
||||
@ -1636,3 +1614,35 @@ def ix_(*args): # pylint: disable=missing-docstring
|
||||
'Only integer and bool dtypes are supported, got {}'.format(dtype))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@np_utils.np_doc(np.broadcast_arrays)
|
||||
def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring
|
||||
subok = kwargs.pop('subok', False)
|
||||
if subok:
|
||||
raise ValueError('subok=True is not supported.')
|
||||
if kwargs:
|
||||
raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
|
||||
|
||||
args = [asarray(arg).data for arg in args]
|
||||
args = np_utils.tf_broadcast(*args)
|
||||
return [np_utils.tensor_to_ndarray(arg) for arg in args]
|
||||
|
||||
|
||||
@np_utils.np_doc_only(np.sign)
|
||||
def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name
|
||||
if out:
|
||||
raise ValueError('tf.numpy doesnt support setting out.')
|
||||
if where:
|
||||
raise ValueError('tf.numpy doesnt support setting where.')
|
||||
if kwargs:
|
||||
raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
|
||||
|
||||
x = asarray(x)
|
||||
dtype = x.dtype
|
||||
if np.issubdtype(dtype, np.complex):
|
||||
result = math_ops.cast(math_ops.sign(math_ops.real(x.data)), dtype)
|
||||
else:
|
||||
result = math_ops.sign(x.data)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
@ -400,27 +399,6 @@ class ArrayCreationTest(test.TestCase):
|
||||
np.arange(start, stop, step, dtype=dtype),
|
||||
msg=msg)
|
||||
|
||||
def testGeomSpace(self):
|
||||
|
||||
def run_test(start, stop, **kwargs):
|
||||
arg1 = start
|
||||
arg2 = stop
|
||||
if test_util.is_gpu_available():
|
||||
decimal = 3
|
||||
else:
|
||||
decimal = 4
|
||||
self.match(
|
||||
np_array_ops.geomspace(arg1, arg2, **kwargs),
|
||||
np.geomspace(arg1, arg2, **kwargs),
|
||||
msg='geomspace({}, {})'.format(arg1, arg2),
|
||||
almost=True,
|
||||
decimal=decimal)
|
||||
|
||||
run_test(1, 1000, num=5)
|
||||
run_test(1, 1000, num=5, endpoint=False)
|
||||
run_test(-1, -1000, num=5)
|
||||
run_test(-1, -1000, num=5, endpoint=False)
|
||||
|
||||
def testDiag(self):
|
||||
array_transforms = [
|
||||
lambda x: x, # Identity,
|
||||
@ -1081,6 +1059,21 @@ class ArrayMethodsTest(test.TestCase):
|
||||
y = np_array_ops.split(x, [3, 5, 6, 10])
|
||||
self.assertListEqual([([0, 1, 2]), ([3, 4]), ([5]), ([6, 7]), ([])], y)
|
||||
|
||||
def testSign(self):
|
||||
state = np.random.RandomState(0)
|
||||
test_types = [np.float16, np.float32, np.float64, np.int32, np.int64,
|
||||
np.complex64, np.complex128]
|
||||
test_shapes = [(), (1,), (2, 3, 4), (2, 3, 0, 4)]
|
||||
|
||||
for dtype in test_types:
|
||||
for shape in test_shapes:
|
||||
if np.issubdtype(dtype, np.complex):
|
||||
arr = (np.asarray(state.randn(*shape) * 100, dtype=dtype) +
|
||||
1j * np.asarray(state.randn(*shape) * 100, dtype=dtype))
|
||||
else:
|
||||
arr = np.asarray(state.randn(*shape) * 100, dtype=dtype)
|
||||
self.match(np_array_ops.sign(arr), np.sign(arr))
|
||||
|
||||
|
||||
class ArrayManipulationTest(test.TestCase):
|
||||
|
||||
|
@ -1043,6 +1043,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
|
||||
return np_arrays.tensor_to_ndarray(result)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.geomspace)
|
||||
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring
|
||||
dtype = dtype or np_utils.result_type(start, stop, float(num),
|
||||
np_array_ops.zeros((), dtype))
|
||||
computation_dtype = np.promote_types(dtype, np.float32)
|
||||
start = np_array_ops.asarray(start, dtype=computation_dtype)
|
||||
stop = np_array_ops.asarray(stop, dtype=computation_dtype)
|
||||
# follow the numpy geomspace convention for negative and complex endpoints
|
||||
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
|
||||
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
|
||||
signflip = 1 - start_sign * stop_sign // 2
|
||||
res = signflip * logspace(log10(signflip * start),
|
||||
log10(signflip * stop), num,
|
||||
endpoint=endpoint, base=10.0,
|
||||
dtype=computation_dtype, axis=0)
|
||||
if axis != 0:
|
||||
res = np_array_ops.moveaxis(res, 0, axis)
|
||||
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.ptp)
|
||||
def ptp(a, axis=None, keepdims=None):
|
||||
return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) -
|
||||
|
@ -326,6 +326,21 @@ class MathTest(test.TestCase, parameterized.TestCase):
|
||||
run_test(0, -5, endpoint=False)
|
||||
run_test(0, -5, base=2.0)
|
||||
|
||||
def testGeomSpace(self):
|
||||
|
||||
def run_test(start, stop, **kwargs):
|
||||
arg1 = start
|
||||
arg2 = stop
|
||||
self.match(
|
||||
np_math_ops.geomspace(arg1, arg2, **kwargs),
|
||||
np.geomspace(arg1, arg2, **kwargs),
|
||||
msg='geomspace({}, {})'.format(arg1, arg2))
|
||||
|
||||
run_test(1, 1000, num=5)
|
||||
run_test(1, 1000, num=5, endpoint=False)
|
||||
run_test(-1, -1000, num=5)
|
||||
run_test(-1, -1000, num=5, endpoint=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
x
Reference in New Issue
Block a user