[XLA:CLIENT] Add inverse of GetMatrixDiagonal and simplify the slicing logic
PiperOrigin-RevId: 270331448
This commit is contained in:
parent
a62a0fa255
commit
316a882856
@ -176,6 +176,7 @@ cc_library(
|
||||
":arithmetic",
|
||||
":constants",
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -195,6 +196,7 @@ xla_test(
|
||||
name = "matrix_test",
|
||||
srcs = ["matrix_test.cc"],
|
||||
deps = [
|
||||
":constants",
|
||||
":matrix",
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -73,27 +74,67 @@ XlaOp GetMatrixDiagonal(XlaOp x, int k) {
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
|
||||
if (k <= -m || k >= n) {
|
||||
auto zero_size_shape = shape;
|
||||
zero_size_shape.DeleteDimension(n_dims - 1);
|
||||
zero_size_shape.set_dimensions(n_dims - 2, 0);
|
||||
return ConstantLiteral(builder, Literal{zero_size_shape});
|
||||
}
|
||||
auto mask = GetDiagonalMask(x, k);
|
||||
|
||||
// TPUs don't support S64 add reduction at the moment. But fortunately
|
||||
// OR-reductions work just as well for integers.
|
||||
XlaComputation reducer =
|
||||
CreateScalarIdentityWithZeroComputation(shape.element_type(), builder);
|
||||
int64 reduce_dim = n_dims - 1;
|
||||
if ((k == 0 && m >= n) || k < 0) {
|
||||
reduce_dim = n_dims - 2;
|
||||
}
|
||||
auto result = Reduce(
|
||||
Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
|
||||
CreateScalarIdentityWithZeroComputation(shape.element_type(), builder),
|
||||
{reduce_dim});
|
||||
// k == 0, we can save one slice op.
|
||||
if (k == 0) {
|
||||
return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
|
||||
reducer, {m >= n ? n_dims - 2 : n_dims - 1});
|
||||
} else if (k > 0) {
|
||||
auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
|
||||
ScalarLike(x, 0), reducer, {n_dims - 2});
|
||||
return SliceInMinorDims(result, {std::min<int64>(k, n)},
|
||||
{std::min(m + k, n)});
|
||||
} else {
|
||||
auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
|
||||
ScalarLike(x, 0), reducer, {n_dims - 1});
|
||||
return SliceInMinorDims(result, {std::min<int64>(-k, m)},
|
||||
{std::min(m, n - k)});
|
||||
return result;
|
||||
}
|
||||
return SliceInMinorDims(result, {0},
|
||||
{k > 0 ? std::min(m, n - k) : std::min(n, m + k)});
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) {
|
||||
XlaBuilder* builder = matrix.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(matrix));
|
||||
TF_ASSIGN_OR_RETURN(Shape diag_shape, builder->GetShape(diag));
|
||||
auto n_dims = static_cast<int32>(shape.rank());
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
const int64 d = diag_shape.dimensions(n_dims - 2);
|
||||
std::vector<int64> broadcast_dims(n_dims - 1);
|
||||
absl::c_iota(broadcast_dims, 0);
|
||||
int64 pad_high = m - d;
|
||||
if (k < 0) {
|
||||
++(broadcast_dims.back());
|
||||
pad_high = n - d;
|
||||
}
|
||||
|
||||
if (pad_high != 0) {
|
||||
PaddingConfig padding_config;
|
||||
for (xla::int64 i = 0; i < diag_shape.rank() - 1; ++i) {
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(0);
|
||||
dims->set_interior_padding(0);
|
||||
dims->set_edge_padding_high(0);
|
||||
}
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(0);
|
||||
dims->set_interior_padding(0);
|
||||
dims->set_edge_padding_high(pad_high);
|
||||
diag = Pad(diag, ScalarLike(diag, 0), padding_config);
|
||||
}
|
||||
|
||||
return Select(GetDiagonalMask(matrix, k),
|
||||
BroadcastInDim(diag, shape.dimensions(), broadcast_dims),
|
||||
matrix);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,9 @@ XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
|
||||
// diagonal elements (i.e., with indices [..., i - k, i]).
|
||||
XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
|
||||
|
||||
// Places diag along the kth diagonal of target.
|
||||
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
|
||||
|
||||
// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal
|
||||
// and false above that diagonal.
|
||||
XlaOp TriangleMask(XlaOp x, int diagonal);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -32,6 +33,23 @@ class MatrixTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
template <typename T>
|
||||
void TestMatrixDiagonal();
|
||||
template <typename T>
|
||||
void TestSetMatrixDiagonal();
|
||||
|
||||
template <typename T>
|
||||
std::map<int, Array2D<T>> k_and_expected() const {
|
||||
return std::map<int, Array2D<T>>{
|
||||
{0, {{0, 5, 10}, {12, 17, 22}}},
|
||||
{1, {{1, 6, 11}, {13, 18, 23}}},
|
||||
{2, {{2, 7}, {14, 19}}},
|
||||
{3, {{3}, {15}}},
|
||||
{4, {{}, {}}},
|
||||
{-1, {{4, 9}, {16, 21}}},
|
||||
{-2, {{8}, {20}}},
|
||||
{-3, {{}, {}}},
|
||||
{-4, {{}, {}}},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_F(MatrixTest, Triangle) {
|
||||
@ -50,21 +68,10 @@ XLA_TEST_F(MatrixTest, Triangle) {
|
||||
|
||||
template <typename T>
|
||||
void MatrixTest::TestMatrixDiagonal() {
|
||||
XlaBuilder builder("GetMatrixDiagonal");
|
||||
XlaBuilder builder("SetMatrixDiagonal");
|
||||
Array3D<T> input(2, 3, 4);
|
||||
input.FillIota(0);
|
||||
std::map<int, Array2D<T>> k_and_expected = {
|
||||
{0, {{0, 5, 10}, {12, 17, 22}}},
|
||||
{1, {{1, 6, 11}, {13, 18, 23}}},
|
||||
{2, {{2, 7}, {14, 19}}},
|
||||
{3, {{3}, {15}}},
|
||||
{4, {{}, {}}},
|
||||
{-1, {{4, 9}, {16, 21}}},
|
||||
{-2, {{8}, {20}}},
|
||||
{-3, {{}, {}}},
|
||||
{-4, {{}, {}}},
|
||||
};
|
||||
for (const auto& kv : k_and_expected) {
|
||||
for (const auto& kv : k_and_expected<T>()) {
|
||||
XlaOp a;
|
||||
auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
|
||||
GetMatrixDiagonal(a, kv.first);
|
||||
@ -73,6 +80,36 @@ void MatrixTest::TestMatrixDiagonal() {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixTest::TestSetMatrixDiagonal() {
|
||||
XlaBuilder builder("GetMatrixDiagonal");
|
||||
Array3D<T> input(2, 3, 4);
|
||||
input.FillIota(0);
|
||||
for (const auto& kv : k_and_expected<T>()) {
|
||||
XlaOp a;
|
||||
XlaOp b;
|
||||
auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
|
||||
auto new_diag =
|
||||
CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
|
||||
|
||||
GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
|
||||
kv.first) -
|
||||
ScalarLike(b, 1);
|
||||
|
||||
ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
|
||||
TestSetMatrixDiagonal<int32>();
|
||||
}
|
||||
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
|
||||
TestSetMatrixDiagonal<int64>();
|
||||
}
|
||||
XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
|
||||
TestSetMatrixDiagonal<float>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
|
||||
|
||||
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
|
||||
|
Loading…
Reference in New Issue
Block a user