[TF-numpy] Added einsum
to numpy_ops.
PiperOrigin-RevId: 317431034 Change-Id: I209bc52b5c03369526eb45486c1b6c09352f9f06
This commit is contained in:
parent
728a4a4405
commit
8efcda9caa
@ -34,6 +34,7 @@ from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sort_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.numpy_ops import np_dtypes
|
||||
@ -1336,3 +1337,30 @@ def meshgrid(*xi, **kwargs):
|
||||
outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@np_utils.np_doc('einsum')
|
||||
def einsum(subscripts, *operands, casting='safe', optimize=False): # pylint: disable=missing-docstring
|
||||
if casting == 'safe':
|
||||
operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access
|
||||
elif casting == 'no':
|
||||
operands = [np_array_ops.asarray(x) for x in operands]
|
||||
else:
|
||||
raise ValueError('casting policy not supported: %s' % casting)
|
||||
if not optimize:
|
||||
# TF doesn't have a "no optimization" option.
|
||||
# TODO(wangpeng): Print a warning that np and tf use different
|
||||
# optimizations.
|
||||
tf_optimize = 'greedy'
|
||||
elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
|
||||
tf_optimize = 'greedy'
|
||||
elif optimize == 'greedy':
|
||||
tf_optimize = 'greedy'
|
||||
elif optimize == 'optimal':
|
||||
tf_optimize = 'optimal'
|
||||
else:
|
||||
raise ValueError('`optimize` method not supported: %s' % optimize)
|
||||
operands = [x.data for x in operands]
|
||||
res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
|
||||
res = np_utils.tensor_to_ndarray(res)
|
||||
return res
|
||||
|
Loading…
Reference in New Issue
Block a user