Let tensorflow op take precedence when doing "ndarray <op> tensor"
Also add a few more interop tests. PiperOrigin-RevId: 317339113 Change-Id: Ic28fab7abefea681e1e8d840b8e4cf4f98b63f1e
This commit is contained in:
parent
89de554b69
commit
8088eddf20
@ -262,7 +262,8 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||
"""
|
||||
return np.asarray(self.data, dtype)
|
||||
|
||||
__array_priority__ = 110
|
||||
# NOTE: we currently prefer interop with TF to allow TF to take precedence.
|
||||
__array_priority__ = 90
|
||||
|
||||
def __index__(self):
|
||||
"""Returns a python scalar.
|
||||
|
@ -19,12 +19,18 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import numpy as onp
|
||||
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.numpy_ops import np_math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -88,6 +94,50 @@ class InteropTest(test.TestCase):
|
||||
self.assertEqual(10000, fn()[0])
|
||||
self.assertEqual(10000, def_function.function(fn)()[0])
|
||||
|
||||
def testTensorTFNPArrayInterop(self):
|
||||
arr = np_array_ops.asarray(0.)
|
||||
t = constant_op.constant(10.)
|
||||
|
||||
arr_plus_t = arr + t
|
||||
t_plus_arr = t + arr
|
||||
|
||||
self.assertIsInstance(arr_plus_t, ops.Tensor)
|
||||
self.assertIsInstance(t_plus_arr, ops.Tensor)
|
||||
self.assertEqual(10., arr_plus_t.numpy())
|
||||
self.assertEqual(10., t_plus_arr.numpy())
|
||||
|
||||
def testTensorTFNPOp(self):
|
||||
t = constant_op.constant(10.)
|
||||
|
||||
sq = np_math_ops.square(t)
|
||||
self.assertIsInstance(sq, np_arrays.ndarray)
|
||||
self.assertEqual(100., sq)
|
||||
|
||||
def testTFNPArrayTFOpInterop(self):
|
||||
arr = np_array_ops.asarray(10.)
|
||||
|
||||
# TODO(nareshmodi): Test more ops.
|
||||
sq = math_ops.square(arr)
|
||||
self.assertIsInstance(sq, ops.Tensor)
|
||||
self.assertEqual(100., sq.numpy())
|
||||
|
||||
def testTFNPArrayNPOpInterop(self):
|
||||
arr = np_array_ops.asarray([10.])
|
||||
|
||||
# TODO(nareshmodi): Test more ops.
|
||||
sq = onp.square(arr)
|
||||
self.assertIsInstance(sq, onp.ndarray)
|
||||
self.assertEqual(100., sq[0])
|
||||
|
||||
# TODO(nareshmodi): Fails since the autopacking code doesn't use
|
||||
# nest.flatten.
|
||||
# def testAutopacking(self):
|
||||
# arr1 = np_array_ops.asarray(1.)
|
||||
# arr2 = np_array_ops.asarray(2.)
|
||||
# arr3 = np_array_ops.asarray(3.)
|
||||
# t = ops.convert_to_tensor_v2([arr1, arr2, arr3])
|
||||
|
||||
# self.assertEqual(t.numpy(), [1., 2., 3.])
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -917,29 +917,37 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||
return _scalar(f, a)
|
||||
|
||||
|
||||
def _flip_args(f):
|
||||
def _wrap(f, reverse=False):
|
||||
"""Wraps binary ops so they can be added as operator overloads on ndarray."""
|
||||
|
||||
def _f(a, b):
|
||||
return f(b, a)
|
||||
if reverse:
|
||||
a, b = b, a
|
||||
|
||||
if getattr(b, '__array_priority__',
|
||||
0) > np_arrays.ndarray.__array_priority__:
|
||||
return NotImplemented
|
||||
|
||||
return f(a, b)
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
setattr(np_arrays.ndarray, '__abs__', absolute)
|
||||
setattr(np_arrays.ndarray, '__floordiv__', floor_divide)
|
||||
setattr(np_arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide))
|
||||
setattr(np_arrays.ndarray, '__mod__', mod)
|
||||
setattr(np_arrays.ndarray, '__rmod__', _flip_args(mod))
|
||||
setattr(np_arrays.ndarray, '__add__', add)
|
||||
setattr(np_arrays.ndarray, '__radd__', _flip_args(add))
|
||||
setattr(np_arrays.ndarray, '__sub__', subtract)
|
||||
setattr(np_arrays.ndarray, '__rsub__', _flip_args(subtract))
|
||||
setattr(np_arrays.ndarray, '__mul__', multiply)
|
||||
setattr(np_arrays.ndarray, '__rmul__', _flip_args(multiply))
|
||||
setattr(np_arrays.ndarray, '__pow__', power)
|
||||
setattr(np_arrays.ndarray, '__rpow__', _flip_args(power))
|
||||
setattr(np_arrays.ndarray, '__truediv__', true_divide)
|
||||
setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
|
||||
setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide))
|
||||
setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True))
|
||||
setattr(np_arrays.ndarray, '__mod__', _wrap(mod))
|
||||
setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True))
|
||||
setattr(np_arrays.ndarray, '__add__', _wrap(add))
|
||||
setattr(np_arrays.ndarray, '__radd__', _wrap(add, True))
|
||||
setattr(np_arrays.ndarray, '__sub__', _wrap(subtract))
|
||||
setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True))
|
||||
setattr(np_arrays.ndarray, '__mul__', _wrap(multiply))
|
||||
setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True))
|
||||
setattr(np_arrays.ndarray, '__pow__', _wrap(power))
|
||||
setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True))
|
||||
setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide))
|
||||
setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True))
|
||||
|
||||
|
||||
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
|
||||
@ -1031,12 +1039,12 @@ def logical_not(x):
|
||||
|
||||
|
||||
setattr(np_arrays.ndarray, '__invert__', logical_not)
|
||||
setattr(np_arrays.ndarray, '__lt__', less)
|
||||
setattr(np_arrays.ndarray, '__le__', less_equal)
|
||||
setattr(np_arrays.ndarray, '__gt__', greater)
|
||||
setattr(np_arrays.ndarray, '__ge__', greater_equal)
|
||||
setattr(np_arrays.ndarray, '__eq__', equal)
|
||||
setattr(np_arrays.ndarray, '__ne__', not_equal)
|
||||
setattr(np_arrays.ndarray, '__lt__', _wrap(less))
|
||||
setattr(np_arrays.ndarray, '__le__', _wrap(less_equal))
|
||||
setattr(np_arrays.ndarray, '__gt__', _wrap(greater))
|
||||
setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal))
|
||||
setattr(np_arrays.ndarray, '__eq__', _wrap(equal))
|
||||
setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.linspace)
|
||||
|
Loading…
Reference in New Issue
Block a user