[XLA] [TF:XLA] Move utilities used by linear algebra routines into XLA client library.

Move slicing utilities into a new xla/client/lib/slicing.h.
Move MaybeConjugate into xla/client/lib/math.h.
Move *TransposeInMinorDims into xla/client/lib/matrix.h.\

CL in preparation for moving linear algebra routines into XLA client library.

PiperOrigin-RevId: 224036090
This commit is contained in:
Peter Hawkins 2018-12-04 13:30:12 -08:00 committed by TensorFlower Gardener
parent 1c2c0c9d09
commit 2efc17d7f7
16 changed files with 320 additions and 245 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
namespace tensorflow {

View File

@ -49,6 +49,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
@ -85,6 +86,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
@ -124,7 +126,9 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
@ -171,29 +175,6 @@ cc_library(
],
)
xla_test(
name = "util_test",
srcs = ["util_test.cc"],
deps = [
":util",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
cc_library(
name = "while_loop",
srcs = ["while_loop.cc"],

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"

View File

@ -20,7 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"

View File

@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
return xla::ConstantLiteral(builder, literal);
}
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
absl::Span<const int64> end) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_RET_CHECK(start.size() == end.size());
int64 n_minor_dims = start.size();
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
auto major_dims = xla::AsInt64Slice(shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - n_minor_dims);
// Prepends 0s in the major dim
std::vector<int64> padded_start(n_dims, 0);
std::copy(start.begin(), start.end(),
padded_start.begin() + major_dims.size());
// Prepends the shape of the major dims.
std::vector<int64> padded_end(n_dims);
std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
std::vector<int64> strides(n_dims, 1);
return xla::Slice(x, padded_start, padded_end, strides);
});
}
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
absl::Span<const int64> ys) {
@ -152,104 +122,4 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
return output;
}
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
absl::Span<const xla::XlaOp> starts,
absl::Span<const int64> sizes) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
auto major_dims = xla::AsInt64Slice(shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - sizes.size());
auto padded_starts = PrependZerosInMajorDims(x, starts);
auto padded_sizes = ConcatVectors(major_dims, sizes);
return xla::DynamicSlice(x, padded_starts, padded_sizes);
});
}
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
xla::ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
return xla::DynamicUpdateSlice(x, update, start_constant);
});
}
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
std::copy(start.begin(), start.end(),
padded_start.begin() + (n_dims - n_minor_dims));
return UpdateSlice(x, update, padded_start);
});
}
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
absl::Span<const xla::XlaOp> starts) {
auto padded_starts = PrependZerosInMajorDims(x, starts);
return xla::DynamicUpdateSlice(x, update, padded_starts);
}
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
absl::Span<const xla::XlaOp> starts) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
std::vector<xla::XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
}
return xla::ConcatInDim(builder, padded_starts, 0);
});
}
xla::XlaOp TransposeInMinorDims(xla::XlaOp x) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
return xla::Transpose(x, permutation);
});
}
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose) {
return transpose ? TransposeInMinorDims(x) : x;
}
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
auto perform_conj = shape.element_type() == xla::C64 && conjugate;
return perform_conj ? xla::Conj(x) : x;
});
}
} // namespace tensorflow

View File

@ -38,48 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
int64 value);
// Builds a vector of zeros of length rank(x) with the last values being
// those in `starts`.
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
absl::Span<const xla::XlaOp> starts);
// Performs a slice in the minor dimensions of a Tensor.
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
absl::Span<const int64> end);
// Returns the concatenation of `xs` and `ys`.
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
absl::Span<const int64> ys);
// Performs a dynamic slice in the minor dimensions of a Tensor.
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
absl::Span<const xla::XlaOp> starts,
absl::Span<const int64> sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
absl::Span<const int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
absl::Span<const int64> start);
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
absl::Span<const xla::XlaOp> starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
// Transposes `x` in its minor dimensions if `transpose` is true, otherwise
// returns `x` unchanged.
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_

