Merge pull request #42539 from DarrenZhang01:master

PiperOrigin-RevId: 330840468
Change-Id: Iafdbf4bd8bc74f1487daf05e15ce15f5c406d6aa
This commit is contained in:
TensorFlower Gardener 2020-09-09 18:25:07 -07:00
commit f56f3a7391
3 changed files with 46 additions and 0 deletions

View File

@ -560,6 +560,21 @@ def _reduce(tf_fn,
tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
# TODO (DarrenZhang01): Add `axis` support to the `size` API.
@np_utils.np_doc('size')
def size(x, axis=None): # pylint: disable=missing-docstring
if axis is not None:
raise NotImplementedError('axis argument is not supported in the current '
'`np.size` implementation')
if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
return 1
x = asarray(x).data
if x.shape.is_fully_defined():
return np.prod(x.shape.as_list())
else:
return np_utils.tensor_to_ndarray(array_ops.size_v2(x))
@np_utils.np_doc('sum')
def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin
return _reduce(

View File

@ -25,12 +25,14 @@ from six.moves import range
from six.moves import zip
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
@ -804,6 +806,31 @@ class ArrayMethodsTest(test.TestCase):
def testAmax(self):
self._testReduce(np_array_ops.amax, np.amax, 'amax')
def testSize(self):
def run_test(arr, axis=None):
onp_arr = np.array(arr)
self.assertEqual(np_array_ops.size(arr, axis), np.size(onp_arr, axis))
run_test(np_array_ops.array([1]))
run_test(np_array_ops.array([1, 2, 3, 4, 5]))
run_test(np_array_ops.ones((2, 3, 2)))
run_test(np_array_ops.ones((3, 2)))
run_test(np_array_ops.zeros((5, 6, 7)))
run_test(1)
run_test(np_array_ops.ones((3, 2, 1)))
run_test(constant_op.constant(5))
run_test(constant_op.constant([1, 1, 1]))
self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
1)
@def_function.function(input_signature=[tensor_spec.TensorSpec(shape=None)])
def f(arr):
arr = np_array_ops.asarray(arr)
return np_array_ops.size(arr)
self.assertEqual(f(np_array_ops.ones((3, 2))).data.numpy(), 6)
def testRavel(self):
def run_test(arr, *args, **kwargs):

View File

@ -784,6 +784,10 @@ tf_module {
name: "sinh"
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "size"
argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "sort"
argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], "