[TF:XLA] Simplify the lowering for DataFormatVecPermute
Sort is a bit heavy for reordering 4 values. Gather is still not perfect but at least we can easily fuse it. PiperOrigin-RevId: 254731756
This commit is contained in:
parent
e8ad5b0552
commit
0e989df426
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -65,24 +65,18 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
|
||||
"Second dimension of 2D input must be of size 2, but got shape ",
|
||||
input_tensor_shape.DebugString()));
|
||||
}
|
||||
std::vector<int32> dst_indices(4, 0);
|
||||
int32 dst_indices[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
if (src_format_[i] == dst_format_[j]) {
|
||||
dst_indices[i] = j;
|
||||
dst_indices[j] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
|
||||
if (input_rank == 2) {
|
||||
keys = xla::BroadcastInDim(keys, {4, 2}, {0});
|
||||
}
|
||||
auto sorted = xla::Sort({keys, ctx->Input(0)},
|
||||
xla::CreateScalarLtComputation(
|
||||
{xla::S32, ctx->input_xla_type(0)}, builder),
|
||||
0);
|
||||
auto output = xla::GetTupleElement(sorted, 1);
|
||||
xla::XlaOp indices =
|
||||
xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
|
||||
xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user