From 889c063f0e9e8aaeee73f88664ccb631be6c13fd Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 20 Jun 2019 12:39:41 -0700 Subject: [PATCH] Fix variable name (forgot to change padding to padding_value in diag_op_test). PiperOrigin-RevId: 254251013 --- .../python/kernel_tests/diag_op_test.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py index 27bbc98c37f..d8d1cea6c9f 100644 --- a/tensorflow/python/kernel_tests/diag_op_test.py +++ b/tensorflow/python/kernel_tests/diag_op_test.py @@ -348,14 +348,14 @@ class MatrixDiagTest(test.TestCase): self.assertEqual(mat_batch.shape, v_batch_diag.get_shape()) self.assertAllEqual(v_batch_diag.eval(), mat_batch) - # Diagonal bands with padding. - for padding in [0, 555, -11]: + # Diagonal bands with padding_value. + for padding_value in [0, 555, -11]: for _, tests in [self._moreCases(), square_cases()]: for diags, (vecs, solution) in tests.items(): v_diags = array_ops.matrix_diag( - vecs.astype(dtype), k=diags, padding=padding) + vecs.astype(dtype), k=diags, padding_value=padding_value) mask = solution == 0 - solution = (solution + padding * mask).astype(dtype) + solution = (solution + padding_value * mask).astype(dtype) self.assertEqual(v_diags.get_shape(), solution.shape) self.assertAllEqual(v_diags.eval(), solution) @@ -407,7 +407,7 @@ class MatrixDiagTest(test.TestCase): } test_list.append((expected, fat_cases())) - for padding in [0, 555, -11]: + for padding_value in [0, 555, -11]: # Giving both num_rows and num_cols for _, tests in [tall_cases(), fat_cases()]: for diags, (vecs, solution) in tests.items(): @@ -416,9 +416,9 @@ class MatrixDiagTest(test.TestCase): k=diags, num_rows=solution.shape[-2], num_cols=solution.shape[-1], - padding=padding) + padding_value=padding_value) mask = solution == 0 - solution = solution + padding * mask + solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) self.assertAllEqual(v_diags.eval(), solution) @@ -428,9 +428,12 @@ class MatrixDiagTest(test.TestCase): vecs, solution = tests[diags] solution = solution.take(indices=range(new_num_cols), axis=-1) v_diags = array_ops.matrix_diag( - vecs, k=diags, num_rows=solution.shape[-2], padding=padding) + vecs, + k=diags, + num_rows=solution.shape[-2], + padding_value=padding_value) mask = solution == 0 - solution = solution + padding * mask + solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) self.assertAllEqual(v_diags.eval(), solution) @@ -440,9 +443,12 @@ class MatrixDiagTest(test.TestCase): vecs, solution = tests[diags] solution = solution.take(indices=range(new_num_rows), axis=-2) v_diags = array_ops.matrix_diag( - vecs, k=diags, num_cols=solution.shape[-1], padding=padding) + vecs, + k=diags, + num_cols=solution.shape[-1], + padding_value=padding_value) mask = solution == 0 - solution = solution + padding * mask + solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) self.assertAllEqual(v_diags.eval(), solution) @@ -731,15 +737,15 @@ class MatrixDiagPartTest(test.TestCase): self.assertAllEqual(mat_batch_diag.eval(), v_batch) if compat.forward_compatible(2019, 7, 4): - # Diagonal bands with padding. + # Diagonal bands with padding_value. mat, tests = square_cases() - for padding in [0, 555, -11]: + for padding_value in [0, 555, -11]: for diags, pair in tests.items(): solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part( - mat.astype(dtype), k=diags, padding=padding) + mat.astype(dtype), k=diags, padding_value=padding_value) mask = solution == 0 - solution = (solution + padding * mask).astype(dtype) + solution = (solution + padding_value * mask).astype(dtype) self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertAllEqual(mat_batch_diag.eval(), solution) @@ -763,15 +769,15 @@ class MatrixDiagPartTest(test.TestCase): self.assertAllEqual(mat_batch_diag.eval(), v_batch) if compat.forward_compatible(2019, 7, 4): - # Diagonal bands with padding. - for padding in [0, 555, -11]: + # Diagonal bands with padding_value. + for padding_value in [0, 555, -11]: for mat, tests in [tall_cases(), fat_cases()]: for diags, pair in tests.items(): solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part( - mat, k=diags, padding=padding) + mat, k=diags, padding_value=padding_value) mask = solution == 0 - solution = solution + padding * mask + solution = solution + padding_value * mask self.assertEqual(mat_batch_diag.get_shape(), solution.shape) self.assertAllEqual(mat_batch_diag.eval(), solution)