Fix variable name (forgot to change padding to padding_value in diag_op_test).

PiperOrigin-RevId: 254251013
This commit is contained in:
Penporn Koanantakool 2019-06-20 12:39:41 -07:00 committed by TensorFlower Gardener
parent b57953d5d1
commit 889c063f0e

View File

@ -348,14 +348,14 @@ class MatrixDiagTest(test.TestCase):
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
# Diagonal bands with padding.
for padding in [0, 555, -11]:
# Diagonal bands with padding_value.
for padding_value in [0, 555, -11]:
for _, tests in [self._moreCases(), square_cases()]:
for diags, (vecs, solution) in tests.items():
v_diags = array_ops.matrix_diag(
vecs.astype(dtype), k=diags, padding=padding)
vecs.astype(dtype), k=diags, padding_value=padding_value)
mask = solution == 0
solution = (solution + padding * mask).astype(dtype)
solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@ -407,7 +407,7 @@ class MatrixDiagTest(test.TestCase):
}
test_list.append((expected, fat_cases()))
for padding in [0, 555, -11]:
for padding_value in [0, 555, -11]:
# Giving both num_rows and num_cols
for _, tests in [tall_cases(), fat_cases()]:
for diags, (vecs, solution) in tests.items():
@ -416,9 +416,9 @@ class MatrixDiagTest(test.TestCase):
k=diags,
num_rows=solution.shape[-2],
num_cols=solution.shape[-1],
padding=padding)
padding_value=padding_value)
mask = solution == 0
solution = solution + padding * mask
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@ -428,9 +428,12 @@ class MatrixDiagTest(test.TestCase):
vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_cols), axis=-1)
v_diags = array_ops.matrix_diag(
vecs, k=diags, num_rows=solution.shape[-2], padding=padding)
vecs,
k=diags,
num_rows=solution.shape[-2],
padding_value=padding_value)
mask = solution == 0
solution = solution + padding * mask
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@ -440,9 +443,12 @@ class MatrixDiagTest(test.TestCase):
vecs, solution = tests[diags]
solution = solution.take(indices=range(new_num_rows), axis=-2)
v_diags = array_ops.matrix_diag(
vecs, k=diags, num_cols=solution.shape[-1], padding=padding)
vecs,
k=diags,
num_cols=solution.shape[-1],
padding_value=padding_value)
mask = solution == 0
solution = solution + padding * mask
solution = solution + padding_value * mask
self.assertEqual(v_diags.get_shape(), solution.shape)
self.assertAllEqual(v_diags.eval(), solution)
@ -731,15 +737,15 @@ class MatrixDiagPartTest(test.TestCase):
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(2019, 7, 4):
# Diagonal bands with padding.
# Diagonal bands with padding_value.
mat, tests = square_cases()
for padding in [0, 555, -11]:
for padding_value in [0, 555, -11]:
for diags, pair in tests.items():
solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part(
mat.astype(dtype), k=diags, padding=padding)
mat.astype(dtype), k=diags, padding_value=padding_value)
mask = solution == 0
solution = (solution + padding * mask).astype(dtype)
solution = (solution + padding_value * mask).astype(dtype)
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution)
@ -763,15 +769,15 @@ class MatrixDiagPartTest(test.TestCase):
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
if compat.forward_compatible(2019, 7, 4):
# Diagonal bands with padding.
for padding in [0, 555, -11]:
# Diagonal bands with padding_value.
for padding_value in [0, 555, -11]:
for mat, tests in [tall_cases(), fat_cases()]:
for diags, pair in tests.items():
solution, _ = pair
mat_batch_diag = array_ops.matrix_diag_part(
mat, k=diags, padding=padding)
mat, k=diags, padding_value=padding_value)
mask = solution == 0
solution = solution + padding * mask
solution = solution + padding_value * mask
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
self.assertAllEqual(mat_batch_diag.eval(), solution)