Add XLA implementations for MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.

PiperOrigin-RevId: 259379918
This commit is contained in:
Penporn Koanantakool 2019-07-22 12:19:25 -07:00 committed by TensorFlower Gardener
parent 1f7959a055
commit 5334adcddb
8 changed files with 1096 additions and 220 deletions

View File

@ -665,6 +665,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "matrix_diag_ops_test",
size = "medium",
timeout = "long",
srcs = ["matrix_diag_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "momentum_test",
size = "small",

View File

@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@ -1464,53 +1463,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([4, 5, 6], dtype=np.int32),
expected=None)
def testMatrixSetDiag(self):
# TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to
# gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2).
for dtype in self.numeric_types:
# Square
self._testBinary(
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
dtype=dtype),
np.array([1.0, 2.0, 3.0], dtype=dtype),
expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
dtype=dtype))
self._testBinary(
gen_array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
dtype=dtype),
np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
expected=np.array(
[[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
dtype=dtype))
# Rectangular
self._testBinary(
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
self._testBinary(
gen_array_ops.matrix_set_diag,
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
self._testBinary(
gen_array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
np.array([[-1.0, -2.0], [-4.0, -5.0]],
dtype=dtype),
expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
def testBroadcastTo(self):
for dtype in self.all_types:
x = np.random.randint(0, high=100, size=[2, 3])

View File

@ -0,0 +1,655 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA matrix diag ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
def square_cases():
# pyformat: disable
mat = np.array([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 1],
[3, 4, 5, 6, 7],
[8, 9, 1, 2, 3],
[4, 5, 6, 7, 8]],
[[9, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 1],
[2, 3, 4, 5, 6]]])
tests = dict()
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
tests[-1, -1] = (np.array([[6, 4, 1, 7],
[5, 2, 8, 5]]),
np.array([[[0, 0, 0, 0, 0],
[6, 0, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 7, 0]],
[[0, 0, 0, 0, 0],
[5, 0, 0, 0, 0],
[0, 2, 0, 0, 0],
[0, 0, 8, 0, 0],
[0, 0, 0, 5, 0]]]))
tests[-4, -3] = (np.array([[[8, 5],
[4, 0]],
[[6, 3],
[2, 0]]]),
np.array([[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[8, 0, 0, 0, 0],
[4, 5, 0, 0, 0]],
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[6, 0, 0, 0, 0],
[2, 3, 0, 0, 0]]]))
tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0],
[1, 7, 5, 2, 8],
[6, 4, 1, 7, 0],
[3, 9, 6, 0, 0]],
[[1, 7, 4, 1, 0],
[9, 6, 3, 9, 6],
[5, 2, 8, 5, 0],
[1, 7, 4, 0, 0]]]),
np.array([[[1, 2, 0, 0, 0],
[6, 7, 8, 0, 0],
[3, 4, 5, 6, 0],
[0, 9, 1, 2, 3],
[0, 0, 6, 7, 8]],
[[9, 1, 0, 0, 0],
[5, 6, 7, 0, 0],
[1, 2, 3, 4, 0],
[0, 7, 8, 9, 1],
[0, 0, 4, 5, 6]]]))
tests[2, 4] = (np.array([[[5, 0, 0],
[4, 1, 0],
[3, 9, 7]],
[[4, 0, 0],
[3, 9, 0],
[2, 8, 5]]]),
np.array([[[0, 0, 3, 4, 5],
[0, 0, 0, 9, 1],
[0, 0, 0, 0, 7],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]],
[[0, 0, 2, 3, 4],
[0, 0, 0, 8, 9],
[0, 0, 0, 0, 5],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]))
# pyformat: enable
return (mat, tests)
def tall_cases():
# pyformat: disable
mat = np.array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[9, 8, 7],
[6, 5, 4]],
[[3, 2, 1],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[9, 8, 7]]])
tests = dict()
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
tests[0, 0] = (np.array([[1, 5, 9],
[3, 2, 6]]),
np.array([[[1, 0, 0],
[0, 5, 0],
[0, 0, 9],
[0, 0, 0]],
[[3, 0, 0],
[0, 2, 0],
[0, 0, 6],
[0, 0, 0]]]))
tests[-4, -3] = (np.array([[[9, 5],
[6, 0]],
[[7, 8],
[9, 0]]]),
np.array([[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[9, 0, 0],
[6, 5, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[7, 0, 0],
[9, 8, 0]]]))
tests[-2, -1] = (np.array([[[4, 8, 7],
[7, 8, 4]],
[[1, 5, 9],
[4, 8, 7]]]),
np.array([[[0, 0, 0],
[4, 0, 0],
[7, 8, 0],
[0, 8, 7],
[0, 0, 4]],
[[0, 0, 0],
[1, 0, 0],
[4, 5, 0],
[0, 8, 9],
[0, 0, 7]]]))
tests[-2, 1] = (np.array([[[2, 6, 0],
[1, 5, 9],
[4, 8, 7],
[7, 8, 4]],
[[2, 3, 0],
[3, 2, 6],
[1, 5, 9],
[4, 8, 7]]]),
np.array([[[1, 2, 0],
[4, 5, 6],
[7, 8, 9],
[0, 8, 7],
[0, 0, 4]],
[[3, 2, 0],
[1, 2, 3],
[4, 5, 6],
[0, 8, 9],
[0, 0, 7]]]))
tests[1, 2] = (np.array([[[3, 0],
[2, 6]],
[[1, 0],
[2, 3]]]),
np.array([[[0, 2, 3],
[0, 0, 6],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 2, 1],
[0, 0, 3],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]]))
# pyformat: enable
return (mat, tests)
def fat_cases():
# pyformat: disable
mat = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 1, 2, 3]],
[[4, 5, 6, 7],
[8, 9, 1, 2],
[3, 4, 5, 6]]])
tests = dict()
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
tests[0, 0] = (np.array([[1, 6, 2],
[4, 9, 5]]),
np.array([[[1, 0, 0, 0],
[0, 6, 0, 0],
[0, 0, 2, 0]],
[[4, 0, 0, 0],
[0, 9, 0, 0],
[0, 0, 5, 0]]]))
tests[2, 2] = (np.array([[3, 8],
[6, 2]]),
np.array([[[0, 0, 3, 0],
[0, 0, 0, 8],
[0, 0, 0, 0]],
[[0, 0, 6, 0],
[0, 0, 0, 2],
[0, 0, 0, 0]]]))
tests[-2, 0] = (np.array([[[1, 6, 2],
[5, 1, 0],
[9, 0, 0]],
[[4, 9, 5],
[8, 4, 0],
[3, 0, 0]]]),
np.array([[[1, 0, 0, 0],
[5, 6, 0, 0],
[9, 1, 2, 0]],
[[4, 0, 0, 0],
[8, 9, 0, 0],
[3, 4, 5, 0]]]))
tests[-1, 1] = (np.array([[[2, 7, 3],
[1, 6, 2],
[5, 1, 0]],
[[5, 1, 6],
[4, 9, 5],
[8, 4, 0]]]),
np.array([[[1, 2, 0, 0],
[5, 6, 7, 0],
[0, 1, 2, 3]],
[[4, 5, 0, 0],
[8, 9, 1, 0],
[0, 4, 5, 6]]]))
tests[0, 3] = (np.array([[[4, 0, 0],
[3, 8, 0],
[2, 7, 3],
[1, 6, 2]],
[[7, 0, 0],
[6, 2, 0],
[5, 1, 6],
[4, 9, 5]]]),
np.array([[[1, 2, 3, 4],
[0, 6, 7, 8],
[0, 0, 2, 3]],
[[4, 5, 6, 7],
[0, 9, 1, 2],
[0, 0, 5, 6]]]))
# pyformat: enable
return (mat, tests)
class MatrixDiagTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self,
params,
solution,
rtol=1e-3,
atol=1e-5):
"""Verifies that matrix_diag produces `solution` when fed `params`.
Args:
params: dictionary containing input parameters to matrix_diag.
solution: numpy array representing the expected output of matrix_diag.
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
diagonal = params["diagonal"]
with self.session() as session:
for dtype in self.numeric_types - {np.int8, np.uint8}:
expected = solution.astype(dtype)
with self.test_scope():
params["diagonal"] = array_ops.placeholder(
dtype, diagonal.shape, name="diagonal")
output = array_ops.matrix_diag(**params)
result = session.run(output,
{params["diagonal"]: diagonal.astype(dtype)})
self.assertEqual(output.dtype, expected.dtype)
self.assertAllCloseAccordingToType(
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
# Generic tests applicable to both v1 and v2 ops.
# Originally from unary_ops_tests.py.
def testV1(self):
# pyformat: disable
vecs1 = np.array([[1, 2],
[3, 4]])
solution1 = np.array([[[1, 0], [0, 2]],
[[3, 0], [0, 4]]])
vecs2 = np.array([1, 2, 3, 4])
solution2 = np.array([[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]])
vecs3 = np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9], # pylint: disable=bad-whitespace
[10, 11, 12]]])
solution3 = np.array([[[[1, 0, 0],
[0, 2, 0],
[0, 0, 3]],
[[4, 0, 0],
[0, 5, 0],
[0, 0, 6]]],
[[[7, 0, 0],
[0, 8, 0],
[0, 0, 9]],
[[10, 0, 0],
[0, 11, 0],
[0, 0, 12]]]])
# pyformat: enable
self._assertOpOutputMatchesExpected({"diagonal": vecs1}, solution1)
self._assertOpOutputMatchesExpected({"diagonal": vecs2}, solution2)
self._assertOpOutputMatchesExpected({"diagonal": vecs3}, solution3)
# From here onwards are v2-only tests.
def testSquare(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases()]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs[0],
"k": diag_index
}, solution[0])
def testSquareBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases()]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index
}, solution)
def testRectangularBatch(self):
# LINT.IfChange
if not compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
return
# Stores expected num_rows and num_cols (when the other is given).
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
test_list = list()
# Square cases:
expected = {
(-1, -1): (5, 4),
(-4, -3): (5, 2),
(-2, 1): (5, 5),
(2, 4): (3, 5),
}
test_list.append((expected, square_cases()))
# Tall cases
expected = {
(0, 0): (3, 3),
(-4, -3): (5, 2),
(-2, -1): (4, 3),
(-2, 1): (3, 3),
(1, 2): (2, 3)
}
test_list.append((expected, tall_cases()))
# Fat cases
expected = {
(2, 2): (2, 4),
(-2, 0): (3, 3),
(-1, 1): (3, 3),
(0, 3): (3, 3)
}
test_list.append((expected, fat_cases()))
# Giving both num_rows and num_cols
for _, tests in [tall_cases(), fat_cases()]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index,
"num_rows": solution.shape[-2],
"num_cols": solution.shape[-1]
}, solution)
# Giving just num_rows or num_cols.
for expected, (_, tests) in test_list:
for diag_index, (new_num_rows, new_num_cols) in expected.items():
vecs, solution = tests[diag_index]
solution_given_num_rows = solution.take(
indices=range(new_num_cols), axis=-1)
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index,
"num_rows": solution_given_num_rows.shape[-2]
}, solution_given_num_rows)
solution_given_num_cols = solution.take(
indices=range(new_num_rows), axis=-2)
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index,
"num_cols": solution_given_num_cols.shape[-1]
}, solution_given_num_cols)
def testPadding(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for padding_value in [555, -11]:
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (vecs, solution) in tests.items():
mask = (solution == 0)
solution = solution + (mask * padding_value)
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index,
"num_rows": solution.shape[-2],
"num_cols": solution.shape[-1],
"padding_value": padding_value
}, solution)
class MatrixSetDiagTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self,
params,
solution,
rtol=1e-3,
atol=1e-5):
"""Verifies that matrix_set_diag produces `solution` when fed `params`.
Args:
params: dictionary containing input parameters to matrix_set_diag.
solution: numpy array representing the expected output of matrix_set_diag.
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
input = params["input"] # pylint: disable=redefined-builtin
diagonal = params["diagonal"]
with self.session() as session:
for dtype in self.numeric_types - {np.int8, np.uint8}:
expected = solution.astype(dtype)
with self.test_scope():
params["input"] = array_ops.placeholder(
dtype, input.shape, name="input")
params["diagonal"] = array_ops.placeholder(
dtype, diagonal.shape, name="diagonal")
output = array_ops.matrix_set_diag(**params)
result = session.run(
output, {
params["input"]: input.astype(dtype),
params["diagonal"]: diagonal.astype(dtype)
})
self.assertEqual(output.dtype, expected.dtype)
self.assertAllCloseAccordingToType(
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
# Generic tests applicable to both v1 and v2 ops.
# Originally from binary_ops_tests.py.
def testV1(self):
test_cases = list()
# pyformat: disable
# pylint: disable=bad-whitespace
# Square cases.
input = np.array([[0, 1, 0], # pylint: disable=redefined-builtin
[1, 0, 1],
[1, 1, 1]])
diag = np.array([1, 2, 3])
solution = np.array([[1, 1, 0],
[1, 2, 1],
[1, 1, 3]])
test_cases.append(({"input": input, "diagonal": diag}, solution))
input = np.array([[[1, 0, 3],
[0, 2, 0],
[1, 0, 3]],
[[4, 0, 4],
[0, 5, 0],
[2, 0, 6]]])
diag = np.array([[-1, 0, -3],
[-4, -5, -6]])
solution = np.array([[[-1, 0, 3],
[ 0, 0, 0],
[ 1, 0, -3]],
[[-4, 0, 4],
[ 0, -5, 0],
[ 2, 0, -6]]])
test_cases.append(({"input": input, "diagonal": diag}, solution))
# Rectangular cases.
input = np.array([[0, 1, 0],
[1, 0, 1]])
diag = np.array([3, 4])
solution = np.array([[3, 1, 0],
[1, 4, 1]])
test_cases.append(({"input": input, "diagonal": diag}, solution))
input = np.array([[0, 1],
[1, 0],
[1, 1]])
diag = np.array([3, 4])
solution = np.array([[3, 1],
[1, 4],
[1, 1]])
test_cases.append(({"input": input, "diagonal": diag}, solution))
input = np.array([[[1, 0, 3],
[0, 2, 0]],
[[4, 0, 4],
[0, 5, 0]]])
diag = np.array([[-1, -2], [-4, -5]])
solution = np.array([[[-1, 0, 3],
[ 0, -2, 0]],
[[-4, 0, 4],
[ 0, -5, 0]]])
test_cases.append(({"input": input, "diagonal": diag}, solution))
# pylint: enable=bad-whitespace
# pyformat: enable
for test in test_cases:
self._assertOpOutputMatchesExpected(test[0], test[1])
# From here onwards are v2-only tests.
def testSingleMatrix(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat[0] == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat[0]
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs[0],
"k": diag_index
}, solution)
def testBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs,
"k": diag_index
}, solution)
class MatrixDiagPartTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self,
params,
solution,
rtol=1e-3,
atol=1e-5):
"""Verifies that matrix_diag_part produces `solution` when fed `params`.
Args:
params: dictionary containing input parameters to matrix_diag_part.
solution: numpy array representing the expected output.
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
input = params["input"] # pylint: disable=redefined-builtin
with self.session() as session:
for dtype in self.numeric_types - {np.int8, np.uint8}:
expected = solution.astype(dtype)
with self.test_scope():
params["input"] = array_ops.placeholder(
dtype, input.shape, name="input")
output = array_ops.matrix_diag_part(**params)
result = session.run(output, {
params["input"]: input.astype(dtype),
})
self.assertEqual(output.dtype, expected.dtype)
self.assertAllCloseAccordingToType(
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
# Generic tests applicable to both v1 and v2 ops.
# Originally from unary_ops_tests.py.
def testV1(self):
matrices = np.arange(3 * 2 * 4).reshape([3, 2, 4])
solution = np.array([[0, 5], [8, 13], [16, 21]])
self._assertOpOutputMatchesExpected({"input": matrices}, solution)
# From here onwards are v2-only tests.
def testSingleMatrix(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected({
"input": mat[0],
"k": diag_index
}, solution[0])
def testBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected({
"input": mat,
"k": diag_index
}, solution)
def testPadding(self):
# LINT.IfChange
if compat.forward_compatible(2019, 7, 31):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for padding_value in [555, -11]:
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (solution, _) in tests.items():
mask = (solution == 0)
solution = solution + (mask * padding_value)
self._assertOpOutputMatchesExpected(
{
"input": mat,
"k": diag_index,
"padding_value": padding_value
}, solution)
if __name__ == "__main__":
googletest.main()

View File

@ -27,7 +27,6 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@ -108,31 +107,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
# TODO(penporn): Once XLA supports MatrixDiagV2, change the call to
# gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2).
self._assertOpOutputMatchesExpected(
gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
self._assertOpOutputMatchesExpected(
gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
gen_array_ops.matrix_diag,
np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
np.array(
[[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [
0, 0, 6
]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0],
[0, 0, 12]]]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
gen_array_ops.matrix_diag_part,
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.prevent_gradient,
np.array([[-1, 1]], dtype=dtype),

View File

@ -55,8 +55,8 @@ tf_kernel_library(
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
"matrix_diag_ops.cc",
"matrix_inverse_op.cc",
"matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
"next_after_op.cc",

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/pooling.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@ -153,52 +155,5 @@ class DiagPartOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
class MatrixDiagOp : public XlaOpKernel {
public:
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
OP_REQUIRES(ctx, !dims.empty(),
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
absl::Span<const int64> other_dims(dims);
other_dims.remove_suffix(1);
xla::XlaOp input = ctx->Input(0);
xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims);
ctx->SetOutput(0, diag);
}
};
REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
class MatrixDiagPartOp : public XlaOpKernel {
public:
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
auto dims = input_shape.dim_sizes();
OP_REQUIRES(ctx, 2 <= dims.size(),
errors::InvalidArgument("Expected 2 <= dims, got shape ",
input_shape.DebugString()));
xla::XlaOp input = ctx->Input(0);
ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
}
};
REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,425 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
// Reads or infers lower_diag_index and upper_diag_index from kernel's input
// parameter "k". Also validates their values.
std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
int64 lower_diag_index = 0;
int64 upper_diag_index = 0;
TensorShape diag_index_shape = context->InputShape("k");
// Wrapping OP_REQUIRES* macros with a function because they can "return;"
// early (without values) which contradicts ProcessDiagIndex's signature.
auto validate_diag_indices = [&]() {
if (diag_index_shape.dims() == 0) {
OP_REQUIRES_OK(context,
context->ConstantInputAsIntScalar("k", &lower_diag_index));
upper_diag_index = lower_diag_index;
} else {
std::vector<int64> diag_index;
OP_REQUIRES_OK(context,
context->ConstantInputAsIntVector("k", &diag_index));
OP_REQUIRES(
context, !diag_index.empty() && diag_index.size() <= 2,
errors::InvalidArgument(
"diag_index must have only one or two elements, received ",
diag_index.size(), " elements."));
lower_diag_index = diag_index[0];
upper_diag_index =
(diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
}
OP_REQUIRES(
context, lower_diag_index <= upper_diag_index,
errors::InvalidArgument(
"lower_diag_index must not be larger than upper_diag_index: ",
lower_diag_index, " > ", upper_diag_index));
};
validate_diag_indices();
return {lower_diag_index, upper_diag_index};
}
// Makes sure lower_diag_index and upper_diag_index are consistent with the
// input matrix size.
void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
const int64 lower_diag_index,
const int64 upper_diag_index,
const int64 num_rows,
const int64 num_cols) {
// `lower_diag_index == 0` condition is added to handle matrix shape = 0.
OP_REQUIRES(context,
(-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
lower_diag_index == 0,
errors::InvalidArgument(
"lower_diag_index is out of bound: ", lower_diag_index,
" It must be between ", -num_rows, " and ", num_cols));
OP_REQUIRES(context,
(-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
upper_diag_index == 0,
errors::InvalidArgument(
"upper_diag_index is out of bound: ", upper_diag_index,
" It must be between ", -num_rows, " and ", num_cols));
OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
errors::InvalidArgument(
"lower_diag_index must not be larger than upper_diag_index: ",
lower_diag_index, " > ", upper_diag_index));
}
// Kernel to set matrix diagonals.
xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
const TensorShape& input_shape, const int64 diag_rank,
const int64 num_diags, const int64 lower_diag_index,
const int64 upper_diag_index, const int64 max_diag_len,
const int64 num_rows, const int64 num_cols) {
// Creates a padding config.
const int input_rank = input_shape.dims();
xla::PaddingConfig padding_config;
padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
// Processes one diagonal at a time:
// 1) Extracts a single diagonal (diag_slice).
// 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
// 3) Masks diag_broadcast to get the right diagonal shape.
//
// XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
//
// For example,
// diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4.
// [4, 5, 6],
// [7, 8, 9]]
// The expected output is [[4, 2, 0],
// [7, 5, 4],
// [0, 8, 6],
// [0, 0, 9]]
// The 1st diagonal is created by:
// 1) Extracting diag_slice = [1, 2, 0].
// 2) Padding the vector to be as long as num_rows,
// diag_slice = [1, 2, 0, 0],
// then broadcasting diag_slice row-wise to a full matrix,
// diag_broadcast = [[1, 1, 1],
// [2, 2, 2],
// [0, 0, 0],
// [0, 0, 0]]
// The padding value can be anything because it will not appear in the
// results after masking. Here, we use zero.
// 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
// mask = [[0, 1, 0], --> output = [[x, 2, x],
// [0, 0, 1], [x, x, 3],
// [0, 0, 0], [x, x, x],
// [0, 0, 0]] [x, x, x]],
// where x denotes the existing input contents.
std::vector<int64> broadcast_dimensions(input_rank - 1);
absl::c_iota(broadcast_dimensions, 0);
auto output = input;
for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
++diag_index) {
// Extracts a single diagonal.
auto diag_slice = diag;
if (num_diags > 1) {
const int64 mapped_diag_index = upper_diag_index - diag_index;
diag_slice = xla::Collapse(
xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
diag_rank - 2),
{diag_rank - 2, diag_rank - 1});
}
// Pads if necessary. Always pad at the end because shorter diagonals in
// the input come padded at the end.
const int64 padding_length =
((diag_index <= 0) ? num_cols : num_rows) - max_diag_len;
const xla::XlaOp zero = xla::ScalarLike(input, 0);
if (padding_length > 0) {
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_length);
diag_slice = xla::Pad(diag_slice, zero, padding_config);
}
// Broadcasts column-wise for subdiagonals; row-wise for superdiagonals.
broadcast_dimensions.back() =
(diag_index <= 0) ? input_rank - 1 : input_rank - 2;
xla::XlaOp diag_broadcast = xla::BroadcastInDim(
diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
const auto mask = xla::GetDiagonalMask(output, diag_index);
output = xla::Select(mask, diag_broadcast, output);
}
return output;
}
} // namespace
class MatrixDiagOp : public XlaOpKernel {
public:
explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
OP_REQUIRES(
context, context->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at least one input"));
const TensorShape diag_shape = context->InputShape(0);
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
errors::InvalidArgument("Expected >= 1 dims, got shape ",
diag_shape.DebugString()));
const DataType dtype = context->expected_output_dtype(0);
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
// Initializes MatrixDiagV2-specific variables.
// Input arguments providing the values of num_rows and num_cols can be
// absent (-1) and will be inferred later.
int64 lower_diag_index = 0;
int64 upper_diag_index = 0;
int64 num_rows = -1;
int64 num_cols = -1;
xla::XlaOp padding_value = zero;
// MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
// one input, so we have to check the number of inputs before reading
// additional parameters for MatrixDiagV2.
if (context->num_inputs() > 1) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
padding_value = context->Input(4);
}
// More size validations.
const int64 diag_rank = diag_shape.dims();
const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
const int64 num_diags = upper_diag_index - lower_diag_index + 1;
OP_REQUIRES(
context,
num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
errors::InvalidArgument(
"The number of diagonals provided in the input does not "
"match the lower_diag_index and upper_diag_index range."));
const int64 min_num_rows = max_diag_len - std::min(upper_diag_index, 0LL);
const int64 min_num_cols = max_diag_len + std::max(lower_diag_index, 0LL);
OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
errors::InvalidArgument("The number of rows is too small."));
OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
errors::InvalidArgument("The number of columns is too small."));
// Infers num_rows and num_cols. If both are unknown, assume that the output
// is square. Otherwise, use smallest possible values.
if (num_rows == -1 && num_cols == -1) {
num_rows = std::max(min_num_rows, min_num_cols);
num_cols = num_rows;
} else if (num_rows == -1) {
num_rows = min_num_rows;
} else if (num_cols == -1) {
num_cols = min_num_cols;
}
// At least one of num_rows and num_cols must match its minimum length.
// Otherwise, we'll have some incomplete diagonals.
OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
errors::InvalidArgument(
"The number of rows or columns is not consistent with "
"the specified d_lower, d_upper, and diagonal."));
// Actual processing.
// Initializes the output tensor with padding_value.
TensorShape output_shape = diag_shape;
output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
output_shape.AddDim(num_rows);
output_shape.AddDim(num_cols);
xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
xla::XlaOp diag = context->Input(0);
context->SetOutput(
0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
lower_diag_index, upper_diag_index, max_diag_len,
num_rows, num_cols));
}
};
REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
REGISTER_XLA_OP(Name("MatrixDiagV2")
.CompileTimeConstantInput("k")
.CompileTimeConstantInput("num_rows")
.CompileTimeConstantInput("num_cols")
.CompileTimeConstantInput("padding_value"),
MatrixDiagOp);
class MatrixDiagPartOp : public XlaOpKernel {
public:
explicit MatrixDiagPartOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const int input_rank = input_shape.dims();
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
errors::InvalidArgument(
"input must be at least 2-dim, received shape: ",
input_shape.DebugString()));
const DataType dtype = context->expected_output_dtype(0);
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
// Initializes MatrixDiagPartV2-specific variables.
int64 lower_diag_index = 0;
int64 upper_diag_index = 0;
xla::XlaOp padding_value = zero;
// MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
// MatrixDiagPart only has one input, so we have to check the number of
// inputs before reading additional parameters in MatrixDiagV2.
if (context->num_inputs() > 1) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
padding_value = context->Input(2);
}
// Checks if diag sizes are consistent with input.
const int64 num_rows = input_shape.dim_size(input_rank - 2);
const int64 num_cols = input_shape.dim_size(input_rank - 1);
ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
upper_diag_index, num_rows, num_cols);
// Creates output shape.
TensorShape output_shape = input_shape;
output_shape.RemoveLastDims(2);
const int num_diags = upper_diag_index - lower_diag_index + 1;
if (num_diags > 1) output_shape.AddDim(num_diags);
const int32 max_diag_len =
std::min(num_rows + std::min(upper_diag_index, 0LL),
num_cols - std::max(lower_diag_index, 0LL));
output_shape.AddDim(max_diag_len);
// Computes output.
xla::XlaOp input = context->Input(0);
std::vector<xla::XlaOp> diag_list;
xla::PaddingConfig padding_config;
if (num_diags == 1) {
context->SetOutput(0, xla::GetMatrixDiagonal(input, upper_diag_index));
return;
}
padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
--diag_index) {
auto single_diag = xla::GetMatrixDiagonal(input, diag_index);
const int64 diag_length =
(diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index);
const int64 padding_length = max_diag_len - diag_length;
if (padding_length > 0) {
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_length);
single_diag = xla::Pad(single_diag, padding_value, padding_config);
}
diag_list.emplace_back(single_diag);
}
auto concat =
xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
}
};
REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
REGISTER_XLA_OP(Name("MatrixDiagPartV2")
.CompileTimeConstantInput("k")
.CompileTimeConstantInput("padding_value"),
MatrixDiagPartOp);
class MatrixSetDiagOp : public XlaOpKernel {
public:
explicit MatrixSetDiagOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const TensorShape diag_shape = context->InputShape(1);
const int input_rank = input_shape.dims();
const int diag_rank = diag_shape.dims();
// Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
errors::InvalidArgument(
"input must be at least 2-dim, received shape: ",
input_shape.DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
errors::InvalidArgument(
"diagonal must be at least 1-dim, received shape: ",
diag_shape.DebugString()));
// MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
// only has two inputs, so we have to check the number of inputs before
// reading additional parameters in MatrixSetDiagV2.
int64 lower_diag_index = 0;
int64 upper_diag_index = 0;
if (context->num_inputs() > 2) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
}
// Checks if diag sizes are consistent with input.
const int64 num_rows = input_shape.dim_size(input_rank - 2);
const int64 num_cols = input_shape.dim_size(input_rank - 1);
ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
upper_diag_index, num_rows, num_cols);
const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
OP_REQUIRES(
context,
lower_diag_index == upper_diag_index ||
(diag_shape.dim_size(input_rank - 2) == num_diags),
errors::InvalidArgument("The number of diagonals provided in `diag` "
"is not consistent with `lower_diag_index` and "
"`upper_diag_index`"));
TensorShape expected_diag_shape = input_shape;
expected_diag_shape.RemoveLastDims(2);
if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
const int32 max_diag_len =
std::min(num_rows + std::min(upper_diag_index, 0LL),
num_cols - std::max(lower_diag_index, 0LL));
expected_diag_shape.AddDim(max_diag_len);
OP_REQUIRES(
context, expected_diag_shape == diag_shape,
errors::InvalidArgument(
"Either first dimensions of diagonal don't match input.shape[:-2], "
"or diagonal.shape[:-1] is not equal to the longests diagonal in "
"range [lower_diag_index:upper_diag_index].\nInput shape: ",
input_shape.DebugString(),
"\nDiagonal shape: ", diag_shape.DebugString(),
"\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
// Actual processing.
xla::XlaOp input = context->Input(0);
xla::XlaOp diag = context->Input(1);
context->SetOutput(
0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
lower_diag_index, upper_diag_index, max_diag_len,
num_rows, num_cols));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
};
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
MatrixSetDiagOp);
} // namespace tensorflow

View File

@ -1,98 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
namespace tensorflow {
class MatrixSetDiagOp : public XlaOpKernel {
public:
explicit MatrixSetDiagOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const TensorShape diag_shape = context->InputShape(1);
const int rank = input_shape.dims();
// Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
errors::InvalidArgument(
"input must be at least 2-dim, received shape: ",
input_shape.DebugString()));
// Check to make sure the last dimension of diag is equal to the smaller of
// the last two dimensions of input.
const int64 m = input_shape.dim_size(rank - 2);
const int64 n = input_shape.dim_size(rank - 1);
const int64 min_dim = std::min(m, n);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
TensorShape expected_diag_shape = batch_shape;
expected_diag_shape.AddDim(min_dim);
OP_REQUIRES(context, expected_diag_shape == diag_shape,
errors::InvalidArgument(
"must have diagonal.shape == input.shape[:-2] + "
"min(input.shape[-2:]), but received input shape: ",
input_shape.DebugString(),
" and diagonal shape: ", diag_shape.DebugString()));
xla::XlaBuilder* builder = context->builder();
xla::XlaOp input = context->Input(0);
xla::XlaOp diag = context->Input(1);
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m);
xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n);
auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
/*broadcast_dimensions=*/{0});
indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
// Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or)
// because we need to broadcast on the right.
std::vector<int64> diag_broadcast_dims(rank - 1);
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
if (min_dim != m) {
diag_broadcast_dims.back() = rank - 1;
}
if (context->input_xla_type(0) == xla::PRED) {
diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
/*broadcast_dimensions=*/diag_broadcast_dims);
} else {
diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
/*broadcast_dimensions=*/diag_broadcast_dims);
}
auto output = xla::Select(indicator, diag, input);
context->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
};
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
} // namespace tensorflow