[XLA] [TF:XLA] Move utilities used by linear algebra routines into XLA client library.
Move slicing utilities into a new xla/client/lib/slicing.h. Move MaybeConjugate into xla/client/lib/math.h. Move *TransposeInMinorDims into xla/client/lib/matrix.h.\ CL in preparation for moving linear algebra routines into XLA client library. PiperOrigin-RevId: 224036090
This commit is contained in:
parent
1c2c0c9d09
commit
2efc17d7f7
tensorflow/compiler
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -85,6 +86,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -124,7 +126,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -171,29 +175,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.cc"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop",
|
||||
srcs = ["while_loop.cc"],
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
|
@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
return xla::ConstantLiteral(builder, literal);
|
||||
}
|
||||
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_RET_CHECK(start.size() == end.size());
|
||||
int64 n_minor_dims = start.size();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
auto major_dims = xla::AsInt64Slice(shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - n_minor_dims);
|
||||
|
||||
// Prepends 0s in the major dim
|
||||
std::vector<int64> padded_start(n_dims, 0);
|
||||
std::copy(start.begin(), start.end(),
|
||||
padded_start.begin() + major_dims.size());
|
||||
|
||||
// Prepends the shape of the major dims.
|
||||
std::vector<int64> padded_end(n_dims);
|
||||
std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
|
||||
std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
|
||||
|
||||
std::vector<int64> strides(n_dims, 1);
|
||||
return xla::Slice(x, padded_start, padded_end, strides);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys) {
|
||||
@ -152,104 +122,4 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
return output;
|
||||
}
|
||||
|
||||
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
|
||||
absl::Span<const xla::XlaOp> starts,
|
||||
absl::Span<const int64> sizes) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
int64 n_minor_dims = starts.size();
|
||||
TF_RET_CHECK(n_minor_dims == sizes.size());
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
auto major_dims = xla::AsInt64Slice(shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - sizes.size());
|
||||
auto padded_starts = PrependZerosInMajorDims(x, starts);
|
||||
auto padded_sizes = ConcatVectors(major_dims, sizes);
|
||||
return xla::DynamicSlice(x, padded_starts, padded_sizes);
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const int64> start) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
|
||||
std::vector<int32> start_as_int32(start.begin(), start.end());
|
||||
auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
|
||||
builder->GetShape(start_constant));
|
||||
const int64 start_length =
|
||||
xla::ShapeUtil::GetDimension(start_constant_shape, -1);
|
||||
TF_RET_CHECK(start_length == n_dims);
|
||||
return xla::DynamicUpdateSlice(x, update, start_constant);
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const int64> start) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
const int64 n_minor_dims = start.size();
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
std::vector<int64> padded_start(n_dims, 0);
|
||||
std::copy(start.begin(), start.end(),
|
||||
padded_start.begin() + (n_dims - n_minor_dims));
|
||||
return UpdateSlice(x, update, padded_start);
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const xla::XlaOp> starts) {
|
||||
auto padded_starts = PrependZerosInMajorDims(x, starts);
|
||||
return xla::DynamicUpdateSlice(x, update, padded_starts);
|
||||
}
|
||||
|
||||
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
absl::Span<const xla::XlaOp> starts) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
|
||||
std::vector<xla::XlaOp> padded_starts(n_dims, zero);
|
||||
for (int i = 0; i < starts.size(); ++i) {
|
||||
padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
|
||||
}
|
||||
return xla::ConcatInDim(builder, padded_starts, 0);
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp TransposeInMinorDims(xla::XlaOp x) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(shape);
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
std::vector<int64> permutation(n_dims);
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
|
||||
return xla::Transpose(x, permutation);
|
||||
});
|
||||
}
|
||||
|
||||
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose) {
|
||||
return transpose ? TransposeInMinorDims(x) : x;
|
||||
}
|
||||
|
||||
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
auto perform_conj = shape.element_type() == xla::C64 && conjugate;
|
||||
return perform_conj ? xla::Conj(x) : x;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -38,48 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
int64 value);
|
||||
|
||||
// Builds a vector of zeros of length rank(x) with the last values being
|
||||
// those in `starts`.
|
||||
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
absl::Span<const xla::XlaOp> starts);
|
||||
|
||||
// Performs a slice in the minor dimensions of a Tensor.
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end);
|
||||
|
||||
// Returns the concatenation of `xs` and `ys`.
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys);
|
||||
|
||||
// Performs a dynamic slice in the minor dimensions of a Tensor.
|
||||
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
|
||||
absl::Span<const xla::XlaOp> starts,
|
||||
absl::Span<const int64> sizes);
|
||||
|
||||
// Updates a slice of 'x', i.e.,
|
||||
// x[start[0], ..., start[n]] = update
|
||||
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const int64> start);
|
||||
|
||||
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
|
||||
// x[..., start[0], ..., start[n]] = update
|
||||
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const int64> start);
|
||||
|
||||
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
absl::Span<const xla::XlaOp> starts);
|
||||
|
||||
// Transposes a stack of matrices `x` by swapping the last two dimensions.
|
||||
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
|
||||
|
||||
// Transposes `x` in its minor dimensions if `transpose` is true, otherwise
|
||||
// returns `x` unchanged.
|
||||
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
|
||||
|
||||
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
|
||||
// is true, otherwise returns its argument.
|
||||
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
|
||||
|
@ -127,6 +127,7 @@ xla_test(
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":matrix",
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
@ -176,6 +177,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "slicing",
|
||||
srcs = ["slicing.cc"],
|
||||
hdrs = ["slicing.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "slicing_test",
|
||||
srcs = ["slicing_test.cc"],
|
||||
tags = ["enable_for_xla_interpreter"],
|
||||
deps = [
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sorting",
|
||||
srcs = ["sorting.cc"],
|
||||
|
@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) {
|
||||
// Implements Banker's rounding: numbers that are equidistant between two
|
||||
// integers are rounded towards even.
|
||||
XlaOp RoundToEven(XlaOp x) {
|
||||
auto half = xla::ScalarLike(x, 0.5);
|
||||
auto one = xla::ScalarLike(x, 1.0);
|
||||
auto two = xla::ScalarLike(x, 2.0);
|
||||
auto half = ScalarLike(x, 0.5);
|
||||
auto one = ScalarLike(x, 1.0);
|
||||
auto two = ScalarLike(x, 2.0);
|
||||
|
||||
auto round_val = xla::Floor(x);
|
||||
auto round_val = Floor(x);
|
||||
auto fraction = x - round_val;
|
||||
auto nearest_even_int = round_val - two * xla::Floor(half * x);
|
||||
auto is_odd = xla::Eq(nearest_even_int, one);
|
||||
return xla::Select(xla::Or(xla::Gt(fraction, half),
|
||||
xla::And(xla::Eq(fraction, half), is_odd)),
|
||||
round_val + one, round_val);
|
||||
auto nearest_even_int = round_val - two * Floor(half * x);
|
||||
auto is_odd = Eq(nearest_even_int, one);
|
||||
return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
|
||||
round_val + one, round_val);
|
||||
}
|
||||
|
||||
// Trigonometric functions.
|
||||
@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
|
||||
|
||||
XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
|
||||
|
||||
XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
auto perform_conj = shape.element_type() == C64 && conjugate;
|
||||
return perform_conj ? Conj(x) : x;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x);
|
||||
// Computes the hyperbolic sine of 'x'.
|
||||
XlaOp Sinh(XlaOp x);
|
||||
|
||||
// Applies a complex conjugation operation if `a` is complex and `conjugate`
|
||||
// is true, otherwise returns its argument.
|
||||
xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
|
||||
|
@ -166,4 +166,20 @@ XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TransposeInMinorDims(XlaOp x) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
std::vector<int64> permutation(n_dims);
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
|
||||
return Transpose(x, permutation);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
|
||||
return transpose ? TransposeInMinorDims(x) : x;
|
||||
}
|
||||
} // namespace xla
|
||||
|
@ -60,6 +60,14 @@ XlaOp LowerTriangle(XlaOp x);
|
||||
xla::XlaOp BatchDot(
|
||||
xla::XlaOp x, xla::XlaOp y,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
|
||||
|
||||
// Transposes a stack of matrices `x` by swapping the last two dimensions.
|
||||
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
|
||||
|
||||
// Transposes `x` in its minor dimensions if `transpose` is true, otherwise
|
||||
// returns `x` unchanged.
|
||||
xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
@ -25,13 +26,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class NumericTest : public ClientLibraryTestBase {
|
||||
class MatrixTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
template <typename T>
|
||||
void TestMatrixDiagonal();
|
||||
};
|
||||
|
||||
XLA_TEST_F(NumericTest, Triangle) {
|
||||
XLA_TEST_F(MatrixTest, Triangle) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array3D<int32> input(2, 3, 4);
|
||||
input.FillIota(0);
|
||||
@ -46,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void NumericTest::TestMatrixDiagonal() {
|
||||
void MatrixTest::TestMatrixDiagonal() {
|
||||
XlaBuilder builder("GetMatrixDiagonal");
|
||||
Array3D<T> input(2, 3, 4);
|
||||
input.FillIota(0);
|
||||
@ -59,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() {
|
||||
ComputeAndCompareR2<T>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
|
||||
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32>(); }
|
||||
|
||||
XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
|
||||
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64>(); }
|
||||
|
||||
XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
|
||||
XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
|
||||
|
||||
Array3D<float> BatchedAValsFull() {
|
||||
return {{
|
||||
{2, 0, 1, 2},
|
||||
{3, 6, 0, 1},
|
||||
{4, 7, 9, 0},
|
||||
{5, 8, 10, 11},
|
||||
},
|
||||
{
|
||||
{16, 24, 8, 12},
|
||||
{24, 61, 82, 48},
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
}};
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, RowBatchDot) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
int n = 4;
|
||||
|
||||
XlaOp a, row, index;
|
||||
auto a_data =
|
||||
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
|
||||
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
|
||||
"row", &builder, &row);
|
||||
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
|
||||
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
|
||||
|
||||
auto l_index = DynamicSliceInMinorDims(
|
||||
a, {index, ConstantR0<int32>(&builder, 0)}, {1, n});
|
||||
BatchDot(l_index, TransposeInMinorDims(row));
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
|
||||
{a_data.get(), row_data.get(), index_data.get()});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
134
tensorflow/compiler/xla/client/lib/slicing.cc
Normal file
134
tensorflow/compiler/xla/client/lib/slicing.cc
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_RET_CHECK(start.size() == end.size());
|
||||
int64 n_minor_dims = start.size();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
auto major_dims = AsInt64Slice(shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - n_minor_dims);
|
||||
|
||||
// Prepends 0s in the major dim
|
||||
std::vector<int64> padded_start(n_dims, 0);
|
||||
std::copy(start.begin(), start.end(),
|
||||
padded_start.begin() + major_dims.size());
|
||||
|
||||
// Prepends the shape of the major dims.
|
||||
std::vector<int64> padded_end(n_dims);
|
||||
std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
|
||||
std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
|
||||
|
||||
std::vector<int64> strides(n_dims, 1);
|
||||
return Slice(x, padded_start, padded_end, strides);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
|
||||
std::vector<int32> start_as_int32(start.begin(), start.end());
|
||||
auto start_constant = ConstantR1<int32>(builder, start_as_int32);
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
TF_ASSIGN_OR_RETURN(Shape start_constant_shape,
|
||||
builder->GetShape(start_constant));
|
||||
const int64 start_length =
|
||||
ShapeUtil::GetDimension(start_constant_shape, -1);
|
||||
TF_RET_CHECK(start_length == n_dims);
|
||||
return DynamicUpdateSlice(x, update, start_constant);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
absl::Span<const int64> start) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
const int64 n_minor_dims = start.size();
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
std::vector<int64> padded_start(n_dims, 0);
|
||||
std::copy(start.begin(), start.end(),
|
||||
padded_start.begin() + (n_dims - n_minor_dims));
|
||||
return UpdateSlice(x, update, padded_start);
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys) {
|
||||
std::vector<int64> output(xs.size() + ys.size());
|
||||
std::copy(xs.begin(), xs.end(), output.begin());
|
||||
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
|
||||
return output;
|
||||
}
|
||||
|
||||
XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span<const XlaOp> starts) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
auto zero = Reshape(ConstantR0<int32>(builder, 0), {1});
|
||||
std::vector<XlaOp> padded_starts(n_dims, zero);
|
||||
for (int i = 0; i < starts.size(); ++i) {
|
||||
padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1});
|
||||
}
|
||||
return ConcatInDim(builder, padded_starts, 0);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
|
||||
absl::Span<const int64> sizes) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = ShapeUtil::Rank(shape);
|
||||
int64 n_minor_dims = starts.size();
|
||||
TF_RET_CHECK(n_minor_dims == sizes.size());
|
||||
TF_RET_CHECK(n_minor_dims <= n_dims);
|
||||
auto major_dims = AsInt64Slice(shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - sizes.size());
|
||||
auto padded_starts = PrependZerosInMajorDims(x, starts);
|
||||
auto padded_sizes = ConcatVectors(major_dims, sizes);
|
||||
return DynamicSlice(x, padded_starts, padded_sizes);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
absl::Span<const XlaOp> starts) {
|
||||
auto padded_starts = PrependZerosInMajorDims(x, starts);
|
||||
return DynamicUpdateSlice(x, update, padded_starts);
|
||||
}
|
||||
|
||||
} // namespace xla
|
48
tensorflow/compiler/xla/client/lib/slicing.h
Normal file
48
tensorflow/compiler/xla/client/lib/slicing.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Updates a slice of 'x', i.e.,
|
||||
// x[start[0], ..., start[n]] = update
|
||||
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);
|
||||
|
||||
// Performs a slice in the minor dimensions of a tensor.
|
||||
// x[..., start[0]:end[0], ..., start[n]:end[n]]
|
||||
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end);
|
||||
|
||||
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
|
||||
// x[..., start[0]:..., ..., start[n]:...] = update
|
||||
XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
absl::Span<const int64> start);
|
||||
|
||||
// Performs a dynamic slice in the minor dimensions of a tensor.
|
||||
XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
|
||||
absl::Span<const int64> sizes);
|
||||
|
||||
XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
|
||||
absl::Span<const XlaOp> starts);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_
|
@ -13,28 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using UtilTest = xla::ClientLibraryTestBase;
|
||||
using UtilLeftLookingTest = xla::ClientLibraryTestBase;
|
||||
using SlicingTest = xla::ClientLibraryTestBase;
|
||||
|
||||
xla::Array2D<float> BValsRight() {
|
||||
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
|
||||
@ -63,7 +54,7 @@ xla::Array3D<float> BatchedAValsFull() {
|
||||
}};
|
||||
}
|
||||
|
||||
XLA_TEST_F(UtilTest, Simple2dLookup) {
|
||||
XLA_TEST_F(SlicingTest, Simple2dLookup) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::XlaOp a, x, y;
|
||||
@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(UtilTest, Simple3dLookup) {
|
||||
XLA_TEST_F(SlicingTest, Simple3dLookup) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::XlaOp a, index;
|
||||
@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
|
||||
{a_data.get(), index_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
|
||||
XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::XlaOp a, b, x, y;
|
||||
@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
|
||||
{a_data.get(), b_data.get(), x_data.get(), y_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(UtilTest, RowBatchDot) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
int n = 4;
|
||||
|
||||
xla::XlaOp a, row, index;
|
||||
auto a_data =
|
||||
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
|
||||
auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
|
||||
"row", &builder, &row);
|
||||
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
|
||||
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
|
||||
|
||||
auto l_index = DynamicSliceInMinorDims(
|
||||
a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n});
|
||||
BatchDot(l_index, TransposeInMinorDims(row));
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
|
||||
{a_data.get(), row_data.get(), index_data.get()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user