Fix variable name (forgot to change padding to padding_value in diag_op_test).

PiperOrigin-RevId: 254251013
This commit is contained in:
Penporn Koanantakool 2019-06-20 12:39:41 -07:00 committed by TensorFlower Gardener
parent b57953d5d1
commit 889c063f0e

View File

@ -348,14 +348,14 @@ class MatrixDiagTest(test.TestCase):
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape()) self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch) self.assertAllEqual(v_batch_diag.eval(), mat_batch)
# Diagonal bands with padding. # Diagonal bands with padding_value.
for padding in [0, 555, -11]: for padding_value in [0, 555, -11]:
for _, tests in [self._moreCases(), square_cases()]: for _, tests in [self._moreCases(), square_cases()]:
for diags, (vecs, solution) in tests.items(): for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs.astype(dtype), k=diags, padding=padding) vecs.astype(dtype), k=diags, padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = (solution + padding * mask).astype(dtype) solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@ -407,7 +407,7 @@ class MatrixDiagTest(test.TestCase):
} }
test_list.append((expected, fat_cases())) test_list.append((expected, fat_cases()))
for padding in [0, 555, -11]: for padding_value in [0, 555, -11]:
# Giving both num_rows and num_cols # Giving both num_rows and num_cols
for _, tests in [tall_cases(), fat_cases()]: for _, tests in [tall_cases(), fat_cases()]:
for diags, (vecs, solution) in tests.items(): for diags, (vecs, solution) in tests.items():
@ -416,9 +416,9 @@ class MatrixDiagTest(test.TestCase):
k=diags, k=diags,
num_rows=solution.shape[-2], num_rows=solution.shape[-2],
num_cols=solution.shape[-1], num_cols=solution.shape[-1],
padding=padding) padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = solution + padding * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@ -428,9 +428,12 @@ class MatrixDiagTest(test.TestCase):
vecs, solution = tests[diags] vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_cols), axis=-1) solution = solution.take(indices=range(new_num_cols), axis=-1)
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs, k=diags, num_rows=solution.shape[-2], padding=padding) vecs,
k=diags,
num_rows=solution.shape[-2],
padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = solution + padding * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@ -440,9 +443,12 @@ class MatrixDiagTest(test.TestCase):
vecs, solution = tests[diags] vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_rows), axis=-2) solution = solution.take(indices=range(new_num_rows), axis=-2)
v_diags = array_ops.matrix_diag( v_diags = array_ops.matrix_diag(
vecs, k=diags, num_cols=solution.shape[-1], padding=padding) vecs,
k=diags,
num_cols=solution.shape[-1],
padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = solution + padding * mask solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape) self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution) self.assertAllEqual(v_diags.eval(), solution)
@ -731,15 +737,15 @@ class MatrixDiagPartTest(test.TestCase):
self.assertAllEqual(mat_batch_diag.eval(), v_batch) self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(2019, 7, 4): if compat.forward_compatible(2019, 7, 4):
# Diagonal bands with padding. # Diagonal bands with padding_value.
mat, tests = square_cases() mat, tests = square_cases()
for padding in [0, 555, -11]: for padding_value in [0, 555, -11]:
for diags, pair in tests.items(): for diags, pair in tests.items():
solution, _ = pair solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part( mat_batch_diag = array_ops.matrix_diag_part(
mat.astype(dtype), k=diags, padding=padding) mat.astype(dtype), k=diags, padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = (solution + padding * mask).astype(dtype) solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution) self.assertAllEqual(mat_batch_diag.eval(), solution)
@ -763,15 +769,15 @@ class MatrixDiagPartTest(test.TestCase):
self.assertAllEqual(mat_batch_diag.eval(), v_batch) self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(2019, 7, 4): if compat.forward_compatible(2019, 7, 4):
# Diagonal bands with padding. # Diagonal bands with padding_value.
for padding in [0, 555, -11]: for padding_value in [0, 555, -11]:
for mat, tests in [tall_cases(), fat_cases()]: for mat, tests in [tall_cases(), fat_cases()]:
for diags, pair in tests.items(): for diags, pair in tests.items():
solution, _ = pair solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part( mat_batch_diag = array_ops.matrix_diag_part(
mat, k=diags, padding=padding) mat, k=diags, padding_value=padding_value)
mask = solution == 0 mask = solution == 0
solution = solution + padding * mask solution = solution + padding_value * mask
self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution) self.assertAllEqual(mat_batch_diag.eval(), solution)