[XLA:CLIENT] Add inverse of GetMatrixDiagonal and simplify the slicing logic

PiperOrigin-RevId: 270331448
This commit is contained in:
Blake Hechtman 2019-09-20 12:57:14 -07:00 committed by TensorFlower Gardener
parent a62a0fa255
commit 316a882856
4 changed files with 112 additions and 29 deletions

View File

@ -176,6 +176,7 @@ cc_library(
":arithmetic",
":constants",
":slicing",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@ -195,6 +196,7 @@ xla_test(
name = "matrix_test",
srcs = ["matrix_test.cc"],
deps = [
":constants",
":matrix",
":slicing",
"//tensorflow/compiler/xla:status",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -73,27 +74,67 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) {
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
if (k <= -m || k >= n) {
auto zero_size_shape = shape;
zero_size_shape.DeleteDimension(n_dims - 1);
zero_size_shape.set_dimensions(n_dims - 2, 0);
return ConstantLiteral(builder, Literal{zero_size_shape});
}
auto mask = GetDiagonalMask(x, k);
// TPUs don't support S64 add reduction at the moment. But fortunately
// OR-reductions work just as well for integers.
XlaComputation reducer =
CreateScalarIdentityWithZeroComputation(shape.element_type(), builder);
int64 reduce_dim = n_dims - 1;
if ((k == 0 && m >= n) || k < 0) {
reduce_dim = n_dims - 2;
}
auto result = Reduce(
Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
{reduce_dim});
// k == 0, we can save one slice op.
if (k == 0) {
return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
reducer, {m >= n ? n_dims - 2 : n_dims - 1});
} else if (k > 0) {
auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
ScalarLike(x, 0), reducer, {n_dims - 2});
return SliceInMinorDims(result, {std::min<int64>(k, n)},
{std::min(m + k, n)});
} else {
auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
ScalarLike(x, 0), reducer, {n_dims - 1});
return SliceInMinorDims(result, {std::min<int64>(-k, m)},
{std::min(m, n - k)});
return result;
}
return SliceInMinorDims(result, {0},
{k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
});
}
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
XlaBuilder* builder = matrix.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
auto n_dims = static_cast<int32>(shape.rank());
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
const int64 d = diag_shape.dimensions(n_dims - 2);
std::vector<int64> broadcast_dims(n_dims - 1);
absl::c_iota(broadcast_dims, 0);
int64 pad_high = m - d;
if (k < 0) {
++(broadcast_dims.back());
pad_high = n - d;
}
if (pad_high != 0) {
PaddingConfig padding_config;
for (xla::int64 i = 0; i < diag_shape.rank() - 1; ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(0);
dims->set_interior_padding(0);
dims->set_edge_padding_high(0);
}
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(0);
dims->set_interior_padding(0);
dims->set_edge_padding_high(pad_high);
diag = Pad(diag, ScalarLike(diag, 0), padding_config);
}
return Select(GetDiagonalMask(matrix, k),
BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
matrix);
});
}

View File

@ -45,6 +45,9 @@ XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
// diagonal elements (i.e., with indices [..., i - k, i]).
XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
// Places diag along the kth diagonal of target.
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal
// and false above that diagonal.
XlaOp TriangleMask(XlaOp x, int diagonal);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status.h"
@ -32,6 +33,23 @@ class MatrixTest : public ClientLibraryTestBase {
protected:
template <typename T>
void TestMatrixDiagonal();
template <typename T>
void TestSetMatrixDiagonal();
template <typename T>
std::map<int, Array2D<T>> k_and_expected() const {
return std::map<int, Array2D<T>>{
{0, {{0, 5, 10}, {12, 17, 22}}},
{1, {{1, 6, 11}, {13, 18, 23}}},
{2, {{2, 7}, {14, 19}}},
{3, {{3}, {15}}},
{4, {{}, {}}},
{-1, {{4, 9}, {16, 21}}},
{-2, {{8}, {20}}},
{-3, {{}, {}}},
{-4, {{}, {}}},
};
}
};
XLA_TEST_F(MatrixTest, Triangle) {
@ -50,21 +68,10 @@ XLA_TEST_F(MatrixTest, Triangle) {
template <typename T>
void MatrixTest::TestMatrixDiagonal() {
XlaBuilder builder("GetMatrixDiagonal");
XlaBuilder builder("SetMatrixDiagonal");
Array3D<T> input(2, 3, 4);
input.FillIota(0);
std::map<int, Array2D<T>> k_and_expected = {
{0, {{0, 5, 10}, {12, 17, 22}}},
{1, {{1, 6, 11}, {13, 18, 23}}},
{2, {{2, 7}, {14, 19}}},
{3, {{3}, {15}}},
{4, {{}, {}}},
{-1, {{4, 9}, {16, 21}}},
{-2, {{8}, {20}}},
{-3, {{}, {}}},
{-4, {{}, {}}},
};
for (const auto& kv : k_and_expected) {
for (const auto& kv : k_and_expected<T>()) {
XlaOp a;
auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
GetMatrixDiagonal(a, kv.first);
@ -73,6 +80,36 @@ void MatrixTest::TestMatrixDiagonal() {
}
}
template <typename T>
void MatrixTest::TestSetMatrixDiagonal() {
XlaBuilder builder("GetMatrixDiagonal");
Array3D<T> input(2, 3, 4);
input.FillIota(0);
for (const auto& kv : k_and_expected<T>()) {
XlaOp a;
XlaOp b;
auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
auto new_diag =
CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
kv.first) -
ScalarLike(b, 1);
ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
}
}
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
TestSetMatrixDiagonal<int32>();
}
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
TestSetMatrixDiagonal<int64>();
}
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
TestSetMatrixDiagonal<float>();
}
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }