Fix variable name (forgot to change padding to padding_value in diag_op_test).
PiperOrigin-RevId: 254251013
This commit is contained in:
parent
b57953d5d1
commit
889c063f0e
@ -348,14 +348,14 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
|
self.assertEqual(mat_batch.shape, v_batch_diag.get_shape())
|
||||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||||
|
|
||||||
# Diagonal bands with padding.
|
# Diagonal bands with padding_value.
|
||||||
for padding in [0, 555, -11]:
|
for padding_value in [0, 555, -11]:
|
||||||
for _, tests in [self._moreCases(), square_cases()]:
|
for _, tests in [self._moreCases(), square_cases()]:
|
||||||
for diags, (vecs, solution) in tests.items():
|
for diags, (vecs, solution) in tests.items():
|
||||||
v_diags = array_ops.matrix_diag(
|
v_diags = array_ops.matrix_diag(
|
||||||
vecs.astype(dtype), k=diags, padding=padding)
|
vecs.astype(dtype), k=diags, padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = (solution + padding * mask).astype(dtype)
|
solution = (solution + padding_value * mask).astype(dtype)
|
||||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(v_diags.eval(), solution)
|
self.assertAllEqual(v_diags.eval(), solution)
|
||||||
|
|
||||||
@ -407,7 +407,7 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
test_list.append((expected, fat_cases()))
|
test_list.append((expected, fat_cases()))
|
||||||
|
|
||||||
for padding in [0, 555, -11]:
|
for padding_value in [0, 555, -11]:
|
||||||
# Giving both num_rows and num_cols
|
# Giving both num_rows and num_cols
|
||||||
for _, tests in [tall_cases(), fat_cases()]:
|
for _, tests in [tall_cases(), fat_cases()]:
|
||||||
for diags, (vecs, solution) in tests.items():
|
for diags, (vecs, solution) in tests.items():
|
||||||
@ -416,9 +416,9 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
k=diags,
|
k=diags,
|
||||||
num_rows=solution.shape[-2],
|
num_rows=solution.shape[-2],
|
||||||
num_cols=solution.shape[-1],
|
num_cols=solution.shape[-1],
|
||||||
padding=padding)
|
padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = solution + padding * mask
|
solution = solution + padding_value * mask
|
||||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(v_diags.eval(), solution)
|
self.assertAllEqual(v_diags.eval(), solution)
|
||||||
|
|
||||||
@ -428,9 +428,12 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
vecs, solution = tests[diags]
|
vecs, solution = tests[diags]
|
||||||
solution = solution.take(indices=range(new_num_cols), axis=-1)
|
solution = solution.take(indices=range(new_num_cols), axis=-1)
|
||||||
v_diags = array_ops.matrix_diag(
|
v_diags = array_ops.matrix_diag(
|
||||||
vecs, k=diags, num_rows=solution.shape[-2], padding=padding)
|
vecs,
|
||||||
|
k=diags,
|
||||||
|
num_rows=solution.shape[-2],
|
||||||
|
padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = solution + padding * mask
|
solution = solution + padding_value * mask
|
||||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(v_diags.eval(), solution)
|
self.assertAllEqual(v_diags.eval(), solution)
|
||||||
|
|
||||||
@ -440,9 +443,12 @@ class MatrixDiagTest(test.TestCase):
|
|||||||
vecs, solution = tests[diags]
|
vecs, solution = tests[diags]
|
||||||
solution = solution.take(indices=range(new_num_rows), axis=-2)
|
solution = solution.take(indices=range(new_num_rows), axis=-2)
|
||||||
v_diags = array_ops.matrix_diag(
|
v_diags = array_ops.matrix_diag(
|
||||||
vecs, k=diags, num_cols=solution.shape[-1], padding=padding)
|
vecs,
|
||||||
|
k=diags,
|
||||||
|
num_cols=solution.shape[-1],
|
||||||
|
padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = solution + padding * mask
|
solution = solution + padding_value * mask
|
||||||
self.assertEqual(v_diags.get_shape(), solution.shape)
|
self.assertEqual(v_diags.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(v_diags.eval(), solution)
|
self.assertAllEqual(v_diags.eval(), solution)
|
||||||
|
|
||||||
@ -731,15 +737,15 @@ class MatrixDiagPartTest(test.TestCase):
|
|||||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||||
|
|
||||||
if compat.forward_compatible(2019, 7, 4):
|
if compat.forward_compatible(2019, 7, 4):
|
||||||
# Diagonal bands with padding.
|
# Diagonal bands with padding_value.
|
||||||
mat, tests = square_cases()
|
mat, tests = square_cases()
|
||||||
for padding in [0, 555, -11]:
|
for padding_value in [0, 555, -11]:
|
||||||
for diags, pair in tests.items():
|
for diags, pair in tests.items():
|
||||||
solution, _ = pair
|
solution, _ = pair
|
||||||
mat_batch_diag = array_ops.matrix_diag_part(
|
mat_batch_diag = array_ops.matrix_diag_part(
|
||||||
mat.astype(dtype), k=diags, padding=padding)
|
mat.astype(dtype), k=diags, padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = (solution + padding * mask).astype(dtype)
|
solution = (solution + padding_value * mask).astype(dtype)
|
||||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||||
|
|
||||||
@ -763,15 +769,15 @@ class MatrixDiagPartTest(test.TestCase):
|
|||||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||||
|
|
||||||
if compat.forward_compatible(2019, 7, 4):
|
if compat.forward_compatible(2019, 7, 4):
|
||||||
# Diagonal bands with padding.
|
# Diagonal bands with padding_value.
|
||||||
for padding in [0, 555, -11]:
|
for padding_value in [0, 555, -11]:
|
||||||
for mat, tests in [tall_cases(), fat_cases()]:
|
for mat, tests in [tall_cases(), fat_cases()]:
|
||||||
for diags, pair in tests.items():
|
for diags, pair in tests.items():
|
||||||
solution, _ = pair
|
solution, _ = pair
|
||||||
mat_batch_diag = array_ops.matrix_diag_part(
|
mat_batch_diag = array_ops.matrix_diag_part(
|
||||||
mat, k=diags, padding=padding)
|
mat, k=diags, padding_value=padding_value)
|
||||||
mask = solution == 0
|
mask = solution == 0
|
||||||
solution = solution + padding * mask
|
solution = solution + padding_value * mask
|
||||||
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
|
||||||
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
self.assertAllEqual(mat_batch_diag.eval(), solution)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user