[XLA] Invert the meaning of xla::Permute, add an inverse xla::PermuteInverse().
Also swap its arguments; to my eyes the call is more natural with the permutation second. It turns out that a significant number of users of xla::Permute() want the inverse behavior. The new semantics of xla::Permute() also match the permutation semantics of operations like xla::Transpose(); it seems more consistent to use a matching convention. PiperOrigin-RevId: 356796646 Change-Id: Ifd6c0d7d6eea69c50a342601a3f1cba725880696
This commit is contained in:
parent
e04eb6b4af
commit
6d24ecfdcc
tensorflow/compiler/xla
@ -31,14 +31,33 @@ namespace xla {
|
||||
bool IsPermutation(absl::Span<const int64> permutation);
|
||||
|
||||
// Applies `permutation` on `input` and returns the permuted array.
|
||||
// For each i, output[permutation[i]] = input[i].
|
||||
// For each i, output[i] = input[permutation[i]].
|
||||
//
|
||||
// Precondition:
|
||||
// 1. `permutation` is a permutation of 0..permutation.size()-1.
|
||||
// 2. permutation.size() == input.size().
|
||||
template <typename Container>
|
||||
std::vector<typename Container::value_type> Permute(
|
||||
absl::Span<const int64> permutation, const Container& input) {
|
||||
const Container& input, absl::Span<const int64> permutation) {
|
||||
using T = typename Container::value_type;
|
||||
absl::Span<const T> data(input);
|
||||
CHECK_EQ(permutation.size(), data.size());
|
||||
CHECK(IsPermutation(permutation));
|
||||
std::vector<T> output(data.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[i] = data[permutation[i]];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
// Applies the inverse of `permutation` on `input` and returns the permuted
|
||||
// array. For each i, output[permutation[i]] = input[i].
|
||||
//
|
||||
// Precondition:
|
||||
// 1. `permutation` is a permutation of 0..permutation.size()-1.
|
||||
// 2. permutation.size() == input.size().
|
||||
template <typename Container>
|
||||
std::vector<typename Container::value_type> PermuteInverse(
|
||||
const Container& input, absl::Span<const int64> permutation) {
|
||||
using T = typename Container::value_type;
|
||||
absl::Span<const T> data(input);
|
||||
CHECK_EQ(permutation.size(), data.size());
|
||||
|
@ -4868,7 +4868,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
|
||||
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
|
||||
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims));
|
||||
param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
|
||||
param_in_reduced_shape_arrays.push_back(
|
||||
param_arrays[id].CastToShape(reduced_shape, &b_));
|
||||
} else {
|
||||
@ -4903,8 +4903,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
if (!tiled_param_ids.empty()) {
|
||||
// Calculate the input tile origin from the output tile origin.
|
||||
const IrArray::Index input_tile_origin(
|
||||
Permute({0, 2, 1}, index.multidim()),
|
||||
Permute({0, 2, 1}, index.dims()), index.GetType());
|
||||
Permute(index.multidim(), {0, 2, 1}),
|
||||
Permute(index.dims(), {0, 2, 1}), index.GetType());
|
||||
|
||||
// Copy input parameter values to shared memory buffers:
|
||||
// tile[thread_id_y, thread_id_x] = input[index]
|
||||
|
@ -203,7 +203,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
bool matched = true;
|
||||
root->literal().EachCell<NativeT>(
|
||||
[&](absl::Span<const int64> indices, NativeT value) {
|
||||
std::vector<int64> rindexes = Permute(permutation, indices);
|
||||
std::vector<int64> rindexes = PermuteInverse(indices, permutation);
|
||||
matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
|
||||
});
|
||||
EXPECT_TRUE(matched);
|
||||
|
@ -509,7 +509,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) {
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
|
||||
std::vector<int64> rindexes = Permute(permutation, indices);
|
||||
std::vector<int64> rindexes = PermuteInverse(indices, permutation);
|
||||
EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
|
||||
});
|
||||
}
|
||||
|
@ -1865,9 +1865,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
TF_RET_CHECK(shape.dimensions().size() ==
|
||||
transpose->operand(0)->shape().dimensions().size());
|
||||
TF_RET_CHECK(std::equal(
|
||||
operand->shape().dimensions().begin(),
|
||||
operand->shape().dimensions().end(),
|
||||
Permute(transpose->dimensions(), shape.dimensions()).begin()))
|
||||
shape.dimensions().begin(), shape.dimensions().end(),
|
||||
Permute(operand->shape().dimensions(), transpose->dimensions())
|
||||
.begin()))
|
||||
<< "shape: " << shape << ", operand->shape(): " << shape
|
||||
<< ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
|
||||
<< "}";
|
||||
|
@ -286,7 +286,7 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose(
|
||||
const Shape& shape, const Shape& operand_shape,
|
||||
absl::Span<const int64> dimension_mapping) const {
|
||||
std::vector<llvm::Value*> operand_multidim_index =
|
||||
Permute(dimension_mapping, multidim());
|
||||
PermuteInverse(multidim(), dimension_mapping);
|
||||
|
||||
if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
|
||||
LayoutUtil::HasLayout(shape) &&
|
||||
|
@ -1058,7 +1058,7 @@ Status ForEachMutableSubshapeHelper(
|
||||
absl::Span<const int64> permutation, const Shape& shape) {
|
||||
Shape new_shape = shape;
|
||||
new_shape.clear_dimensions();
|
||||
for (auto dim : Permute(permutation, shape.dimensions())) {
|
||||
for (auto dim : PermuteInverse(shape.dimensions(), permutation)) {
|
||||
new_shape.add_dimensions(dim);
|
||||
}
|
||||
for (int64 i = 0; i < shape.rank(); i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user