[TF-numpy] Added einsum
to numpy_ops.
PiperOrigin-RevId: 317431034 Change-Id: I209bc52b5c03369526eb45486c1b6c09352f9f06
This commit is contained in:
parent
728a4a4405
commit
8efcda9caa
@ -34,6 +34,7 @@ from tensorflow.python.ops import clip_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import sort_ops
|
from tensorflow.python.ops import sort_ops
|
||||||
|
from tensorflow.python.ops import special_math_ops
|
||||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
from tensorflow.python.ops.numpy_ops import np_dtypes
|
from tensorflow.python.ops.numpy_ops import np_dtypes
|
||||||
@ -1336,3 +1337,30 @@ def meshgrid(*xi, **kwargs):
|
|||||||
outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
|
outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc('einsum')
|
||||||
|
def einsum(subscripts, *operands, casting='safe', optimize=False): # pylint: disable=missing-docstring
|
||||||
|
if casting == 'safe':
|
||||||
|
operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access
|
||||||
|
elif casting == 'no':
|
||||||
|
operands = [np_array_ops.asarray(x) for x in operands]
|
||||||
|
else:
|
||||||
|
raise ValueError('casting policy not supported: %s' % casting)
|
||||||
|
if not optimize:
|
||||||
|
# TF doesn't have a "no optimization" option.
|
||||||
|
# TODO(wangpeng): Print a warning that np and tf use different
|
||||||
|
# optimizations.
|
||||||
|
tf_optimize = 'greedy'
|
||||||
|
elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
|
||||||
|
tf_optimize = 'greedy'
|
||||||
|
elif optimize == 'greedy':
|
||||||
|
tf_optimize = 'greedy'
|
||||||
|
elif optimize == 'optimal':
|
||||||
|
tf_optimize = 'optimal'
|
||||||
|
else:
|
||||||
|
raise ValueError('`optimize` method not supported: %s' % optimize)
|
||||||
|
operands = [x.data for x in operands]
|
||||||
|
res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
|
||||||
|
res = np_utils.tensor_to_ndarray(res)
|
||||||
|
return res
|
||||||
|
Loading…
Reference in New Issue
Block a user