diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index 88bf4e7499a..65e8273375f 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -262,7 +262,8 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name """ return np.asarray(self.data, dtype) - __array_priority__ = 110 + # NOTE: we currently prefer interop with TF to allow TF to take precedence. + __array_priority__ = 90 def __index__(self): """Returns a python scalar. diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py index 052949dff9d..f52d3dae78b 100644 --- a/tensorflow/python/ops/numpy_ops/np_interop_test.py +++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py @@ -19,12 +19,18 @@ from __future__ import division from __future__ import print_function +import numpy as onp + + from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays +from tensorflow.python.ops.numpy_ops import np_math_ops from tensorflow.python.platform import test @@ -88,6 +94,50 @@ class InteropTest(test.TestCase): self.assertEqual(10000, fn()[0]) self.assertEqual(10000, def_function.function(fn)()[0]) + def testTensorTFNPArrayInterop(self): + arr = np_array_ops.asarray(0.) + t = constant_op.constant(10.) + + arr_plus_t = arr + t + t_plus_arr = t + arr + + self.assertIsInstance(arr_plus_t, ops.Tensor) + self.assertIsInstance(t_plus_arr, ops.Tensor) + self.assertEqual(10., arr_plus_t.numpy()) + self.assertEqual(10., t_plus_arr.numpy()) + + def testTensorTFNPOp(self): + t = constant_op.constant(10.) + + sq = np_math_ops.square(t) + self.assertIsInstance(sq, np_arrays.ndarray) + self.assertEqual(100., sq) + + def testTFNPArrayTFOpInterop(self): + arr = np_array_ops.asarray(10.) + + # TODO(nareshmodi): Test more ops. + sq = math_ops.square(arr) + self.assertIsInstance(sq, ops.Tensor) + self.assertEqual(100., sq.numpy()) + + def testTFNPArrayNPOpInterop(self): + arr = np_array_ops.asarray([10.]) + + # TODO(nareshmodi): Test more ops. + sq = onp.square(arr) + self.assertIsInstance(sq, onp.ndarray) + self.assertEqual(100., sq[0]) + + # TODO(nareshmodi): Fails since the autopacking code doesn't use + # nest.flatten. +# def testAutopacking(self): +# arr1 = np_array_ops.asarray(1.) +# arr2 = np_array_ops.asarray(2.) +# arr3 = np_array_ops.asarray(3.) +# t = ops.convert_to_tensor_v2([arr1, arr2, arr3]) + +# self.assertEqual(t.numpy(), [1., 2., 3.]) if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index abfd9087ffd..361bfb50dec 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -917,29 +917,37 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring return _scalar(f, a) -def _flip_args(f): +def _wrap(f, reverse=False): + """Wraps binary ops so they can be added as operator overloads on ndarray.""" def _f(a, b): - return f(b, a) + if reverse: + a, b = b, a + + if getattr(b, '__array_priority__', + 0) > np_arrays.ndarray.__array_priority__: + return NotImplemented + + return f(a, b) return _f setattr(np_arrays.ndarray, '__abs__', absolute) -setattr(np_arrays.ndarray, '__floordiv__', floor_divide) -setattr(np_arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide)) -setattr(np_arrays.ndarray, '__mod__', mod) -setattr(np_arrays.ndarray, '__rmod__', _flip_args(mod)) -setattr(np_arrays.ndarray, '__add__', add) -setattr(np_arrays.ndarray, '__radd__', _flip_args(add)) -setattr(np_arrays.ndarray, '__sub__', subtract) -setattr(np_arrays.ndarray, '__rsub__', _flip_args(subtract)) -setattr(np_arrays.ndarray, '__mul__', multiply) -setattr(np_arrays.ndarray, '__rmul__', _flip_args(multiply)) -setattr(np_arrays.ndarray, '__pow__', power) -setattr(np_arrays.ndarray, '__rpow__', _flip_args(power)) -setattr(np_arrays.ndarray, '__truediv__', true_divide) -setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide)) +setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide)) +setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True)) +setattr(np_arrays.ndarray, '__mod__', _wrap(mod)) +setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True)) +setattr(np_arrays.ndarray, '__add__', _wrap(add)) +setattr(np_arrays.ndarray, '__radd__', _wrap(add, True)) +setattr(np_arrays.ndarray, '__sub__', _wrap(subtract)) +setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True)) +setattr(np_arrays.ndarray, '__mul__', _wrap(multiply)) +setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True)) +setattr(np_arrays.ndarray, '__pow__', _wrap(power)) +setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True)) +setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide)) +setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True)) def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): @@ -1031,12 +1039,12 @@ def logical_not(x): setattr(np_arrays.ndarray, '__invert__', logical_not) -setattr(np_arrays.ndarray, '__lt__', less) -setattr(np_arrays.ndarray, '__le__', less_equal) -setattr(np_arrays.ndarray, '__gt__', greater) -setattr(np_arrays.ndarray, '__ge__', greater_equal) -setattr(np_arrays.ndarray, '__eq__', equal) -setattr(np_arrays.ndarray, '__ne__', not_equal) +setattr(np_arrays.ndarray, '__lt__', _wrap(less)) +setattr(np_arrays.ndarray, '__le__', _wrap(less_equal)) +setattr(np_arrays.ndarray, '__gt__', _wrap(greater)) +setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal)) +setattr(np_arrays.ndarray, '__eq__', _wrap(equal)) +setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal)) @np_utils.np_doc(np.linspace)