Add the vdot op to tf-numpy.

PiperOrigin-RevId: 316949446
Change-Id: I9b0718ec108486096e032729481ec9129863b429
This commit is contained in:
Lukasz Kaiser 2020-06-17 13:07:01 -07:00 committed by TensorFlower Gardener
parent d5ca984c53
commit 9426d35abc
2 changed files with 16 additions and 0 deletions

View File

@ -309,6 +309,16 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=mis
return _bin_op(f, a, b)
@np_utils.np_doc_only(np.vdot)
def vdot(a, b): # pylint: disable=missing-docstring
a, b = np_array_ops._promote_dtype(a, b)
a = np_array_ops.reshape(a, [-1])
b = np_array_ops.reshape(b, [-1])
if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
a = conj(a)
return dot(a, b)
@np_utils.np_doc(np.power)
def power(x1, x2):
return _bin_op(math_ops.pow, x1, x2)

View File

@ -124,6 +124,12 @@ class MathTest(test.TestCase, parameterized.TestCase):
np_math_ops.matmul(
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
def testVDot(self):
operands = [([[1, 2], [3, 4]], [[3, 4], [6, 7]]),
([[1, 2], [3, 4]], [3, 4, 6, 7])]
return self._testBinaryOp(
np_math_ops.vdot, np.vdot, 'vdot', operands=operands)
def _testUnaryOp(self, math_fun, np_fun, name):
def run_test(a):