[XLA] Invert the meaning of xla::Permute, add an inverse xla::PermuteInverse().

Also swap its arguments; to my eyes the call is more natural with the permutation second.

It turns out that a significant number of users of xla::Permute() want the inverse behavior. The new semantics of xla::Permute() also match the permutation semantics of operations like xla::Transpose(); it seems more consistent to use a matching convention.

PiperOrigin-RevId: 356796646
Change-Id: Ifd6c0d7d6eea69c50a342601a3f1cba725880696
This commit is contained in:
Peter Hawkins 2021-02-10 12:15:43 -08:00 committed by TensorFlower Gardener
parent e04eb6b4af
commit 6d24ecfdcc
7 changed files with 31 additions and 12 deletions

View File

@ -31,14 +31,33 @@ namespace xla {
bool IsPermutation(absl::Span<const int64> permutation);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
// For each i, output[i] = input[permutation[i]].
//
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
template <typename Container>
std::vector<typename Container::value_type> Permute(
absl::Span<const int64> permutation, const Container& input) {
const Container& input, absl::Span<const int64> permutation) {
using T = typename Container::value_type;
absl::Span<const T> data(input);
CHECK_EQ(permutation.size(), data.size());
CHECK(IsPermutation(permutation));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[i] = data[permutation[i]];
}
return output;
}
// Applies the inverse of `permutation` on `input` and returns the permuted
// array. For each i, output[permutation[i]] = input[i].
//
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
template <typename Container>
std::vector<typename Container::value_type> PermuteInverse(
const Container& input, absl::Span<const int64> permutation) {
using T = typename Container::value_type;
absl::Span<const T> data(input);
CHECK_EQ(permutation.size(), data.size());

View File

@ -4868,7 +4868,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims));
param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
param_in_reduced_shape_arrays.push_back(
param_arrays[id].CastToShape(reduced_shape, &b_));
} else {
@ -4903,8 +4903,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
if (!tiled_param_ids.empty()) {
// Calculate the input tile origin from the output tile origin.
const IrArray::Index input_tile_origin(
Permute({0, 2, 1}, index.multidim()),
Permute({0, 2, 1}, index.dims()), index.GetType());
Permute(index.multidim(), {0, 2, 1}),
Permute(index.dims(), {0, 2, 1}), index.GetType());
// Copy input parameter values to shared memory buffers:
// tile[thread_id_y, thread_id_x] = input[index]

View File

@ -203,7 +203,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
bool matched = true;
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
std::vector<int64> rindexes = PermuteInverse(indices, permutation);
matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);

View File

@ -509,7 +509,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
std::vector<int64> rindexes = PermuteInverse(indices, permutation);
EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
});
}

View File

@ -1865,9 +1865,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
TF_RET_CHECK(shape.dimensions().size() ==
transpose->operand(0)->shape().dimensions().size());
TF_RET_CHECK(std::equal(
operand->shape().dimensions().begin(),
operand->shape().dimensions().end(),
Permute(transpose->dimensions(), shape.dimensions()).begin()))
shape.dimensions().begin(), shape.dimensions().end(),
Permute(operand->shape().dimensions(), transpose->dimensions())
.begin()))
<< "shape: " << shape << ", operand->shape(): " << shape
<< ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
<< "}";

View File

@ -286,7 +286,7 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose(
const Shape& shape, const Shape& operand_shape,
absl::Span<const int64> dimension_mapping) const {
std::vector<llvm::Value*> operand_multidim_index =
Permute(dimension_mapping, multidim());
PermuteInverse(multidim(), dimension_mapping);
if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
LayoutUtil::HasLayout(shape) &&

View File

@ -1058,7 +1058,7 @@ Status ForEachMutableSubshapeHelper(
absl::Span<const int64> permutation, const Shape& shape) {
Shape new_shape = shape;
new_shape.clear_dimensions();
for (auto dim : Permute(permutation, shape.dimensions())) {
for (auto dim : PermuteInverse(shape.dimensions(), permutation)) {
new_shape.add_dimensions(dim);
}
for (int64 i = 0; i < shape.rank(); i++) {