[TF-numpy] Added einsum to numpy_ops.

PiperOrigin-RevId: 317431034
Change-Id: I209bc52b5c03369526eb45486c1b6c09352f9f06
This commit is contained in:
Peng Wang 2020-06-19 21:56:50 -07:00 committed by TensorFlower Gardener
parent 728a4a4405
commit 8efcda9caa

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sort_ops from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_dtypes from tensorflow.python.ops.numpy_ops import np_dtypes
@ -1336,3 +1337,30 @@ def meshgrid(*xi, **kwargs):
outputs = [np_utils.tensor_to_ndarray(output) for output in outputs] outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
return outputs return outputs
@np_utils.np_doc('einsum')
def einsum(subscripts, *operands, casting='safe', optimize=False): # pylint: disable=missing-docstring
if casting == 'safe':
operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access
elif casting == 'no':
operands = [np_array_ops.asarray(x) for x in operands]
else:
raise ValueError('casting policy not supported: %s' % casting)
if not optimize:
# TF doesn't have a "no optimization" option.
# TODO(wangpeng): Print a warning that np and tf use different
# optimizations.
tf_optimize = 'greedy'
elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
tf_optimize = 'greedy'
elif optimize == 'greedy':
tf_optimize = 'greedy'
elif optimize == 'optimal':
tf_optimize = 'optimal'
else:
raise ValueError('`optimize` method not supported: %s' % optimize)
operands = [x.data for x in operands]
res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
res = np_utils.tensor_to_ndarray(res)
return res