Merge pull request #42539 from DarrenZhang01:master
PiperOrigin-RevId: 330840468 Change-Id: Iafdbf4bd8bc74f1487daf05e15ce15f5c406d6aa
This commit is contained in:
		
						commit
						f56f3a7391
					
				@ -560,6 +560,21 @@ def _reduce(tf_fn,
 | 
			
		||||
      tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO (DarrenZhang01): Add `axis` support to the `size` API.
 | 
			
		||||
@np_utils.np_doc('size')
 | 
			
		||||
def size(x, axis=None):  # pylint: disable=missing-docstring
 | 
			
		||||
  if axis is not None:
 | 
			
		||||
    raise NotImplementedError('axis argument is not supported in the current '
 | 
			
		||||
                              '`np.size` implementation')
 | 
			
		||||
  if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
 | 
			
		||||
    return 1
 | 
			
		||||
  x = asarray(x).data
 | 
			
		||||
  if x.shape.is_fully_defined():
 | 
			
		||||
    return np.prod(x.shape.as_list())
 | 
			
		||||
  else:
 | 
			
		||||
    return np_utils.tensor_to_ndarray(array_ops.size_v2(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@np_utils.np_doc('sum')
 | 
			
		||||
def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
 | 
			
		||||
  return _reduce(
 | 
			
		||||
 | 
			
		||||
@ -25,12 +25,14 @@ from six.moves import range
 | 
			
		||||
from six.moves import zip
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.framework import config
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import indexed_slices
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.framework import tensor_shape
 | 
			
		||||
from tensorflow.python.framework import tensor_spec
 | 
			
		||||
from tensorflow.python.ops import array_ops
 | 
			
		||||
from tensorflow.python.ops.numpy_ops import np_array_ops
 | 
			
		||||
from tensorflow.python.ops.numpy_ops import np_arrays
 | 
			
		||||
@ -804,6 +806,31 @@ class ArrayMethodsTest(test.TestCase):
 | 
			
		||||
  def testAmax(self):
 | 
			
		||||
    self._testReduce(np_array_ops.amax, np.amax, 'amax')
 | 
			
		||||
 | 
			
		||||
  def testSize(self):
 | 
			
		||||
 | 
			
		||||
    def run_test(arr, axis=None):
 | 
			
		||||
      onp_arr = np.array(arr)
 | 
			
		||||
      self.assertEqual(np_array_ops.size(arr, axis), np.size(onp_arr, axis))
 | 
			
		||||
 | 
			
		||||
    run_test(np_array_ops.array([1]))
 | 
			
		||||
    run_test(np_array_ops.array([1, 2, 3, 4, 5]))
 | 
			
		||||
    run_test(np_array_ops.ones((2, 3, 2)))
 | 
			
		||||
    run_test(np_array_ops.ones((3, 2)))
 | 
			
		||||
    run_test(np_array_ops.zeros((5, 6, 7)))
 | 
			
		||||
    run_test(1)
 | 
			
		||||
    run_test(np_array_ops.ones((3, 2, 1)))
 | 
			
		||||
    run_test(constant_op.constant(5))
 | 
			
		||||
    run_test(constant_op.constant([1, 1, 1]))
 | 
			
		||||
    self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
 | 
			
		||||
                      1)
 | 
			
		||||
 | 
			
		||||
    @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=None)])
 | 
			
		||||
    def f(arr):
 | 
			
		||||
      arr = np_array_ops.asarray(arr)
 | 
			
		||||
      return np_array_ops.size(arr)
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(f(np_array_ops.ones((3, 2))).data.numpy(), 6)
 | 
			
		||||
 | 
			
		||||
  def testRavel(self):
 | 
			
		||||
 | 
			
		||||
    def run_test(arr, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -784,6 +784,10 @@ tf_module {
 | 
			
		||||
    name: "sinh"
 | 
			
		||||
    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "size"
 | 
			
		||||
    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "sort"
 | 
			
		||||
    argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], "
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user