Fix variable name (forgot to change padding to padding_value in diag_op_test).
PiperOrigin-RevId: 254251013
This commit is contained in:
parent
b57953d5d1
commit
889c063f0e
@ -348,14 +348,14 @@ class MatrixDiagTest(test.TestCase):
|
||||
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
|
||||
# Diagonal bands with padding.
|
||||
for padding in [0, 555, -11]:
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value in [0, 555, -11]:
|
||||
for _, tests in [self._moreCases(), square_cases()]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs.astype(dtype), k=diags, padding=padding)
|
||||
vecs.astype(dtype), k=diags, padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding * mask).astype(dtype)
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@ -407,7 +407,7 @@ class MatrixDiagTest(test.TestCase):
|
||||
}
|
||||
test_list.append((expected, fat_cases()))
|
||||
|
||||
for padding in [0, 555, -11]:
|
||||
for padding_value in [0, 555, -11]:
|
||||
# Giving both num_rows and num_cols
|
||||
for _, tests in [tall_cases(), fat_cases()]:
|
||||
for diags, (vecs, solution) in tests.items():
|
||||
@ -416,9 +416,9 @@ class MatrixDiagTest(test.TestCase):
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
num_cols=solution.shape[-1],
|
||||
padding=padding)
|
||||
padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = solution + padding * mask
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@ -428,9 +428,12 @@ class MatrixDiagTest(test.TestCase):
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_cols), axis=-1)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs, k=diags, num_rows=solution.shape[-2], padding=padding)
|
||||
vecs,
|
||||
k=diags,
|
||||
num_rows=solution.shape[-2],
|
||||
padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = solution + padding * mask
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@ -440,9 +443,12 @@ class MatrixDiagTest(test.TestCase):
|
||||
vecs, solution = tests[diags]
|
||||
solution = solution.take(indices=range(new_num_rows), axis=-2)
|
||||
v_diags = array_ops.matrix_diag(
|
||||
vecs, k=diags, num_cols=solution.shape[-1], padding=padding)
|
||||
vecs,
|
||||
k=diags,
|
||||
num_cols=solution.shape[-1],
|
||||
padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = solution + padding * mask
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||
self.assertAllEqual(v_diags.eval(), solution)
|
||||
|
||||
@ -731,15 +737,15 @@ class MatrixDiagPartTest(test.TestCase):
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(2019, 7, 4):
|
||||
# Diagonal bands with padding.
|
||||
# Diagonal bands with padding_value.
|
||||
mat, tests = square_cases()
|
||||
for padding in [0, 555, -11]:
|
||||
for padding_value in [0, 555, -11]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat.astype(dtype), k=diags, padding=padding)
|
||||
mat.astype(dtype), k=diags, padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = (solution + padding * mask).astype(dtype)
|
||||
solution = (solution + padding_value * mask).astype(dtype)
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
|
||||
@ -763,15 +769,15 @@ class MatrixDiagPartTest(test.TestCase):
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
if compat.forward_compatible(2019, 7, 4):
|
||||
# Diagonal bands with padding.
|
||||
for padding in [0, 555, -11]:
|
||||
# Diagonal bands with padding_value.
|
||||
for padding_value in [0, 555, -11]:
|
||||
for mat, tests in [tall_cases(), fat_cases()]:
|
||||
for diags, pair in tests.items():
|
||||
solution, _ = pair
|
||||
mat_batch_diag = array_ops.matrix_diag_part(
|
||||
mat, k=diags, padding=padding)
|
||||
mat, k=diags, padding_value=padding_value)
|
||||
mask = solution == 0
|
||||
solution = solution + padding * mask
|
||||
solution = solution + padding_value * mask
|
||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user