[TF:XLA] Simplify the lowering for DataFormatVecPermute

Sort is a bit heavy for reordering 4 values. Gather is still not perfect
but at least we can easily fuse it.

PiperOrigin-RevId: 254731756
This commit is contained in:
Benjamin Kramer 2019-06-24 04:20:11 -07:00 committed by TensorFlower Gardener
parent e8ad5b0552
commit 0e989df426

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
@ -65,24 +65,18 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
"Second dimension of 2D input must be of size 2, but got shape ",
input_tensor_shape.DebugString()));
}
std::vector<int32> dst_indices(4, 0);
int32 dst_indices[4];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
if (src_format_[i] == dst_format_[j]) {
dst_indices[i] = j;
dst_indices[j] = i;
break;
}
}
}
auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
if (input_rank == 2) {
keys = xla::BroadcastInDim(keys, {4, 2}, {0});
}
auto sorted = xla::Sort({keys, ctx->Input(0)},
xla::CreateScalarLtComputation(
{xla::S32, ctx->input_xla_type(0)}, builder),
0);
auto output = xla::GetTupleElement(sorted, 1);
xla::XlaOp indices =
xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0);
ctx->SetOutput(0, output);
}