Add XLA implementations for MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
PiperOrigin-RevId: 259379918
This commit is contained in:
parent
1f7959a055
commit
5334adcddb
@ -665,6 +665,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_diag_ops_test",
|
||||
size = "medium",
|
||||
timeout = "long",
|
||||
srcs = ["matrix_diag_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "momentum_test",
|
||||
size = "small",
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -1464,53 +1463,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([4, 5, 6], dtype=np.int32),
|
||||
expected=None)
|
||||
|
||||
def testMatrixSetDiag(self):
|
||||
# TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to
|
||||
# gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2).
|
||||
for dtype in self.numeric_types:
|
||||
# Square
|
||||
self._testBinary(
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
dtype=dtype),
|
||||
np.array([1.0, 2.0, 3.0], dtype=dtype),
|
||||
expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
|
||||
dtype=dtype),
|
||||
np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
|
||||
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
|
||||
dtype=dtype))
|
||||
|
||||
# Rectangular
|
||||
self._testBinary(
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
|
||||
np.array([[-1.0, -2.0], [-4.0, -5.0]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
|
||||
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
|
||||
dtype=dtype))
|
||||
|
||||
def testBroadcastTo(self):
|
||||
for dtype in self.all_types:
|
||||
x = np.random.randint(0, high=100, size=[2, 3])
|
||||
|
655
tensorflow/compiler/tests/matrix_diag_ops_test.py
Normal file
655
tensorflow/compiler/tests/matrix_diag_ops_test.py
Normal file
@ -0,0 +1,655 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for XLA matrix diag ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
|
||||
# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
|
||||
def square_cases():
|
||||
# pyformat: disable
|
||||
mat = np.array([[[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[3, 4, 5, 6, 7],
|
||||
[8, 9, 1, 2, 3],
|
||||
[4, 5, 6, 7, 8]],
|
||||
[[9, 1, 2, 3, 4],
|
||||
[5, 6, 7, 8, 9],
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 1],
|
||||
[2, 3, 4, 5, 6]]])
|
||||
tests = dict()
|
||||
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
|
||||
tests[-1, -1] = (np.array([[6, 4, 1, 7],
|
||||
[5, 2, 8, 5]]),
|
||||
np.array([[[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[0, 4, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 7, 0]],
|
||||
[[0, 0, 0, 0, 0],
|
||||
[5, 0, 0, 0, 0],
|
||||
[0, 2, 0, 0, 0],
|
||||
[0, 0, 8, 0, 0],
|
||||
[0, 0, 0, 5, 0]]]))
|
||||
tests[-4, -3] = (np.array([[[8, 5],
|
||||
[4, 0]],
|
||||
[[6, 3],
|
||||
[2, 0]]]),
|
||||
np.array([[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[8, 0, 0, 0, 0],
|
||||
[4, 5, 0, 0, 0]],
|
||||
[[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[6, 0, 0, 0, 0],
|
||||
[2, 3, 0, 0, 0]]]))
|
||||
tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0],
|
||||
[1, 7, 5, 2, 8],
|
||||
[6, 4, 1, 7, 0],
|
||||
[3, 9, 6, 0, 0]],
|
||||
[[1, 7, 4, 1, 0],
|
||||
[9, 6, 3, 9, 6],
|
||||
[5, 2, 8, 5, 0],
|
||||
[1, 7, 4, 0, 0]]]),
|
||||
np.array([[[1, 2, 0, 0, 0],
|
||||
[6, 7, 8, 0, 0],
|
||||
[3, 4, 5, 6, 0],
|
||||
[0, 9, 1, 2, 3],
|
||||
[0, 0, 6, 7, 8]],
|
||||
[[9, 1, 0, 0, 0],
|
||||
[5, 6, 7, 0, 0],
|
||||
[1, 2, 3, 4, 0],
|
||||
[0, 7, 8, 9, 1],
|
||||
[0, 0, 4, 5, 6]]]))
|
||||
tests[2, 4] = (np.array([[[5, 0, 0],
|
||||
[4, 1, 0],
|
||||
[3, 9, 7]],
|
||||
[[4, 0, 0],
|
||||
[3, 9, 0],
|
||||
[2, 8, 5]]]),
|
||||
np.array([[[0, 0, 3, 4, 5],
|
||||
[0, 0, 0, 9, 1],
|
||||
[0, 0, 0, 0, 7],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 0, 2, 3, 4],
|
||||
[0, 0, 0, 8, 9],
|
||||
[0, 0, 0, 0, 5],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]))
|
||||
# pyformat: enable
|
||||
return (mat, tests)
|
||||
|
||||
|
||||
def tall_cases():
|
||||
# pyformat: disable
|
||||
mat = np.array([[[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[9, 8, 7],
|
||||
[6, 5, 4]],
|
||||
[[3, 2, 1],
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[9, 8, 7]]])
|
||||
tests = dict()
|
||||
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
|
||||
tests[0, 0] = (np.array([[1, 5, 9],
|
||||
[3, 2, 6]]),
|
||||
np.array([[[1, 0, 0],
|
||||
[0, 5, 0],
|
||||
[0, 0, 9],
|
||||
[0, 0, 0]],
|
||||
[[3, 0, 0],
|
||||
[0, 2, 0],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0]]]))
|
||||
tests[-4, -3] = (np.array([[[9, 5],
|
||||
[6, 0]],
|
||||
[[7, 8],
|
||||
[9, 0]]]),
|
||||
np.array([[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[9, 0, 0],
|
||||
[6, 5, 0]],
|
||||
[[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[7, 0, 0],
|
||||
[9, 8, 0]]]))
|
||||
tests[-2, -1] = (np.array([[[4, 8, 7],
|
||||
[7, 8, 4]],
|
||||
[[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
np.array([[[0, 0, 0],
|
||||
[4, 0, 0],
|
||||
[7, 8, 0],
|
||||
[0, 8, 7],
|
||||
[0, 0, 4]],
|
||||
[[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[4, 5, 0],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
tests[-2, 1] = (np.array([[[2, 6, 0],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7],
|
||||
[7, 8, 4]],
|
||||
[[2, 3, 0],
|
||||
[3, 2, 6],
|
||||
[1, 5, 9],
|
||||
[4, 8, 7]]]),
|
||||
np.array([[[1, 2, 0],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[0, 8, 7],
|
||||
[0, 0, 4]],
|
||||
[[3, 2, 0],
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[0, 8, 9],
|
||||
[0, 0, 7]]]))
|
||||
tests[1, 2] = (np.array([[[3, 0],
|
||||
[2, 6]],
|
||||
[[1, 0],
|
||||
[2, 3]]]),
|
||||
np.array([[[0, 2, 3],
|
||||
[0, 0, 6],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]],
|
||||
[[0, 2, 1],
|
||||
[0, 0, 3],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]]]))
|
||||
# pyformat: enable
|
||||
return (mat, tests)
|
||||
|
||||
|
||||
def fat_cases():
|
||||
# pyformat: disable
|
||||
mat = np.array([[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 1, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[8, 9, 1, 2],
|
||||
[3, 4, 5, 6]]])
|
||||
tests = dict()
|
||||
# tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals)
|
||||
tests[0, 0] = (np.array([[1, 6, 2],
|
||||
[4, 9, 5]]),
|
||||
np.array([[[1, 0, 0, 0],
|
||||
[0, 6, 0, 0],
|
||||
[0, 0, 2, 0]],
|
||||
[[4, 0, 0, 0],
|
||||
[0, 9, 0, 0],
|
||||
[0, 0, 5, 0]]]))
|
||||
tests[2, 2] = (np.array([[3, 8],
|
||||
[6, 2]]),
|
||||
np.array([[[0, 0, 3, 0],
|
||||
[0, 0, 0, 8],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 6, 0],
|
||||
[0, 0, 0, 2],
|
||||
[0, 0, 0, 0]]]))
|
||||
tests[-2, 0] = (np.array([[[1, 6, 2],
|
||||
[5, 1, 0],
|
||||
[9, 0, 0]],
|
||||
[[4, 9, 5],
|
||||
[8, 4, 0],
|
||||
[3, 0, 0]]]),
|
||||
np.array([[[1, 0, 0, 0],
|
||||
[5, 6, 0, 0],
|
||||
[9, 1, 2, 0]],
|
||||
[[4, 0, 0, 0],
|
||||
[8, 9, 0, 0],
|
||||
[3, 4, 5, 0]]]))
|
||||
tests[-1, 1] = (np.array([[[2, 7, 3],
|
||||
[1, 6, 2],
|
||||
[5, 1, 0]],
|
||||
[[5, 1, 6],
|
||||
[4, 9, 5],
|
||||
[8, 4, 0]]]),
|
||||
np.array([[[1, 2, 0, 0],
|
||||
[5, 6, 7, 0],
|
||||
[0, 1, 2, 3]],
|
||||
[[4, 5, 0, 0],
|
||||
[8, 9, 1, 0],
|
||||
[0, 4, 5, 6]]]))
|
||||
tests[0, 3] = (np.array([[[4, 0, 0],
|
||||
[3, 8, 0],
|
||||
[2, 7, 3],
|
||||
[1, 6, 2]],
|
||||
[[7, 0, 0],
|
||||
[6, 2, 0],
|
||||
[5, 1, 6],
|
||||
[4, 9, 5]]]),
|
||||
np.array([[[1, 2, 3, 4],
|
||||
[0, 6, 7, 8],
|
||||
[0, 0, 2, 3]],
|
||||
[[4, 5, 6, 7],
|
||||
[0, 9, 1, 2],
|
||||
[0, 0, 5, 6]]]))
|
||||
# pyformat: enable
|
||||
return (mat, tests)
|
||||
|
||||
|
||||
class MatrixDiagTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self,
|
||||
params,
|
||||
solution,
|
||||
rtol=1e-3,
|
||||
atol=1e-5):
|
||||
"""Verifies that matrix_diag produces `solution` when fed `params`.
|
||||
|
||||
Args:
|
||||
params: dictionary containing input parameters to matrix_diag.
|
||||
solution: numpy array representing the expected output of matrix_diag.
|
||||
rtol: relative tolerance for equality test.
|
||||
atol: absolute tolerance for equality test.
|
||||
"""
|
||||
diagonal = params["diagonal"]
|
||||
with self.session() as session:
|
||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||
expected = solution.astype(dtype)
|
||||
with self.test_scope():
|
||||
params["diagonal"] = array_ops.placeholder(
|
||||
dtype, diagonal.shape, name="diagonal")
|
||||
output = array_ops.matrix_diag(**params)
|
||||
result = session.run(output,
|
||||
{params["diagonal"]: diagonal.astype(dtype)})
|
||||
self.assertEqual(output.dtype, expected.dtype)
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||
|
||||
# Generic tests applicable to both v1 and v2 ops.
|
||||
# Originally from unary_ops_tests.py.
|
||||
def testV1(self):
|
||||
# pyformat: disable
|
||||
vecs1 = np.array([[1, 2],
|
||||
[3, 4]])
|
||||
solution1 = np.array([[[1, 0], [0, 2]],
|
||||
[[3, 0], [0, 4]]])
|
||||
vecs2 = np.array([1, 2, 3, 4])
|
||||
solution2 = np.array([[1, 0, 0, 0],
|
||||
[0, 2, 0, 0],
|
||||
[0, 0, 3, 0],
|
||||
[0, 0, 0, 4]])
|
||||
vecs3 = np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[7, 8, 9], # pylint: disable=bad-whitespace
|
||||
[10, 11, 12]]])
|
||||
solution3 = np.array([[[[1, 0, 0],
|
||||
[0, 2, 0],
|
||||
[0, 0, 3]],
|
||||
[[4, 0, 0],
|
||||
[0, 5, 0],
|
||||
[0, 0, 6]]],
|
||||
[[[7, 0, 0],
|
||||
[0, 8, 0],
|
||||
[0, 0, 9]],
|
||||
[[10, 0, 0],
|
||||
[0, 11, 0],
|
||||
[0, 0, 12]]]])
|
||||
# pyformat: enable
|
||||
self._assertOpOutputMatchesExpected({"diagonal": vecs1}, solution1)
|
||||
self._assertOpOutputMatchesExpected({"diagonal": vecs2}, solution2)
|
||||
self._assertOpOutputMatchesExpected({"diagonal": vecs3}, solution3)
|
||||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSquare(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for _, tests in [square_cases()]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs[0],
|
||||
"k": diag_index
|
||||
}, solution[0])
|
||||
|
||||
def testSquareBatch(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for _, tests in [square_cases()]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index
|
||||
}, solution)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
# LINT.IfChange
|
||||
if not compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
return
|
||||
|
||||
# Stores expected num_rows and num_cols (when the other is given).
|
||||
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
|
||||
test_list = list()
|
||||
|
||||
# Square cases:
|
||||
expected = {
|
||||
(-1, -1): (5, 4),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, 1): (5, 5),
|
||||
(2, 4): (3, 5),
|
||||
}
|
||||
test_list.append((expected, square_cases()))
|
||||
|
||||
# Tall cases
|
||||
expected = {
|
||||
(0, 0): (3, 3),
|
||||
(-4, -3): (5, 2),
|
||||
(-2, -1): (4, 3),
|
||||
(-2, 1): (3, 3),
|
||||
(1, 2): (2, 3)
|
||||
}
|
||||
test_list.append((expected, tall_cases()))
|
||||
|
||||
# Fat cases
|
||||
expected = {
|
||||
(2, 2): (2, 4),
|
||||
(-2, 0): (3, 3),
|
||||
(-1, 1): (3, 3),
|
||||
(0, 3): (3, 3)
|
||||
}
|
||||
test_list.append((expected, fat_cases()))
|
||||
|
||||
# Giving both num_rows and num_cols
|
||||
for _, tests in [tall_cases(), fat_cases()]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_rows": solution.shape[-2],
|
||||
"num_cols": solution.shape[-1]
|
||||
}, solution)
|
||||
|
||||
# Giving just num_rows or num_cols.
|
||||
for expected, (_, tests) in test_list:
|
||||
for diag_index, (new_num_rows, new_num_cols) in expected.items():
|
||||
vecs, solution = tests[diag_index]
|
||||
solution_given_num_rows = solution.take(
|
||||
indices=range(new_num_cols), axis=-1)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_rows": solution_given_num_rows.shape[-2]
|
||||
}, solution_given_num_rows)
|
||||
solution_given_num_cols = solution.take(
|
||||
indices=range(new_num_rows), axis=-2)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_cols": solution_given_num_cols.shape[-1]
|
||||
}, solution_given_num_cols)
|
||||
|
||||
def testPadding(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for padding_value in [555, -11]:
|
||||
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (vecs, solution) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"diagonal": vecs,
|
||||
"k": diag_index,
|
||||
"num_rows": solution.shape[-2],
|
||||
"num_cols": solution.shape[-1],
|
||||
"padding_value": padding_value
|
||||
}, solution)
|
||||
|
||||
|
||||
class MatrixSetDiagTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self,
|
||||
params,
|
||||
solution,
|
||||
rtol=1e-3,
|
||||
atol=1e-5):
|
||||
"""Verifies that matrix_set_diag produces `solution` when fed `params`.
|
||||
|
||||
Args:
|
||||
params: dictionary containing input parameters to matrix_set_diag.
|
||||
solution: numpy array representing the expected output of matrix_set_diag.
|
||||
rtol: relative tolerance for equality test.
|
||||
atol: absolute tolerance for equality test.
|
||||
"""
|
||||
input = params["input"] # pylint: disable=redefined-builtin
|
||||
diagonal = params["diagonal"]
|
||||
with self.session() as session:
|
||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||
expected = solution.astype(dtype)
|
||||
with self.test_scope():
|
||||
params["input"] = array_ops.placeholder(
|
||||
dtype, input.shape, name="input")
|
||||
params["diagonal"] = array_ops.placeholder(
|
||||
dtype, diagonal.shape, name="diagonal")
|
||||
output = array_ops.matrix_set_diag(**params)
|
||||
result = session.run(
|
||||
output, {
|
||||
params["input"]: input.astype(dtype),
|
||||
params["diagonal"]: diagonal.astype(dtype)
|
||||
})
|
||||
self.assertEqual(output.dtype, expected.dtype)
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||
|
||||
# Generic tests applicable to both v1 and v2 ops.
|
||||
# Originally from binary_ops_tests.py.
|
||||
def testV1(self):
|
||||
test_cases = list()
|
||||
|
||||
# pyformat: disable
|
||||
# pylint: disable=bad-whitespace
|
||||
# Square cases.
|
||||
input = np.array([[0, 1, 0], # pylint: disable=redefined-builtin
|
||||
[1, 0, 1],
|
||||
[1, 1, 1]])
|
||||
diag = np.array([1, 2, 3])
|
||||
solution = np.array([[1, 1, 0],
|
||||
[1, 2, 1],
|
||||
[1, 1, 3]])
|
||||
test_cases.append(({"input": input, "diagonal": diag}, solution))
|
||||
|
||||
input = np.array([[[1, 0, 3],
|
||||
[0, 2, 0],
|
||||
[1, 0, 3]],
|
||||
[[4, 0, 4],
|
||||
[0, 5, 0],
|
||||
[2, 0, 6]]])
|
||||
diag = np.array([[-1, 0, -3],
|
||||
[-4, -5, -6]])
|
||||
solution = np.array([[[-1, 0, 3],
|
||||
[ 0, 0, 0],
|
||||
[ 1, 0, -3]],
|
||||
[[-4, 0, 4],
|
||||
[ 0, -5, 0],
|
||||
[ 2, 0, -6]]])
|
||||
test_cases.append(({"input": input, "diagonal": diag}, solution))
|
||||
|
||||
# Rectangular cases.
|
||||
input = np.array([[0, 1, 0],
|
||||
[1, 0, 1]])
|
||||
diag = np.array([3, 4])
|
||||
solution = np.array([[3, 1, 0],
|
||||
[1, 4, 1]])
|
||||
test_cases.append(({"input": input, "diagonal": diag}, solution))
|
||||
|
||||
input = np.array([[0, 1],
|
||||
[1, 0],
|
||||
[1, 1]])
|
||||
diag = np.array([3, 4])
|
||||
solution = np.array([[3, 1],
|
||||
[1, 4],
|
||||
[1, 1]])
|
||||
test_cases.append(({"input": input, "diagonal": diag}, solution))
|
||||
|
||||
input = np.array([[[1, 0, 3],
|
||||
[0, 2, 0]],
|
||||
[[4, 0, 4],
|
||||
[0, 5, 0]]])
|
||||
diag = np.array([[-1, -2], [-4, -5]])
|
||||
solution = np.array([[[-1, 0, 3],
|
||||
[ 0, -2, 0]],
|
||||
[[-4, 0, 4],
|
||||
[ 0, -5, 0]]])
|
||||
test_cases.append(({"input": input, "diagonal": diag}, solution))
|
||||
# pylint: enable=bad-whitespace
|
||||
# pyformat: enable
|
||||
|
||||
for test in test_cases:
|
||||
self._assertOpOutputMatchesExpected(test[0], test[1])
|
||||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat[0] == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat[0]
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs[0],
|
||||
"k": diag_index
|
||||
}, solution)
|
||||
|
||||
def testBatch(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (vecs, banded_mat) in tests.items():
|
||||
mask = (banded_mat == 0)
|
||||
input_mat = np.random.randint(10, size=mask.shape)
|
||||
solution = input_mat * mask + banded_mat
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": input_mat,
|
||||
"diagonal": vecs,
|
||||
"k": diag_index
|
||||
}, solution)
|
||||
|
||||
|
||||
class MatrixDiagPartTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self,
|
||||
params,
|
||||
solution,
|
||||
rtol=1e-3,
|
||||
atol=1e-5):
|
||||
"""Verifies that matrix_diag_part produces `solution` when fed `params`.
|
||||
|
||||
Args:
|
||||
params: dictionary containing input parameters to matrix_diag_part.
|
||||
solution: numpy array representing the expected output.
|
||||
rtol: relative tolerance for equality test.
|
||||
atol: absolute tolerance for equality test.
|
||||
"""
|
||||
input = params["input"] # pylint: disable=redefined-builtin
|
||||
with self.session() as session:
|
||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||
expected = solution.astype(dtype)
|
||||
with self.test_scope():
|
||||
params["input"] = array_ops.placeholder(
|
||||
dtype, input.shape, name="input")
|
||||
output = array_ops.matrix_diag_part(**params)
|
||||
result = session.run(output, {
|
||||
params["input"]: input.astype(dtype),
|
||||
})
|
||||
self.assertEqual(output.dtype, expected.dtype)
|
||||
self.assertAllCloseAccordingToType(
|
||||
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||
|
||||
# Generic tests applicable to both v1 and v2 ops.
|
||||
# Originally from unary_ops_tests.py.
|
||||
def testV1(self):
|
||||
matrices = np.arange(3 * 2 * 4).reshape([3, 2, 4])
|
||||
solution = np.array([[0, 5], [8, 13], [16, 21]])
|
||||
self._assertOpOutputMatchesExpected({"input": matrices}, solution)
|
||||
|
||||
# From here onwards are v2-only tests.
|
||||
def testSingleMatrix(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected({
|
||||
"input": mat[0],
|
||||
"k": diag_index
|
||||
}, solution[0])
|
||||
|
||||
def testBatch(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
self._assertOpOutputMatchesExpected({
|
||||
"input": mat,
|
||||
"k": diag_index
|
||||
}, solution)
|
||||
|
||||
def testPadding(self):
|
||||
# LINT.IfChange
|
||||
if compat.forward_compatible(2019, 7, 31):
|
||||
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
|
||||
for padding_value in [555, -11]:
|
||||
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
|
||||
for diag_index, (solution, _) in tests.items():
|
||||
mask = (solution == 0)
|
||||
solution = solution + (mask * padding_value)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
{
|
||||
"input": mat,
|
||||
"k": diag_index,
|
||||
"padding_value": padding_value
|
||||
}, solution)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -27,7 +27,6 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -108,31 +107,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
expected=np.array([[-1, 1]], dtype=dtype))
|
||||
|
||||
# TODO(penporn): Once XLA supports MatrixDiagV2, change the call to
|
||||
# gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2).
|
||||
self._assertOpOutputMatchesExpected(
|
||||
gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
|
||||
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
|
||||
np.array(
|
||||
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
gen_array_ops.matrix_diag,
|
||||
np.array(
|
||||
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
|
||||
np.array(
|
||||
[[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [
|
||||
0, 0, 6
|
||||
]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0],
|
||||
[0, 0, 12]]]],
|
||||
dtype=dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
gen_array_ops.matrix_diag_part,
|
||||
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
|
||||
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.prevent_gradient,
|
||||
np.array([[-1, 1]], dtype=dtype),
|
||||
|
@ -55,8 +55,8 @@ tf_kernel_library(
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"matrix_band_part_op.cc",
|
||||
"matrix_diag_ops.cc",
|
||||
"matrix_inverse_op.cc",
|
||||
"matrix_set_diag_op.cc",
|
||||
"matrix_triangular_solve_op.cc",
|
||||
"mirror_pad_op.cc",
|
||||
"next_after_op.cc",
|
||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/pooling.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -153,52 +155,5 @@ class DiagPartOp : public XlaOpKernel {
|
||||
|
||||
REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
|
||||
|
||||
class MatrixDiagOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
|
||||
errors::InvalidArgument("MatrixDiag op must have at an input"));
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
auto dims = input_shape.dim_sizes();
|
||||
OP_REQUIRES(ctx, !dims.empty(),
|
||||
errors::InvalidArgument("Expected 1 <= dims, got shape ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
|
||||
int last_dim = dims.size() - 1;
|
||||
int64 last_dim_size = input_shape.dim_size(last_dim);
|
||||
absl::Span<const int64> other_dims(dims);
|
||||
other_dims.remove_suffix(1);
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims);
|
||||
ctx->SetOutput(0, diag);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
|
||||
|
||||
class MatrixDiagPartOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
auto dims = input_shape.dim_sizes();
|
||||
|
||||
OP_REQUIRES(ctx, 2 <= dims.size(),
|
||||
errors::InvalidArgument("Expected 2 <= dims, got shape ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
425
tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
Normal file
425
tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
Normal file
@ -0,0 +1,425 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Reads or infers lower_diag_index and upper_diag_index from kernel's input
|
||||
// parameter "k". Also validates their values.
|
||||
std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
|
||||
int64 lower_diag_index = 0;
|
||||
int64 upper_diag_index = 0;
|
||||
TensorShape diag_index_shape = context->InputShape("k");
|
||||
|
||||
// Wrapping OP_REQUIRES* macros with a function because they can "return;"
|
||||
// early (without values) which contradicts ProcessDiagIndex's signature.
|
||||
auto validate_diag_indices = [&]() {
|
||||
if (diag_index_shape.dims() == 0) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->ConstantInputAsIntScalar("k", &lower_diag_index));
|
||||
upper_diag_index = lower_diag_index;
|
||||
} else {
|
||||
std::vector<int64> diag_index;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->ConstantInputAsIntVector("k", &diag_index));
|
||||
OP_REQUIRES(
|
||||
context, !diag_index.empty() && diag_index.size() <= 2,
|
||||
errors::InvalidArgument(
|
||||
"diag_index must have only one or two elements, received ",
|
||||
diag_index.size(), " elements."));
|
||||
lower_diag_index = diag_index[0];
|
||||
upper_diag_index =
|
||||
(diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
|
||||
}
|
||||
OP_REQUIRES(
|
||||
context, lower_diag_index <= upper_diag_index,
|
||||
errors::InvalidArgument(
|
||||
"lower_diag_index must not be larger than upper_diag_index: ",
|
||||
lower_diag_index, " > ", upper_diag_index));
|
||||
};
|
||||
validate_diag_indices();
|
||||
return {lower_diag_index, upper_diag_index};
|
||||
}
|
||||
|
||||
// Makes sure lower_diag_index and upper_diag_index are consistent with the
|
||||
// input matrix size.
|
||||
void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
|
||||
const int64 lower_diag_index,
|
||||
const int64 upper_diag_index,
|
||||
const int64 num_rows,
|
||||
const int64 num_cols) {
|
||||
// `lower_diag_index == 0` condition is added to handle matrix shape = 0.
|
||||
OP_REQUIRES(context,
|
||||
(-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
|
||||
lower_diag_index == 0,
|
||||
errors::InvalidArgument(
|
||||
"lower_diag_index is out of bound: ", lower_diag_index,
|
||||
" It must be between ", -num_rows, " and ", num_cols));
|
||||
OP_REQUIRES(context,
|
||||
(-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
|
||||
upper_diag_index == 0,
|
||||
errors::InvalidArgument(
|
||||
"upper_diag_index is out of bound: ", upper_diag_index,
|
||||
" It must be between ", -num_rows, " and ", num_cols));
|
||||
OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
|
||||
errors::InvalidArgument(
|
||||
"lower_diag_index must not be larger than upper_diag_index: ",
|
||||
lower_diag_index, " > ", upper_diag_index));
|
||||
}
|
||||
|
||||
// Kernel to set matrix diagonals.
|
||||
xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
|
||||
const TensorShape& input_shape, const int64 diag_rank,
|
||||
const int64 num_diags, const int64 lower_diag_index,
|
||||
const int64 upper_diag_index, const int64 max_diag_len,
|
||||
const int64 num_rows, const int64 num_cols) {
|
||||
// Creates a padding config.
|
||||
const int input_rank = input_shape.dims();
|
||||
xla::PaddingConfig padding_config;
|
||||
padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
|
||||
|
||||
// Processes one diagonal at a time:
|
||||
// 1) Extracts a single diagonal (diag_slice).
|
||||
// 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
|
||||
// 3) Masks diag_broadcast to get the right diagonal shape.
|
||||
//
|
||||
// XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
|
||||
//
|
||||
// For example,
|
||||
// diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4.
|
||||
// [4, 5, 6],
|
||||
// [7, 8, 9]]
|
||||
// The expected output is [[4, 2, 0],
|
||||
// [7, 5, 4],
|
||||
// [0, 8, 6],
|
||||
// [0, 0, 9]]
|
||||
// The 1st diagonal is created by:
|
||||
// 1) Extracting diag_slice = [1, 2, 0].
|
||||
// 2) Padding the vector to be as long as num_rows,
|
||||
// diag_slice = [1, 2, 0, 0],
|
||||
// then broadcasting diag_slice row-wise to a full matrix,
|
||||
// diag_broadcast = [[1, 1, 1],
|
||||
// [2, 2, 2],
|
||||
// [0, 0, 0],
|
||||
// [0, 0, 0]]
|
||||
// The padding value can be anything because it will not appear in the
|
||||
// results after masking. Here, we use zero.
|
||||
// 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
|
||||
// mask = [[0, 1, 0], --> output = [[x, 2, x],
|
||||
// [0, 0, 1], [x, x, 3],
|
||||
// [0, 0, 0], [x, x, x],
|
||||
// [0, 0, 0]] [x, x, x]],
|
||||
// where x denotes the existing input contents.
|
||||
std::vector<int64> broadcast_dimensions(input_rank - 1);
|
||||
absl::c_iota(broadcast_dimensions, 0);
|
||||
auto output = input;
|
||||
for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
|
||||
++diag_index) {
|
||||
// Extracts a single diagonal.
|
||||
auto diag_slice = diag;
|
||||
if (num_diags > 1) {
|
||||
const int64 mapped_diag_index = upper_diag_index - diag_index;
|
||||
diag_slice = xla::Collapse(
|
||||
xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
|
||||
diag_rank - 2),
|
||||
{diag_rank - 2, diag_rank - 1});
|
||||
}
|
||||
|
||||
// Pads if necessary. Always pad at the end because shorter diagonals in
|
||||
// the input come padded at the end.
|
||||
const int64 padding_length =
|
||||
((diag_index <= 0) ? num_cols : num_rows) - max_diag_len;
|
||||
const xla::XlaOp zero = xla::ScalarLike(input, 0);
|
||||
if (padding_length > 0) {
|
||||
padding_config.mutable_dimensions(input_rank - 2)
|
||||
->set_edge_padding_high(padding_length);
|
||||
diag_slice = xla::Pad(diag_slice, zero, padding_config);
|
||||
}
|
||||
|
||||
// Broadcasts column-wise for subdiagonals; row-wise for superdiagonals.
|
||||
broadcast_dimensions.back() =
|
||||
(diag_index <= 0) ? input_rank - 1 : input_rank - 2;
|
||||
xla::XlaOp diag_broadcast = xla::BroadcastInDim(
|
||||
diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
|
||||
const auto mask = xla::GetDiagonalMask(output, diag_index);
|
||||
output = xla::Select(mask, diag_broadcast, output);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class MatrixDiagOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
OP_REQUIRES(
|
||||
context, context->num_inputs() >= 1,
|
||||
errors::InvalidArgument("MatrixDiag op must have at least one input"));
|
||||
const TensorShape diag_shape = context->InputShape(0);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
|
||||
errors::InvalidArgument("Expected >= 1 dims, got shape ",
|
||||
diag_shape.DebugString()));
|
||||
|
||||
const DataType dtype = context->expected_output_dtype(0);
|
||||
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
|
||||
|
||||
// Initializes MatrixDiagV2-specific variables.
|
||||
// Input arguments providing the values of num_rows and num_cols can be
|
||||
// absent (-1) and will be inferred later.
|
||||
int64 lower_diag_index = 0;
|
||||
int64 upper_diag_index = 0;
|
||||
int64 num_rows = -1;
|
||||
int64 num_cols = -1;
|
||||
xla::XlaOp padding_value = zero;
|
||||
|
||||
// MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
|
||||
// one input, so we have to check the number of inputs before reading
|
||||
// additional parameters for MatrixDiagV2.
|
||||
if (context->num_inputs() > 1) {
|
||||
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
|
||||
padding_value = context->Input(4);
|
||||
}
|
||||
|
||||
// More size validations.
|
||||
const int64 diag_rank = diag_shape.dims();
|
||||
const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
|
||||
const int64 num_diags = upper_diag_index - lower_diag_index + 1;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
|
||||
errors::InvalidArgument(
|
||||
"The number of diagonals provided in the input does not "
|
||||
"match the lower_diag_index and upper_diag_index range."));
|
||||
const int64 min_num_rows = max_diag_len - std::min(upper_diag_index, 0LL);
|
||||
const int64 min_num_cols = max_diag_len + std::max(lower_diag_index, 0LL);
|
||||
OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
|
||||
errors::InvalidArgument("The number of rows is too small."));
|
||||
OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
|
||||
errors::InvalidArgument("The number of columns is too small."));
|
||||
|
||||
// Infers num_rows and num_cols. If both are unknown, assume that the output
|
||||
// is square. Otherwise, use smallest possible values.
|
||||
if (num_rows == -1 && num_cols == -1) {
|
||||
num_rows = std::max(min_num_rows, min_num_cols);
|
||||
num_cols = num_rows;
|
||||
} else if (num_rows == -1) {
|
||||
num_rows = min_num_rows;
|
||||
} else if (num_cols == -1) {
|
||||
num_cols = min_num_cols;
|
||||
}
|
||||
|
||||
// At least one of num_rows and num_cols must match its minimum length.
|
||||
// Otherwise, we'll have some incomplete diagonals.
|
||||
OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
|
||||
errors::InvalidArgument(
|
||||
"The number of rows or columns is not consistent with "
|
||||
"the specified d_lower, d_upper, and diagonal."));
|
||||
|
||||
// Actual processing.
|
||||
// Initializes the output tensor with padding_value.
|
||||
TensorShape output_shape = diag_shape;
|
||||
output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
|
||||
output_shape.AddDim(num_rows);
|
||||
output_shape.AddDim(num_cols);
|
||||
xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
|
||||
xla::XlaOp diag = context->Input(0);
|
||||
context->SetOutput(
|
||||
0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
|
||||
lower_diag_index, upper_diag_index, max_diag_len,
|
||||
num_rows, num_cols));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
|
||||
REGISTER_XLA_OP(Name("MatrixDiagV2")
|
||||
.CompileTimeConstantInput("k")
|
||||
.CompileTimeConstantInput("num_rows")
|
||||
.CompileTimeConstantInput("num_cols")
|
||||
.CompileTimeConstantInput("padding_value"),
|
||||
MatrixDiagOp);
|
||||
|
||||
class MatrixDiagPartOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixDiagPartOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const int input_rank = input_shape.dims();
|
||||
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
errors::InvalidArgument(
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
const DataType dtype = context->expected_output_dtype(0);
|
||||
const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
|
||||
|
||||
// Initializes MatrixDiagPartV2-specific variables.
|
||||
int64 lower_diag_index = 0;
|
||||
int64 upper_diag_index = 0;
|
||||
xla::XlaOp padding_value = zero;
|
||||
|
||||
// MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
|
||||
// MatrixDiagPart only has one input, so we have to check the number of
|
||||
// inputs before reading additional parameters in MatrixDiagV2.
|
||||
if (context->num_inputs() > 1) {
|
||||
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
|
||||
padding_value = context->Input(2);
|
||||
}
|
||||
|
||||
// Checks if diag sizes are consistent with input.
|
||||
const int64 num_rows = input_shape.dim_size(input_rank - 2);
|
||||
const int64 num_cols = input_shape.dim_size(input_rank - 1);
|
||||
ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
|
||||
upper_diag_index, num_rows, num_cols);
|
||||
|
||||
// Creates output shape.
|
||||
TensorShape output_shape = input_shape;
|
||||
output_shape.RemoveLastDims(2);
|
||||
const int num_diags = upper_diag_index - lower_diag_index + 1;
|
||||
if (num_diags > 1) output_shape.AddDim(num_diags);
|
||||
const int32 max_diag_len =
|
||||
std::min(num_rows + std::min(upper_diag_index, 0LL),
|
||||
num_cols - std::max(lower_diag_index, 0LL));
|
||||
output_shape.AddDim(max_diag_len);
|
||||
|
||||
// Computes output.
|
||||
xla::XlaOp input = context->Input(0);
|
||||
std::vector<xla::XlaOp> diag_list;
|
||||
xla::PaddingConfig padding_config;
|
||||
if (num_diags == 1) {
|
||||
context->SetOutput(0, xla::GetMatrixDiagonal(input, upper_diag_index));
|
||||
return;
|
||||
}
|
||||
padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
|
||||
for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
|
||||
--diag_index) {
|
||||
auto single_diag = xla::GetMatrixDiagonal(input, diag_index);
|
||||
const int64 diag_length =
|
||||
(diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index);
|
||||
const int64 padding_length = max_diag_len - diag_length;
|
||||
if (padding_length > 0) {
|
||||
padding_config.mutable_dimensions(input_rank - 2)
|
||||
->set_edge_padding_high(padding_length);
|
||||
single_diag = xla::Pad(single_diag, padding_value, padding_config);
|
||||
}
|
||||
diag_list.emplace_back(single_diag);
|
||||
}
|
||||
auto concat =
|
||||
xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
|
||||
context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
|
||||
REGISTER_XLA_OP(Name("MatrixDiagPartV2")
|
||||
.CompileTimeConstantInput("k")
|
||||
.CompileTimeConstantInput("padding_value"),
|
||||
MatrixDiagPartOp);
|
||||
|
||||
class MatrixSetDiagOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixSetDiagOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const TensorShape diag_shape = context->InputShape(1);
|
||||
const int input_rank = input_shape.dims();
|
||||
const int diag_rank = diag_shape.dims();
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
errors::InvalidArgument(
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input_shape.DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
|
||||
errors::InvalidArgument(
|
||||
"diagonal must be at least 1-dim, received shape: ",
|
||||
diag_shape.DebugString()));
|
||||
|
||||
// MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
|
||||
// only has two inputs, so we have to check the number of inputs before
|
||||
// reading additional parameters in MatrixSetDiagV2.
|
||||
int64 lower_diag_index = 0;
|
||||
int64 upper_diag_index = 0;
|
||||
if (context->num_inputs() > 2) {
|
||||
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
|
||||
}
|
||||
|
||||
// Checks if diag sizes are consistent with input.
|
||||
const int64 num_rows = input_shape.dim_size(input_rank - 2);
|
||||
const int64 num_cols = input_shape.dim_size(input_rank - 1);
|
||||
ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
|
||||
upper_diag_index, num_rows, num_cols);
|
||||
const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
lower_diag_index == upper_diag_index ||
|
||||
(diag_shape.dim_size(input_rank - 2) == num_diags),
|
||||
errors::InvalidArgument("The number of diagonals provided in `diag` "
|
||||
"is not consistent with `lower_diag_index` and "
|
||||
"`upper_diag_index`"));
|
||||
|
||||
TensorShape expected_diag_shape = input_shape;
|
||||
expected_diag_shape.RemoveLastDims(2);
|
||||
if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
|
||||
const int32 max_diag_len =
|
||||
std::min(num_rows + std::min(upper_diag_index, 0LL),
|
||||
num_cols - std::max(lower_diag_index, 0LL));
|
||||
expected_diag_shape.AddDim(max_diag_len);
|
||||
OP_REQUIRES(
|
||||
context, expected_diag_shape == diag_shape,
|
||||
errors::InvalidArgument(
|
||||
"Either first dimensions of diagonal don't match input.shape[:-2], "
|
||||
"or diagonal.shape[:-1] is not equal to the longests diagonal in "
|
||||
"range [lower_diag_index:upper_diag_index].\nInput shape: ",
|
||||
input_shape.DebugString(),
|
||||
"\nDiagonal shape: ", diag_shape.DebugString(),
|
||||
"\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
|
||||
|
||||
// Actual processing.
|
||||
xla::XlaOp input = context->Input(0);
|
||||
xla::XlaOp diag = context->Input(1);
|
||||
context->SetOutput(
|
||||
0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
|
||||
lower_diag_index, upper_diag_index, max_diag_len,
|
||||
num_rows, num_cols));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
|
||||
REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
|
||||
MatrixSetDiagOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -1,98 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MatrixSetDiagOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixSetDiagOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const TensorShape diag_shape = context->InputShape(1);
|
||||
|
||||
const int rank = input_shape.dims();
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
errors::InvalidArgument(
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
// Check to make sure the last dimension of diag is equal to the smaller of
|
||||
// the last two dimensions of input.
|
||||
const int64 m = input_shape.dim_size(rank - 2);
|
||||
const int64 n = input_shape.dim_size(rank - 1);
|
||||
const int64 min_dim = std::min(m, n);
|
||||
|
||||
TensorShape batch_shape = input_shape;
|
||||
batch_shape.RemoveLastDims(2);
|
||||
|
||||
TensorShape expected_diag_shape = batch_shape;
|
||||
expected_diag_shape.AddDim(min_dim);
|
||||
OP_REQUIRES(context, expected_diag_shape == diag_shape,
|
||||
errors::InvalidArgument(
|
||||
"must have diagonal.shape == input.shape[:-2] + "
|
||||
"min(input.shape[-2:]), but received input shape: ",
|
||||
input_shape.DebugString(),
|
||||
" and diagonal shape: ", diag_shape.DebugString()));
|
||||
|
||||
xla::XlaBuilder* builder = context->builder();
|
||||
xla::XlaOp input = context->Input(0);
|
||||
xla::XlaOp diag = context->Input(1);
|
||||
|
||||
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
|
||||
|
||||
// Create an indicator tensor that is true only on the diagonal.
|
||||
xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m);
|
||||
xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n);
|
||||
auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
|
||||
/*broadcast_dimensions=*/{0});
|
||||
indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
|
||||
|
||||
// Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or)
|
||||
// because we need to broadcast on the right.
|
||||
std::vector<int64> diag_broadcast_dims(rank - 1);
|
||||
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
|
||||
if (min_dim != m) {
|
||||
diag_broadcast_dims.back() = rank - 1;
|
||||
}
|
||||
if (context->input_xla_type(0) == xla::PRED) {
|
||||
diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
|
||||
} else {
|
||||
diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
}
|
||||
|
||||
auto output = xla::Select(indicator, diag, input);
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user