update xla data format map ops

This commit is contained in:
Kaixi Hou 2020-09-23 11:49:20 -07:00
parent 25567bd841
commit 7a38d3fd96
2 changed files with 28 additions and 8 deletions
tensorflow/compiler

View File

@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase):
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
[3, 2, 1, 0, 3, 2, 1, 0])
self._test(0, "NDHWC", "NCDHW", 0)
self._test(1, "NDHWC", "NCDHW", 2)
self._test(2, "NDHWC", "NCDHW", 3)
self._test(3, "NDHWC", "NCDHW", 4)
self._test(4, "NDHWC", "NCDHW", 1)
self._test([1, 4], "NDHWC", "NCDHW", [2, 1])
self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4])
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]])
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC",
[3, 0, 1, 2, 4, 3, 0, 1, 2, 4])
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN",
[4, 2, 1, 0, 3, 4, 2, 1, 0, 3])
class XlaPermuteOpTest(xla_test.XLATestCase):

View File

@ -35,15 +35,18 @@ class DataFormatDimMapOp : public XlaOpKernel {
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
string dst_format;
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
OP_REQUIRES(context, src_format.size() == 4,
OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5,
errors::InvalidArgument(absl::StrCat(
"Source format must of length 4, received src_format = ",
src_format)));
"Source format must of length 4 or 5, "
"received src_format = ", src_format)));
OP_REQUIRES(
context, dst_format.size() == 4,
context, dst_format.size() == 4 or dst_format.size() == 5,
errors::InvalidArgument(absl::StrCat(
"Destination format must of length 4, received dst_format = ",
"Destination format must of length 4 or 5, received dst_format = ",
dst_format)));
for (int i = 0; i < src_format.size(); ++i) {
dst_idx_.push_back(-1);
}
for (int i = 0; i < src_format.size(); ++i) {
for (int j = 0; j < dst_format.size(); ++j) {
if (dst_format[j] == src_format[i]) {
@ -61,9 +64,10 @@ class DataFormatDimMapOp : public XlaOpKernel {
auto builder = context->builder();
xla::XlaOp dst_indices =
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
xla::XlaOp four = xla::ConstantR0<int32>(builder, 4);
const int dims = dst_idx_.size();
xla::XlaOp rank = xla::ConstantR0<int32>(builder, dims);
xla::XlaOp src_indices =
(xla::ConvertElementType(context->Input(0), xla::S32) + four) % four;
(xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank;
xla::XlaOp output =
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
context->SetOutput(
@ -71,7 +75,7 @@ class DataFormatDimMapOp : public XlaOpKernel {
}
private:
std::array<int32, 4> dst_idx_ = {{-1, -1, -1, -1}};
std::vector<int32> dst_idx_;
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
};