Use subTest on einsum_test to make errors easier to understand.

PiperOrigin-RevId: 311606884
Change-Id: I7f8738ffc26479f98d468431706c7d4f7c6efcfc
This commit is contained in:
Andrew Selle 2020-05-14 14:32:59 -07:00 committed by TensorFlower Gardener
parent 501309eef9
commit 6f57007fb8
1 changed files with 40 additions and 32 deletions

View File

@ -42,10 +42,11 @@ class EinsumOpTest(test.TestCase):
r = np.random.RandomState(0)
inputs = []
for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
with self.subTest(s=s, shape=shape):
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
a = np.einsum(s, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
@ -160,10 +161,11 @@ class EinsumOpTest(test.TestCase):
input_shapes = [(2, 2), (2, 2)]
inputs = []
for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
with self.subTest(dtype=dtype, shape=shape):
arr = np.array(r.randn(*shape)).astype(dtype)
if dtype == np.complex64 or dtype == np.complex128:
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x) for x in inputs]
if dtype == bfloat16:
# np.einsum doesn't support bfloat16.
@ -199,14 +201,15 @@ class EinsumOpTest(test.TestCase):
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
]
for args in cases:
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
with self.subTest(args=args):
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
]
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
]
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
@test_util.run_in_graph_and_eager_modes
def testPlaceholder(self):
@ -216,10 +219,12 @@ class EinsumOpTest(test.TestCase):
inputs = []
input_placeholders = []
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
input_np = np.array(r.randn(*actual_shape))
inputs.append(input_np)
input_placeholders.append(
array_ops.placeholder_with_default(input_np, placeholder_shape))
with self.subTest(equation=equation, actual_shape=actual_shape,
placeholder_shape=placeholder_shape):
input_np = np.array(r.randn(*actual_shape))
inputs.append(input_np)
input_placeholders.append(
array_ops.placeholder_with_default(input_np, placeholder_shape))
a = np.einsum(equation, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
@ -288,19 +293,22 @@ class EinsumGradTest(test.TestCase):
with self.cached_session():
r = np.random.RandomState(seed=0)
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
if dtype in (np.complex64, np.complex128):
inputs = [
np.array(r.randn(*shape), dtype) +
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
]
else:
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), tol)
with self.subTest(s=s, dtype=dtype):
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
if dtype in (np.complex64, np.complex128):
inputs = [
np.array(r.randn(*shape), dtype) +
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
]
else:
inputs = [
np.array(r.randn(*shape), dtype) for shape in input_shapes]
input_tensors = [
constant_op.constant(x, shape=x.shape) for x in inputs]
analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), tol)
@test_util.disable_xla('b/131919749')
def testUnary(self):