Use subTest on einsum_test to make errors easier to understand.

PiperOrigin-RevId: 311606884
Change-Id: I7f8738ffc26479f98d468431706c7d4f7c6efcfc
This commit is contained in:
Andrew Selle 2020-05-14 14:32:59 -07:00 committed by TensorFlower Gardener
parent 501309eef9
commit 6f57007fb8
1 changed files with 40 additions and 32 deletions

View File

@ -42,10 +42,11 @@ class EinsumOpTest(test.TestCase):
r = np.random.RandomState(0) r = np.random.RandomState(0)
inputs = [] inputs = []
for shape in input_shapes: for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype) with self.subTest(s=s, shape=shape):
if dtype == np.complex64 or dtype == np.complex128: arr = np.array(r.randn(*shape)).astype(dtype)
arr += 1j * np.array(r.randn(*shape)).astype(dtype) if dtype == np.complex64 or dtype == np.complex128:
inputs.append(arr) arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
a = np.einsum(s, *inputs) a = np.einsum(s, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s)) b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
@ -160,10 +161,11 @@ class EinsumOpTest(test.TestCase):
input_shapes = [(2, 2), (2, 2)] input_shapes = [(2, 2), (2, 2)]
inputs = [] inputs = []
for shape in input_shapes: for shape in input_shapes:
arr = np.array(r.randn(*shape)).astype(dtype) with self.subTest(dtype=dtype, shape=shape):
if dtype == np.complex64 or dtype == np.complex128: arr = np.array(r.randn(*shape)).astype(dtype)
arr += 1j * np.array(r.randn(*shape)).astype(dtype) if dtype == np.complex64 or dtype == np.complex128:
inputs.append(arr) arr += 1j * np.array(r.randn(*shape)).astype(dtype)
inputs.append(arr)
input_tensors = [constant_op.constant(x) for x in inputs] input_tensors = [constant_op.constant(x) for x in inputs]
if dtype == bfloat16: if dtype == bfloat16:
# np.einsum doesn't support bfloat16. # np.einsum doesn't support bfloat16.
@ -199,14 +201,15 @@ class EinsumOpTest(test.TestCase):
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)), ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
] ]
for args in cases: for args in cases:
with self.assertRaises((ValueError, errors.InvalidArgumentError)): with self.subTest(args=args):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0])) with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
placeholders = [ placeholders = [
array_ops.placeholder_with_default(x, shape=None) for x in args[1:] array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
] ]
with self.assertRaises((ValueError, errors.InvalidArgumentError)): with self.assertRaises((ValueError, errors.InvalidArgumentError)):
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0])) _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testPlaceholder(self): def testPlaceholder(self):
@ -216,10 +219,12 @@ class EinsumOpTest(test.TestCase):
inputs = [] inputs = []
input_placeholders = [] input_placeholders = []
for actual_shape, placeholder_shape in input_and_placeholder_shapes: for actual_shape, placeholder_shape in input_and_placeholder_shapes:
input_np = np.array(r.randn(*actual_shape)) with self.subTest(equation=equation, actual_shape=actual_shape,
inputs.append(input_np) placeholder_shape=placeholder_shape):
input_placeholders.append( input_np = np.array(r.randn(*actual_shape))
array_ops.placeholder_with_default(input_np, placeholder_shape)) inputs.append(input_np)
input_placeholders.append(
array_ops.placeholder_with_default(input_np, placeholder_shape))
a = np.einsum(equation, *inputs) a = np.einsum(equation, *inputs)
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation)) b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
@ -288,19 +293,22 @@ class EinsumGradTest(test.TestCase):
with self.cached_session(): with self.cached_session():
r = np.random.RandomState(seed=0) r = np.random.RandomState(seed=0)
for dtype in (np.float32, np.float64, np.complex64, np.complex128): for dtype in (np.float32, np.float64, np.complex64, np.complex128):
tol = 10 * np.sqrt(np.finfo(dtype).resolution) with self.subTest(s=s, dtype=dtype):
if dtype in (np.complex64, np.complex128): tol = 10 * np.sqrt(np.finfo(dtype).resolution)
inputs = [ if dtype in (np.complex64, np.complex128):
np.array(r.randn(*shape), dtype) + inputs = [
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes np.array(r.randn(*shape), dtype) +
] 1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
else: ]
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes] else:
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] inputs = [
analytical, numerical = gradient_checker_v2.compute_gradient( np.array(r.randn(*shape), dtype) for shape in input_shapes]
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors) input_tensors = [
self.assertLess( constant_op.constant(x, shape=x.shape) for x in inputs]
gradient_checker_v2.max_error(analytical, numerical), tol) analytical, numerical = gradient_checker_v2.compute_gradient(
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
self.assertLess(
gradient_checker_v2.max_error(analytical, numerical), tol)
@test_util.disable_xla('b/131919749') @test_util.disable_xla('b/131919749')
def testUnary(self): def testUnary(self):