From 0e989df4269bda82d4eb2efce52feb0878c5c339 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 24 Jun 2019 04:20:11 -0700 Subject: [PATCH] [TF:XLA] Simplify the lowering for DataFormatVecPermute Sort is a bit heavy for reordering 4 values. Gather is still not perfect but at least we can easily fuse it. PiperOrigin-RevId: 254731756 --- .../compiler/tf2xla/kernels/permute_op.cc | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc index 94db561ee65..81e15e5f816 100644 --- a/tensorflow/compiler/tf2xla/kernels/permute_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" @@ -65,24 +65,18 @@ class DataFormatVecPermuteOp : public XlaOpKernel { "Second dimension of 2D input must be of size 2, but got shape ", input_tensor_shape.DebugString())); } - std::vector dst_indices(4, 0); + int32 dst_indices[4]; for (int i = 0; i < 4; ++i) { for (int j = 0; j < 4; ++j) { if (src_format_[i] == dst_format_[j]) { - dst_indices[i] = j; + dst_indices[j] = i; break; } } } - auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); - if (input_rank == 2) { - keys = xla::BroadcastInDim(keys, {4, 2}, {0}); - } - auto sorted = xla::Sort({keys, ctx->Input(0)}, - xla::CreateScalarLtComputation( - {xla::S32, ctx->input_xla_type(0)}, builder), - 0); - auto output = xla::GetTupleElement(sorted, 1); + xla::XlaOp indices = + xla::ConstantR1(builder, absl::Span(dst_indices)); + xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0); ctx->SetOutput(0, output); }