Use subTest on einsum_test to make errors easier to understand.
PiperOrigin-RevId: 311606884 Change-Id: I7f8738ffc26479f98d468431706c7d4f7c6efcfc
This commit is contained in:
parent
501309eef9
commit
6f57007fb8
|
@ -42,10 +42,11 @@ class EinsumOpTest(test.TestCase):
|
||||||
r = np.random.RandomState(0)
|
r = np.random.RandomState(0)
|
||||||
inputs = []
|
inputs = []
|
||||||
for shape in input_shapes:
|
for shape in input_shapes:
|
||||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
with self.subTest(s=s, shape=shape):
|
||||||
if dtype == np.complex64 or dtype == np.complex128:
|
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
if dtype == np.complex64 or dtype == np.complex128:
|
||||||
inputs.append(arr)
|
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||||
|
inputs.append(arr)
|
||||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||||
a = np.einsum(s, *inputs)
|
a = np.einsum(s, *inputs)
|
||||||
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
|
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
|
||||||
|
@ -160,10 +161,11 @@ class EinsumOpTest(test.TestCase):
|
||||||
input_shapes = [(2, 2), (2, 2)]
|
input_shapes = [(2, 2), (2, 2)]
|
||||||
inputs = []
|
inputs = []
|
||||||
for shape in input_shapes:
|
for shape in input_shapes:
|
||||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
with self.subTest(dtype=dtype, shape=shape):
|
||||||
if dtype == np.complex64 or dtype == np.complex128:
|
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
if dtype == np.complex64 or dtype == np.complex128:
|
||||||
inputs.append(arr)
|
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||||
|
inputs.append(arr)
|
||||||
input_tensors = [constant_op.constant(x) for x in inputs]
|
input_tensors = [constant_op.constant(x) for x in inputs]
|
||||||
if dtype == bfloat16:
|
if dtype == bfloat16:
|
||||||
# np.einsum doesn't support bfloat16.
|
# np.einsum doesn't support bfloat16.
|
||||||
|
@ -199,14 +201,15 @@ class EinsumOpTest(test.TestCase):
|
||||||
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
||||||
]
|
]
|
||||||
for args in cases:
|
for args in cases:
|
||||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
with self.subTest(args=args):
|
||||||
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
|
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||||
|
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
|
||||||
|
|
||||||
placeholders = [
|
placeholders = [
|
||||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||||
]
|
]
|
||||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||||
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
|
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testPlaceholder(self):
|
def testPlaceholder(self):
|
||||||
|
@ -216,10 +219,12 @@ class EinsumOpTest(test.TestCase):
|
||||||
inputs = []
|
inputs = []
|
||||||
input_placeholders = []
|
input_placeholders = []
|
||||||
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
|
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
|
||||||
input_np = np.array(r.randn(*actual_shape))
|
with self.subTest(equation=equation, actual_shape=actual_shape,
|
||||||
inputs.append(input_np)
|
placeholder_shape=placeholder_shape):
|
||||||
input_placeholders.append(
|
input_np = np.array(r.randn(*actual_shape))
|
||||||
array_ops.placeholder_with_default(input_np, placeholder_shape))
|
inputs.append(input_np)
|
||||||
|
input_placeholders.append(
|
||||||
|
array_ops.placeholder_with_default(input_np, placeholder_shape))
|
||||||
|
|
||||||
a = np.einsum(equation, *inputs)
|
a = np.einsum(equation, *inputs)
|
||||||
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
|
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
|
||||||
|
@ -288,19 +293,22 @@ class EinsumGradTest(test.TestCase):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
r = np.random.RandomState(seed=0)
|
r = np.random.RandomState(seed=0)
|
||||||
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
|
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
|
||||||
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
|
with self.subTest(s=s, dtype=dtype):
|
||||||
if dtype in (np.complex64, np.complex128):
|
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
|
||||||
inputs = [
|
if dtype in (np.complex64, np.complex128):
|
||||||
np.array(r.randn(*shape), dtype) +
|
inputs = [
|
||||||
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
|
np.array(r.randn(*shape), dtype) +
|
||||||
]
|
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
|
||||||
else:
|
]
|
||||||
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
|
else:
|
||||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
inputs = [
|
||||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
np.array(r.randn(*shape), dtype) for shape in input_shapes]
|
||||||
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
input_tensors = [
|
||||||
self.assertLess(
|
constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||||
gradient_checker_v2.max_error(analytical, numerical), tol)
|
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
|
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
||||||
|
self.assertLess(
|
||||||
|
gradient_checker_v2.max_error(analytical, numerical), tol)
|
||||||
|
|
||||||
@test_util.disable_xla('b/131919749')
|
@test_util.disable_xla('b/131919749')
|
||||||
def testUnary(self):
|
def testUnary(self):
|
||||||
|
|
Loading…
Reference in New Issue