diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py index a6b623b828c..47d5d457193 100644 --- a/tensorflow/python/kernel_tests/einsum_op_test.py +++ b/tensorflow/python/kernel_tests/einsum_op_test.py @@ -42,10 +42,11 @@ class EinsumOpTest(test.TestCase): r = np.random.RandomState(0) inputs = [] for shape in input_shapes: - arr = np.array(r.randn(*shape)).astype(dtype) - if dtype == np.complex64 or dtype == np.complex128: - arr += 1j * np.array(r.randn(*shape)).astype(dtype) - inputs.append(arr) + with self.subTest(s=s, shape=shape): + arr = np.array(r.randn(*shape)).astype(dtype) + if dtype == np.complex64 or dtype == np.complex128: + arr += 1j * np.array(r.randn(*shape)).astype(dtype) + inputs.append(arr) input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] a = np.einsum(s, *inputs) b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s)) @@ -160,10 +161,11 @@ class EinsumOpTest(test.TestCase): input_shapes = [(2, 2), (2, 2)] inputs = [] for shape in input_shapes: - arr = np.array(r.randn(*shape)).astype(dtype) - if dtype == np.complex64 or dtype == np.complex128: - arr += 1j * np.array(r.randn(*shape)).astype(dtype) - inputs.append(arr) + with self.subTest(dtype=dtype, shape=shape): + arr = np.array(r.randn(*shape)).astype(dtype) + if dtype == np.complex64 or dtype == np.complex128: + arr += 1j * np.array(r.randn(*shape)).astype(dtype) + inputs.append(arr) input_tensors = [constant_op.constant(x) for x in inputs] if dtype == bfloat16: # np.einsum doesn't support bfloat16. @@ -199,14 +201,15 @@ class EinsumOpTest(test.TestCase): ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)), ] for args in cases: - with self.assertRaises((ValueError, errors.InvalidArgumentError)): - _ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0])) + with self.subTest(args=args): + with self.assertRaises((ValueError, errors.InvalidArgumentError)): + _ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0])) - placeholders = [ - array_ops.placeholder_with_default(x, shape=None) for x in args[1:] - ] - with self.assertRaises((ValueError, errors.InvalidArgumentError)): - _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0])) + placeholders = [ + array_ops.placeholder_with_default(x, shape=None) for x in args[1:] + ] + with self.assertRaises((ValueError, errors.InvalidArgumentError)): + _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0])) @test_util.run_in_graph_and_eager_modes def testPlaceholder(self): @@ -216,10 +219,12 @@ class EinsumOpTest(test.TestCase): inputs = [] input_placeholders = [] for actual_shape, placeholder_shape in input_and_placeholder_shapes: - input_np = np.array(r.randn(*actual_shape)) - inputs.append(input_np) - input_placeholders.append( - array_ops.placeholder_with_default(input_np, placeholder_shape)) + with self.subTest(equation=equation, actual_shape=actual_shape, + placeholder_shape=placeholder_shape): + input_np = np.array(r.randn(*actual_shape)) + inputs.append(input_np) + input_placeholders.append( + array_ops.placeholder_with_default(input_np, placeholder_shape)) a = np.einsum(equation, *inputs) b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation)) @@ -288,19 +293,22 @@ class EinsumGradTest(test.TestCase): with self.cached_session(): r = np.random.RandomState(seed=0) for dtype in (np.float32, np.float64, np.complex64, np.complex128): - tol = 10 * np.sqrt(np.finfo(dtype).resolution) - if dtype in (np.complex64, np.complex128): - inputs = [ - np.array(r.randn(*shape), dtype) + - 1j * np.array(r.randn(*shape), dtype) for shape in input_shapes - ] - else: - inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes] - input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] - analytical, numerical = gradient_checker_v2.compute_gradient( - lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors) - self.assertLess( - gradient_checker_v2.max_error(analytical, numerical), tol) + with self.subTest(s=s, dtype=dtype): + tol = 10 * np.sqrt(np.finfo(dtype).resolution) + if dtype in (np.complex64, np.complex128): + inputs = [ + np.array(r.randn(*shape), dtype) + + 1j * np.array(r.randn(*shape), dtype) for shape in input_shapes + ] + else: + inputs = [ + np.array(r.randn(*shape), dtype) for shape in input_shapes] + input_tensors = [ + constant_op.constant(x, shape=x.shape) for x in inputs] + analytical, numerical = gradient_checker_v2.compute_gradient( + lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors) + self.assertLess( + gradient_checker_v2.max_error(analytical, numerical), tol) @test_util.disable_xla('b/131919749') def testUnary(self):