diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py index 343f367c1d0..69ae03a06cf 100644 --- a/tensorflow/compiler/tests/matrix_diag_ops_test.py +++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py @@ -26,9 +26,82 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +# LINT.IfChange +matrix_diag_v3_forward_compat_date = (2019, 12, 6) +# LINT.ThenChange( +# //tensorflow/python/kernel_tests/diag_op_test.py, +# //tensorflow/python/ops/array_ops.py, +# //tensorflow/python/ops/parallel_for/array_test.py +# ) + +default_v2_alignment = "LEFT_LEFT" +alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"] + + +def zip_to_first_list_length(a, b): + if len(b) > len(a): + return zip(a, b[:len(a)]) + return zip(a, b + [None] * (len(a) - len(b))) + + +# Routines to convert test cases to have diagonals in a specified alignment. +# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py +def repack_diagonals(packed_diagonals, + diag_index, + num_rows, + num_cols, + align=None): + # The original test cases are LEFT_LEFT aligned. + if align == default_v2_alignment or align is None: + return packed_diagonals + + align = align.split("_") + d_lower, d_upper = diag_index + batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1) + max_diag_len = packed_diagonals.shape[-1] + index = (slice(None),) * batch_dims + repacked_diagonals = np.zeros_like(packed_diagonals) + + # Aligns each diagonal row-by-row. + for diag_index in range(d_lower, d_upper + 1): + diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index)) + row_index = d_upper - diag_index + padding_len = max_diag_len - diag_len + left_align = (diag_index >= 0 and + align[0] == "LEFT") or (diag_index <= 0 and + align[1] == "LEFT") + # Prepares index tuples. + extra_dim = tuple() if d_lower == d_upper else (row_index,) + packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),) + repacked_last_dim = (slice(None),) if left_align else (slice( + padding_len, max_diag_len, 1),) + packed_index = index + extra_dim + packed_last_dim + repacked_index = index + extra_dim + repacked_last_dim + + # Repacks the diagonal. + repacked_diagonals[repacked_index] = packed_diagonals[packed_index] + return repacked_diagonals + + +def repack_diagonals_in_tests(tests, align=None): + # The original test cases are LEFT_LEFT aligned. + if align == default_v2_alignment or align is None: + return tests + + new_tests = dict() + # Loops through each case. + for diag_index, (packed_diagonals, padded_diagonals) in tests.items(): + num_rows, num_cols = padded_diagonals.shape[-2:] + repacked_diagonals = repack_diagonals( + packed_diagonals, diag_index, num_rows, num_cols, align=align) + new_tests[diag_index] = (repacked_diagonals, padded_diagonals) + + return new_tests + + # Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2. # Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py -def square_cases(): +def square_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 1], @@ -103,10 +176,10 @@ def square_cases(): [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) -def tall_cases(): +def tall_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3], [4, 5, 6], @@ -191,10 +264,10 @@ def tall_cases(): [0, 0, 0], [0, 0, 0]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) -def fat_cases(): +def fat_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], @@ -259,7 +332,11 @@ def fat_cases(): [0, 9, 1, 2], [0, 0, 5, 6]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) + + +def all_tests(align=None): + return [square_cases(align), tall_cases(align), fat_cases(align)] class MatrixDiagTest(xla_test.XLATestCase): @@ -327,39 +404,31 @@ class MatrixDiagTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSquare(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for _, tests in [square_cases()]: - for diag_index, (vecs, solution) in tests.items(): - self._assertOpOutputMatchesExpected( - { - "diagonal": vecs[0], - "k": diag_index - }, solution[0]) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + for _, tests in [square_cases(align)]: + for diag_index, (vecs, solution) in tests.items(): + params = {"diagonal": vecs[0], "k": diag_index, "align": align} + self._assertOpOutputMatchesExpected(params, solution[0]) def testSquareBatch(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for _, tests in [square_cases()]: - for diag_index, (vecs, solution) in tests.items(): - self._assertOpOutputMatchesExpected( - { - "diagonal": vecs, - "k": diag_index - }, solution) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + for _, tests in [square_cases(align)]: + for diag_index, (vecs, solution) in tests.items(): + params = {"diagonal": vecs, "k": diag_index, "align": align} + self._assertOpOutputMatchesExpected(params, solution) def testRectangularBatch(self): - # LINT.IfChange - if not compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date): return # Stores expected num_rows and num_cols (when the other is given). # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) test_list = list() + # Do not align the test cases here. Re-alignment needs to happen after the + # solution shape is updated. # Square cases: expected = { (-1, -1): (5, 4), @@ -389,43 +458,65 @@ class MatrixDiagTest(xla_test.XLATestCase): test_list.append((expected, fat_cases())) # Giving both num_rows and num_cols - for _, tests in [tall_cases(), fat_cases()]: + align = alignment_list[0] + for _, tests in [tall_cases(align), fat_cases(align)]: for diag_index, (vecs, solution) in tests.items(): self._assertOpOutputMatchesExpected( { "diagonal": vecs, "k": diag_index, "num_rows": solution.shape[-2], - "num_cols": solution.shape[-1] + "num_cols": solution.shape[-1], + "align": align }, solution) + # We go through each alignment in a round-robin manner. + align_index = 0 + # Giving just num_rows or num_cols. for expected, (_, tests) in test_list: for diag_index, (new_num_rows, new_num_cols) in expected.items(): + align = alignment_list[align_index] + align_index = (align_index + 1) % len(alignment_list) vecs, solution = tests[diag_index] solution_given_num_rows = solution.take( indices=range(new_num_cols), axis=-1) + # Repacks the diagonal input according to the new solution shape. + vecs_given_num_rows = repack_diagonals( + vecs, + diag_index, + solution_given_num_rows.shape[-2], + new_num_cols, + align=align) self._assertOpOutputMatchesExpected( { - "diagonal": vecs, + "diagonal": vecs_given_num_rows, "k": diag_index, - "num_rows": solution_given_num_rows.shape[-2] + "num_rows": solution_given_num_rows.shape[-2], + "align": align }, solution_given_num_rows) solution_given_num_cols = solution.take( indices=range(new_num_rows), axis=-2) + # Repacks the diagonal input according to the new solution shape. + vecs_given_num_cols = repack_diagonals( + vecs, + diag_index, + new_num_rows, + solution_given_num_cols.shape[-1], + align=align) self._assertOpOutputMatchesExpected( { - "diagonal": vecs, + "diagonal": vecs_given_num_cols, "k": diag_index, - "num_cols": solution_given_num_cols.shape[-1] + "num_cols": solution_given_num_cols.shape[-1], + "align": align }, solution_given_num_cols) def testPadding(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for padding_value in [555, -11]: - for _, tests in [square_cases(), tall_cases(), fat_cases()]: + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for padding_value, align in zip_to_first_list_length([555, -11], + alignment_list): + for _, tests in all_tests(align): for diag_index, (vecs, solution) in tests.items(): mask = (solution == 0) solution = solution + (mask * padding_value) @@ -435,7 +526,8 @@ class MatrixDiagTest(xla_test.XLATestCase): "k": diag_index, "num_rows": solution.shape[-2], "num_cols": solution.shape[-1], - "padding_value": padding_value + "padding_value": padding_value, + "align": align }, solution) @@ -542,36 +634,36 @@ class MatrixSetDiagTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSingleMatrix(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for _, tests in [square_cases(), tall_cases(), fat_cases()]: - for diag_index, (vecs, banded_mat) in tests.items(): - mask = (banded_mat[0] == 0) - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat[0] - self._assertOpOutputMatchesExpected( - { - "input": input_mat, - "diagonal": vecs[0], - "k": diag_index - }, solution) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + for _, tests in all_tests(align): + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat[0] == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs[0], + "k": diag_index, + "align": align + }, solution) def testBatch(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for _, tests in [square_cases(), tall_cases(), fat_cases()]: - for diag_index, (vecs, banded_mat) in tests.items(): - mask = (banded_mat == 0) - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat - self._assertOpOutputMatchesExpected( - { - "input": input_mat, - "diagonal": vecs, - "k": diag_index - }, solution) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + for _, tests in all_tests(align): + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs, + "k": diag_index, + "align": align + }, solution) class MatrixDiagPartTest(xla_test.XLATestCase): @@ -613,33 +705,35 @@ class MatrixDiagPartTest(xla_test.XLATestCase): # From here onwards are v2-only tests. def testSingleMatrix(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for mat, tests in [square_cases(), tall_cases(), fat_cases()]: - for diag_index, (solution, _) in tests.items(): - self._assertOpOutputMatchesExpected({ - "input": mat[0], - "k": diag_index - }, solution[0]) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + test_list = [square_cases(align), tall_cases(align), fat_cases(align)] + for mat, tests in test_list: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "input": mat[0], + "k": diag_index, + "align": align + }, solution[0]) def testBatch(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for mat, tests in [square_cases(), tall_cases(), fat_cases()]: - for diag_index, (solution, _) in tests.items(): - self._assertOpOutputMatchesExpected({ - "input": mat, - "k": diag_index - }, solution) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for align in alignment_list: + for mat, tests in all_tests(align): + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "input": mat, + "k": diag_index, + "align": align + }, solution) def testPadding(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - for padding_value in [555, -11]: - for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + for padding_value, align in zip_to_first_list_length([555, -11], + alignment_list): + for mat, tests in all_tests(align): for diag_index, (solution, _) in tests.items(): mask = (solution == 0) solution = solution + (mask * padding_value) @@ -647,7 +741,8 @@ class MatrixDiagPartTest(xla_test.XLATestCase): { "input": mat, "k": diag_index, - "padding_value": padding_value + "padding_value": padding_value, + "align": align }, solution) diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index c6babf0d5f7..7cf9da0c057 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -26,6 +26,30 @@ limitations under the License. namespace tensorflow { namespace { +// Calculates the diagonal length of a diagonal. +static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) { + return std::min(num_rows + std::min(0, diag_index), + num_cols - std::max(0, diag_index)); +} + +// Checks if a diagonal is to be aligned left or right. +static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal, + bool left_align_subdiagonal) { + return (diag_index >= 0 && left_align_superdiagonal) || + (diag_index <= 0 && left_align_subdiagonal); +} + +// Reads the diagonal packing alignment. +void ReadAlignment(OpKernelConstruction* context, + bool* left_align_superdiagonal, + bool* left_align_subdiagonal) { + string align; + OP_REQUIRES_OK(context, context->GetAttr("align", &align)); + + *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT"; + *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT"; +} + // Reads or infers lower_diag_index and upper_diag_index from kernel's input // parameter "k". Also validates their values. std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) { @@ -94,7 +118,9 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, const TensorShape& input_shape, const int64 diag_rank, const int64 num_diags, const int64 lower_diag_index, const int64 upper_diag_index, const int64 max_diag_len, - const int64 num_rows, const int64 num_cols) { + const int64 num_rows, const int64 num_cols, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { // Creates a padding config. const int input_rank = input_shape.dims(); xla::PaddingConfig padding_config; @@ -108,29 +134,26 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow. // // For example, - // diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4. + // diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT". // [4, 5, 6], // [7, 8, 9]] - // The expected output is [[4, 2, 0], - // [7, 5, 4], - // [0, 8, 6], - // [0, 0, 9]] + // The expected output is [[7, 4, 2, 0], + // [0, 8, 5, 3], + // [0, 0, 9, 6]]. // The 1st diagonal is created by: - // 1) Extracting diag_slice = [1, 2, 0]. - // 2) Padding the vector to be as long as num_rows, - // diag_slice = [1, 2, 0, 0], - // then broadcasting diag_slice row-wise to a full matrix, - // diag_broadcast = [[1, 1, 1], - // [2, 2, 2], - // [0, 0, 0], - // [0, 0, 0]] + // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned. + // 2) Padding the vector (in the same direction) to be as long as num_cols, + // diag_slice = [0, 0, 2, 3], + // then broadcasting diag_slice column-wise to a full matrix, + // diag_broadcast = [[0, 0, 2, 3], + // [0, 0, 2, 3], + // [0, 0, 2, 3]]. // The padding value can be anything because it will not appear in the // results after masking. Here, we use zero. // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal. - // mask = [[0, 1, 0], --> output = [[x, 2, x], - // [0, 0, 1], [x, x, 3], - // [0, 0, 0], [x, x, x], - // [0, 0, 0]] [x, x, x]], + // mask = [[0, 0, 1, 0], --> output = [[x, x, 2, x], + // [0, 0, 0, 1], [x, x, x, 3], + // [0, 0, 0, 0]] [x, x, x, x]], // where x denotes the existing input contents. std::vector<int64> broadcast_dimensions(input_rank - 1); absl::c_iota(broadcast_dimensions, 0); @@ -140,6 +163,8 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, // Extracts a single diagonal. auto diag_slice = diag; if (num_diags > 1) { + // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len]. + // We call Collapse to make the dims: [<batch_dim>, max_diag_len]. const int64 mapped_diag_index = upper_diag_index - diag_index; diag_slice = xla::Collapse( xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1, @@ -147,20 +172,51 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, {diag_rank - 2, diag_rank - 1}); } - // Pads if necessary. Always pad at the end because shorter diagonals in - // the input come padded at the end. - const int64 padding_length = - ((diag_index <= 0) ? num_cols : num_rows) - max_diag_len; - const xla::XlaOp zero = xla::ScalarLike(input, 0); - if (padding_length > 0) { + // Pad if necessary. + // - If the diagonal has the longest length, i.e., min(num_rows, num_cols), + // no padding is necessary. It is broadcast column-wise if it is a sub- + // diagonal, row-wise if superdiagonal. + // - Otherwise, pad and keep the old alignment (shorter diagonals in the + // input come pre-padded). max_len in the table refers to max_diag_len. + // ------------------------------------------------------------------- + // | Diag | Align | Broadcast | padding_low | padding_high | + // ------------------------------------------------------------------- + // | Super | Left | Row-wise | 0 | #rows - max_len | + // | | Right | Column-wise | #cols - max_len | 0 | + // ------------------------------------------------------------------- + // | Sub | Left | Column-wise | 0 | #cols - max_len | + // | | Right | Row-wise | #rows - max_len | 0 | + // ------------------------------------------------------------------- + if (num_cols - num_rows <= diag_index && diag_index <= 0) { + broadcast_dimensions.back() = input_rank - 1; // Column-wise. + } else if (0 <= diag_index && diag_index <= num_cols - num_rows) { + broadcast_dimensions.back() = input_rank - 2; // Row-wise. + } else { + int length_to_pad_to; + if ((diag_index > 0 && left_align_superdiagonal) || + (diag_index < 0 && !left_align_subdiagonal)) { + length_to_pad_to = num_rows; + broadcast_dimensions.back() = input_rank - 2; // Row-wise. + } else { + length_to_pad_to = num_cols; + broadcast_dimensions.back() = input_rank - 1; // Column-wise. + } + int padding_low = length_to_pad_to - max_diag_len; + int padding_high = 0; + if (IsLeftAligned(diag_index, left_align_superdiagonal, + left_align_subdiagonal)) { + std::swap(padding_low, padding_high); + } padding_config.mutable_dimensions(input_rank - 2) - ->set_edge_padding_high(padding_length); + ->set_edge_padding_low(padding_low); + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_high); + + const xla::XlaOp zero = xla::ScalarLike(input, 0); diag_slice = xla::Pad(diag_slice, zero, padding_config); } - // Broadcasts column-wise for subdiagonals; row-wise for superdiagonals. - broadcast_dimensions.back() = - (diag_index <= 0) ? input_rank - 1 : input_rank - 2; + // Broadcast and mask. xla::XlaOp diag_broadcast = xla::BroadcastInDim( diag_slice, input_shape.dim_sizes(), broadcast_dimensions); const auto mask = xla::GetDiagonalMask(output, diag_index); @@ -173,11 +229,17 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, class MatrixDiagOp : public XlaOpKernel { public: - explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) { + // MatrixDiagV3-specific. + if (context->HasAttr("align")) { + ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compile(XlaOpKernelContext* context) override { OP_REQUIRES( - context, context->num_inputs() >= 1, + context, context->num_inputs() >= kNumV1Inputs, errors::InvalidArgument("MatrixDiag op must have at least one input")); const TensorShape diag_shape = context->InputShape(0); OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), @@ -199,7 +261,7 @@ class MatrixDiagOp : public XlaOpKernel { // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has // one input, so we have to check the number of inputs before reading // additional parameters for MatrixDiagV2. - if (context->num_inputs() > 1) { + if (context->num_inputs() > kNumV1Inputs) { std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows)); OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols)); @@ -252,8 +314,14 @@ class MatrixDiagOp : public XlaOpKernel { context->SetOutput( 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, lower_diag_index, upper_diag_index, max_diag_len, - num_rows, num_cols)); + num_rows, num_cols, left_align_superdiagonal_, + left_align_subdiagonal_)); } + + private: + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 1; }; REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); @@ -263,12 +331,24 @@ REGISTER_XLA_OP(Name("MatrixDiagV2") .CompileTimeConstantInput("num_cols") .CompileTimeConstantInput("padding_value"), MatrixDiagOp); +REGISTER_XLA_OP(Name("MatrixDiagV3") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("num_rows") + .CompileTimeConstantInput("num_cols") + .CompileTimeConstantInput("padding_value"), + MatrixDiagOp); class MatrixDiagPartOp : public XlaOpKernel { public: explicit MatrixDiagPartOp(OpKernelConstruction* context) : XlaOpKernel(context), - is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {} + is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) { + // MatrixDiagPartV3-specific. + if (context->HasAttr("align")) { + ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); @@ -290,7 +370,7 @@ class MatrixDiagPartOp : public XlaOpKernel { // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel. // MatrixDiagPart only has one input, so we have to check the number of // inputs before reading additional parameters in MatrixDiagV2. - if (context->num_inputs() > 1) { + if (context->num_inputs() > kNumV1Inputs) { std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); padding_value = context->Input(2); } @@ -314,25 +394,34 @@ class MatrixDiagPartOp : public XlaOpKernel { // Computes output. xla::XlaOp input = context->Input(0); std::vector<xla::XlaOp> diag_list; - xla::PaddingConfig padding_config; + xla::PaddingConfig padding_config = + xla::MakeNoPaddingConfig(input_rank - 1); if (num_diags == 1) { context->SetOutput( 0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index) : xla::GetMatrixDiagonal(input, upper_diag_index)); return; } - padding_config = xla::MakeNoPaddingConfig(input_rank - 1); for (int diag_index = upper_diag_index; diag_index >= lower_diag_index; --diag_index) { xla::XlaOp single_diag = is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index) : xla::GetMatrixDiagonal(input, diag_index); - const int64 diag_length = - (diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index); - const int64 padding_length = max_diag_len - diag_length; - if (padding_length > 0) { - padding_config.mutable_dimensions(input_rank - 2) - ->set_edge_padding_high(padding_length); + const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols); + const int64 padding_len = max_diag_len - diag_len; + if (padding_len > 0) { + if (IsLeftAligned(diag_index, left_align_superdiagonal_, + left_align_subdiagonal_)) { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_low(0); + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_len); + } else { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_low(padding_len); + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(0); + } single_diag = xla::Pad(single_diag, padding_value, padding_config); } diag_list.emplace_back(single_diag); @@ -344,6 +433,9 @@ class MatrixDiagPartOp : public XlaOpKernel { private: const bool is_gpu_; + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 1; }; REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); @@ -351,11 +443,21 @@ REGISTER_XLA_OP(Name("MatrixDiagPartV2") .CompileTimeConstantInput("k") .CompileTimeConstantInput("padding_value"), MatrixDiagPartOp); +REGISTER_XLA_OP(Name("MatrixDiagPartV3") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("padding_value"), + MatrixDiagPartOp); class MatrixSetDiagOp : public XlaOpKernel { public: explicit MatrixSetDiagOp(OpKernelConstruction* context) - : XlaOpKernel(context) {} + : XlaOpKernel(context) { + // MatrixSetDiagV3-specific. + if (context->HasAttr("align")) { + ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); @@ -378,7 +480,7 @@ class MatrixSetDiagOp : public XlaOpKernel { // reading additional parameters in MatrixSetDiagV2. int64 lower_diag_index = 0; int64 upper_diag_index = 0; - if (context->num_inputs() > 2) { + if (context->num_inputs() > kNumV1Inputs) { std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); } @@ -419,15 +521,21 @@ class MatrixSetDiagOp : public XlaOpKernel { context->SetOutput( 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags, lower_diag_index, upper_diag_index, max_diag_len, - num_rows, num_cols)); + num_rows, num_cols, left_align_superdiagonal_, + left_align_subdiagonal_)); } private: + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 2; TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); }; REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"), MatrixSetDiagOp); +REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"), + MatrixSetDiagOp); } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt new file mode 100644 index 00000000000..a9fe8802e66 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt @@ -0,0 +1,141 @@ +op { + graph_op_name: "MatrixDiagPartV3" + in_arg { + name: "input" + description: "Rank `r` tensor where `r >= 2`." + } + in_arg { + name: "k" + description: <<END +Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main +diagonal, and negative value means subdiagonals. `k` can be a single integer +(for a single diagonal) or a pair of integers specifying the low and high ends +of a matrix band. `k[0]` must not be larger than `k[1]`. +END + } + in_arg { + name: "padding_value" + description: <<END +The value to fill the area outside the specified diagonal band with. +Default is 0. +END + } + out_arg { + name: "diagonal" + description: "The extracted diagonal(s)." + } + attr { + name: "align" + description: <<END +Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is +a string specifying how superdiagonals and subdiagonals should be aligned, +respectively. There are four possible alignments: "RIGHT_LEFT" (default), +"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals +to the right (left-pads the row) and subdiagonals to the left (right-pads the +row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is +the opposite alignment. +END + } + summary: "Returns the batched diagonal part of a batched tensor." + description: <<END +Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched +`input`. + +Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`. +Let `max_diag_len` be the maximum length among all diagonals to be extracted, +`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` +Let `num_diags` be the number of diagonals to extract, +`num_diags = k[1] - k[0] + 1`. + +If `num_diags == 1`, the output tensor is of rank `r - 1` with shape +`[I, J, ..., L, max_diag_len]` and values: + +``` +diagonal[i, j, ..., l, n] + = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + padding_value ; otherwise. +``` +where `y = max(-k[1], 0)`, `x = max(k[1], 0)`. + +Otherwise, the output tensor has rank `r` with dimensions +`[I, J, ..., L, num_diags, max_diag_len]` with values: + +``` +diagonal[i, j, ..., l, m, n] + = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + padding_value ; otherwise. +``` +where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`. + +`offset` is zero except when the alignment of the diagonal is to the right. +``` +offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise +``` +where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + +The input must be at least a matrix. + +For example: + +``` +input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4) + [5, 6, 7, 8], + [9, 8, 7, 6]], + [[5, 4, 3, 2], + [1, 2, 3, 4], + [5, 6, 7, 8]]]) + +# A main diagonal from each batch. +tf.matrix_diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3) + [5, 2, 7]] + +# A superdiagonal from each batch. +tf.matrix_diag_part(input, k = 1) + ==> [[2, 7, 6], # Output shape: (2, 3) + [4, 3, 8]] + +# A band from each batch. +tf.matrix_diag_part(input, k = (-1, 2)) + ==> [[[0, 3, 8], # Output shape: (2, 4, 3) + [2, 7, 6], + [1, 6, 7], + [5, 8, 0]], + [[0, 3, 4], + [4, 3, 8], + [5, 2, 7], + [1, 6, 0]]] + +# LEFT_RIGHT alignment. +tf.matrix_diag_part(input, k = (-1, 2), align="LEFT_RIGHT") + ==> [[[3, 8, 0], # Output shape: (2, 4, 3) + [2, 7, 6], + [1, 6, 7], + [0, 5, 8]], + [[3, 4, 0], + [4, 3, 8], + [5, 2, 7], + [0, 1, 6]]] + +# max_diag_len can be shorter than the main diagonal. +tf.matrix_diag_part(input, k = (-2, -1)) + ==> [[[5, 8], + [9, 0]], + [[1, 6], + [5, 0]]] + +# padding_value = 9 +tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) + ==> [[[9, 9, 4], # Output shape: (2, 3, 3) + [9, 3, 8], + [2, 7, 6]], + [[9, 9, 2], + [9, 3, 4], + [4, 3, 8]]] + +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt new file mode 100644 index 00000000000..d96e48fdbea --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt @@ -0,0 +1,177 @@ +op { + graph_op_name: "MatrixDiagV3" + in_arg { + name: "diagonal" + description: "Rank `r`, where `r >= 1`" + } + in_arg { + name: "k" + description: <<END +Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main +diagonal, and negative value means subdiagonals. `k` can be a single integer +(for a single diagonal) or a pair of integers specifying the low and high ends +of a matrix band. `k[0]` must not be larger than `k[1]`. +END + } + in_arg { + name: "num_rows" + description: <<END +The number of rows of the output matrix. If it is not provided, the op assumes +the output matrix is a square matrix and infers the matrix size from k and the +innermost dimension of `diagonal`. +END + } + in_arg { + name: "num_cols" + description: <<END +The number of columns of the output matrix. If it is not provided, the op +assumes the output matrix is a square matrix and infers the matrix size from +k and the innermost dimension of `diagonal`. +END + } + in_arg { + name: "padding_value" + description: <<END +The number to fill the area outside the specified diagonal band with. +Default is 0. +END + } + out_arg { + name: "output" + description: <<END +Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise. +END + } + attr { + name: "align" + description: <<END +Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is +a string specifying how superdiagonals and subdiagonals should be aligned, +respectively. There are four possible alignments: "RIGHT_LEFT" (default), +"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals +to the right (left-pads the row) and subdiagonals to the left (right-pads the +row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is +the opposite alignment. +END + } + summary: + "Returns a batched diagonal tensor with given batched diagonal values." + description: <<END +Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th +diagonals of a matrix, with everything else padded with `padding`. `num_rows` +and `num_cols` specify the dimension of the innermost matrix of the output. If +both are not specified, the op assumes the innermost matrix is square and infers +its size from `k` and the innermost dimension of `diagonal`. If only one of them +is specified, the op assumes the unspecified value is the smallest possible +based on other criteria. + +Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor has +rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only one +diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has rank +`r` with shape `[I, J, ..., L, num_rows, num_cols]`. + +The second innermost dimension of `diagonal` has double meaning. +When `k` is scalar or `k[0] == k[1]`, `M` is part of the batch size +[I, J, ..., M], and the output tensor is: + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper + padding_value ; otherwise +``` + +Otherwise, `M` is treated as the number of diagonals for the matrix in the +same batch (`M = k[1]-k[0]+1`), and the output tensor is: + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + padding_value ; otherwise +``` +where `d = n - m`, `diag_index = [k] - d`, and +`index_in_diag = n - max(d, 0) + offset`. + +`offset` is zero except when the alignment of the diagonal is to the right. +``` +offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise +``` +where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + +For example: + +``` +# The main diagonal. +diagonal = np.array([[1, 2, 3, 4], # Input shape: (2, 4) + [5, 6, 7, 8]]) +tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0], # Output shape: (2, 4, 4) + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]], + [[5, 0, 0, 0], + [0, 6, 0, 0], + [0, 0, 7, 0], + [0, 0, 0, 8]]] + +# A superdiagonal (per batch). +diagonal = np.array([[1, 2, 3], # Input shape: (2, 3) + [4, 5, 6]]) +tf.matrix_diag(diagonal, k = 1) + ==> [[[0, 1, 0, 0], # Output shape: (2, 4, 4) + [0, 0, 2, 0], + [0, 0, 0, 3], + [0, 0, 0, 0]], + [[0, 4, 0, 0], + [0, 0, 5, 0], + [0, 0, 0, 6], + [0, 0, 0, 0]]] + +# A tridiagonal band (per batch). +diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3) + [1, 2, 3], + [4, 5, 0]], + [[0, 2, 3], + [6, 7, 9], + [9, 1, 0]]]) +tf.matrix_diag(diagonals, k = (-1, 1)) + ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + [4, 2, 9], + [0, 5, 3]], + [[6, 2, 0], + [9, 7, 3], + [0, 1, 9]]] + +# LEFT_RIGHT alignment. +diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3) + [1, 2, 3], + [0, 4, 5]], + [[2, 3, 0], + [6, 7, 9], + [0, 9, 1]]]) +tf.matrix_diag(diagonals, k = (-1, 1), align="LEFT_RIGHT") + ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + [4, 2, 9], + [0, 5, 3]], + [[6, 2, 0], + [9, 7, 3], + [0, 1, 9]]] + +# Rectangular matrix. +diagonal = np.array([1, 2]) # Input shape: (2) +tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4) + ==> [[0, 0, 0, 0], # Output shape: (3, 4) + [1, 0, 0, 0], + [0, 2, 0, 0]] + +# Rectangular matrix with inferred num_cols and padding_value = 9. +tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) + ==> [[9, 9], # Output shape: (3, 2) + [1, 9], + [9, 2]] + +``` +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt new file mode 100644 index 00000000000..b3b5980ad44 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt @@ -0,0 +1,148 @@ +op { + graph_op_name: "MatrixSetDiagV3" + in_arg { + name: "input" + description: "Rank `r+1`, where `r >= 1`." + } + in_arg { + name: "diagonal" + description: <<END +Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. +`k >= 1`. +END + } + in_arg { + name: "k" + description: <<END +Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main +diagonal, and negative value means subdiagonals. `k` can be a single integer +(for a single diagonal) or a pair of integers specifying the low and high ends +of a matrix band. `k[0]` must not be larger than `k[1]`. +END + } + out_arg { + name: "output" + description: <<END +Rank `r+1`, with `output.shape = input.shape`. +END + } + attr { + name: "align" + description: <<END +Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is +a string specifying how superdiagonals and subdiagonals should be aligned, +respectively. There are four possible alignments: "RIGHT_LEFT" (default), +"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals +to the right (left-pads the row) and subdiagonals to the left (right-pads the +row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is +the opposite alignment. +END + } + summary: "Returns a batched matrix tensor with new batched diagonal values." + description: <<END +Given `input` and `diagonal`, this operation returns a tensor with the +same shape and values as `input`, except for the specified diagonals of the +innermost matrices. These will be overwritten by the values in `diagonal`. + +`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or +`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`. +Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`. +`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`. +`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`, +`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` + +The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`. +If `k` is scalar or `k[0] == k[1]`: + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1] + input[i, j, ..., l, m, n] ; otherwise +``` + +Otherwise, + +``` +output[i, j, ..., l, m, n] + = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1] + input[i, j, ..., l, m, n] ; otherwise +``` +where `d = n - m`, `diag_index = k[1] - d`, and +`index_in_diag = n - max(d, 0) + offset`. + +`offset` is zero except when the alignment of the diagonal is to the right. +``` +offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise +``` +where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + +For example: + +``` +# The main diagonal. +input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4) + [7, 7, 7, 7], + [7, 7, 7, 7]], + [[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]]]) +diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3) + [4, 5, 6]]) +tf.matrix_set_diag(input, diagonal) + ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + [7, 2, 7, 7], + [7, 7, 3, 7]], + [[4, 7, 7, 7], + [7, 5, 7, 7], + [7, 7, 6, 7]]] + +# A superdiagonal (per batch). +tf.matrix_set_diag(input, diagonal, k = 1) + ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4) + [7, 7, 2, 7], + [7, 7, 7, 3]], + [[7, 4, 7, 7], + [7, 7, 5, 7], + [7, 7, 7, 6]]] + +# A band of diagonals. +diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3) + [6, 5, 8], + [1, 2, 3], + [4, 5, 0]], + [[0, 1, 2], + [5, 6, 4], + [6, 1, 2], + [3, 4, 0]]]) +tf.matrix_set_diag(input, diagonals, k = (-1, 2)) + ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + [4, 2, 5, 1], + [7, 5, 3, 8]], + [[6, 5, 1, 7], + [3, 1, 6, 2], + [7, 4, 2, 4]]] + +# LEFT_RIGHT alignment. +diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3) + [6, 5, 8], + [1, 2, 3], + [0, 4, 5]], + [[1, 2, 0], + [5, 6, 4], + [6, 1, 2], + [0, 3, 4]]]) +tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") + ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + [4, 2, 5, 1], + [7, 5, 3, 8]], + [[6, 5, 1, 7], + [3, 1, 6, 2], + [7, 4, 2, 4]]] + +``` +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt index 753cfe5e235..180fa21ba6d 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt @@ -1,11 +1,4 @@ op { graph_op_name: "MatrixDiagPartV2" - endpoint { - name: "linalg.diag_part" - } - endpoint { - name: "matrix_diag_part" - deprecation_version: 2 - } visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt new file mode 100644 index 00000000000..386502ceea2 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt @@ -0,0 +1,11 @@ +op { + graph_op_name: "MatrixDiagPartV3" + endpoint { + name: "linalg.diag_part" + } + endpoint { + name: "matrix_diag_part" + deprecation_version: 2 + } + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt index 7f76ddba0c4..b6bbdf54a81 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt @@ -1,11 +1,4 @@ op { graph_op_name: "MatrixDiagV2" - endpoint { - name: "linalg.diag" - } - endpoint { - name: "matrix_diag" - deprecation_version: 2 - } visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt new file mode 100644 index 00000000000..e1ed2061f06 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt @@ -0,0 +1,11 @@ +op { + graph_op_name: "MatrixDiagV3" + endpoint { + name: "linalg.diag" + } + endpoint { + name: "matrix_diag" + deprecation_version: 2 + } + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt index eed43da2bb9..ecf7353c9d8 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt @@ -1,11 +1,4 @@ op { graph_op_name: "MatrixSetDiagV2" - endpoint { - name: "linalg.set_diag" - } - endpoint { - name: "matrix_set_diag" - deprecation_version: 2 - } visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt new file mode 100644 index 00000000000..a92ff977c58 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt @@ -0,0 +1,11 @@ +op { + graph_op_name: "MatrixSetDiagV3" + endpoint { + name: "linalg.set_diag" + } + endpoint { + name: "matrix_set_diag" + deprecation_version: 2 + } + visibility: HIDDEN +} diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index ea8678de5cf..e78f3970b89 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1189,6 +1189,254 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, + int32* lower_diag_index, int32* upper_diag_index) { + // This function assumes that the shape of diag_index_tensor is fully defined. + if (diag_index_tensor->dims() == 0) { + *lower_diag_index = diag_index_tensor->scalar<int32>()(); + *upper_diag_index = *lower_diag_index; + } else { + int32 num_elements = diag_index_tensor->dim_size(0); + if (num_elements == 1) { + *lower_diag_index = diag_index_tensor->vec<int32>()(0); + *upper_diag_index = *lower_diag_index; + } else if (num_elements == 2) { + *lower_diag_index = diag_index_tensor->vec<int32>()(0); + *upper_diag_index = diag_index_tensor->vec<int32>()(1); + } else { + return errors::InvalidArgument( + "diag_index must be a vector with one or two elements. It has ", + num_elements, " elements."); + } + } + return Status::OK(); +} + +Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape, diag_index_shape, unused_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); + + const Tensor* diag_index_tensor = c->input_tensor(1); + if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || + diag_index_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + int32 lower_diag_index = 0; + int32 upper_diag_index = 0; + TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, + &upper_diag_index)); + if (lower_diag_index > upper_diag_index) { + return errors::InvalidArgument( + "lower_diag_index is greater than upper_diag_index"); + } + + // Validates lower_diag_index and upper_diag_index. + const int32 input_rank = c->Rank(input_shape); + const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); + const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); + if (num_rows != InferenceContext::kUnknownDim && + num_cols != InferenceContext::kUnknownDim) { + if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. + (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { + return errors::InvalidArgument("lower_diag_index is out of bound."); + } + if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. + (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { + return errors::InvalidArgument("upper_diag_index is out of bound."); + } + } + + const int32 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0), + num_cols - std::max(lower_diag_index, 0)); + std::vector<DimensionHandle> dims; + dims.reserve(input_rank - 2); + for (int i = 0; i < input_rank - 2; ++i) { + dims.push_back(c->Dim(input_shape, i)); + } + if (lower_diag_index < upper_diag_index) { + dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1)); + } + dims.push_back(c->MakeDim(max_diag_len)); + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + +Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { + // Checks input ranks. + ShapeHandle input_shape, diag_index_shape, unused_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); + + // Reads the diagonal indices. + const Tensor* diag_index_tensor = c->input_tensor(1); + if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || + diag_index_tensor == nullptr) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + } + int32 lower_diag_index = 0; + int32 upper_diag_index = 0; + TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, + &upper_diag_index)); + if (lower_diag_index > upper_diag_index) { + return errors::InvalidArgument( + "lower_diag_index is greater than upper_diag_index"); + } + + // Checks if the number of diagonals provided matches what we imply from + // lower_diag_index and upper_diag_index. + const int32 input_rank = c->Rank(input_shape); + if (lower_diag_index < upper_diag_index) { + const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2)); + const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1)); + + if (num_diags != (upper_diag_index - lower_diag_index + 1)) { + return errors::InvalidArgument( + "The number of rows of `diagonal` doesn't match the number of " + "diagonals implied from `d_lower` and `d_upper`.\n", + "num_diags = ", num_diags, ", d_lower = ", lower_diag_index, + ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim); + } + } + + // Reads num_rows and num_cols. + const Tensor* num_rows_tensor = c->input_tensor(2); + const Tensor* num_cols_tensor = c->input_tensor(3); + int64 num_rows = -1; + int64 num_cols = -1; + if (num_rows_tensor != nullptr) { + TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows)); + } + if (num_cols_tensor != nullptr) { + TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols)); + } + + // Infers the missing num_rows or num_cols: If both are missing, assume + // output is square. Otherwise, use the smallest possible value. Also + // validates the provided values. + const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1)); + const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0); + const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0); + if (num_rows == -1 && num_cols == -1) { // Special case. + num_rows = std::max(min_num_rows, min_num_cols); + num_cols = num_rows; + } + if (num_rows == -1) { + num_rows = min_num_rows; + } else if (num_rows < min_num_rows) { + return errors::InvalidArgument("num_rows is too small"); + } + if (num_cols == -1) { + num_cols = min_num_cols; + } else if (num_cols < min_num_cols) { + return errors::InvalidArgument("num_cols is too small."); + } + // At least one of them must match the minimum length. + if (num_rows != min_num_rows && num_cols != min_num_cols) { + return errors::InvalidArgument( + "num_rows and num_cols are not consistent with lower_diag_index, " + "upper_diag_index, and the length of the given diagonals.\n", + "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows, + ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols); + } + + // Sets output shape. + ShapeHandle output_shape; + const DimensionHandle output_row_dim = c->MakeDim(num_rows); + const DimensionHandle output_col_dim = c->MakeDim(num_cols); + if (lower_diag_index == upper_diag_index) { + TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1, + output_row_dim, &output_shape)); + TF_RETURN_IF_ERROR( + c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape)); + } else { + TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2, + output_row_dim, &output_shape)); + TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1, + output_col_dim, &output_shape)); + } + c->set_output(0, output_shape); + return Status::OK(); +} + +Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { + ShapeHandle input_shape, diag_shape, diag_index_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape)); + + int32 lower_diag_index = 0; + int32 upper_diag_index = 0; + bool diag_index_known = false; + const Tensor* diag_index_tensor = c->input_tensor(2); + if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) { + diag_index_known = true; + TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, + &upper_diag_index)); + if (lower_diag_index > upper_diag_index) { + return errors::InvalidArgument( + "lower_diag_index is greater than upper_diag_index"); + } + } + + // Do more checks when input rank is known. + if (c->RankKnown(input_shape)) { + int32 input_rank = c->Rank(input_shape); + + // If diag_index is set, we know the exact rank of diagonal. + if (diag_index_known) { + TF_RETURN_IF_ERROR(c->WithRank( + c->input(1), + (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank, + &diag_shape)); + } else { + TF_RETURN_IF_ERROR( + c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape)); + TF_RETURN_IF_ERROR( + c->WithRankAtMost(c->input(1), input_rank, &diag_shape)); + } + + // Validates lower_diag_index and upper_diag_index. + const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); + const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); + if (num_rows != InferenceContext::kUnknownDim && + num_cols != InferenceContext::kUnknownDim) { + if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. + (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { + return errors::InvalidArgument("lower_diag_index is out of bound."); + } + if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. + (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { + return errors::InvalidArgument("upper_diag_index is out of bound."); + } + } + } + + ShapeHandle output_shape = input_shape; + if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) { + // Try to infer parts of shape from diag. + ShapeHandle diag_prefix; + TF_RETURN_IF_ERROR(c->Subshape( + diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2, + &diag_prefix)); + + // The inner matrices can be rectangular, so we can't pinpoint their + // exact height and width by just lower_diag_index, upper_diag_index, + // and the longest length of given diagonals. + TF_RETURN_IF_ERROR( + c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape)); + TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); + } + c->set_output(0, output_shape); + return Status::OK(); +} + Status MaxPoolShape(shape_inference::InferenceContext* c) { string data_format_str; TensorFormat data_format; diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 590aa98b60b..434948bafa2 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -270,6 +270,15 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c); // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); +// Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations. +Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c); + +// Shape function for MatrixDiagV2 and MatrixDiagV3 operations. +Status MatrixDiagV2Shape(shape_inference::InferenceContext* c); + +// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations. +Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c); + // Shape function for MaxPool-like operations. Status MaxPoolShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ac9c6299833..3945d5d7f55 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1158,7 +1158,7 @@ tf_kernel_library( tf_kernel_library( name = "matrix_set_diag_op", prefix = "matrix_set_diag_op", - deps = ARRAY_DEPS, + deps = ARRAY_DEPS + [":matrix_diag_op"], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc index ae69e7752f1..9796fd25e39 100644 --- a/tensorflow/core/kernels/matrix_diag_op.cc +++ b/tensorflow/core/kernels/matrix_diag_op.cc @@ -46,8 +46,13 @@ typedef Eigen::GpuDevice GPUDevice; template <typename Device, typename T> class MatrixDiagPartOp : public OpKernel { public: - explicit MatrixDiagPartOp(OpKernelConstruction* context) - : OpKernel(context) {} + explicit MatrixDiagPartOp(OpKernelConstruction* context) : OpKernel(context) { + // MatrixDiagPartV3-specific. + if (context->HasAttr("align")) { + functor::ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); @@ -60,7 +65,7 @@ class MatrixDiagPartOp : public OpKernel { T padding_value(0); // MatrixDiagPartV2-specific. - if (context->num_inputs() > 1) { + if (context->num_inputs() > kNumV1Inputs) { auto& diag_index = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsScalar(diag_index.shape()) || @@ -132,17 +137,26 @@ class MatrixDiagPartOp : public OpKernel { functor::MatrixDiagPart<Device, T>::Compute( context, context->eigen_device<Device>(), input_reshaped, output_reshaped, lower_diag_index, upper_diag_index, max_diag_len, - padding_value); + padding_value, left_align_superdiagonal_, left_align_subdiagonal_); } private: + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 1; TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp); }; template <typename Device, typename T> class MatrixDiagOp : public OpKernel { public: - explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {} + explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) { + // MatrixDiagV3-specific. + if (context->HasAttr("align")) { + functor::ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compute(OpKernelContext* context) override { const Tensor& diagonal = context->input(0); @@ -157,7 +171,7 @@ class MatrixDiagOp : public OpKernel { T padding_value(0); // MatrixDiagOpV2-specific. - if (context->num_inputs() > 1) { + if (context->num_inputs() > kNumV1Inputs) { auto& diag_index = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsScalar(diag_index.shape()) || @@ -242,10 +256,13 @@ class MatrixDiagOp : public OpKernel { functor::MatrixDiag<Device, T>::Compute( context, context->eigen_device<Device>(), diag_reshaped, output_reshaped, lower_diag_index, upper_diag_index, max_diag_len, - padding_value); + padding_value, left_align_superdiagonal_, left_align_subdiagonal_); } private: + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 1; TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp); }; @@ -256,12 +273,19 @@ class MatrixDiagOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("MatrixDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ MatrixDiagOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatrixDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + MatrixDiagOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ MatrixDiagPartOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("MatrixDiagPartV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + MatrixDiagPartOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatrixDiagPartV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ MatrixDiagPartOp<CPUDevice, type>); + TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG); #undef REGISTER_MATRIX_DIAG @@ -280,6 +304,28 @@ TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG); // Implementation of the functor specialization for CPU. namespace functor { + +void ReadAlignment(OpKernelConstruction* context, + bool* left_align_superdiagonal, + bool* left_align_subdiagonal) { + string align; + OP_REQUIRES_OK(context, context->GetAttr("align", &align)); + + *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT"; + *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT"; +} + +std::pair<int, int> ComputeDiagLenAndContentOffset( + int diag_index, int max_diag_len, int num_rows, int num_cols, + bool left_align_superdiagonal, bool left_align_subdiagonal) { + const bool left_align = (diag_index >= 0 && left_align_superdiagonal) || + (diag_index <= 0 && left_align_subdiagonal); + const int diag_len = std::min(num_rows + std::min(0, diag_index), + num_cols - std::max(0, diag_index)); + const int content_offset = (left_align) ? 0 : (max_diag_len - diag_len); + return {diag_len, content_offset}; +} + template <typename T> struct MatrixDiag<CPUDevice, T> { static void Compute(OpKernelContext* context, const CPUDevice& device, @@ -287,16 +333,21 @@ struct MatrixDiag<CPUDevice, T> { typename TTypes<T, 3>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len, const T padding_value) { + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { // 10 in cost_per_batch is from existing heuristic. // TODO(penporn): Tune for the best constant in cost_per_batch. const Eigen::Index num_batches = output.dimension(0); - const Eigen::Index cost_per_batch = - 10 * output.dimension(1) * output.dimension(2); + const Eigen::Index num_rows = output.dimension(1); + const Eigen::Index num_cols = output.dimension(2); + const Eigen::Index cost_per_batch = 10 * num_rows * num_cols; - auto compute_shard = [&output, &diag, &lower_diag_index, &upper_diag_index, - &max_diag_len, &padding_value](Eigen::Index begin, - Eigen::Index end) { + auto compute_shard = [&output, &num_rows, &num_cols, &diag, + &lower_diag_index, &upper_diag_index, &max_diag_len, + &padding_value, &left_align_superdiagonal, + &left_align_subdiagonal](Eigen::Index begin, + Eigen::Index end) { const int num_diags = upper_diag_index - lower_diag_index + 1; const int diag_elements_in_batch = num_diags * max_diag_len; Eigen::Index diag_batch_base_index = begin * diag_elements_in_batch; @@ -305,8 +356,12 @@ struct MatrixDiag<CPUDevice, T> { for (Eigen::Index j = 0; j < output.dimension(2); ++j) { const int diag_index = j - i; const int diag_index_in_input = upper_diag_index - diag_index; + int diag_len, content_offset; + std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset( + diag_index, max_diag_len, num_rows, num_cols, + left_align_superdiagonal, left_align_subdiagonal); const int index_in_the_diagonal = - j - std::max<Eigen::Index>(diag_index, 0); + j - std::max<Eigen::Index>(diag_index, 0) + content_offset; if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) { output(batch, i, j) = diag(diag_batch_base_index + @@ -334,7 +389,9 @@ struct MatrixDiagPart<CPUDevice, T> { typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len, const T padding_value) { + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { // 10 in cost_per_batch is from existing heuristic. // TODO(penporn): Tune for the best constant in cost_per_batch. const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1; @@ -346,24 +403,32 @@ struct MatrixDiagPart<CPUDevice, T> { auto compute_shard = [&output, &input, &num_rows, &num_cols, &upper_diag_index, &max_diag_len, &num_diags, - &output_elements_in_batch, &padding_value]( + &output_elements_in_batch, &padding_value, + &left_align_superdiagonal, &left_align_subdiagonal]( Eigen::Index begin, Eigen::Index end) { Eigen::Index output_base_index = begin * output_elements_in_batch; for (Eigen::Index batch = begin; batch < end; ++batch) { for (Eigen::Index m = 0; m < num_diags; ++m) { - const Eigen::Index d = upper_diag_index - m; - Eigen::Index n = 0; - // Make two separate cases to save some index calculations. - if (d >= 0) { - for (; n < std::min(num_rows, num_cols - d); ++n) { - output(output_base_index + n) = input(batch, n, n + d); - } - } else { - for (; n < std::min(num_rows + d, num_cols); ++n) { - output(output_base_index + n) = input(batch, n - d, n); - } + const Eigen::Index diag_index = upper_diag_index - m; + Eigen::Index y_offset = std::max<Eigen::Index>(0, -diag_index); + Eigen::Index x_offset = std::max<Eigen::Index>(0, diag_index); + int diag_len, content_offset; + std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset( + diag_index, max_diag_len, num_rows, num_cols, + left_align_superdiagonal, left_align_subdiagonal); + + // Fills the diagonal. + for (Eigen::Index n = 0; n < diag_len; ++n) { + output(output_base_index + content_offset + n) = + input(batch, n + y_offset, n + x_offset); } - for (; n < max_diag_len; ++n) { // Padding. + + // Padding. + const bool left_align = (content_offset == 0); + const Eigen::Index padding_start = (left_align) ? diag_len : 0; + const Eigen::Index padding_end = + (left_align) ? max_diag_len : content_offset; + for (Eigen::Index n = padding_start; n < padding_end; ++n) { output(output_base_index + n) = padding_value; } output_base_index += max_diag_len; @@ -391,7 +456,8 @@ namespace functor { typename TTypes<T, 3>::Tensor& output, \ const Eigen::Index lower_diag_index, \ const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \ - const T padding_value); \ + const T padding_value, const bool left_align_superdiagonal, \ + const bool left_align_subdiagonal); \ extern template struct MatrixDiag<GPUDevice, T>; \ template <> \ void MatrixDiagPart<GPUDevice, T>::Compute( \ @@ -399,7 +465,8 @@ namespace functor { typename TTypes<T, 3>::ConstTensor& input, \ typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, \ const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \ - const T padding_value); \ + const T padding_value, const bool left_align_superdiagonal, \ + const bool left_align_subdiagonal); \ extern template struct MatrixDiagPart<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); @@ -422,6 +489,14 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); .HostMemory("num_cols") \ .HostMemory("padding_value"), \ MatrixDiagOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("MatrixDiagV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("k") \ + .HostMemory("num_rows") \ + .HostMemory("num_cols") \ + .HostMemory("padding_value"), \ + MatrixDiagOp<GPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ MatrixDiagPartOp<GPUDevice, type>); \ @@ -430,6 +505,12 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); .TypeConstraint<type>("T") \ .HostMemory("k") \ .HostMemory("padding_value"), \ + MatrixDiagPartOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("k") \ + .HostMemory("padding_value"), \ MatrixDiagPartOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU); diff --git a/tensorflow/core/kernels/matrix_diag_op.h b/tensorflow/core/kernels/matrix_diag_op.h index 619fb5855eb..707fd9b6c14 100644 --- a/tensorflow/core/kernels/matrix_diag_op.h +++ b/tensorflow/core/kernels/matrix_diag_op.h @@ -19,6 +19,7 @@ limitations under the License. // Generator definition for MatrixDiagOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -26,26 +27,43 @@ limitations under the License. namespace tensorflow { namespace functor { +// Reads the diagonal packing alignment. +void ReadAlignment(OpKernelConstruction* context, + bool* left_align_superdiagonal, + bool* left_align_subdiagonal); + +// Calculates diagonal length and content offset (from aligning) of a diagonal. +// Returns a pair of integers {diag_len, content_offset}: +// - diag_len: The length of the diag_index-th diagonal. +// - content_offset: Each diagonal is stored as a row in the compact format. +// If the diagonal is shorter than max_diag_len, its content is aligned +// either to the left or right. content_offset is the index in the row +// where the first element of the diag-index-th diagonal is stored. It is +// always zero when the diagonal is left-aligned. +std::pair<int, int> ComputeDiagLenAndContentOffset( + int diag_index, int max_diag_len, int num_rows, int num_cols, + bool left_align_superdiagonal, bool left_align_subdiagonal); + template <typename Device, typename T> struct MatrixDiagPart { EIGEN_ALWAYS_INLINE static void Compute( OpKernelContext* context, const Device& device, typename TTypes<T, 3>::ConstTensor& input, - typename TTypes<T>::Tensor& output_original, const Eigen::Index d_lower, - const Eigen::Index d_upper, const Eigen::Index max_diag_len, - const T padding); + typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, + const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, + const T padding_value, const bool left_align_superdiagonal, + const bool left_align_subdiagonal); }; template <typename Device, typename T> struct MatrixDiag { - EIGEN_ALWAYS_INLINE static void Compute(OpKernelContext* context, - const Device& device, - typename TTypes<T>::ConstTensor& diag, - typename TTypes<T, 3>::Tensor& output, - const Eigen::Index d_lower, - const Eigen::Index d_upper, - const Eigen::Index max_diag_len, - const T padding); + EIGEN_ALWAYS_INLINE static void Compute( + OpKernelContext* context, const Device& device, + typename TTypes<T>::ConstTensor& diag, + typename TTypes<T, 3>::Tensor& output, + const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, const bool left_align_subdiagonal); }; } // namespace functor diff --git a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc index 9f6d4a0ea87..53cd2d2dc46 100644 --- a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc @@ -26,13 +26,28 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; +__device__ inline int ComputeContentOffset(const int diag_index, + const int max_diag_len, + const int num_rows, + const int num_cols, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { + const bool left_align = (diag_index >= 0 && left_align_superdiagonal) || + (diag_index <= 0 && left_align_subdiagonal); + if (left_align) return 0; + const int y_offset = min(0, diag_index); + const int x_offset = max(0, diag_index); + const int diag_len = min(num_rows + y_offset, num_cols - x_offset); + return max_diag_len - diag_len; +} + template <typename T> -__global__ void MatrixDiagKernel(const int num_threads, const int num_rows, - const int num_cols, const int num_diags, - const int max_diag_len, - const int lower_diag_index, - const int upper_diag_index, const T padding, - const T* diag_ptr, T* output_ptr) { +__global__ void MatrixDiagKernel( + const int num_threads, const int num_rows, const int num_cols, + const int num_diags, const int max_diag_len, const int lower_diag_index, + const int upper_diag_index, const T padding_value, + const bool left_align_superdiagonal, const bool left_align_subdiagonal, + const T* diag_ptr, T* output_ptr) { GPU_1D_KERNEL_LOOP(index, num_threads) { const int batch_and_row_index = index / num_cols; const int col = index - batch_and_row_index * num_cols; @@ -40,13 +55,16 @@ __global__ void MatrixDiagKernel(const int num_threads, const int num_rows, const int row = batch_and_row_index - batch * num_rows; const int diag_index = col - row; const int diag_index_in_input = upper_diag_index - diag_index; - const int index_in_the_diagonal = col - max(diag_index, 0); + const int content_offset = + ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols, + left_align_superdiagonal, left_align_subdiagonal); + const int index_in_the_diagonal = col - max(diag_index, 0) + content_offset; if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) { output_ptr[index] = diag_ptr[batch * num_diags * max_diag_len + diag_index_in_input * max_diag_len + index_in_the_diagonal]; } else { - output_ptr[index] = padding; + output_ptr[index] = padding_value; } } } @@ -58,7 +76,9 @@ struct MatrixDiag<GPUDevice, T> { typename TTypes<T, 3>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len, const T padding) { + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { const int batch_size = output.dimension(0); const int num_rows = output.dimension(1); const int num_cols = output.dimension(2); @@ -72,19 +92,19 @@ struct MatrixDiag<GPUDevice, T> { TF_CHECK_OK(GpuLaunchKernel( MatrixDiagKernel<T>, config.block_count, config.thread_per_block, 0, device.stream(), config.virtual_thread_count, num_rows, num_cols, - num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding, + num_diags, max_diag_len, lower_diag_index, upper_diag_index, + padding_value, left_align_superdiagonal, left_align_subdiagonal, diag.data(), output.data())); } }; template <typename T> -__global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows, - const int num_cols, const int num_diags, - const int max_diag_len, - const int lower_diag_index, - const int upper_diag_index, - const T padding, const T* input_ptr, - T* output_ptr) { +__global__ void MatrixDiagPartKernel( + const int num_threads, const int num_rows, const int num_cols, + const int num_diags, const int max_diag_len, const int lower_diag_index, + const int upper_diag_index, const T padding_value, + const bool left_align_superdiagonal, const bool left_align_subdiagonal, + const T* input_ptr, T* output_ptr) { GPU_1D_KERNEL_LOOP(index, num_threads) { const int batch_and_mapped_diag_index = index / max_diag_len; const int index_in_the_diagonal = @@ -93,13 +113,19 @@ __global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows, const int mapped_diag_index = batch_and_mapped_diag_index - batch * num_diags; const int diag_index = upper_diag_index - mapped_diag_index; - const int y_index = index_in_the_diagonal + max(0, -diag_index); - const int x_index = index_in_the_diagonal + max(0, diag_index); - if (y_index < num_rows && x_index < num_cols) { + const int content_offset = + ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols, + left_align_superdiagonal, left_align_subdiagonal); + const int y_offset = max(0, -diag_index); + const int x_offset = max(0, diag_index); + const int y_index = index_in_the_diagonal + y_offset - content_offset; + const int x_index = index_in_the_diagonal + x_offset - content_offset; + if (0 <= y_index && y_index < num_rows && 0 <= x_index && + x_index < num_cols) { output_ptr[index] = input_ptr[batch * num_rows * num_cols + y_index * num_cols + x_index]; } else { - output_ptr[index] = padding; + output_ptr[index] = padding_value; } } } @@ -111,7 +137,9 @@ struct MatrixDiagPart<GPUDevice, T> { typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len, const T padding) { + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { const int batch_size = input.dimension(0); const int num_rows = input.dimension(1); const int num_cols = input.dimension(2); @@ -125,7 +153,8 @@ struct MatrixDiagPart<GPUDevice, T> { TF_CHECK_OK(GpuLaunchKernel( MatrixDiagPartKernel<T>, config.block_count, config.thread_per_block, 0, device.stream(), config.virtual_thread_count, num_rows, num_cols, - num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding, + num_diags, max_diag_len, lower_diag_index, upper_diag_index, + padding_value, left_align_superdiagonal, left_align_subdiagonal, input.data(), output.data())); } }; diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 6507fca3403..2701ff788f7 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/matrix_diag_op.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -42,7 +43,13 @@ typedef Eigen::GpuDevice GPUDevice; template <typename Device, typename T> class MatrixSetDiagOp : public OpKernel { public: - explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {} + explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) { + // MatrixSetDiagV3-specific. + if (context->HasAttr("align")) { + functor::ReadAlignment(context, &left_align_superdiagonal_, + &left_align_subdiagonal_); + } + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); @@ -55,7 +62,7 @@ class MatrixSetDiagOp : public OpKernel { int32 upper_diag_index = 0; // MatrixSetDiagV2-specific. - if (context->num_inputs() > 2) { + if (context->num_inputs() > kNumV1Inputs) { auto& diag_index = context->input(2); OP_REQUIRES(context, TensorShapeUtils::IsScalar(diag_index.shape()) || @@ -155,10 +162,14 @@ class MatrixSetDiagOp : public OpKernel { auto output_reshaped = output->flat_inner_dims<T, 3>(); functor::MatrixSetDiag<Device, T>::Compute( context, context->eigen_device<Device>(), input_reshaped, diag_reshaped, - output_reshaped, lower_diag_index, upper_diag_index, max_diag_len); + output_reshaped, lower_diag_index, upper_diag_index, max_diag_len, + left_align_superdiagonal_, left_align_subdiagonal_); } private: + bool left_align_superdiagonal_ = true; + bool left_align_subdiagonal_ = true; + static constexpr int kNumV1Inputs = 2; TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); }; @@ -168,7 +179,11 @@ class MatrixSetDiagOp : public OpKernel { MatrixSetDiagOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("MatrixSetDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + MatrixSetDiagOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatrixSetDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ MatrixSetDiagOp<CPUDevice, type>); + TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG); #undef REGISTER_MATRIX_SET_DIAG @@ -192,29 +207,38 @@ struct MatrixSetDiag<CPUDevice, T> { typename TTypes<T, 3>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len) { + const Eigen::Index max_diag_len, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { if (input.data() != output.data()) { output.device(device) = input; } const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1; auto compute_shard = [&output, &diag, &upper_diag_index, &max_diag_len, - &num_diags](Eigen::Index begin, Eigen::Index end) { + &num_diags, &left_align_superdiagonal, + &left_align_subdiagonal](Eigen::Index begin, + Eigen::Index end) { const Eigen::Index num_rows = output.dimension(1); const Eigen::Index num_cols = output.dimension(2); Eigen::Index diag_base_index = begin * num_diags * max_diag_len; for (Eigen::Index batch = begin; batch < end; ++batch) { for (Eigen::Index m = 0; m < num_diags; ++m) { - const Eigen::Index d = upper_diag_index - m; + const Eigen::Index diag_index = upper_diag_index - m; + int diag_len, content_offset; + std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset( + diag_index, max_diag_len, num_rows, num_cols, + left_align_superdiagonal, left_align_subdiagonal); + // Make two separate cases to save some index calculations. - if (d >= 0) { - for (Eigen::Index n = 0; n < std::min(num_rows, num_cols - d); - ++n) { - output(batch, n, n + d) = diag(diag_base_index + n); + if (diag_index >= 0) { + for (Eigen::Index n = 0; n < diag_len; ++n) { + output(batch, n, n + diag_index) = + diag(diag_base_index + n + content_offset); } } else { - for (Eigen::Index n = 0; n < std::min(num_rows + d, num_cols); - ++n) { - output(batch, n - d, n) = diag(diag_base_index + n); + for (Eigen::Index n = 0; n < diag_len; ++n) { + output(batch, n - diag_index, n) = + diag(diag_base_index + n + content_offset); } } diag_base_index += max_diag_len; @@ -236,15 +260,16 @@ struct MatrixSetDiag<CPUDevice, T> { // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void MatrixSetDiag<GPUDevice, T>::Compute( \ - OpKernelContext* context, const GPUDevice& device, \ - typename TTypes<T, 3>::ConstTensor& input, \ - typename TTypes<T>::ConstTensor& diag, \ - typename TTypes<T, 3>::Tensor& output, \ - const Eigen::Index lower_diag_index, \ - const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void MatrixSetDiag<GPUDevice, T>::Compute( \ + OpKernelContext* context, const GPUDevice& device, \ + typename TTypes<T, 3>::ConstTensor& input, \ + typename TTypes<T>::ConstTensor& diag, \ + typename TTypes<T, 3>::Tensor& output, \ + const Eigen::Index lower_diag_index, \ + const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \ + const bool left_align_superdiagonal, const bool left_align_subdiagonal); \ extern template struct MatrixSetDiag<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); @@ -263,7 +288,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); .Device(DEVICE_GPU) \ .TypeConstraint<type>("T") \ .HostMemory("k"), \ + MatrixSetDiagOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("MatrixSetDiagV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("k"), \ MatrixSetDiagOp<GPUDevice, type>); + TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU); TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU); TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU); diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h index db30aaee669..04877cd34ca 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.h +++ b/tensorflow/core/kernels/matrix_set_diag_op.h @@ -29,8 +29,11 @@ struct MatrixSetDiag { typename TTypes<T, 3>::ConstTensor& input, typename TTypes<T>::ConstTensor& diag, typename TTypes<T, 3>::Tensor& output, - const Eigen::Index d_lower, const Eigen::Index d_upper, - const Eigen::Index max_diag_len); + const Eigen::Index lower_diag_index, + const Eigen::Index upper_diag_index, + const Eigen::Index max_diag_len, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal); }; } // namespace functor diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc index 55489073f93..4f742b90bff 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc @@ -26,26 +26,42 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; +// TODO(penporn): Merge this file with matrix_diag_op_gpu.cu.cc. +__device__ inline int ComputeContentOffset(const int diag_index, + const int max_diag_len, + const int num_rows, + const int num_cols, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { + const bool left_align = (diag_index >= 0 && left_align_superdiagonal) || + (diag_index <= 0 && left_align_subdiagonal); + if (left_align) return 0; + const int y_offset = min(0, diag_index); + const int x_offset = max(0, diag_index); + const int diag_len = min(num_rows + y_offset, num_cols - x_offset); + return max_diag_len - diag_len; +} + template <typename Scalar> -__global__ void MatrixSetDiagKernel(const int num_threads, const int m, - const int n, const int num_diags, - const int max_diag_len, - const int upper_diag_index, - const Scalar* __restrict__ diag_ptr, - Scalar* __restrict__ output_ptr) { +__global__ void MatrixSetDiagKernel( + const int num_threads, const int m, const int n, const int num_diags, + const int max_diag_len, const int upper_diag_index, + const bool left_align_superdiagonal, const bool left_align_subdiagonal, + const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) { GPU_1D_KERNEL_LOOP(index, num_threads) { const int batch_and_diag_index = index / max_diag_len; - const int index_in_the_diagonal = - index - batch_and_diag_index * max_diag_len; + int index_in_the_diagonal = index - batch_and_diag_index * max_diag_len; const int batch = batch_and_diag_index / num_diags; const int diag_index_in_input = batch_and_diag_index - batch * num_diags; const int diag_index = upper_diag_index - diag_index_in_input; - const int y_index = index_in_the_diagonal + max(0, -diag_index); + index_in_the_diagonal -= + ComputeContentOffset(diag_index, max_diag_len, m, n, + left_align_superdiagonal, left_align_subdiagonal); + const int y_index = index_in_the_diagonal - min(0, diag_index); const int x_index = index_in_the_diagonal + max(0, diag_index); // Upper-bound checks for diagonals shorter than max_diag_len. - // y_index and x_index are nonnegative by construction. - if (y_index < m && x_index < n) { + if (index_in_the_diagonal >= 0 && y_index < m && x_index < n) { const int out_index = batch * m * n + y_index * n + x_index; output_ptr[out_index] = diag_ptr[index]; } @@ -56,17 +72,21 @@ template <typename Scalar> __global__ void MatrixCopyInputAndSetDiagKernel( const int num_threads, const int m, const int n, const int num_diags, const int max_diag_len, const int lower_diag_index, - const int upper_diag_index, const Scalar* __restrict__ input_ptr, + const int upper_diag_index, const bool left_align_superdiagonal, + const bool left_align_subdiagonal, const Scalar* __restrict__ input_ptr, const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) { GPU_1D_KERNEL_LOOP(index, num_threads) { const int batch_and_row_index = index / n; const int col = index - batch_and_row_index * n; const int batch = batch_and_row_index / m; const int row = batch_and_row_index - batch * m; - const int d = col - row; - const int diag_index_in_input = upper_diag_index - d; - const int index_in_the_diagonal = col - max(d, 0); - if (lower_diag_index <= d && d <= upper_diag_index) { + const int diag_index = col - row; + const int diag_index_in_input = upper_diag_index - diag_index; + const int index_in_the_diagonal = + col - max(0, diag_index) + + ComputeContentOffset(diag_index, max_diag_len, m, n, + left_align_superdiagonal, left_align_subdiagonal); + if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) { output_ptr[index] = diag_ptr[batch * num_diags * max_diag_len + diag_index_in_input * max_diag_len + index_in_the_diagonal]; @@ -84,7 +104,9 @@ struct MatrixSetDiag<GPUDevice, Scalar> { typename TTypes<Scalar, 3>::Tensor& output, const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, - const Eigen::Index max_diag_len) { + const Eigen::Index max_diag_len, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal) { const int batch_size = input.dimension(0); const int m = input.dimension(1); const int n = input.dimension(2); @@ -98,15 +120,16 @@ struct MatrixSetDiag<GPUDevice, Scalar> { MatrixSetDiagKernel<Scalar>, config.block_count, config.thread_per_block, 0, device.stream(), config.virtual_thread_count, m, n, num_diags, max_diag_len, - upper_diag_index, diag.data(), output.data())); + upper_diag_index, left_align_superdiagonal, left_align_subdiagonal, + diag.data(), output.data())); } else { GpuLaunchConfig config = GetGpuLaunchConfig(batch_size * m * n, device); TF_CHECK_OK(GpuLaunchKernel( MatrixCopyInputAndSetDiagKernel<Scalar>, config.block_count, config.thread_per_block, 0, device.stream(), config.virtual_thread_count, m, n, num_diags, max_diag_len, - lower_diag_index, upper_diag_index, input.data(), diag.data(), - output.data())); + lower_diag_index, upper_diag_index, left_align_superdiagonal, + left_align_subdiagonal, input.data(), diag.data(), output.data())); } } }; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index c48fecc6147..dbe357dbfe2 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -279,29 +279,6 @@ Status SetOutputShapeForReshape(InferenceContext* c) { return Status::OK(); } -Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, - int32* lower_diag_index, int32* upper_diag_index) { - // This function assumes that the shape of diag_index_tensor is fully defined. - if (diag_index_tensor->dims() == 0) { - *lower_diag_index = diag_index_tensor->scalar<int32>()(); - *upper_diag_index = *lower_diag_index; - } else { - int32 num_elements = diag_index_tensor->dim_size(0); - if (num_elements == 1) { - *lower_diag_index = diag_index_tensor->vec<int32>()(0); - *upper_diag_index = *lower_diag_index; - } else if (num_elements == 2) { - *lower_diag_index = diag_index_tensor->vec<int32>()(0); - *upper_diag_index = diag_index_tensor->vec<int32>()(1); - } else { - return errors::InvalidArgument( - "diag_index must be a vector with one or two elements. It has ", - num_elements, " elements."); - } - } - return Status::OK(); -} - } // namespace REGISTER_OP("ParallelConcat") @@ -861,107 +838,20 @@ REGISTER_OP("MatrixDiagV2") .Input("padding_value: T") .Output("output: T") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { - // Checks input ranks. - ShapeHandle input_shape, diag_index_shape, unused_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); + .SetShapeFn(shape_inference::MatrixDiagV2Shape); - // Reads the diagonal indices. - const Tensor* diag_index_tensor = c->input_tensor(1); - if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || - diag_index_tensor == nullptr) { - c->set_output(0, c->UnknownShape()); - return Status::OK(); - } - int32 lower_diag_index = 0; - int32 upper_diag_index = 0; - TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, - &upper_diag_index)); - if (lower_diag_index > upper_diag_index) { - return errors::InvalidArgument( - "lower_diag_index is greater than upper_diag_index"); - } - - // Checks if the number of diagonals provided matches what we imply from - // lower_diag_index and upper_diag_index. - const int32 input_rank = c->Rank(input_shape); - if (lower_diag_index < upper_diag_index) { - const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2)); - const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1)); - - if (num_diags != (upper_diag_index - lower_diag_index + 1)) { - return errors::InvalidArgument( - "The number of rows of `diagonal` doesn't match the number of " - "diagonals implied from `d_lower` and `d_upper`.\n", - "num_diags = ", num_diags, ", d_lower = ", lower_diag_index, - ", d_upper = ", upper_diag_index, " ", input_rank, " ", - other_dim); - } - } - - // Reads num_rows and num_cols. - const Tensor* num_rows_tensor = c->input_tensor(2); - const Tensor* num_cols_tensor = c->input_tensor(3); - int64 num_rows = -1; - int64 num_cols = -1; - if (num_rows_tensor != nullptr) { - TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows)); - } - if (num_cols_tensor != nullptr) { - TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols)); - } - - // Infers the missing num_rows or num_cols: If both are missing, assume - // output is square. Otherwise, use the smallest possible value. Also - // validates the provided values. - const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1)); - const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0); - const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0); - if (num_rows == -1 && num_cols == -1) { // Special case. - num_rows = std::max(min_num_rows, min_num_cols); - num_cols = num_rows; - } - if (num_rows == -1) { - num_rows = min_num_rows; - } else if (num_rows < min_num_rows) { - return errors::InvalidArgument("num_rows is too small"); - } - if (num_cols == -1) { - num_cols = min_num_cols; - } else if (num_cols < min_num_cols) { - return errors::InvalidArgument("num_cols is too small."); - } - // At least one of them must match the minimum length. - if (num_rows != min_num_rows && num_cols != min_num_cols) { - return errors::InvalidArgument( - "num_rows and num_cols are not consistent with lower_diag_index, " - "upper_diag_index, and the length of the given diagonals.\n", - "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows, - ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols); - } - - // Sets output shape. - ShapeHandle output_shape; - const DimensionHandle output_row_dim = c->MakeDim(num_rows); - const DimensionHandle output_col_dim = c->MakeDim(num_cols); - if (lower_diag_index == upper_diag_index) { - TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1, - output_row_dim, &output_shape)); - TF_RETURN_IF_ERROR(c->Concatenate( - output_shape, c->Vector(output_col_dim), &output_shape)); - } else { - TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2, - output_row_dim, &output_shape)); - TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1, - output_col_dim, &output_shape)); - } - c->set_output(0, output_shape); - return Status::OK(); - }); +REGISTER_OP("MatrixDiagV3") + .Input("diagonal: T") + .Input("k: int32") + .Input("num_rows: int32") + .Input("num_cols: int32") + .Input("padding_value: T") + .Output("output: T") + .Attr("T: type") + .Attr( + "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " + "'RIGHT_LEFT'") + .SetShapeFn(shape_inference::MatrixDiagV2Shape); // -------------------------------------------------------------------------- REGISTER_OP("MatrixSetDiag") @@ -995,84 +885,25 @@ REGISTER_OP("MatrixSetDiag") c->set_output(0, output); return Status::OK(); }); + REGISTER_OP("MatrixSetDiagV2") .Input("input: T") .Input("diagonal: T") .Input("k: int32") .Output("output: T") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input_shape, diag_shape, diag_index_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape)); + .SetShapeFn(shape_inference::MatrixSetDiagV2Shape); - int32 lower_diag_index = 0; - int32 upper_diag_index = 0; - bool diag_index_known = false; - const Tensor* diag_index_tensor = c->input_tensor(2); - if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) { - diag_index_known = true; - TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, - &lower_diag_index, &upper_diag_index)); - if (lower_diag_index > upper_diag_index) { - return errors::InvalidArgument( - "lower_diag_index is greater than upper_diag_index"); - } - } - - // Do more checks when input rank is known. - if (c->RankKnown(input_shape)) { - int32 input_rank = c->Rank(input_shape); - - // If diag_index is set, we know the exact rank of diagonal. - if (diag_index_known) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), - (lower_diag_index == upper_diag_index) - ? input_rank - 1 - : input_rank, - &diag_shape)); - } else { - TF_RETURN_IF_ERROR( - c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape)); - TF_RETURN_IF_ERROR( - c->WithRankAtMost(c->input(1), input_rank, &diag_shape)); - } - - // Validates lower_diag_index and upper_diag_index. - const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); - const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); - if (num_rows != InferenceContext::kUnknownDim && - num_cols != InferenceContext::kUnknownDim) { - if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. - (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { - return errors::InvalidArgument("lower_diag_index is out of bound."); - } - if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. - (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { - return errors::InvalidArgument("upper_diag_index is out of bound."); - } - } - } - - ShapeHandle output_shape = input_shape; - if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) { - // Try to infer parts of shape from diag. - ShapeHandle diag_prefix; - TF_RETURN_IF_ERROR(c->Subshape( - diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2, - &diag_prefix)); - - // The inner matrices can be rectangular, so we can't pinpoint their - // exact height and width by just lower_diag_index, upper_diag_index, - // and the longest length of given diagonals. - TF_RETURN_IF_ERROR( - c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape)); - TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); - } - c->set_output(0, output_shape); - return Status::OK(); - }); +REGISTER_OP("MatrixSetDiagV3") + .Input("input: T") + .Input("diagonal: T") + .Input("k: int32") + .Output("output: T") + .Attr("T: type") + .Attr( + "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " + "'RIGHT_LEFT'") + .SetShapeFn(shape_inference::MatrixSetDiagV2Shape); // -------------------------------------------------------------------------- REGISTER_OP("MatrixDiagPart") @@ -1105,58 +936,18 @@ REGISTER_OP("MatrixDiagPartV2") .Input("padding_value: T") .Output("diagonal: T") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input_shape, diag_index_shape, unused_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); + .SetShapeFn(shape_inference::MatrixDiagPartV2Shape); - const Tensor* diag_index_tensor = c->input_tensor(1); - if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || - diag_index_tensor == nullptr) { - c->set_output(0, c->UnknownShape()); - return Status::OK(); - } - int32 lower_diag_index = 0; - int32 upper_diag_index = 0; - TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index, - &upper_diag_index)); - if (lower_diag_index > upper_diag_index) { - return errors::InvalidArgument( - "lower_diag_index is greater than upper_diag_index"); - } - - // Validates lower_diag_index and upper_diag_index. - const int32 input_rank = c->Rank(input_shape); - const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2)); - const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1)); - if (num_rows != InferenceContext::kUnknownDim && - num_cols != InferenceContext::kUnknownDim) { - if (lower_diag_index != 0 && // For when num_rows or num_cols == 0. - (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) { - return errors::InvalidArgument("lower_diag_index is out of bound."); - } - if (upper_diag_index != 0 && // For when num_rows or num_cols == 0. - (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) { - return errors::InvalidArgument("upper_diag_index is out of bound."); - } - } - - const int32 max_diag_len = - std::min(num_rows + std::min(upper_diag_index, 0), - num_cols - std::max(lower_diag_index, 0)); - std::vector<DimensionHandle> dims; - dims.reserve(input_rank - 2); - for (int i = 0; i < input_rank - 2; ++i) { - dims.push_back(c->Dim(input_shape, i)); - } - if (lower_diag_index < upper_diag_index) { - dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1)); - } - dims.push_back(c->MakeDim(max_diag_len)); - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - }); +REGISTER_OP("MatrixDiagPartV3") + .Input("input: T") + .Input("k: int32") + .Input("padding_value: T") + .Output("diagonal: T") + .Attr("T: type") + .Attr( + "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = " + "'RIGHT_LEFT'") + .SetShapeFn(shape_inference::MatrixDiagPartV2Shape); // -------------------------------------------------------------------------- REGISTER_OP("MatrixBandPart") diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 0a579b48d47..82e362f514e 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -180,8 +180,8 @@ cc_library( name = "graph_transformations", srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", - "graph_transformations/convert_matrix_diag_v2_to_v1.cc", - "graph_transformations/convert_matrix_set_diag_v2_to_v1.cc", + "graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc", + "graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc", "graph_transformations/convert_pure_conv_to_depthwise.cc", "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_squeeze_to_reshape.cc", diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc similarity index 88% rename from tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc rename to tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc index 5037b05f6fd..c0dbea05a26 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc @@ -20,13 +20,16 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertMatrixDiagV2ToV1::Run(Model* model, - std::size_t op_index, - bool* modified) { +// V3 is only different from V2 because it has an extra attribute (align). +// This attribute doesn't affect V1 so we don't have to keep track of it here. +::tensorflow::Status ConvertMatrixDiagV2OrV3ToV1::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (op->type != OperatorType::kMatrixDiagV2) { + if (op->type != OperatorType::kMatrixDiagV2 && + op->type != OperatorType::kMatrixDiagV3) { return ::tensorflow::Status::OK(); } diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc similarity index 83% rename from tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc rename to tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc index 61288f626b6..a3801398d71 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc @@ -26,13 +26,16 @@ limitations under the License. namespace toco { -::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model, - std::size_t op_index, - bool* modified) { +// V3 is only different from V2 because it has an extra attribute (align). +// This attribute doesn't affect V1 so we don't have to keep track of it here. +::tensorflow::Status ConvertMatrixSetDiagV2OrV3ToV1::Run(Model* model, + std::size_t op_index, + bool* modified) { *modified = false; auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (op->type != OperatorType::kMatrixSetDiagV2) { + if (op->type != OperatorType::kMatrixSetDiagV2 && + op->type != OperatorType::kMatrixSetDiagV3) { return ::tensorflow::Status::OK(); } diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 4a1334f41f1..0b765b1f507 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -123,8 +123,8 @@ inline void RunGraphTransformations( // List of all graph transformations DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) -DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1) -DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1) +DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2OrV3ToV1) +DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2OrV3ToV1) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index d7a56e6d4b2..4b1a6fab607 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -2434,6 +2434,14 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) { // MatrixDiagV2 operators are converted to MatrixDiag, after which their // shapes are propagated. break; + case OperatorType::kMatrixDiagV3: + // MatrixDiagV3 operators are converted to MatrixDiag, after which their + // shapes are propagated. + break; + case OperatorType::kMatrixSetDiagV3: + // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which + // their shapes are propagated. + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index fa8cc9799e1..dd7a9e3d835 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2560,8 +2560,16 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"MatMul", ConvertMatMulOperator}, {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>}, {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>}, + // `MatrixDiagV3` has an `align` attribute. However, Toco only converts + // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols, + // padding_value` inputs. In this case, `align` can be ignored. + {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>}, {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>}, {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>}, + // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts + // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this + // case, `align` can be ignored. + {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>}, {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>}, diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index ef717ee4e18..d8f4b73115c 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -175,7 +175,9 @@ enum class OperatorType : uint8 { kMatrixDiag, kMatrixSetDiag, kMatrixDiagV2, - kMatrixSetDiagV2 + kMatrixSetDiagV2, + kMatrixDiagV3, + kMatrixSetDiagV3 }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -2122,12 +2124,24 @@ struct MatrixDiagOperator : Operator { // Matrix Diag Operator V2: // Construct a batched diagonal tensor with given batched diagonal values. -// Not fully supported, constains 4 extra inputs compared to MatrixDiag, support -// default parameters settings which performs the same as MatrixDiag +// Not fully supported, contains 4 extra inputs compared to MatrixDiag. Behave +// like MatrixDiag when default parameters are used. struct MatrixDiagV2Operator : Operator { MatrixDiagV2Operator() : Operator(OperatorType::kMatrixDiagV2) {} }; +// Matrix Diag Operator V3: +// Construct a batched diagonal tensor with given batched diagonal values. +// Not fully supported, contains 5 extra inputs compared to MatrixDiag. Behave +// like MatrixDiag when default parameters are used. +// V3 is only different from V2 because it has an extra attribute (align) which +// controls the alignment of diagonals in the band matrix (compact) format. +// The alignment in V2 contradicts with the default alignment in V3 so V2 is +// skipped. (It has never been, and should never be, exposed in the public API.) +struct MatrixDiagV3Operator : Operator { + MatrixDiagV3Operator() : Operator(OperatorType::kMatrixDiagV3) {} +}; + // Matrix Set Diag Operator: // Construct a batched diagonal tensor with given input and diagonal values. // Input is a rank (k+1) tensor of values. @@ -2140,12 +2154,24 @@ struct MatrixSetDiagOperator : Operator { // Matrix Set Diag Operator V2: // Construct a batched diagonal tensor with given input and diagonal values. -// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag, -// support default parameters settings which performs the same as MatrixSetDiag +// Not fully supported, contains 1 extra inputs compared to MatrixSetDiag. +// Behave like MatrixSetDiag when default parameters are used. struct MatrixSetDiagV2Operator : Operator { MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {} }; +// Matrix Set Diag Operator V3: +// Construct a batched diagonal tensor with given input and diagonal values. +// Not fully supported, contains 2 extra inputs compared to MatrixSetDiag. +// Behave like MatrixSetDiag when default parameters are used. +// V3 is only different from V2 because it has an extra attribute (align) which +// controls the alignment of diagonals in the band matrix (compact) format. +// The alignment in V2 contradicts with the default alignment in V3 so V2 is +// skipped. (It has never been, and should never be, exposed in the public API.) +struct MatrixSetDiagV3Operator : Operator { + MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index 5a448131820..f96c6b83025 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -54,8 +54,8 @@ void MakeGeneralGraphTransformationsSet( GraphTransformationsSet* transformations) { CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); - transformations->Add(new ConvertMatrixDiagV2ToV1); - transformations->Add(new ConvertMatrixSetDiagV2ToV1); + transformations->Add(new ConvertMatrixDiagV2OrV3ToV1); + transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1); transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialPackToReshape); diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 418361cd0e7..ebcb17599b1 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -449,6 +449,8 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag) HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2) HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2) + HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3) + HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py index afa0850b853..df8409f09e6 100644 --- a/tensorflow/python/kernel_tests/diag_op_test.py +++ b/tensorflow/python/kernel_tests/diag_op_test.py @@ -33,8 +33,80 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +# LINT.IfChange +matrix_diag_v3_forward_compat_date = (2019, 12, 6) +# LINT.ThenChange( +# //tensorflow/compiler/tests/matrix_diag_ops_test.py, +# //tensorflow/python/ops/array_ops.py, +# //tensorflow/python/ops/parallel_for/array_test.py +# ) + + +default_v2_alignment = "LEFT_LEFT" +alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"] + + +def zip_to_first_list_length(a, b): + if len(b) > len(a): + return zip(a, b[:len(a)]) + return zip(a, b + [None] * (len(a) - len(b))) + + +def repack_diagonals(packed_diagonals, + diag_index, + num_rows, + num_cols, + align=None): + # The original test cases are LEFT_LEFT aligned. + if align == default_v2_alignment or align is None: + return packed_diagonals + + align = align.split("_") + d_lower, d_upper = diag_index + batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1) + max_diag_len = packed_diagonals.shape[-1] + index = (slice(None),) * batch_dims + repacked_diagonals = np.zeros_like(packed_diagonals) + + # Aligns each diagonal row-by-row. + for diag_index in range(d_lower, d_upper + 1): + diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index)) + row_index = d_upper - diag_index + padding_len = max_diag_len - diag_len + left_align = (diag_index >= 0 and + align[0] == "LEFT") or (diag_index <= 0 and + align[1] == "LEFT") + # Prepares index tuples. + extra_dim = tuple() if d_lower == d_upper else (row_index,) + packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),) + repacked_last_dim = (slice(None),) if left_align else (slice( + padding_len, max_diag_len, 1),) + packed_index = index + extra_dim + packed_last_dim + repacked_index = index + extra_dim + repacked_last_dim + + # Repacks the diagonal. + repacked_diagonals[repacked_index] = packed_diagonals[packed_index] + return repacked_diagonals + + +def repack_diagonals_in_tests(tests, align=None): + # The original test cases are LEFT_LEFT aligned. + if align == default_v2_alignment or align is None: + return tests + + new_tests = dict() + # Loops through each case. + for diag_index, (packed_diagonals, padded_diagonals) in tests.items(): + num_rows, num_cols = padded_diagonals.shape[-2:] + repacked_diagonals = repack_diagonals( + packed_diagonals, diag_index, num_rows, num_cols, align=align) + new_tests[diag_index] = (repacked_diagonals, padded_diagonals) + + return new_tests + + # Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2. -def square_cases(): +def square_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 1], @@ -47,71 +119,72 @@ def square_cases(): [6, 7, 8, 9, 1], [2, 3, 4, 5, 6]]]) tests = dict() - tests[(-1, -1)] = (np.array([[6, 4, 1, 7], - [5, 2, 8, 5]]), - np.array([[[0, 0, 0, 0, 0], - [6, 0, 0, 0, 0], - [0, 4, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 7, 0]], - [[0, 0, 0, 0, 0], - [5, 0, 0, 0, 0], - [0, 2, 0, 0, 0], - [0, 0, 8, 0, 0], - [0, 0, 0, 5, 0]]])) - tests[(-4, -3)] = (np.array([[[8, 5], - [4, 0]], - [[6, 3], - [2, 0]]]), - np.array([[[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [8, 0, 0, 0, 0], - [4, 5, 0, 0, 0]], - [[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [6, 0, 0, 0, 0], - [2, 3, 0, 0, 0]]])) - tests[(-2, 1)] = (np.array([[[2, 8, 6, 3, 0], - [1, 7, 5, 2, 8], - [6, 4, 1, 7, 0], - [3, 9, 6, 0, 0]], - [[1, 7, 4, 1, 0], - [9, 6, 3, 9, 6], - [5, 2, 8, 5, 0], - [1, 7, 4, 0, 0]]]), - np.array([[[1, 2, 0, 0, 0], - [6, 7, 8, 0, 0], - [3, 4, 5, 6, 0], - [0, 9, 1, 2, 3], - [0, 0, 6, 7, 8]], - [[9, 1, 0, 0, 0], - [5, 6, 7, 0, 0], - [1, 2, 3, 4, 0], - [0, 7, 8, 9, 1], - [0, 0, 4, 5, 6]]])) - tests[(2, 4)] = (np.array([[[5, 0, 0], - [4, 1, 0], - [3, 9, 7]], - [[4, 0, 0], - [3, 9, 0], - [2, 8, 5]]]), - np.array([[[0, 0, 3, 4, 5], - [0, 0, 0, 9, 1], - [0, 0, 0, 0, 7], + # tests[d_lower, d_upper] = (packed_diagonals, padded_diagonals) + tests[-1, -1] = (np.array([[6, 4, 1, 7], + [5, 2, 8, 5]]), + np.array([[[0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 7, 0]], + [[0, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + [0, 0, 8, 0, 0], + [0, 0, 0, 5, 0]]])) + tests[-4, -3] = (np.array([[[8, 5], + [4, 0]], + [[6, 3], + [2, 0]]]), + np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - [[0, 0, 2, 3, 4], - [0, 0, 0, 8, 9], - [0, 0, 0, 0, 5], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]])) + [8, 0, 0, 0, 0], + [4, 5, 0, 0, 0]], + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [2, 3, 0, 0, 0]]])) + tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0], + [1, 7, 5, 2, 8], + [6, 4, 1, 7, 0], + [3, 9, 6, 0, 0]], + [[1, 7, 4, 1, 0], + [9, 6, 3, 9, 6], + [5, 2, 8, 5, 0], + [1, 7, 4, 0, 0]]]), + np.array([[[1, 2, 0, 0, 0], + [6, 7, 8, 0, 0], + [3, 4, 5, 6, 0], + [0, 9, 1, 2, 3], + [0, 0, 6, 7, 8]], + [[9, 1, 0, 0, 0], + [5, 6, 7, 0, 0], + [1, 2, 3, 4, 0], + [0, 7, 8, 9, 1], + [0, 0, 4, 5, 6]]])) + tests[2, 4] = (np.array([[[5, 0, 0], + [4, 1, 0], + [3, 9, 7]], + [[4, 0, 0], + [3, 9, 0], + [2, 8, 5]]]), + np.array([[[0, 0, 3, 4, 5], + [0, 0, 0, 9, 1], + [0, 0, 0, 0, 7], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 2, 3, 4], + [0, 0, 0, 8, 9], + [0, 0, 0, 0, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) -def tall_cases(): +def tall_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3], [4, 5, 6], @@ -124,81 +197,81 @@ def tall_cases(): [7, 8, 9], [9, 8, 7]]]) tests = dict() - tests[(0, 0)] = (np.array([[1, 5, 9], - [3, 2, 6]]), - np.array([[[1, 0, 0], - [0, 5, 0], - [0, 0, 9], - [0, 0, 0]], - [[3, 0, 0], - [0, 2, 0], - [0, 0, 6], - [0, 0, 0]]])) - tests[(-4, -3)] = (np.array([[[9, 5], - [6, 0]], - [[7, 8], - [9, 0]]]), - np.array([[[0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [9, 0, 0], - [6, 5, 0]], - [[0, 0, 0], - [0, 0, 0], - [0, 0, 0], - [7, 0, 0], - [9, 8, 0]]])) - tests[(-2, -1)] = (np.array([[[4, 8, 7], - [7, 8, 4]], - [[1, 5, 9], - [4, 8, 7]]]), - np.array([[[0, 0, 0], - [4, 0, 0], - [7, 8, 0], - [0, 8, 7], - [0, 0, 4]], - [[0, 0, 0], - [1, 0, 0], - [4, 5, 0], - [0, 8, 9], - [0, 0, 7]]])) - tests[(-2, 1)] = (np.array([[[2, 6, 0], - [1, 5, 9], - [4, 8, 7], - [7, 8, 4]], - [[2, 3, 0], - [3, 2, 6], - [1, 5, 9], - [4, 8, 7]]]), - np.array([[[1, 2, 0], - [4, 5, 6], - [7, 8, 9], - [0, 8, 7], - [0, 0, 4]], - [[3, 2, 0], - [1, 2, 3], - [4, 5, 6], - [0, 8, 9], - [0, 0, 7]]])) - tests[(1, 2)] = (np.array([[[3, 0], - [2, 6]], - [[1, 0], - [2, 3]]]), - np.array([[[0, 2, 3], - [0, 0, 6], + tests[0, 0] = (np.array([[1, 5, 9], + [3, 2, 6]]), + np.array([[[1, 0, 0], + [0, 5, 0], + [0, 0, 9], + [0, 0, 0]], + [[3, 0, 0], + [0, 2, 0], + [0, 0, 6], + [0, 0, 0]]])) + tests[-4, -3] = (np.array([[[9, 5], + [6, 0]], + [[7, 8], + [9, 0]]]), + np.array([[[0, 0, 0], [0, 0, 0], [0, 0, 0], - [0, 0, 0]], - [[0, 2, 1], - [0, 0, 3], + [9, 0, 0], + [6, 5, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], - [0, 0, 0]]])) + [7, 0, 0], + [9, 8, 0]]])) + tests[-2, -1] = (np.array([[[4, 8, 7], + [7, 8, 4]], + [[1, 5, 9], + [4, 8, 7]]]), + np.array([[[0, 0, 0], + [4, 0, 0], + [7, 8, 0], + [0, 8, 7], + [0, 0, 4]], + [[0, 0, 0], + [1, 0, 0], + [4, 5, 0], + [0, 8, 9], + [0, 0, 7]]])) + tests[-2, 1] = (np.array([[[2, 6, 0], + [1, 5, 9], + [4, 8, 7], + [7, 8, 4]], + [[2, 3, 0], + [3, 2, 6], + [1, 5, 9], + [4, 8, 7]]]), + np.array([[[1, 2, 0], + [4, 5, 6], + [7, 8, 9], + [0, 8, 7], + [0, 0, 4]], + [[3, 2, 0], + [1, 2, 3], + [4, 5, 6], + [0, 8, 9], + [0, 0, 7]]])) + tests[1, 2] = (np.array([[[3, 0], + [2, 6]], + [[1, 0], + [2, 3]]]), + np.array([[[0, 2, 3], + [0, 0, 6], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 2, 1], + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) -def fat_cases(): +def fat_cases(align=None): # pyformat: disable mat = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], @@ -207,59 +280,63 @@ def fat_cases(): [8, 9, 1, 2], [3, 4, 5, 6]]]) tests = dict() - tests[(2, 2)] = (np.array([[3, 8], - [6, 2]]), - np.array([[[0, 0, 3, 0], - [0, 0, 0, 8], - [0, 0, 0, 0]], - [[0, 0, 6, 0], - [0, 0, 0, 2], - [0, 0, 0, 0]]])) - tests[(-2, 0)] = (np.array([[[1, 6, 2], - [5, 1, 0], - [9, 0, 0]], - [[4, 9, 5], - [8, 4, 0], - [3, 0, 0]]]), - np.array([[[1, 0, 0, 0], - [5, 6, 0, 0], - [9, 1, 2, 0]], - [[4, 0, 0, 0], - [8, 9, 0, 0], - [3, 4, 5, 0]]])) - tests[(-1, 1)] = (np.array([[[2, 7, 3], - [1, 6, 2], - [5, 1, 0]], - [[5, 1, 6], - [4, 9, 5], - [8, 4, 0]]]), - np.array([[[1, 2, 0, 0], - [5, 6, 7, 0], - [0, 1, 2, 3]], - [[4, 5, 0, 0], - [8, 9, 1, 0], - [0, 4, 5, 6]]])) - tests[(0, 3)] = (np.array([[[4, 0, 0], - [3, 8, 0], - [2, 7, 3], - [1, 6, 2]], - [[7, 0, 0], - [6, 2, 0], - [5, 1, 6], - [4, 9, 5]]]), - np.array([[[1, 2, 3, 4], - [0, 6, 7, 8], - [0, 0, 2, 3]], - [[4, 5, 6, 7], - [0, 9, 1, 2], - [0, 0, 5, 6]]])) + tests[2, 2] = (np.array([[3, 8], + [6, 2]]), + np.array([[[0, 0, 3, 0], + [0, 0, 0, 8], + [0, 0, 0, 0]], + [[0, 0, 6, 0], + [0, 0, 0, 2], + [0, 0, 0, 0]]])) + tests[-2, 0] = (np.array([[[1, 6, 2], + [5, 1, 0], + [9, 0, 0]], + [[4, 9, 5], + [8, 4, 0], + [3, 0, 0]]]), + np.array([[[1, 0, 0, 0], + [5, 6, 0, 0], + [9, 1, 2, 0]], + [[4, 0, 0, 0], + [8, 9, 0, 0], + [3, 4, 5, 0]]])) + tests[-1, 1] = (np.array([[[2, 7, 3], + [1, 6, 2], + [5, 1, 0]], + [[5, 1, 6], + [4, 9, 5], + [8, 4, 0]]]), + np.array([[[1, 2, 0, 0], + [5, 6, 7, 0], + [0, 1, 2, 3]], + [[4, 5, 0, 0], + [8, 9, 1, 0], + [0, 4, 5, 6]]])) + tests[0, 3] = (np.array([[[4, 0, 0], + [3, 8, 0], + [2, 7, 3], + [1, 6, 2]], + [[7, 0, 0], + [6, 2, 0], + [5, 1, 6], + [4, 9, 5]]]), + np.array([[[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0, 2, 3]], + [[4, 5, 6, 7], + [0, 9, 1, 2], + [0, 0, 5, 6]]])) # pyformat: enable - return (mat, tests) + return (mat, repack_diagonals_in_tests(tests, align)) + + +def all_tests(align=None): + return [square_cases(align), tall_cases(align), fat_cases(align)] class MatrixDiagTest(test.TestCase): - def _moreCases(self): + def _moreCases(self, align=None): # Diagonal bands. # pyformat: disable vecs = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4) @@ -269,41 +346,41 @@ class MatrixDiagTest(test.TestCase): [1, 2, 3, 4], [5, 6, 7, 8]]]) tests = dict() - tests[(-3, -1)] = (vecs, - np.array([[[0, 0, 0, 0, 0], - [1, 0, 0, 0, 0], - [5, 2, 0, 0, 0], - [9, 6, 3, 0, 0], - [0, 8, 7, 4, 0]], - [[0, 0, 0, 0, 0], - [5, 0, 0, 0, 0], - [1, 4, 0, 0, 0], - [5, 2, 3, 0, 0], - [0, 6, 3, 2, 0]]])) - tests[(-1, 1)] = (vecs, - np.array([[[5, 1, 0, 0], - [9, 6, 2, 0], - [0, 8, 7, 3], - [0, 0, 7, 8]], - [[1, 5, 0, 0], - [5, 2, 4, 0], - [0, 6, 3, 3], - [0, 0, 7, 4]]])) - tests[(2, 4)] = (vecs, - np.array([[[0, 0, 9, 5, 1, 0], - [0, 0, 0, 8, 6, 2], - [0, 0, 0, 0, 7, 7], - [0, 0, 0, 0, 0, 6], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], - [[0, 0, 5, 1, 5, 0], - [0, 0, 0, 6, 2, 4], - [0, 0, 0, 0, 7, 3], - [0, 0, 0, 0, 0, 8], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]]])) + tests[-3, -1] = (vecs, + np.array([[[0, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [5, 2, 0, 0, 0], + [9, 6, 3, 0, 0], + [0, 8, 7, 4, 0]], + [[0, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [1, 4, 0, 0, 0], + [5, 2, 3, 0, 0], + [0, 6, 3, 2, 0]]])) + tests[-1, 1] = (vecs, + np.array([[[5, 1, 0, 0], + [9, 6, 2, 0], + [0, 8, 7, 3], + [0, 0, 7, 8]], + [[1, 5, 0, 0], + [5, 2, 4, 0], + [0, 6, 3, 3], + [0, 0, 7, 4]]])) + tests[2, 4] = (vecs, + np.array([[[0, 0, 9, 5, 1, 0], + [0, 0, 0, 8, 6, 2], + [0, 0, 0, 0, 7, 7], + [0, 0, 0, 0, 0, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[0, 0, 5, 1, 5, 0], + [0, 0, 0, 6, 2, 4], + [0, 0, 0, 0, 7, 3], + [0, 0, 0, 0, 0, 8], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]])) # pyformat: enable - return (None, tests) + return (None, repack_diagonals_in_tests(tests, align)) @test_util.run_deprecated_v1 def testVector(self): @@ -314,10 +391,7 @@ class MatrixDiagTest(test.TestCase): self.assertEqual((3, 3), v_diag.get_shape()) self.assertAllEqual(v_diag.eval(), mat) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,Super}diagonals. for offset in [1, -2, 5]: mat = np.diag(v, offset) @@ -326,13 +400,14 @@ class MatrixDiagTest(test.TestCase): self.assertAllEqual(v_diag.eval(), mat) # Diagonal bands. - for _, tests in [self._moreCases(), square_cases()]: - for diags, (vecs, solution) in tests.items(): - v_diags = array_ops.matrix_diag(vecs[0], k=diags) - self.assertEqual(v_diags.get_shape(), solution[0].shape) - self.assertAllEqual(v_diags.eval(), solution[0]) + for align in alignment_list: + for _, tests in [self._moreCases(align), square_cases(align)]: + for diags, (vecs, solution) in tests.items(): + v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align) + self.assertEqual(v_diags.get_shape(), solution[0].shape) + self.assertAllEqual(v_diags.eval(), solution[0]) - def _testBatchVector(self, dtype): + def _testVectorBatch(self, dtype): with self.cached_session(use_gpu=True): v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype) mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]], @@ -342,10 +417,7 @@ class MatrixDiagTest(test.TestCase): self.assertEqual((2, 3, 3), v_batch_diag.get_shape()) self.assertAllEqual(v_batch_diag.eval(), mat_batch) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,Super}diagonals. for offset in [1, -2, 5]: v_batch_diag = array_ops.matrix_diag(v_batch, k=offset) @@ -357,33 +429,34 @@ class MatrixDiagTest(test.TestCase): self.assertAllEqual(v_batch_diag.eval(), mat_batch) # Diagonal bands with padding_value. - for padding_value in [0, 555, -11]: - for _, tests in [self._moreCases(), square_cases()]: + for padding_value, align in zip_to_first_list_length([0, 555, -11], + alignment_list): + for _, tests in [self._moreCases(align), square_cases(align)]: for diags, (vecs, solution) in tests.items(): v_diags = array_ops.matrix_diag( - vecs.astype(dtype), k=diags, padding_value=padding_value) + vecs.astype(dtype), + k=diags, + padding_value=padding_value, + align=align) mask = solution == 0 solution = (solution + padding_value * mask).astype(dtype) self.assertEqual(v_diags.get_shape(), solution.shape) self.assertAllEqual(v_diags.eval(), solution) @test_util.run_deprecated_v1 - def testBatchVector(self): - self._testBatchVector(np.float32) - self._testBatchVector(np.float64) - self._testBatchVector(np.int32) - self._testBatchVector(np.int64) - self._testBatchVector(np.bool) + def testVectorBatch(self): + self._testVectorBatch(np.float32) + self._testVectorBatch(np.float64) + self._testVectorBatch(np.int32) + self._testVectorBatch(np.int64) + self._testVectorBatch(np.bool) @test_util.run_deprecated_v1 def testRectangularBatch(self): - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): with self.cached_session(use_gpu=True): # Stores expected num_rows and num_cols (when the other is given). - # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) + # expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols) test_list = list() # Square cases: @@ -393,6 +466,8 @@ class MatrixDiagTest(test.TestCase): (-2, 1): (5, 5), (2, 4): (3, 5), } + # Do not change alignment yet. Re-alignment needs to happen after the + # solution shape is updated. test_list.append((expected, square_cases())) # More cases: @@ -418,16 +493,18 @@ class MatrixDiagTest(test.TestCase): } test_list.append((expected, fat_cases())) - for padding_value in [0, 555, -11]: + for padding_value, align in zip_to_first_list_length([0, 555, -11], + alignment_list): # Giving both num_rows and num_cols - for _, tests in [tall_cases(), fat_cases()]: + for _, tests in [tall_cases(align), fat_cases(align)]: for diags, (vecs, solution) in tests.items(): v_diags = array_ops.matrix_diag( vecs, k=diags, num_rows=solution.shape[-2], num_cols=solution.shape[-1], - padding_value=padding_value) + padding_value=padding_value, + align=align) mask = solution == 0 solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) @@ -438,11 +515,15 @@ class MatrixDiagTest(test.TestCase): for diags, (_, new_num_cols) in expected.items(): vecs, solution = tests[diags] solution = solution.take(indices=range(new_num_cols), axis=-1) + # Repacks the diagonal input according to the new solution shape. + vecs = repack_diagonals( + vecs, diags, solution.shape[-2], new_num_cols, align=align) v_diags = array_ops.matrix_diag( vecs, k=diags, num_rows=solution.shape[-2], - padding_value=padding_value) + padding_value=padding_value, + align=align) mask = solution == 0 solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) @@ -453,11 +534,15 @@ class MatrixDiagTest(test.TestCase): for diags, (new_num_rows, _) in expected.items(): vecs, solution = tests[diags] solution = solution.take(indices=range(new_num_rows), axis=-2) + # Repacks the diagonal input according to the new solution shape. + vecs = repack_diagonals( + vecs, diags, new_num_rows, solution.shape[-1], align=align) v_diags = array_ops.matrix_diag( vecs, k=diags, num_cols=solution.shape[-1], - padding_value=padding_value) + padding_value=padding_value, + align=align) mask = solution == 0 solution = solution + padding_value * mask self.assertEqual(v_diags.get_shape(), solution.shape) @@ -489,10 +574,7 @@ class MatrixDiagTest(test.TestCase): y.get_shape().as_list()) self.assertLess(error, 1e-4) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # {Sub,super}diagonals/band. tests = dict() # tests[shape] = (d_lower, d_upper) tests[(3,)] = (-1, -1) @@ -500,12 +582,13 @@ class MatrixDiagTest(test.TestCase): with self.session(use_gpu=True): for shape, diags in tests.items(): x = constant_op.constant(np.random.rand(*shape), np.float32) - y = array_ops.matrix_diag(x, k=diags) - error = gradient_checker.compute_gradient_error( - x, - x.get_shape().as_list(), y, - y.get_shape().as_list()) - self.assertLess(error, 1e-4) + for align in alignment_list: + y = array_ops.matrix_diag(x, k=diags, align=align) + error = gradient_checker.compute_gradient_error( + x, + x.get_shape().as_list(), y, + y.get_shape().as_list()) + self.assertLess(error, 1e-4) class MatrixSetDiagTest(test.TestCase): @@ -521,20 +604,18 @@ class MatrixSetDiagTest(test.TestCase): self.assertEqual((3, 3), output.get_shape()) self.assertAllEqual(mat_set_diag, self.evaluate(output)) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands. - _, tests = square_cases() - for diags, pair in tests.items(): - vecs, banded_mat = pair - mask = banded_mat[0] == 0 - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat[0] - output = array_ops.matrix_set_diag(input_mat, vecs[0], k=diags) - self.assertEqual(output.get_shape(), solution.shape) - self.assertAllEqual(output.eval(), solution) + for align in alignment_list: + _, tests = square_cases(align) + for diags, (vecs, banded_mat) in tests.items(): + mask = banded_mat[0] == 0 + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + output = array_ops.matrix_set_diag( + input_mat, vecs[0], k=diags, align=align) + self.assertEqual(output.get_shape(), solution.shape) + self.assertAllEqual(output.eval(), solution) @test_util.run_deprecated_v1 def testRectangular(self): @@ -553,20 +634,18 @@ class MatrixSetDiagTest(test.TestCase): self.assertEqual((3, 2), output.get_shape()) self.assertAllEqual(expected, self.evaluate(output)) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands. - for _, tests in [tall_cases(), fat_cases()]: - for diags, pair in tests.items(): - vecs, banded_mat = pair - mask = banded_mat[0] == 0 - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat[0] - output = array_ops.matrix_set_diag(input_mat, vecs[0], k=diags) - self.assertEqual(output.get_shape(), solution.shape) - self.assertAllEqual(output.eval(), solution) + for align in alignment_list: + for _, tests in [tall_cases(align), fat_cases(align)]: + for diags, (vecs, banded_mat) in tests.items(): + mask = banded_mat[0] == 0 + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + output = array_ops.matrix_set_diag( + input_mat, vecs[0], k=diags, align=align) + self.assertEqual(output.get_shape(), solution.shape) + self.assertAllEqual(output.eval(), solution) def _testSquareBatch(self, dtype): with self.cached_session(use_gpu=True): @@ -584,21 +663,18 @@ class MatrixSetDiagTest(test.TestCase): self.assertEqual((2, 3, 3), output.get_shape()) self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands. - _, tests = square_cases() - for diags, pair in tests.items(): - vecs, banded_mat = pair - mask = banded_mat == 0 - input_mat = np.random.randint(10, size=mask.shape).astype(dtype) - solution = (input_mat * mask + banded_mat).astype(dtype) - output = array_ops.matrix_set_diag( - input_mat, vecs.astype(dtype), k=diags) - self.assertEqual(output.get_shape(), solution.shape) - self.assertAllEqual(output.eval(), solution) + for align in alignment_list: + _, tests = square_cases(align) + for diags, (vecs, banded_mat) in tests.items(): + mask = banded_mat == 0 + input_mat = np.random.randint(10, size=mask.shape).astype(dtype) + solution = (input_mat * mask + banded_mat).astype(dtype) + output = array_ops.matrix_set_diag( + input_mat, vecs.astype(dtype), k=diags, align=align) + self.assertEqual(output.get_shape(), solution.shape) + self.assertAllEqual(output.eval(), solution) @test_util.run_deprecated_v1 def testSquareBatch(self): @@ -621,20 +697,19 @@ class MatrixSetDiagTest(test.TestCase): self.assertEqual((2, 2, 3), output.get_shape()) self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands. - for _, tests in [tall_cases(), fat_cases()]: - for diags, pair in tests.items(): - vecs, banded_mat = pair - mask = banded_mat == 0 - input_mat = np.random.randint(10, size=mask.shape) - solution = input_mat * mask + banded_mat - output = array_ops.matrix_set_diag(input_mat, vecs, k=diags) - self.assertEqual(output.get_shape(), solution.shape) - self.assertAllEqual(output.eval(), solution) + for align in alignment_list: + for _, tests in [tall_cases(align), fat_cases(align)]: + for diags, pair in tests.items(): + vecs, banded_mat = pair + mask = banded_mat == 0 + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat + output = array_ops.matrix_set_diag( + input_mat, vecs, k=diags, align=align) + self.assertEqual(output.get_shape(), solution.shape) + self.assertAllEqual(output.eval(), solution) @test_util.run_deprecated_v1 def testInvalidShape(self): @@ -643,10 +718,7 @@ class MatrixSetDiagTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "must be at least rank 1"): array_ops.matrix_set_diag([[0]], 0) - # TODO(penporn): Un-skip the XLA test when XLA has MatrixSetDiagV2. @test_util.run_deprecated_v1 - @test_util.disable_xla("XLA op hasn't supported new features in V2, which" - "change the shape requirements.") def testInvalidShapeAtEval(self): with self.session(use_gpu=True): v = array_ops.placeholder(dtype=dtypes_lib.float32) @@ -655,10 +727,7 @@ class MatrixSetDiagTest(test.TestCase): with self.assertRaisesOpError("diagonal must be at least 1-dim"): array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0}) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): d = array_ops.placeholder(dtype=dtypes_lib.float32) with self.assertRaisesOpError( "first dimensions of diagonal don't match"): @@ -667,17 +736,15 @@ class MatrixSetDiagTest(test.TestCase): d: np.ones((2, 4)) }) - def _testGrad(self, input_shape, diag_shape, diags): + def _testGrad(self, input_shape, diag_shape, diags, align): with self.session(use_gpu=True): x = constant_op.constant( np.random.rand(*input_shape), dtype=dtypes_lib.float32) x_diag = constant_op.constant( np.random.rand(*diag_shape), dtype=dtypes_lib.float32) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - y = array_ops.matrix_set_diag(x, x_diag, k=diags) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align) else: y = array_ops.matrix_set_diag(x, x_diag) error_x = gradient_checker.compute_gradient_error(x, @@ -696,16 +763,15 @@ class MatrixSetDiagTest(test.TestCase): input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)] diag_bands = [(0, 0)] - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): diag_bands.append((-1, 1)) - for input_shape, diags in itertools.product(input_shapes, diag_bands): + for input_shape, diags, align in itertools.product(input_shapes, diag_bands, + alignment_list): lower_diag_index, upper_diag_index = diags num_diags = upper_diag_index - lower_diag_index + 1 num_diags_dim = () if num_diags == 1 else (num_diags,) diag_shape = input_shape[:-2] + num_diags_dim + (min(input_shape[-2:]),) - self._testGrad(input_shape, diag_shape, diags) + self._testGrad(input_shape, diag_shape, diags, align) @test_util.run_deprecated_v1 def testGradWithNoShapeInformation(self): @@ -739,9 +805,7 @@ class MatrixDiagPartTest(test.TestCase): self.assertEqual((3,), mat_diag.get_shape()) self.assertAllEqual(mat_diag.eval(), v) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): for offset in [-2, 3]: mat = np.diag(v, offset) mat_diag = array_ops.matrix_diag_part(mat, k=offset) @@ -749,12 +813,13 @@ class MatrixDiagPartTest(test.TestCase): self.assertAllEqual(mat_diag.eval(), v) # Diagonal bands. - mat, tests = square_cases() - for diags, pair in tests.items(): - solution, _ = pair - mat_diag = array_ops.matrix_diag_part(mat[0], k=diags) - self.assertEqual(mat_diag.get_shape(), solution[0].shape) - self.assertAllEqual(mat_diag.eval(), solution[0]) + for align in alignment_list: + mat, tests = square_cases(align) + for diags, pair in tests.items(): + solution, _ = pair + mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align) + self.assertEqual(mat_diag.get_shape(), solution[0].shape) + self.assertAllEqual(mat_diag.eval(), solution[0]) @test_util.run_deprecated_v1 def testRectangular(self): @@ -766,17 +831,16 @@ class MatrixDiagPartTest(test.TestCase): mat_diag = array_ops.matrix_diag_part(mat) self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0])) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands. - for mat, tests in [tall_cases(), fat_cases()]: - for diags, pair in tests.items(): - solution, _ = pair - mat_diag = array_ops.matrix_diag_part(mat[0], k=diags) - self.assertEqual(mat_diag.get_shape(), solution[0].shape) - self.assertAllEqual(mat_diag.eval(), solution[0]) + for align in alignment_list: + for mat, tests in [tall_cases(align), fat_cases(align)]: + for diags, pair in tests.items(): + solution, _ = pair + mat_diag = array_ops.matrix_diag_part( + mat[0], k=diags, align=align) + self.assertEqual(mat_diag.get_shape(), solution[0].shape) + self.assertAllEqual(mat_diag.eval(), solution[0]) def _testSquareBatch(self, dtype): with self.cached_session(use_gpu=True): @@ -789,17 +853,18 @@ class MatrixDiagPartTest(test.TestCase): self.assertEqual((2, 3), mat_batch_diag.get_shape()) self.assertAllEqual(mat_batch_diag.eval(), v_batch) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Diagonal bands with padding_value. - mat, tests = square_cases() - for padding_value in [0, 555, -11]: + for padding_value, align in zip_to_first_list_length([0, 555, -11], + alignment_list): + mat, tests = square_cases(align) for diags, pair in tests.items(): solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part( - mat.astype(dtype), k=diags, padding_value=padding_value) + mat.astype(dtype), + k=diags, + padding_value=padding_value, + align=align) mask = solution == 0 solution = (solution + padding_value * mask).astype(dtype) self.assertEqual(mat_batch_diag.get_shape(), solution.shape) @@ -824,17 +889,15 @@ class MatrixDiagPartTest(test.TestCase): self.assertEqual((2, 2), mat_batch_diag.get_shape()) self.assertAllEqual(mat_batch_diag.eval(), v_batch) - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) - - # Diagonal bands with padding_value. - for padding_value in [0, 555, -11]: - for mat, tests in [tall_cases(), fat_cases()]: + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + # Diagonal bands with padding_value and align. + for padding_value, align in zip_to_first_list_length([0, 555, -11], + alignment_list): + for mat, tests in [tall_cases(align), fat_cases(align)]: for diags, pair in tests.items(): solution, _ = pair mat_batch_diag = array_ops.matrix_diag_part( - mat, k=diags, padding_value=padding_value) + mat, k=diags, padding_value=padding_value, align=align) mask = solution == 0 solution = solution + padding_value * mask self.assertEqual(mat_batch_diag.get_shape(), solution.shape) @@ -866,6 +929,22 @@ class MatrixDiagPartTest(test.TestCase): y.get_shape().as_list()) self.assertLess(error, 1e-4) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + # {Sub,super}diagonals/band. + tests = dict() # tests[shape] = (d_lower, d_upper) + tests[(3, 3)] = (-1, -1) + tests[(7, 3, 4)] = (-1, 1) + with self.session(use_gpu=True): + for align in alignment_list: + for shape, diags in tests.items(): + x = constant_op.constant(np.random.rand(*shape), np.float32) + y = array_ops.matrix_diag_part(input=x, k=diags, align=align) + error = gradient_checker.compute_gradient_error( + x, + x.get_shape().as_list(), y, + y.get_shape().as_list()) + self.assertLess(error, 1e-4) + class DiagTest(test.TestCase): diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 2c1bd445e54..6dd853846db 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -341,6 +341,12 @@ def _MatrixDiagV2Grad(op, grad): grad, k=op.inputs[1]), None, None, None, None +@ops.RegisterGradient("MatrixDiagV3") +def _MatrixDiagV3Grad(op, grad): + return array_ops.matrix_diag_part( + grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None + + @ops.RegisterGradient("MatrixDiagPart") def _MatrixDiagPartGrad(op, grad): matrix_shape = op.inputs[0].get_shape()[-2:] @@ -362,8 +368,25 @@ def _MatrixDiagPartV2Grad(op, grad): num_cols=matrix_shape[1]), None, None else: return array_ops.matrix_set_diag( - array_ops.zeros_like(op.inputs[0]), grad, - k=op.inputs[1]), None, None + array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None + + +@ops.RegisterGradient("MatrixDiagPartV3") +def _MatrixDiagPartV3Grad(op, grad): + """Gradient for MatrixDiagPartV3.""" + matrix_shape = op.inputs[0].get_shape()[-2:] + align = op.get_attr("align") + if matrix_shape.is_fully_defined(): + return array_ops.matrix_diag( + grad, + k=op.inputs[1], + num_rows=matrix_shape[0], + num_cols=matrix_shape[1], + align=align), None, None + else: + return array_ops.matrix_set_diag( + array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], + align=align), None, None @ops.RegisterGradient("MatrixSetDiag") @@ -392,7 +415,7 @@ def _MatrixSetDiagGrad(op, grad): @ops.RegisterGradient("MatrixSetDiagV2") def _MatrixSetDiagGradV2(op, grad): - """Gradient for MatrixSetDiag.""" + """Gradient for MatrixSetDiagV2.""" diag_shape = op.inputs[1].get_shape() if not diag_shape.is_fully_defined(): # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. @@ -426,6 +449,46 @@ def _MatrixSetDiagGradV2(op, grad): return (grad_input, grad_diag, None) +@ops.RegisterGradient("MatrixSetDiagV3") +def _MatrixSetDiagGradV3(op, grad): + """Gradient for MatrixSetDiagV3.""" + diag_shape = op.inputs[1].get_shape() + align = op.get_attr("align") + if not diag_shape.is_fully_defined(): + # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. + grad_shape = array_ops.shape(grad) + batch_shape = grad_shape[:-2] + matrix_shape = grad_shape[-2:] + diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. + d_lower = diag_index[0] + d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. + y_offset = control_flow_ops.cond( + math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) + x_offset = control_flow_ops.cond( + math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) + + max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, + matrix_shape[1] + x_offset) + # pylint: disable=g-long-lambda + # pyformat: disable + postfix = control_flow_ops.cond( + math_ops.equal(d_lower, d_upper), + lambda: ops.convert_to_tensor([max_diag_len]), + lambda: ops.convert_to_tensor([d_upper - d_lower + 1, + max_diag_len])) + # pyformat: enable + # pylint: enable=g-long-lambda + diag_shape = array_ops.concat([batch_shape, postfix], 0) + + grad_input = array_ops.matrix_set_diag( + grad, + array_ops.zeros(diag_shape, dtype=grad.dtype), + k=op.inputs[2], + align=align) + grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) + return (grad_input, grad_diag, None) + + @ops.RegisterGradient("MatrixBandPart") def _MatrixBandPartGrad(op, grad): num_lower = op.inputs[1] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c6234754d20..488a9bafca9 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -54,6 +54,13 @@ tf_export("newaxis").export_constant(__name__, "newaxis") # existing 'slice' for later use in this module. _BaseSlice = slice +# LINT.IfChange +matrix_diag_v3_forward_compat_date = (2019, 12, 6) +# LINT.ThenChange( +# //tensorflow/compiler/tests/matrix_diag_ops_test.py, +# //tensorflow/python/kernel_tests/diag_op_test.py, +# //tensorflow/python/ops/parallel_for/array_test.py +# ) @tf_export("reshape", v1=["reshape", "manip.reshape"]) def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name @@ -2043,7 +2050,8 @@ def matrix_diag(diagonal, k=0, num_rows=-1, num_cols=-1, - padding_value=0): + padding_value=0, + align="RIGHT_LEFT"): """Returns a batched diagonal tensor with given batched diagonal values. Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th @@ -2078,7 +2086,17 @@ def matrix_diag(diagonal, padding_value ; otherwise ``` where `d = n - m`, `diag_index = k[1] - d`, and - `index_in_diag = n - max(d, 0)`. + `index_in_diag = n - max(d, 0) + offset`. + + `offset` is zero except when the alignment of the diagonal is to the right. + ``` + offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise + ``` + where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. For example: @@ -2108,17 +2126,34 @@ def matrix_diag(diagonal, [0, 0, 0, 6], [0, 0, 0, 0]]] - # A band of diagonals. - diagonals = np.array([[[1, 2, 3], # Input shape: (2, 2, 3) - [4, 5, 0]], - [[6, 7, 9], - [9, 1, 0]]]) - tf.matrix_diag(diagonals, k = (-1, 0)) - ==> [[[1, 0, 0], # Output shape: (2, 3, 3) - [4, 2, 0], + # A tridiagonal band (per batch). + diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3) + [1, 2, 3], + [0, 4, 5]], + [[2, 3, 0], + [6, 7, 9], + [0, 9, 1]]]) + tf.matrix_diag(diagonals, k = (-1, 1)) + ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + [4, 2, 9], [0, 5, 3]], - [[6, 0, 0], - [9, 7, 0], + [[6, 2, 0], + [9, 7, 3], + [0, 1, 9]]] + + # RIGHT_LEFT alignment. + diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3) + [1, 2, 3], + [4, 5, 0]], + [[0, 2, 3], + [6, 7, 9], + [9, 1, 0]]]) + tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT") + ==> [[[1, 8, 0], # Output shape: (2, 3, 3) + [4, 2, 9], + [0, 5, 3]], + [[6, 2, 0], + [9, 7, 3], [0, 1, 9]]] # Rectangular matrix. @@ -2150,27 +2185,34 @@ def matrix_diag(diagonal, size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`. padding_value: The value to fill the area outside the specified diagonal band with. Default is 0. + align: Some diagonals are shorter than `max_diag_len` and need to be padded. + `align` is a string specifying how superdiagonals and subdiagonals should + be aligned, respectively. There are four possible alignments: "RIGHT_LEFT" + (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to + the left (right-pads the row). It is the packing format LAPACK uses. + cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. Returns: A Tensor. Has the same type as `diagonal`. """ - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Special case to sidestep the tf.constant conversion error: # TypeError: Expected bool, got 0 of type 'int' instead. if hasattr(diagonal, "dtype") and diagonal.dtype == "bool": padding_value = bool(padding_value) - return gen_array_ops.matrix_diag_v2( + + return gen_array_ops.matrix_diag_v3( diagonal=diagonal, k=k, num_rows=num_rows, num_cols=num_cols, padding_value=padding_value, + align=align, name=name) # Call v1 to maintain forward compatibility. + # (We skip v2 because its alignment conflicts with v3's default alignment.) return gen_array_ops.matrix_diag(diagonal=diagonal, name=name) @@ -2181,7 +2223,8 @@ def matrix_diag_part( input, # pylint:disable=redefined-builtin name="diag_part", k=0, - padding_value=0): + padding_value=0, + align="RIGHT_LEFT"): """Returns the batched diagonal part of a batched tensor. Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched @@ -2211,7 +2254,17 @@ def matrix_diag_part( = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, padding_value ; otherwise. ``` - where `d = k[1] - m`, `y = max(-d, 0)`, and `x = max(d, 0)`. + where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`. + + `offset` is zero except when the alignment of the diagonal is to the right. + ``` + offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise + ``` + where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. The input must be at least a matrix. @@ -2234,16 +2287,36 @@ def matrix_diag_part( ==> [[2, 7, 6], # Output shape: (2, 3) [4, 3, 8]] - # A tridiagonal band from each batch. - tf.matrix_diag_part(input, k = (-1, 1)) - ==> [[[2, 7, 6], # Output shape: (2, 3, 3) + # A band from each batch. + tf.matrix_diag_part(input, k = (-1, 2)) + ==> [[[3, 8, 0], # Output shape: (2, 4, 3) + [2, 7, 6], + [1, 6, 7], + [0, 5, 8]], + [[3, 4, 0], + [4, 3, 8], + [5, 2, 7], + [0, 1, 6]]] + + # RIGHT_LEFT alignment. + tf.matrix_diag_part(input, k = (-1, 2), align="RIGHT_LEFT") + ==> [[[0, 3, 8], # Output shape: (2, 4, 3) + [2, 7, 6], [1, 6, 7], [5, 8, 0]], - [[4, 3, 8], + [[0, 3, 4], + [4, 3, 8], [5, 2, 7], [1, 6, 0]]] - # Padding value = 9 + # max_diag_len can be shorter than the main diagonal. + tf.matrix_diag_part(input, k = (-2, -1)) + ==> [[[5, 8], + [0, 9]], + [[1, 6], + [0, 5]]] + + # padding_value = 9 tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) ==> [[[4, 9, 9], # Output shape: (2, 3, 3) [3, 8, 9], @@ -2251,6 +2324,7 @@ def matrix_diag_part( [[2, 9, 9], [3, 4, 9], [4, 3, 8]]] + ``` Args: @@ -2262,23 +2336,28 @@ def matrix_diag_part( and high ends of a matrix band. `k[0]` must not be larger than `k[1]`. padding_value: The value to fill the area outside the specified diagonal band with. Default is 0. + align: Some diagonals are shorter than `max_diag_len` and need to be padded. + `align` is a string specifying how superdiagonals and subdiagonals should + be aligned, respectively. There are four possible alignments: "RIGHT_LEFT" + (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to + the left (right-pads the row). It is the packing format LAPACK uses. + cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. Returns: A Tensor containing diagonals of `input`. Has the same type as `input`. """ - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py) - + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): # Special case to sidestep the tf.constant conversion error: # TypeError: Expected bool, got 0 of type 'int' instead. if hasattr(input, "dtype") and input.dtype == "bool": padding_value = bool(padding_value) - return gen_array_ops.matrix_diag_part_v2( - input=input, k=k, padding_value=padding_value, name=name) + return gen_array_ops.matrix_diag_part_v3( + input=input, k=k, padding_value=padding_value, align=align, name=name) # Call v1 to maintain forward compatibility. + # (We skip v2 because its alignment conflicts with v3's default alignment.) return gen_array_ops.matrix_diag_part(input=input, name=name) @@ -2288,7 +2367,8 @@ def matrix_set_diag( input, # pylint:disable=redefined-builtin diagonal, name="set_diag", - k=0): + k=0, + align="RIGHT_LEFT"): """Returns a batched matrix tensor with new batched diagonal values. Given `input` and `diagonal`, this operation returns a tensor with the @@ -2319,7 +2399,17 @@ def matrix_set_diag( input[i, j, ..., l, m, n] ; otherwise ``` where `d = n - m`, `diag_index = k[1] - d`, and - `index_in_diag = n - max(d, 0)`. + `index_in_diag = n - max(d, 0) + offset`. + + `offset` is zero except when the alignment of the diagonal is to the right. + ``` + offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise + ``` + where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. For example: @@ -2333,15 +2423,16 @@ def matrix_set_diag( [7, 7, 7, 7]]]) diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3) [4, 5, 6]]) - tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) - [7, 2, 7, 7], - [7, 7, 3, 7]], - [[4, 7, 7, 7], - [7, 5, 7, 7], - [7, 7, 6, 7]]] + tf.matrix_set_diag(input, diagonal) + ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) + [7, 2, 7, 7], + [7, 7, 3, 7]], + [[4, 7, 7, 7], + [7, 5, 7, 7], + [7, 7, 6, 7]]] # A superdiagonal (per batch). - tf.matrix_set_diag(diagonal, k = 1) + tf.matrix_set_diag(input, diagonal, k = 1) ==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4) [7, 7, 2, 7], [7, 7, 7, 3]], @@ -2350,17 +2441,38 @@ def matrix_set_diag( [7, 7, 7, 6]]] # A band of diagonals. - diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3) + diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3) + [6, 5, 8], + [1, 2, 3], + [0, 4, 5]], + [[1, 2, 0], + [5, 6, 4], + [6, 1, 2], + [0, 3, 4]]]) + tf.matrix_set_diag(input, diagonals, k = (-1, 2)) + ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + [4, 2, 5, 1], + [7, 5, 3, 8]], + [[6, 5, 1, 7], + [3, 1, 6, 2], + [7, 4, 2, 4]]] + + # RIGHT_LEFT alignment. + diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3) + [6, 5, 8], + [1, 2, 3], [4, 5, 0]], - [[6, 1, 2], + [[0, 1, 2], + [5, 6, 4], + [6, 1, 2], [3, 4, 0]]]) - tf.matrix_set_diag(diagonals, k = (-1, 0)) - ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4) - [4, 2, 7, 7], - [0, 5, 3, 7]], - [[6, 7, 7, 7], - [3, 1, 7, 7], - [7, 4, 2, 7]]] + tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT") + ==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4) + [4, 2, 5, 1], + [7, 5, 3, 8]], + [[6, 5, 1, 7], + [3, 1, 6, 2], + [7, 4, 2, 4]]] ``` @@ -2373,14 +2485,20 @@ def matrix_set_diag( main diagonal, and negative value means subdiagonals. `k` can be a single integer (for a single diagonal) or a pair of integers specifying the low and high ends of a matrix band. `k[0]` must not be larger than `k[1]`. + align: Some diagonals are shorter than `max_diag_len` and need to be padded. + `align` is a string specifying how superdiagonals and subdiagonals should + be aligned, respectively. There are four possible alignments: "RIGHT_LEFT" + (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to + the left (right-pads the row). It is the packing format LAPACK uses. + cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment. """ - # LINT.IfChange - if compat.forward_compatible(2019, 11, 30): - # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py) - return gen_array_ops.matrix_set_diag_v2( - input=input, diagonal=diagonal, k=k, name=name) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + return gen_array_ops.matrix_set_diag_v3( + input=input, diagonal=diagonal, k=k, align=align, name=name) # Call v1 to maintain forward compatibility. + # (We skip v2 because its alignment conflicts with v3's default alignment.) return gen_array_ops.matrix_set_diag( input=input, diagonal=diagonal, name=name) diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 205c16e5197..2efa16ce62a 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -1111,10 +1111,17 @@ def _cholesky(input, name=None): # pylint:disable=redefined-builtin # The signature has to match with the one in python/op/array_ops.py, -# so we have k and padding_value even though we don't use them here. +# so we have k, padding_value, and align even though we don't use them here. +# pylint:disable=unused-argument @dispatch.dispatch_for_types(linalg.diag_part, LinearOperator) -def _diag_part(input, name="diag_part", k=0, padding_value=0): # pylint:disable=redefined-builtin, unused-argument +def _diag_part( + input, # pylint:disable=redefined-builtin + name="diag_part", + k=0, + padding_value=0, + align="RIGHT_LEFT"): return input.diag_part(name) +# pylint:enable=unused-argument @dispatch.dispatch_for_types(linalg.det, LinearOperator) diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py index 874e59926af..7986b74ed3b 100644 --- a/tensorflow/python/ops/parallel_for/array_test.py +++ b/tensorflow/python/ops/parallel_for/array_test.py @@ -33,6 +33,15 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase from tensorflow.python.platform import test +# LINT.IfChange +matrix_diag_v3_forward_compat_date = (2019, 12, 6) +# LINT.ThenChange( +# //tensorflow/compiler/tests/matrix_diag_ops_test.py, +# //tensorflow/python/kernel_tests/diag_op_test.py, +# //tensorflow/python/ops/array_ops.py +# ) + + @test_util.run_all_in_graph_and_eager_modes class ArrayTest(PForTestCase): @@ -336,8 +345,9 @@ class ArrayTest(PForTestCase): def loop_fn(i): diagonal = array_ops.gather(x, i) - if compat.forward_compatible(2019, 11, 30): - return array_ops.matrix_diag(diagonal, k=(0, 1), num_rows=4, num_cols=5) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + return array_ops.matrix_diag( + diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT") return array_ops.matrix_diag(diagonal) self._test_loop_fn(loop_fn, 3) @@ -347,8 +357,9 @@ class ArrayTest(PForTestCase): def loop_fn(i): input = array_ops.gather(x, i) # pylint: disable=redefined-builtin - if compat.forward_compatible(2019, 11, 30): - return array_ops.matrix_diag_part(input, k=(-2, 0), padding_value=3) + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): + return array_ops.matrix_diag_part( + input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT") return array_ops.matrix_diag_part(input) self._test_loop_fn(loop_fn, 3) @@ -356,8 +367,7 @@ class ArrayTest(PForTestCase): def test_matrix_set_diag(self): matrices = random_ops.random_uniform([3, 4, 4]) diags = random_ops.random_uniform([3, 4]) - if compat.forward_compatible(2019, 11, 30): - bands = random_ops.random_uniform([3, 3, 4]) + bands = random_ops.random_uniform([3, 3, 4]) def loop_fn(i): matrix_i = array_ops.gather(matrices, i) @@ -365,16 +375,20 @@ class ArrayTest(PForTestCase): results = [ array_ops.matrix_set_diag(matrix_i, diag_i), array_ops.matrix_set_diag(matrices[0, ...], diag_i), - array_ops.matrix_set_diag(matrix_i, diags[0, ...]) + array_ops.matrix_set_diag(matrix_i, diags[0, ...]), ] - if compat.forward_compatible(2019, 11, 30): + + if compat.forward_compatible(*matrix_diag_v3_forward_compat_date): k = (-1, 1) band_i = array_ops.gather(bands, i) - results.extend([ - array_ops.matrix_set_diag(matrix_i, band_i, k=k), - array_ops.matrix_set_diag(matrices[0, ...], band_i, k=k), - array_ops.matrix_set_diag(matrix_i, bands[0, ...], k=k) - ]) + for align in ["RIGHT_LEFT", "LEFT_RIGHT"]: + results.extend([ + array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align), + array_ops.matrix_set_diag( + matrices[0, ...], band_i, k=k, align=align), + array_ops.matrix_set_diag( + matrix_i, bands[0, ...], k=k, align=align) + ]) return results self._test_loop_fn(loop_fn, 3) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index a59f6af4d1f..0e861ffe2ab 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1902,43 +1902,55 @@ def _convert_matrix_set_diag(pfor_input): return wrap(array_ops.matrix_set_diag(t, diag), True) -# Registrations for MatrixDiagV2, MatrixDiagPartv2, and MatrixSetDiagV2. +# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. # The input orders defined in the OpKernel and the actual python API are # different (for compatibility with V1), so we cannot use _convert_identity. +# v2 is not compatible with v3 and is never exposed on the public API. @RegisterPFor("MatrixDiagV2") +@RegisterPFor("MatrixDiagV3") def _convert_matrix_diag_v2(pfor_input): - diagonal = pfor_input.stacked_input(0) - k = pfor_input.unstacked_input(1) - num_rows = pfor_input.unstacked_input(2) - num_cols = pfor_input.unstacked_input(3) - padding_value = pfor_input.unstacked_input(4) - return wrap( - array_ops.matrix_diag( - diagonal, - k=k, - num_rows=num_rows, - num_cols=num_cols, - padding_value=padding_value), True) + params = { + "diagonal": pfor_input.stacked_input(0), + "k": pfor_input.unstacked_input(1), + "num_rows": pfor_input.unstacked_input(2), + "num_cols": pfor_input.unstacked_input(3), + "padding_value": pfor_input.unstacked_input(4) + } + if pfor_input.op_type == "MatrixDiagV2": + return wrap(array_ops.matrix_diag_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_diag(**params), True) # See notes for MatrixDiagV2 @RegisterPFor("MatrixDiagPartV2") +@RegisterPFor("MatrixDiagPartV3") def _convert_matrix_diag_part_v2(pfor_input): - input = pfor_input.stacked_input(0) # pylint:disable=redefined-builtin - k = pfor_input.unstacked_input(1) - padding_value = pfor_input.unstacked_input(2) - return wrap( - array_ops.matrix_diag_part(input, k=k, padding_value=padding_value), True) + params = { + "input": pfor_input.stacked_input(0), + "k": pfor_input.unstacked_input(1), + "padding_value": pfor_input.unstacked_input(2) + } + if pfor_input.op_type == "MatrixDiagPartV2": + return wrap(array_ops.matrix_diag_part_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_diag_part(**params), True) # See notes for MatrixDiagV2 @RegisterPFor("MatrixSetDiagV2") +@RegisterPFor("MatrixSetDiagV3") def _convert_matrix_set_diag_v2(pfor_input): pfor_input.stack_inputs([0, 1]) - input = pfor_input.stacked_input(0) # pylint:disable=redefined-builtin - diagonal = pfor_input.stacked_input(1) - k = pfor_input.unstacked_input(2) - return wrap(array_ops.matrix_set_diag(input, diagonal, k=k), True) + params = { + "input": pfor_input.stacked_input(0), + "diagonal": pfor_input.stacked_input(1), + "k": pfor_input.unstacked_input(2) + } + if pfor_input.op_type == "MatrixSetDiagV2": + return wrap(array_ops.matrix_set_diag_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_set_diag(**params), True) @RegisterPFor("OneHot") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index 283fd9c35d6..f645db2f310 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -102,11 +102,11 @@ tf_module { } member_method { name: "diag" - argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], " + argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "diag_part" - argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], " + argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "eigh" @@ -202,7 +202,7 @@ tf_module { } member_method { name: "set_diag" - argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], " + argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "slogdet" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 6c75ecb5fbf..2883f6d8166 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1646,11 +1646,11 @@ tf_module { } member_method { name: "matrix_diag" - argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], " + argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "matrix_diag_part" - argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], " + argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "matrix_inverse" @@ -1658,7 +1658,7 @@ tf_module { } member_method { name: "matrix_set_diag" - argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], " + argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "matrix_solve" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 57065a6cfc2..edeb3ba0d56 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2172,10 +2172,18 @@ tf_module { name: "MatrixDiagPartV2" argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixDiagPartV3" + argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixDiagV2" argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixDiagV3" + argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixExponential" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -2196,6 +2204,10 @@ tf_module { name: "MatrixSetDiagV2" argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixSetDiagV3" + argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixSolve" argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index a25583d7fdd..a58c988577a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -102,11 +102,11 @@ tf_module { } member_method { name: "diag" - argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], " + argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "diag_part" - argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], " + argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "eig" @@ -210,7 +210,7 @@ tf_module { } member_method { name: "set_diag" - argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], " + argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], " } member_method { name: "slogdet" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 57065a6cfc2..edeb3ba0d56 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2172,10 +2172,18 @@ tf_module { name: "MatrixDiagPartV2" argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixDiagPartV3" + argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixDiagV2" argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixDiagV3" + argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixExponential" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -2196,6 +2204,10 @@ tf_module { name: "MatrixSetDiagV2" argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "MatrixSetDiagV3" + argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], " + } member_method { name: "MatrixSolve" argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "