From 5334adcddb1009ae68316c661f3a40b8c8ff9f5e Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Mon, 22 Jul 2019 12:19:25 -0700 Subject: [PATCH] Add XLA implementations for MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2. PiperOrigin-RevId: 259379918 --- tensorflow/compiler/tests/BUILD | 13 + tensorflow/compiler/tests/binary_ops_test.py | 48 -- .../compiler/tests/matrix_diag_ops_test.py | 655 ++++++++++++++++++ tensorflow/compiler/tests/unary_ops_test.py | 26 - tensorflow/compiler/tf2xla/kernels/BUILD | 2 +- tensorflow/compiler/tf2xla/kernels/diag_op.cc | 49 +- .../tf2xla/kernels/matrix_diag_ops.cc | 425 ++++++++++++ .../tf2xla/kernels/matrix_set_diag_op.cc | 98 --- 8 files changed, 1096 insertions(+), 220 deletions(-) create mode 100644 tensorflow/compiler/tests/matrix_diag_ops_test.py create mode 100644 tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc delete mode 100644 tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 15bb0a863d1..d39d15986be 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -665,6 +665,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_diag_ops_test", + size = "medium", + timeout = "long", + srcs = ["matrix_diag_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "momentum_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0171be42148..14af571d62f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -1464,53 +1463,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) - def testMatrixSetDiag(self): - # TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to - # gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2). - for dtype in self.numeric_types: - # Square - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], - dtype=dtype), - np.array([1.0, 2.0, 3.0], dtype=dtype), - expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], - dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], - [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], - dtype=dtype), - np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), - expected=np.array( - [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], - [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], - dtype=dtype)) - - # Rectangular - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), - np.array([3.0, 4.0], dtype=dtype), - expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), - np.array([3.0, 4.0], dtype=dtype), - expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], - [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), - np.array([[-1.0, -2.0], [-4.0, -5.0]], - dtype=dtype), - expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], - [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], - dtype=dtype)) - def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py new file mode 100644 index 00000000000..a994be8b29d --- /dev/null +++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py @@ -0,0 +1,655 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA matrix diag ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2. +# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py +def square_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [3, 4, 5, 6, 7], + [8, 9, 1, 2, 3], + [4, 5, 6, 7, 8]], + [[9, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [2, 3, 4, 5, 6]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[-1, -1] = (np.array([[6, 4, 1, 7], + [5, 2, 8, 5]]), + np.array([[[0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 7, 0]], + [[0, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + [0, 0, 8, 0, 0], + [0, 0, 0, 5, 0]]])) + tests[-4, -3] = (np.array([[[8, 5], + [4, 0]], + [[6, 3], + [2, 0]]]), + np.array([[[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [8, 0, 0, 0, 0], + [4, 5, 0, 0, 0]], + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [2, 3, 0, 0, 0]]])) + tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0], + [1, 7, 5, 2, 8], + [6, 4, 1, 7, 0], + [3, 9, 6, 0, 0]], + [[1, 7, 4, 1, 0], + [9, 6, 3, 9, 6], + [5, 2, 8, 5, 0], + [1, 7, 4, 0, 0]]]), + np.array([[[1, 2, 0, 0, 0], + [6, 7, 8, 0, 0], + [3, 4, 5, 6, 0], + [0, 9, 1, 2, 3], + [0, 0, 6, 7, 8]], + [[9, 1, 0, 0, 0], + [5, 6, 7, 0, 0], + [1, 2, 3, 4, 0], + [0, 7, 8, 9, 1], + [0, 0, 4, 5, 6]]])) + tests[2, 4] = (np.array([[[5, 0, 0], + [4, 1, 0], + [3, 9, 7]], + [[4, 0, 0], + [3, 9, 0], + [2, 8, 5]]]), + np.array([[[0, 0, 3, 4, 5], + [0, 0, 0, 9, 1], + [0, 0, 0, 0, 7], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 2, 3, 4], + [0, 0, 0, 8, 9], + [0, 0, 0, 0, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]])) + # pyformat: enable + return (mat, tests) + + +def tall_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7], + [6, 5, 4]], + [[3, 2, 1], + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[0, 0] = (np.array([[1, 5, 9], + [3, 2, 6]]), + np.array([[[1, 0, 0], + [0, 5, 0], + [0, 0, 9], + [0, 0, 0]], + [[3, 0, 0], + [0, 2, 0], + [0, 0, 6], + [0, 0, 0]]])) + tests[-4, -3] = (np.array([[[9, 5], + [6, 0]], + [[7, 8], + [9, 0]]]), + np.array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [9, 0, 0], + [6, 5, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [7, 0, 0], + [9, 8, 0]]])) + tests[-2, -1] = (np.array([[[4, 8, 7], + [7, 8, 4]], + [[1, 5, 9], + [4, 8, 7]]]), + np.array([[[0, 0, 0], + [4, 0, 0], + [7, 8, 0], + [0, 8, 7], + [0, 0, 4]], + [[0, 0, 0], + [1, 0, 0], + [4, 5, 0], + [0, 8, 9], + [0, 0, 7]]])) + tests[-2, 1] = (np.array([[[2, 6, 0], + [1, 5, 9], + [4, 8, 7], + [7, 8, 4]], + [[2, 3, 0], + [3, 2, 6], + [1, 5, 9], + [4, 8, 7]]]), + np.array([[[1, 2, 0], + [4, 5, 6], + [7, 8, 9], + [0, 8, 7], + [0, 0, 4]], + [[3, 2, 0], + [1, 2, 3], + [4, 5, 6], + [0, 8, 9], + [0, 0, 7]]])) + tests[1, 2] = (np.array([[[3, 0], + [2, 6]], + [[1, 0], + [2, 3]]]), + np.array([[[0, 2, 3], + [0, 0, 6], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 2, 1], + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]])) + # pyformat: enable + return (mat, tests) + + +def fat_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 1, 2, 3]], + [[4, 5, 6, 7], + [8, 9, 1, 2], + [3, 4, 5, 6]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[0, 0] = (np.array([[1, 6, 2], + [4, 9, 5]]), + np.array([[[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0, 2, 0]], + [[4, 0, 0, 0], + [0, 9, 0, 0], + [0, 0, 5, 0]]])) + tests[2, 2] = (np.array([[3, 8], + [6, 2]]), + np.array([[[0, 0, 3, 0], + [0, 0, 0, 8], + [0, 0, 0, 0]], + [[0, 0, 6, 0], + [0, 0, 0, 2], + [0, 0, 0, 0]]])) + tests[-2, 0] = (np.array([[[1, 6, 2], + [5, 1, 0], + [9, 0, 0]], + [[4, 9, 5], + [8, 4, 0], + [3, 0, 0]]]), + np.array([[[1, 0, 0, 0], + [5, 6, 0, 0], + [9, 1, 2, 0]], + [[4, 0, 0, 0], + [8, 9, 0, 0], + [3, 4, 5, 0]]])) + tests[-1, 1] = (np.array([[[2, 7, 3], + [1, 6, 2], + [5, 1, 0]], + [[5, 1, 6], + [4, 9, 5], + [8, 4, 0]]]), + np.array([[[1, 2, 0, 0], + [5, 6, 7, 0], + [0, 1, 2, 3]], + [[4, 5, 0, 0], + [8, 9, 1, 0], + [0, 4, 5, 6]]])) + tests[0, 3] = (np.array([[[4, 0, 0], + [3, 8, 0], + [2, 7, 3], + [1, 6, 2]], + [[7, 0, 0], + [6, 2, 0], + [5, 1, 6], + [4, 9, 5]]]), + np.array([[[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0, 2, 3]], + [[4, 5, 6, 7], + [0, 9, 1, 2], + [0, 0, 5, 6]]])) + # pyformat: enable + return (mat, tests) + + +class MatrixDiagTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_diag produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_diag. + solution: numpy array representing the expected output of matrix_diag. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + diagonal = params["diagonal"] + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["diagonal"] = array_ops.placeholder( + dtype, diagonal.shape, name="diagonal") + output = array_ops.matrix_diag(**params) + result = session.run(output, + {params["diagonal"]: diagonal.astype(dtype)}) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from unary_ops_tests.py. + def testV1(self): + # pyformat: disable + vecs1 = np.array([[1, 2], + [3, 4]]) + solution1 = np.array([[[1, 0], [0, 2]], + [[3, 0], [0, 4]]]) + vecs2 = np.array([1, 2, 3, 4]) + solution2 = np.array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + vecs3 = np.array([[[1, 2, 3], + [4, 5, 6]], + [[7, 8, 9], # pylint: disable=bad-whitespace + [10, 11, 12]]]) + solution3 = np.array([[[[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], + [[4, 0, 0], + [0, 5, 0], + [0, 0, 6]]], + [[[7, 0, 0], + [0, 8, 0], + [0, 0, 9]], + [[10, 0, 0], + [0, 11, 0], + [0, 0, 12]]]]) + # pyformat: enable + self._assertOpOutputMatchesExpected({"diagonal": vecs1}, solution1) + self._assertOpOutputMatchesExpected({"diagonal": vecs2}, solution2) + self._assertOpOutputMatchesExpected({"diagonal": vecs3}, solution3) + + # From here onwards are v2-only tests. + def testSquare(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs[0], + "k": diag_index + }, solution[0]) + + def testSquareBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index + }, solution) + + def testRectangularBatch(self): + # LINT.IfChange + if not compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + return + + # Stores expected num_rows and num_cols (when the other is given). + # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) + test_list = list() + + # Square cases: + expected = { + (-1, -1): (5, 4), + (-4, -3): (5, 2), + (-2, 1): (5, 5), + (2, 4): (3, 5), + } + test_list.append((expected, square_cases())) + + # Tall cases + expected = { + (0, 0): (3, 3), + (-4, -3): (5, 2), + (-2, -1): (4, 3), + (-2, 1): (3, 3), + (1, 2): (2, 3) + } + test_list.append((expected, tall_cases())) + + # Fat cases + expected = { + (2, 2): (2, 4), + (-2, 0): (3, 3), + (-1, 1): (3, 3), + (0, 3): (3, 3) + } + test_list.append((expected, fat_cases())) + + # Giving both num_rows and num_cols + for _, tests in [tall_cases(), fat_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution.shape[-2], + "num_cols": solution.shape[-1] + }, solution) + + # Giving just num_rows or num_cols. + for expected, (_, tests) in test_list: + for diag_index, (new_num_rows, new_num_cols) in expected.items(): + vecs, solution = tests[diag_index] + solution_given_num_rows = solution.take( + indices=range(new_num_cols), axis=-1) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution_given_num_rows.shape[-2] + }, solution_given_num_rows) + solution_given_num_cols = solution.take( + indices=range(new_num_rows), axis=-2) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_cols": solution_given_num_cols.shape[-1] + }, solution_given_num_cols) + + def testPadding(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for padding_value in [555, -11]: + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, solution) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution.shape[-2], + "num_cols": solution.shape[-1], + "padding_value": padding_value + }, solution) + + +class MatrixSetDiagTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_set_diag produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_set_diag. + solution: numpy array representing the expected output of matrix_set_diag. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + input = params["input"] # pylint: disable=redefined-builtin + diagonal = params["diagonal"] + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["input"] = array_ops.placeholder( + dtype, input.shape, name="input") + params["diagonal"] = array_ops.placeholder( + dtype, diagonal.shape, name="diagonal") + output = array_ops.matrix_set_diag(**params) + result = session.run( + output, { + params["input"]: input.astype(dtype), + params["diagonal"]: diagonal.astype(dtype) + }) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from binary_ops_tests.py. + def testV1(self): + test_cases = list() + + # pyformat: disable + # pylint: disable=bad-whitespace + # Square cases. + input = np.array([[0, 1, 0], # pylint: disable=redefined-builtin + [1, 0, 1], + [1, 1, 1]]) + diag = np.array([1, 2, 3]) + solution = np.array([[1, 1, 0], + [1, 2, 1], + [1, 1, 3]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[[1, 0, 3], + [0, 2, 0], + [1, 0, 3]], + [[4, 0, 4], + [0, 5, 0], + [2, 0, 6]]]) + diag = np.array([[-1, 0, -3], + [-4, -5, -6]]) + solution = np.array([[[-1, 0, 3], + [ 0, 0, 0], + [ 1, 0, -3]], + [[-4, 0, 4], + [ 0, -5, 0], + [ 2, 0, -6]]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + # Rectangular cases. + input = np.array([[0, 1, 0], + [1, 0, 1]]) + diag = np.array([3, 4]) + solution = np.array([[3, 1, 0], + [1, 4, 1]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[0, 1], + [1, 0], + [1, 1]]) + diag = np.array([3, 4]) + solution = np.array([[3, 1], + [1, 4], + [1, 1]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[[1, 0, 3], + [0, 2, 0]], + [[4, 0, 4], + [0, 5, 0]]]) + diag = np.array([[-1, -2], [-4, -5]]) + solution = np.array([[[-1, 0, 3], + [ 0, -2, 0]], + [[-4, 0, 4], + [ 0, -5, 0]]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + # pylint: enable=bad-whitespace + # pyformat: enable + + for test in test_cases: + self._assertOpOutputMatchesExpected(test[0], test[1]) + + # From here onwards are v2-only tests. + def testSingleMatrix(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat[0] == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs[0], + "k": diag_index + }, solution) + + def testBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs, + "k": diag_index + }, solution) + + +class MatrixDiagPartTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_diag_part produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_diag_part. + solution: numpy array representing the expected output. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + input = params["input"] # pylint: disable=redefined-builtin + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["input"] = array_ops.placeholder( + dtype, input.shape, name="input") + output = array_ops.matrix_diag_part(**params) + result = session.run(output, { + params["input"]: input.astype(dtype), + }) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from unary_ops_tests.py. + def testV1(self): + matrices = np.arange(3 * 2 * 4).reshape([3, 2, 4]) + solution = np.array([[0, 5], [8, 13], [16, 21]]) + self._assertOpOutputMatchesExpected({"input": matrices}, solution) + + # From here onwards are v2-only tests. + def testSingleMatrix(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected({ + "input": mat[0], + "k": diag_index + }, solution[0]) + + def testBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected({ + "input": mat, + "k": diag_index + }, solution) + + def testPadding(self): + # LINT.IfChange + if compat.forward_compatible(2019, 7, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for padding_value in [555, -11]: + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "input": mat, + "k": diag_index, + "padding_value": padding_value + }, solution) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index bac30b63bf8..64af33c7a2a 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -27,7 +27,6 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -108,31 +107,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) - # TODO(penporn): Once XLA supports MatrixDiagV2, change the call to - # gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2). - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), - np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), - np.array( - [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], - dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, - np.array( - [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), - np.array( - [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [ - 0, 0, 6 - ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0], - [0, 0, 12]]]], - dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag_part, - np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), - np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype)) - self._assertOpOutputMatchesExpected( array_ops.prevent_gradient, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 139d6709215..ef2202c3931 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,8 +55,8 @@ tf_kernel_library( "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", + "matrix_diag_ops.cc", "matrix_inverse_op.cc", - "matrix_set_diag_op.cc", "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "next_after_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 747ec133983..1f12c7980e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -153,52 +155,5 @@ class DiagPartOp : public XlaOpKernel { REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp); -class MatrixDiagOp : public XlaOpKernel { - public: - explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, - errors::InvalidArgument("MatrixDiag op must have at an input")); - const TensorShape input_shape = ctx->InputShape(0); - - auto dims = input_shape.dim_sizes(); - OP_REQUIRES(ctx, !dims.empty(), - errors::InvalidArgument("Expected 1 <= dims, got shape ", - input_shape.DebugString())); - - - int last_dim = dims.size() - 1; - int64 last_dim_size = input_shape.dim_size(last_dim); - absl::Span other_dims(dims); - other_dims.remove_suffix(1); - - xla::XlaOp input = ctx->Input(0); - xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims); - ctx->SetOutput(0, diag); - } -}; - -REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); - -class MatrixDiagPartOp : public XlaOpKernel { - public: - explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - auto dims = input_shape.dim_sizes(); - - OP_REQUIRES(ctx, 2 <= dims.size(), - errors::InvalidArgument("Expected 2 <= dims, got shape ", - input_shape.DebugString())); - - xla::XlaOp input = ctx->Input(0); - ctx->SetOutput(0, xla::GetMatrixDiagonal(input)); - } -}; - -REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc new file mode 100644 index 00000000000..7eeb05a4920 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -0,0 +1,425 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +// Reads or infers lower_diag_index and upper_diag_index from kernel's input +// parameter "k". Also validates their values. +std::pair ProcessDiagIndex(XlaOpKernelContext* context) { + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + TensorShape diag_index_shape = context->InputShape("k"); + + // Wrapping OP_REQUIRES* macros with a function because they can "return;" + // early (without values) which contradicts ProcessDiagIndex's signature. + auto validate_diag_indices = [&]() { + if (diag_index_shape.dims() == 0) { + OP_REQUIRES_OK(context, + context->ConstantInputAsIntScalar("k", &lower_diag_index)); + upper_diag_index = lower_diag_index; + } else { + std::vector diag_index; + OP_REQUIRES_OK(context, + context->ConstantInputAsIntVector("k", &diag_index)); + OP_REQUIRES( + context, !diag_index.empty() && diag_index.size() <= 2, + errors::InvalidArgument( + "diag_index must have only one or two elements, received ", + diag_index.size(), " elements.")); + lower_diag_index = diag_index[0]; + upper_diag_index = + (diag_index.size() > 1) ? diag_index[1] : lower_diag_index; + } + OP_REQUIRES( + context, lower_diag_index <= upper_diag_index, + errors::InvalidArgument( + "lower_diag_index must not be larger than upper_diag_index: ", + lower_diag_index, " > ", upper_diag_index)); + }; + validate_diag_indices(); + return {lower_diag_index, upper_diag_index}; +} + +// Makes sure lower_diag_index and upper_diag_index are consistent with the +// input matrix size. +void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context, + const int64 lower_diag_index, + const int64 upper_diag_index, + const int64 num_rows, + const int64 num_cols) { + // `lower_diag_index == 0` condition is added to handle matrix shape = 0. + OP_REQUIRES(context, + (-num_rows < lower_diag_index && lower_diag_index < num_cols) || + lower_diag_index == 0, + errors::InvalidArgument( + "lower_diag_index is out of bound: ", lower_diag_index, + " It must be between ", -num_rows, " and ", num_cols)); + OP_REQUIRES(context, + (-num_rows < upper_diag_index && upper_diag_index < num_cols) || + upper_diag_index == 0, + errors::InvalidArgument( + "upper_diag_index is out of bound: ", upper_diag_index, + " It must be between ", -num_rows, " and ", num_cols)); + OP_REQUIRES(context, lower_diag_index <= upper_diag_index, + errors::InvalidArgument( + "lower_diag_index must not be larger than upper_diag_index: ", + lower_diag_index, " > ", upper_diag_index)); +} + +// Kernel to set matrix diagonals. +xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, + const TensorShape& input_shape, const int64 diag_rank, + const int64 num_diags, const int64 lower_diag_index, + const int64 upper_diag_index, const int64 max_diag_len, + const int64 num_rows, const int64 num_cols) { + // Creates a padding config. + const int input_rank = input_shape.dims(); + xla::PaddingConfig padding_config; + padding_config = xla::MakeNoPaddingConfig(input_rank - 1); + + // Processes one diagonal at a time: + // 1) Extracts a single diagonal (diag_slice). + // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast). + // 3) Masks diag_broadcast to get the right diagonal shape. + // + // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow. + // + // For example, + // diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4. + // [4, 5, 6], + // [7, 8, 9]] + // The expected output is [[4, 2, 0], + // [7, 5, 4], + // [0, 8, 6], + // [0, 0, 9]] + // The 1st diagonal is created by: + // 1) Extracting diag_slice = [1, 2, 0]. + // 2) Padding the vector to be as long as num_rows, + // diag_slice = [1, 2, 0, 0], + // then broadcasting diag_slice row-wise to a full matrix, + // diag_broadcast = [[1, 1, 1], + // [2, 2, 2], + // [0, 0, 0], + // [0, 0, 0]] + // The padding value can be anything because it will not appear in the + // results after masking. Here, we use zero. + // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal. + // mask = [[0, 1, 0], --> output = [[x, 2, x], + // [0, 0, 1], [x, x, 3], + // [0, 0, 0], [x, x, x], + // [0, 0, 0]] [x, x, x]], + // where x denotes the existing input contents. + std::vector broadcast_dimensions(input_rank - 1); + absl::c_iota(broadcast_dimensions, 0); + auto output = input; + for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index; + ++diag_index) { + // Extracts a single diagonal. + auto diag_slice = diag; + if (num_diags > 1) { + const int64 mapped_diag_index = upper_diag_index - diag_index; + diag_slice = xla::Collapse( + xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1, + diag_rank - 2), + {diag_rank - 2, diag_rank - 1}); + } + + // Pads if necessary. Always pad at the end because shorter diagonals in + // the input come padded at the end. + const int64 padding_length = + ((diag_index <= 0) ? num_cols : num_rows) - max_diag_len; + const xla::XlaOp zero = xla::ScalarLike(input, 0); + if (padding_length > 0) { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_length); + diag_slice = xla::Pad(diag_slice, zero, padding_config); + } + + // Broadcasts column-wise for subdiagonals; row-wise for superdiagonals. + broadcast_dimensions.back() = + (diag_index <= 0) ? input_rank - 1 : input_rank - 2; + xla::XlaOp diag_broadcast = xla::BroadcastInDim( + diag_slice, input_shape.dim_sizes(), broadcast_dimensions); + const auto mask = xla::GetDiagonalMask(output, diag_index); + output = xla::Select(mask, diag_broadcast, output); + } + return output; +} + +} // namespace + +class MatrixDiagOp : public XlaOpKernel { + public: + explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at least one input")); + const TensorShape diag_shape = context->InputShape(0); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), + errors::InvalidArgument("Expected >= 1 dims, got shape ", + diag_shape.DebugString())); + + const DataType dtype = context->expected_output_dtype(0); + const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype); + + // Initializes MatrixDiagV2-specific variables. + // Input arguments providing the values of num_rows and num_cols can be + // absent (-1) and will be inferred later. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + int64 num_rows = -1; + int64 num_cols = -1; + xla::XlaOp padding_value = zero; + + // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has + // one input, so we have to check the number of inputs before reading + // additional parameters for MatrixDiagV2. + if (context->num_inputs() > 1) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols)); + padding_value = context->Input(4); + } + + // More size validations. + const int64 diag_rank = diag_shape.dims(); + const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1); + const int64 num_diags = upper_diag_index - lower_diag_index + 1; + OP_REQUIRES( + context, + num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2), + errors::InvalidArgument( + "The number of diagonals provided in the input does not " + "match the lower_diag_index and upper_diag_index range.")); + const int64 min_num_rows = max_diag_len - std::min(upper_diag_index, 0LL); + const int64 min_num_cols = max_diag_len + std::max(lower_diag_index, 0LL); + OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows, + errors::InvalidArgument("The number of rows is too small.")); + OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols, + errors::InvalidArgument("The number of columns is too small.")); + + // Infers num_rows and num_cols. If both are unknown, assume that the output + // is square. Otherwise, use smallest possible values. + if (num_rows == -1 && num_cols == -1) { + num_rows = std::max(min_num_rows, min_num_cols); + num_cols = num_rows; + } else if (num_rows == -1) { + num_rows = min_num_rows; + } else if (num_cols == -1) { + num_cols = min_num_cols; + } + + // At least one of num_rows and num_cols must match its minimum length. + // Otherwise, we'll have some incomplete diagonals. + OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols, + errors::InvalidArgument( + "The number of rows or columns is not consistent with " + "the specified d_lower, d_upper, and diagonal.")); + + // Actual processing. + // Initializes the output tensor with padding_value. + TensorShape output_shape = diag_shape; + output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); + output_shape.AddDim(num_rows); + output_shape.AddDim(num_cols); + xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes()); + xla::XlaOp diag = context->Input(0); + context->SetOutput( + 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, + lower_diag_index, upper_diag_index, max_diag_len, + num_rows, num_cols)); + } +}; + +REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); +REGISTER_XLA_OP(Name("MatrixDiagV2") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("num_rows") + .CompileTimeConstantInput("num_cols") + .CompileTimeConstantInput("padding_value"), + MatrixDiagOp); + +class MatrixDiagPartOp : public XlaOpKernel { + public: + explicit MatrixDiagPartOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const int input_rank = input_shape.dims(); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + const DataType dtype = context->expected_output_dtype(0); + const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype); + + // Initializes MatrixDiagPartV2-specific variables. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + xla::XlaOp padding_value = zero; + + // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel. + // MatrixDiagPart only has one input, so we have to check the number of + // inputs before reading additional parameters in MatrixDiagV2. + if (context->num_inputs() > 1) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + padding_value = context->Input(2); + } + + // Checks if diag sizes are consistent with input. + const int64 num_rows = input_shape.dim_size(input_rank - 2); + const int64 num_cols = input_shape.dim_size(input_rank - 1); + ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index, + upper_diag_index, num_rows, num_cols); + + // Creates output shape. + TensorShape output_shape = input_shape; + output_shape.RemoveLastDims(2); + const int num_diags = upper_diag_index - lower_diag_index + 1; + if (num_diags > 1) output_shape.AddDim(num_diags); + const int32 max_diag_len = + std::min(num_rows + std::min(upper_diag_index, 0LL), + num_cols - std::max(lower_diag_index, 0LL)); + output_shape.AddDim(max_diag_len); + + // Computes output. + xla::XlaOp input = context->Input(0); + std::vector diag_list; + xla::PaddingConfig padding_config; + if (num_diags == 1) { + context->SetOutput(0, xla::GetMatrixDiagonal(input, upper_diag_index)); + return; + } + padding_config = xla::MakeNoPaddingConfig(input_rank - 1); + for (int diag_index = upper_diag_index; diag_index >= lower_diag_index; + --diag_index) { + auto single_diag = xla::GetMatrixDiagonal(input, diag_index); + const int64 diag_length = + (diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index); + const int64 padding_length = max_diag_len - diag_length; + if (padding_length > 0) { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_length); + single_diag = xla::Pad(single_diag, padding_value, padding_config); + } + diag_list.emplace_back(single_diag); + } + auto concat = + xla::ConcatInDim(context->builder(), diag_list, input_rank - 2); + context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); +REGISTER_XLA_OP(Name("MatrixDiagPartV2") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("padding_value"), + MatrixDiagPartOp); + +class MatrixSetDiagOp : public XlaOpKernel { + public: + explicit MatrixSetDiagOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape diag_shape = context->InputShape(1); + const int input_rank = input_shape.dims(); + const int diag_rank = diag_shape.dims(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), + errors::InvalidArgument( + "diagonal must be at least 1-dim, received shape: ", + diag_shape.DebugString())); + + // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag + // only has two inputs, so we have to check the number of inputs before + // reading additional parameters in MatrixSetDiagV2. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + if (context->num_inputs() > 2) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + } + + // Checks if diag sizes are consistent with input. + const int64 num_rows = input_shape.dim_size(input_rank - 2); + const int64 num_cols = input_shape.dim_size(input_rank - 1); + ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index, + upper_diag_index, num_rows, num_cols); + const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1; + OP_REQUIRES( + context, + lower_diag_index == upper_diag_index || + (diag_shape.dim_size(input_rank - 2) == num_diags), + errors::InvalidArgument("The number of diagonals provided in `diag` " + "is not consistent with `lower_diag_index` and " + "`upper_diag_index`")); + + TensorShape expected_diag_shape = input_shape; + expected_diag_shape.RemoveLastDims(2); + if (num_diags > 1) expected_diag_shape.AddDim(num_diags); + const int32 max_diag_len = + std::min(num_rows + std::min(upper_diag_index, 0LL), + num_cols - std::max(lower_diag_index, 0LL)); + expected_diag_shape.AddDim(max_diag_len); + OP_REQUIRES( + context, expected_diag_shape == diag_shape, + errors::InvalidArgument( + "Either first dimensions of diagonal don't match input.shape[:-2], " + "or diagonal.shape[:-1] is not equal to the longests diagonal in " + "range [lower_diag_index:upper_diag_index].\nInput shape: ", + input_shape.DebugString(), + "\nDiagonal shape: ", diag_shape.DebugString(), + "\nExpected diagonal shape: ", expected_diag_shape.DebugString())); + + // Actual processing. + xla::XlaOp input = context->Input(0); + xla::XlaOp diag = context->Input(1); + context->SetOutput( + 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags, + lower_diag_index, upper_diag_index, max_diag_len, + num_rows, num_cols)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); +}; + +REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); +REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"), + MatrixSetDiagOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc deleted file mode 100644 index ee9764c0c35..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/primitive_util.h" - -namespace tensorflow { - -class MatrixSetDiagOp : public XlaOpKernel { - public: - explicit MatrixSetDiagOp(OpKernelConstruction* context) - : XlaOpKernel(context) {} - - void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape(0); - const TensorShape diag_shape = context->InputShape(1); - - const int rank = input_shape.dims(); - - // Preliminary validation of sizes. - OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), - errors::InvalidArgument( - "input must be at least 2-dim, received shape: ", - input_shape.DebugString())); - - // Check to make sure the last dimension of diag is equal to the smaller of - // the last two dimensions of input. - const int64 m = input_shape.dim_size(rank - 2); - const int64 n = input_shape.dim_size(rank - 1); - const int64 min_dim = std::min(m, n); - - TensorShape batch_shape = input_shape; - batch_shape.RemoveLastDims(2); - - TensorShape expected_diag_shape = batch_shape; - expected_diag_shape.AddDim(min_dim); - OP_REQUIRES(context, expected_diag_shape == diag_shape, - errors::InvalidArgument( - "must have diagonal.shape == input.shape[:-2] + " - "min(input.shape[-2:]), but received input shape: ", - input_shape.DebugString(), - " and diagonal shape: ", diag_shape.DebugString())); - - xla::XlaBuilder* builder = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp diag = context->Input(1); - - auto zero = XlaHelpers::Zero(builder, context->input_type(0)); - - // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); - xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); - auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); - - // Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or) - // because we need to broadcast on the right. - std::vector diag_broadcast_dims(rank - 1); - std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); - if (min_dim != m) { - diag_broadcast_dims.back() = rank - 1; - } - if (context->input_xla_type(0) == xla::PRED) { - diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); - - } else { - diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); - } - - auto output = xla::Select(indicator, diag, input); - context->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); -}; - -REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); - -} // namespace tensorflow