Let tensorflow op take precedence when doing "ndarray <op> tensor"

Also add a few more interop tests.

PiperOrigin-RevId: 317339113
Change-Id: Ic28fab7abefea681e1e8d840b8e4cf4f98b63f1e
This commit is contained in:
Akshay Modi 2020-06-19 10:58:32 -07:00 committed by TensorFlower Gardener
parent 89de554b69
commit 8088eddf20
3 changed files with 82 additions and 23 deletions

View File

@ -262,7 +262,8 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
"""
return np.asarray(self.data, dtype)
__array_priority__ = 110
# NOTE: we currently prefer interop with TF to allow TF to take precedence.
__array_priority__ = 90
def __index__(self):
"""Returns a python scalar.

View File

@ -19,12 +19,18 @@ from __future__ import division
from __future__ import print_function
import numpy as onp
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_math_ops
from tensorflow.python.platform import test
@ -88,6 +94,50 @@ class InteropTest(test.TestCase):
self.assertEqual(10000, fn()[0])
self.assertEqual(10000, def_function.function(fn)()[0])
def testTensorTFNPArrayInterop(self):
arr = np_array_ops.asarray(0.)
t = constant_op.constant(10.)
arr_plus_t = arr + t
t_plus_arr = t + arr
self.assertIsInstance(arr_plus_t, ops.Tensor)
self.assertIsInstance(t_plus_arr, ops.Tensor)
self.assertEqual(10., arr_plus_t.numpy())
self.assertEqual(10., t_plus_arr.numpy())
def testTensorTFNPOp(self):
t = constant_op.constant(10.)
sq = np_math_ops.square(t)
self.assertIsInstance(sq, np_arrays.ndarray)
self.assertEqual(100., sq)
def testTFNPArrayTFOpInterop(self):
arr = np_array_ops.asarray(10.)
# TODO(nareshmodi): Test more ops.
sq = math_ops.square(arr)
self.assertIsInstance(sq, ops.Tensor)
self.assertEqual(100., sq.numpy())
def testTFNPArrayNPOpInterop(self):
arr = np_array_ops.asarray([10.])
# TODO(nareshmodi): Test more ops.
sq = onp.square(arr)
self.assertIsInstance(sq, onp.ndarray)
self.assertEqual(100., sq[0])
# TODO(nareshmodi): Fails since the autopacking code doesn't use
# nest.flatten.
# def testAutopacking(self):
# arr1 = np_array_ops.asarray(1.)
# arr2 = np_array_ops.asarray(2.)
# arr3 = np_array_ops.asarray(3.)
# t = ops.convert_to_tensor_v2([arr1, arr2, arr3])
# self.assertEqual(t.numpy(), [1., 2., 3.])
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -917,29 +917,37 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
return _scalar(f, a)
def _flip_args(f):
def _wrap(f, reverse=False):
"""Wraps binary ops so they can be added as operator overloads on ndarray."""
def _f(a, b):
return f(b, a)
if reverse:
a, b = b, a
if getattr(b, '__array_priority__',
0) > np_arrays.ndarray.__array_priority__:
return NotImplemented
return f(a, b)
return _f
setattr(np_arrays.ndarray, '__abs__', absolute)
setattr(np_arrays.ndarray, '__floordiv__', floor_divide)
setattr(np_arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide))
setattr(np_arrays.ndarray, '__mod__', mod)
setattr(np_arrays.ndarray, '__rmod__', _flip_args(mod))
setattr(np_arrays.ndarray, '__add__', add)
setattr(np_arrays.ndarray, '__radd__', _flip_args(add))
setattr(np_arrays.ndarray, '__sub__', subtract)
setattr(np_arrays.ndarray, '__rsub__', _flip_args(subtract))
setattr(np_arrays.ndarray, '__mul__', multiply)
setattr(np_arrays.ndarray, '__rmul__', _flip_args(multiply))
setattr(np_arrays.ndarray, '__pow__', power)
setattr(np_arrays.ndarray, '__rpow__', _flip_args(power))
setattr(np_arrays.ndarray, '__truediv__', true_divide)
setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide))
setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True))
setattr(np_arrays.ndarray, '__mod__', _wrap(mod))
setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True))
setattr(np_arrays.ndarray, '__add__', _wrap(add))
setattr(np_arrays.ndarray, '__radd__', _wrap(add, True))
setattr(np_arrays.ndarray, '__sub__', _wrap(subtract))
setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True))
setattr(np_arrays.ndarray, '__mul__', _wrap(multiply))
setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True))
setattr(np_arrays.ndarray, '__pow__', _wrap(power))
setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True))
setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide))
setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True))
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
@ -1031,12 +1039,12 @@ def logical_not(x):
setattr(np_arrays.ndarray, '__invert__', logical_not)
setattr(np_arrays.ndarray, '__lt__', less)
setattr(np_arrays.ndarray, '__le__', less_equal)
setattr(np_arrays.ndarray, '__gt__', greater)
setattr(np_arrays.ndarray, '__ge__', greater_equal)
setattr(np_arrays.ndarray, '__eq__', equal)
setattr(np_arrays.ndarray, '__ne__', not_equal)
setattr(np_arrays.ndarray, '__lt__', _wrap(less))
setattr(np_arrays.ndarray, '__le__', _wrap(less_equal))
setattr(np_arrays.ndarray, '__gt__', _wrap(greater))
setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal))
setattr(np_arrays.ndarray, '__eq__', _wrap(equal))
setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal))
@np_utils.np_doc(np.linspace)