Use subTest on einsum_test to make errors easier to understand.
PiperOrigin-RevId: 311606884 Change-Id: I7f8738ffc26479f98d468431706c7d4f7c6efcfc
This commit is contained in:
parent
501309eef9
commit
6f57007fb8
|
@ -42,10 +42,11 @@ class EinsumOpTest(test.TestCase):
|
|||
r = np.random.RandomState(0)
|
||||
inputs = []
|
||||
for shape in input_shapes:
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
with self.subTest(s=s, shape=shape):
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
a = np.einsum(s, *inputs)
|
||||
b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
|
||||
|
@ -160,10 +161,11 @@ class EinsumOpTest(test.TestCase):
|
|||
input_shapes = [(2, 2), (2, 2)]
|
||||
inputs = []
|
||||
for shape in input_shapes:
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
with self.subTest(dtype=dtype, shape=shape):
|
||||
arr = np.array(r.randn(*shape)).astype(dtype)
|
||||
if dtype == np.complex64 or dtype == np.complex128:
|
||||
arr += 1j * np.array(r.randn(*shape)).astype(dtype)
|
||||
inputs.append(arr)
|
||||
input_tensors = [constant_op.constant(x) for x in inputs]
|
||||
if dtype == bfloat16:
|
||||
# np.einsum doesn't support bfloat16.
|
||||
|
@ -199,14 +201,15 @@ class EinsumOpTest(test.TestCase):
|
|||
('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
|
||||
]
|
||||
for args in cases:
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
|
||||
with self.subTest(args=args):
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
|
||||
|
||||
placeholders = [
|
||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||
]
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
|
||||
placeholders = [
|
||||
array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
|
||||
]
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
_ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testPlaceholder(self):
|
||||
|
@ -216,10 +219,12 @@ class EinsumOpTest(test.TestCase):
|
|||
inputs = []
|
||||
input_placeholders = []
|
||||
for actual_shape, placeholder_shape in input_and_placeholder_shapes:
|
||||
input_np = np.array(r.randn(*actual_shape))
|
||||
inputs.append(input_np)
|
||||
input_placeholders.append(
|
||||
array_ops.placeholder_with_default(input_np, placeholder_shape))
|
||||
with self.subTest(equation=equation, actual_shape=actual_shape,
|
||||
placeholder_shape=placeholder_shape):
|
||||
input_np = np.array(r.randn(*actual_shape))
|
||||
inputs.append(input_np)
|
||||
input_placeholders.append(
|
||||
array_ops.placeholder_with_default(input_np, placeholder_shape))
|
||||
|
||||
a = np.einsum(equation, *inputs)
|
||||
b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
|
||||
|
@ -288,19 +293,22 @@ class EinsumGradTest(test.TestCase):
|
|||
with self.cached_session():
|
||||
r = np.random.RandomState(seed=0)
|
||||
for dtype in (np.float32, np.float64, np.complex64, np.complex128):
|
||||
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
inputs = [
|
||||
np.array(r.randn(*shape), dtype) +
|
||||
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
|
||||
]
|
||||
else:
|
||||
inputs = [np.array(r.randn(*shape), dtype) for shape in input_shapes]
|
||||
input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
||||
self.assertLess(
|
||||
gradient_checker_v2.max_error(analytical, numerical), tol)
|
||||
with self.subTest(s=s, dtype=dtype):
|
||||
tol = 10 * np.sqrt(np.finfo(dtype).resolution)
|
||||
if dtype in (np.complex64, np.complex128):
|
||||
inputs = [
|
||||
np.array(r.randn(*shape), dtype) +
|
||||
1j * np.array(r.randn(*shape), dtype) for shape in input_shapes
|
||||
]
|
||||
else:
|
||||
inputs = [
|
||||
np.array(r.randn(*shape), dtype) for shape in input_shapes]
|
||||
input_tensors = [
|
||||
constant_op.constant(x, shape=x.shape) for x in inputs]
|
||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors)
|
||||
self.assertLess(
|
||||
gradient_checker_v2.max_error(analytical, numerical), tol)
|
||||
|
||||
@test_util.disable_xla('b/131919749')
|
||||
def testUnary(self):
|
||||
|
|
Loading…
Reference in New Issue