diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index f9d3b34e90d..b242e51a2e5 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -1340,7 +1340,9 @@ def meshgrid(*xi, **kwargs): @np_utils.np_doc('einsum') -def einsum(subscripts, *operands, casting='safe', optimize=False): # pylint: disable=missing-docstring +def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring + casting = kwargs.get('casting', 'safe') + optimize = kwargs.get('optimize', False) if casting == 'safe': operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access elif casting == 'no':