View File

@ -127,6 +127,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
":matrix",
":slicing",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@ -176,6 +177,38 @@ cc_library(
],
)
cc_library(
name = "slicing",
srcs = ["slicing.cc"],
hdrs = ["slicing.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"@com_google_absl//absl/types:span",
],
)
xla_test(
name = "slicing_test",
srcs = ["slicing_test.cc"],
tags = ["enable_for_xla_interpreter"],
deps = [
":slicing",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "sorting",
srcs = ["sorting.cc"],

View File

@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) {
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
XlaOp RoundToEven(XlaOp x) {
auto half = xla::ScalarLike(x, 0.5);
auto one = xla::ScalarLike(x, 1.0);
auto two = xla::ScalarLike(x, 2.0);
auto half = ScalarLike(x, 0.5);
auto one = ScalarLike(x, 1.0);
auto two = ScalarLike(x, 2.0);
auto round_val = xla::Floor(x);
auto round_val = Floor(x);
auto fraction = x - round_val;
auto nearest_even_int = round_val - two * xla::Floor(half * x);
auto is_odd = xla::Eq(nearest_even_int, one);
return xla::Select(xla::Or(xla::Gt(fraction, half),
xla::And(xla::Eq(fraction, half), is_odd)),
round_val + one, round_val);
auto nearest_even_int = round_val - two * Floor(half * x);
auto is_odd = Eq(nearest_even_int, one);
return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
round_val + one, round_val);
}
// Trigonometric functions.
@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
auto perform_conj = shape.element_type() == C64 && conjugate;
return perform_conj ? Conj(x) : x;
});
}
} // namespace xla

View File

@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x);
// Computes the hyperbolic sine of 'x'.
XlaOp Sinh(XlaOp x);
// Applies a complex conjugation operation if `a` is complex and `conjugate`
// is true, otherwise returns its argument.
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_

View File

@ -166,4 +166,20 @@ XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
});
}
XlaOp TransposeInMinorDims(XlaOp x) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
return Transpose(x, permutation);
});
}
XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
return transpose ? TransposeInMinorDims(x) : x;
}
} // namespace xla

View File

