Add the vdot op to tf-numpy.
PiperOrigin-RevId: 316949446 Change-Id: I9b0718ec108486096e032729481ec9129863b429
This commit is contained in:
parent
d5ca984c53
commit
9426d35abc
@ -309,6 +309,16 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=mis
|
||||
return _bin_op(f, a, b)
|
||||
|
||||
|
||||
@np_utils.np_doc_only(np.vdot)
|
||||
def vdot(a, b): # pylint: disable=missing-docstring
|
||||
a, b = np_array_ops._promote_dtype(a, b)
|
||||
a = np_array_ops.reshape(a, [-1])
|
||||
b = np_array_ops.reshape(b, [-1])
|
||||
if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
|
||||
a = conj(a)
|
||||
return dot(a, b)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.power)
|
||||
def power(x1, x2):
|
||||
return _bin_op(math_ops.pow, x1, x2)
|
||||
|
@ -124,6 +124,12 @@ class MathTest(test.TestCase, parameterized.TestCase):
|
||||
np_math_ops.matmul(
|
||||
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
|
||||
|
||||
def testVDot(self):
|
||||
operands = [([[1, 2], [3, 4]], [[3, 4], [6, 7]]),
|
||||
([[1, 2], [3, 4]], [3, 4, 6, 7])]
|
||||
return self._testBinaryOp(
|
||||
np_math_ops.vdot, np.vdot, 'vdot', operands=operands)
|
||||
|
||||
def _testUnaryOp(self, math_fun, np_fun, name):
|
||||
|
||||
def run_test(a):
|
||||
|
Loading…
Reference in New Issue
Block a user