Use subTest on einsum_test to make errors easier to understand.

PiperOrigin-RevId: 311606884
Change-Id: I7f8738ffc26479f98d468431706c7d4f7c6efcfc
This commit is contained in:
Andrew Selle 2020-05-14 14:32:59 -07:00 committed by TensorFlower Gardener
parent 501309eef9
commit 6f57007fb8
1 changed files with 40 additions and 32 deletions

View File

@ -42,6 +42,7 @@ class EinsumOpTest(test.TestCase):
r = np.random.RandomState(0)
inputs = []
for shape in input_shapes:
with self.subTest(s=s, shape=shape):
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
@ -160,6 +161,7 @@ class EinsumOpTest(test.TestCase):
input_shapes = [(2, 2), (2, 2)]
inputs = []
for shape in input_shapes:
with self.subTest(dtype=dtype, shape=shape):
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
@ -199,6 +201,7 @@ class EinsumOpTest(test.TestCase):
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
]
for args in cases:
with self.subTest(args=args):
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
@ -216,6 +219,8 @@ class EinsumOpTest(test.TestCase):
inputs = []
input_placeholders = []
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
with self.subTest(equation=equation, actual_shape=actual_shape,
placeholder_shape=placeholder_shape):
input_np = np.array(r.randn(*actual_shape))
inputs.append(input_np)
input_placeholders.append(
@ -288,6 +293,7 @@ class EinsumGradTest(test.TestCase):
with self.cached_session():
r = np.random.RandomState(seed=0)
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
with self.subTest(s=s, dtype=dtype):
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
if dtype in (np.complex64, np.complex128):
inputs = [
@ -295,8 +301,10 @@ class EinsumGradTest(test.TestCase):
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
]
else:
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
inputs = [
np.array(r.randn(*shape), dtype) for shape in input_shapes]
input_tensors = [
constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(