@ -60,6 +60,14 @@ XlaOp LowerTriangle(XlaOp x);
xla::XlaOp BatchDot(
xla::XlaOp x, xla::XlaOp y,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
// Transposes `x` in its minor dimensions if `transpose` is true, otherwise
// returns `x` unchanged.
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@ -25,13 +26,13 @@ limitations under the License.
namespace xla {
namespace {
class NumericTest : public ClientLibraryTestBase {
class MatrixTest : public ClientLibraryTestBase {
protected:
template <typename T>
void TestMatrixDiagonal();
};
XLA_TEST_F(NumericTest, Triangle) {
XLA_TEST_F(MatrixTest, Triangle) {
XlaBuilder builder(TestName());
Array3D<int32> input(2, 3, 4);
input.FillIota(0);
@ -46,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) {
}
template <typename T>
void NumericTest::TestMatrixDiagonal() {
void MatrixTest::TestMatrixDiagonal() {
XlaBuilder builder("GetMatrixDiagonal");
Array3D<T> input(2, 3, 4);
input.FillIota(0);
@ -59,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() {
ComputeAndCompareR2<T>(&builder, expected, {a_data.get()});
}
XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
Array3D<float> BatchedAValsFull() {
return {{
{2, 0, 1, 2},
{3, 6, 0, 1},
{4, 7, 9, 0},
{5, 8, 10, 11},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
}};
}
XLA_TEST_F(MatrixTest, RowBatchDot) {
XlaBuilder builder(TestName());
int n = 4;
XlaOp a, row, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
"row", &builder, &row);
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
auto l_index = DynamicSliceInMinorDims(
a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
BatchDot(l_index, TransposeInMinorDims(row));
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,134 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/slicing.h"
namespace xla {
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
absl::Span<const int64> end) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RET_CHECK(start.size() == end.size());
int64 n_minor_dims = start.size();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
auto major_dims = AsInt64Slice(shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - n_minor_dims);
// Prepends 0s in the major dim
std::vector<int64> padded_start(n_dims, 0);
std::copy(start.begin(), start.end(),
padded_start.begin() + major_dims.size());
// Prepends the shape of the major dims.
std::vector<int64> padded_end(n_dims);
std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
std::vector<int64> strides(n_dims, 1);
return Slice(x, padded_start, padded_end, strides);
});
}
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = ConstantR1<int32>(builder, start_as_int32);
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
TF_ASSIGN_OR_RETURN(Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
return DynamicUpdateSlice(x, update, start_constant);
});
}
XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
absl::Span<const int64> start) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
std::copy(start.begin(), start.end(),
padded_start.begin() + (n_dims - n_minor_dims));
return UpdateSlice(x, update, padded_start);
});
}
namespace {
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
absl::Span<const int64> ys) {
std::vector<int64> output(xs.size() + ys.size());
std::copy(xs.begin(), xs.end(), output.begin());
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
return output;
}
XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span<const XlaOp> starts) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
auto zero = Reshape(ConstantR0<int32>(builder, 0), {1});
std::vector<XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1});
}
return ConcatInDim(builder, padded_starts, 0);
});
}
} // namespace
XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
absl::Span<const int64> sizes) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = ShapeUtil::Rank(shape);
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
auto major_dims = AsInt64Slice(shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - sizes.size());
auto padded_starts = PrependZerosInMajorDims(x, starts);
auto padded_sizes = ConcatVectors(major_dims, sizes);
return DynamicSlice(x, padded_starts, padded_sizes);
});
}
XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
absl::Span<const XlaOp> starts) {
auto padded_starts = PrependZerosInMajorDims(x, starts);
return DynamicUpdateSlice(x, update, padded_starts);
}
} // namespace xla

View File

@ -0,0 +1,48 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_
namespace xla {
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);
// Performs a slice in the minor dimensions of a tensor.
// x[..., start[0]:end[0], ..., start[n]:end[n]]
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
absl::Span<const int64> end);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0]:..., ..., start[n]:...] = update
XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
absl::Span<const int64> start);
// Performs a dynamic slice in the minor dimensions of a tensor.
XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
absl::Span<const int64> sizes);
XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
absl::Span<const XlaOp> starts);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_

View File

@ -13,28 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace xla {
namespace {
using UtilTest = xla::ClientLibraryTestBase;
using UtilLeftLookingTest = xla::ClientLibraryTestBase;
using SlicingTest = xla::ClientLibraryTestBase;
xla::Array2D<float> BValsRight() {
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
@ -63,7 +54,7 @@ xla::Array3D<float> BatchedAValsFull() {
}};
}
XLA_TEST_F(UtilTest, Simple2dLookup) {
XLA_TEST_F(SlicingTest, Simple2dLookup) {
xla::XlaBuilder builder(TestName());
xla::XlaOp a, x, y;
@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(UtilTest, Simple3dLookup) {
XLA_TEST_F(SlicingTest, Simple3dLookup) {
xla::XlaBuilder builder(TestName());
xla::XlaOp a, index;
@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
{a_data.get(), index_data.get()});
}
XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
xla::XlaBuilder builder(TestName());
xla::XlaOp a, b, x, y;
@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
{a_data.get(), b_data.get(), x_data.get(), y_data.get()});
}
XLA_TEST_F(UtilTest, RowBatchDot) {
xla::XlaBuilder builder(TestName());
int n = 4;
xla::XlaOp a, row, index;
auto a_data =
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
"row", &builder, &row);
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
auto l_index = DynamicSliceInMinorDims(
a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n});
BatchDot(l_index, TransposeInMinorDims(row));
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});
}
} // namespace
} // namespace tensorflow
} // namespace xla