Add the vdot op to tf-numpy.
PiperOrigin-RevId: 316949446 Change-Id: I9b0718ec108486096e032729481ec9129863b429
This commit is contained in:
parent
d5ca984c53
commit
9426d35abc
@ -309,6 +309,16 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=mis
|
|||||||
return _bin_op(f, a, b)
|
return _bin_op(f, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc_only(np.vdot)
|
||||||
|
def vdot(a, b): # pylint: disable=missing-docstring
|
||||||
|
a, b = np_array_ops._promote_dtype(a, b)
|
||||||
|
a = np_array_ops.reshape(a, [-1])
|
||||||
|
b = np_array_ops.reshape(b, [-1])
|
||||||
|
if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
|
||||||
|
a = conj(a)
|
||||||
|
return dot(a, b)
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.power)
|
@np_utils.np_doc(np.power)
|
||||||
def power(x1, x2):
|
def power(x1, x2):
|
||||||
return _bin_op(math_ops.pow, x1, x2)
|
return _bin_op(math_ops.pow, x1, x2)
|
||||||
|
@ -124,6 +124,12 @@ class MathTest(test.TestCase, parameterized.TestCase):
|
|||||||
np_math_ops.matmul(
|
np_math_ops.matmul(
|
||||||
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
|
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
|
||||||
|
|
||||||
|
def testVDot(self):
|
||||||
|
operands = [([[1, 2], [3, 4]], [[3, 4], [6, 7]]),
|
||||||
|
([[1, 2], [3, 4]], [3, 4, 6, 7])]
|
||||||
|
return self._testBinaryOp(
|
||||||
|
np_math_ops.vdot, np.vdot, 'vdot', operands=operands)
|
||||||
|
|
||||||
def _testUnaryOp(self, math_fun, np_fun, name):
|
def _testUnaryOp(self, math_fun, np_fun, name):
|
||||||
|
|
||||||
def run_test(a):
|
def run_test(a):
|
||||||
|
Loading…
Reference in New Issue
Block a user