diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index df3d7cf32ab..f9d3b34e90d 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sort_ops +from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.numpy_ops import np_dtypes @@ -1336,3 +1337,30 @@ def meshgrid(*xi, **kwargs): outputs = [np_utils.tensor_to_ndarray(output) for output in outputs] return outputs + + +@np_utils.np_doc('einsum') +def einsum(subscripts, *operands, casting='safe', optimize=False): # pylint: disable=missing-docstring + if casting == 'safe': + operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access + elif casting == 'no': + operands = [np_array_ops.asarray(x) for x in operands] + else: + raise ValueError('casting policy not supported: %s' % casting) + if not optimize: + # TF doesn't have a "no optimization" option. + # TODO(wangpeng): Print a warning that np and tf use different + # optimizations. + tf_optimize = 'greedy' + elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison + tf_optimize = 'greedy' + elif optimize == 'greedy': + tf_optimize = 'greedy' + elif optimize == 'optimal': + tf_optimize = 'optimal' + else: + raise ValueError('`optimize` method not supported: %s' % optimize) + operands = [x.data for x in operands] + res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize) + res = np_utils.tensor_to_ndarray(res) + return res