update xla data format map ops
This commit is contained in:
parent
25567bd841
commit
7a38d3fd96
tensorflow/compiler
@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase):
|
|||||||
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
|
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
|
||||||
[3, 2, 1, 0, 3, 2, 1, 0])
|
[3, 2, 1, 0, 3, 2, 1, 0])
|
||||||
|
|
||||||
|
self._test(0, "NDHWC", "NCDHW", 0)
|
||||||
|
self._test(1, "NDHWC", "NCDHW", 2)
|
||||||
|
self._test(2, "NDHWC", "NCDHW", 3)
|
||||||
|
self._test(3, "NDHWC", "NCDHW", 4)
|
||||||
|
self._test(4, "NDHWC", "NCDHW", 1)
|
||||||
|
self._test([1, 4], "NDHWC", "NCDHW", [2, 1])
|
||||||
|
self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4])
|
||||||
|
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||||
|
self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]])
|
||||||
|
|
||||||
|
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||||
|
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC",
|
||||||
|
[3, 0, 1, 2, 4, 3, 0, 1, 2, 4])
|
||||||
|
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN",
|
||||||
|
[4, 2, 1, 0, 3, 4, 2, 1, 0, 3])
|
||||||
|
|
||||||
|
|
||||||
class XlaPermuteOpTest(xla_test.XLATestCase):
|
class XlaPermuteOpTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@ -35,15 +35,18 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||||
string dst_format;
|
string dst_format;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||||
OP_REQUIRES(context, src_format.size() == 4,
|
OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5,
|
||||||
errors::InvalidArgument(absl::StrCat(
|
errors::InvalidArgument(absl::StrCat(
|
||||||
"Source format must of length 4, received src_format = ",
|
"Source format must of length 4 or 5, "
|
||||||
src_format)));
|
"received src_format = ", src_format)));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, dst_format.size() == 4,
|
context, dst_format.size() == 4 or dst_format.size() == 5,
|
||||||
errors::InvalidArgument(absl::StrCat(
|
errors::InvalidArgument(absl::StrCat(
|
||||||
"Destination format must of length 4, received dst_format = ",
|
"Destination format must of length 4 or 5, received dst_format = ",
|
||||||
dst_format)));
|
dst_format)));
|
||||||
|
for (int i = 0; i < src_format.size(); ++i) {
|
||||||
|
dst_idx_.push_back(-1);
|
||||||
|
}
|
||||||
for (int i = 0; i < src_format.size(); ++i) {
|
for (int i = 0; i < src_format.size(); ++i) {
|
||||||
for (int j = 0; j < dst_format.size(); ++j) {
|
for (int j = 0; j < dst_format.size(); ++j) {
|
||||||
if (dst_format[j] == src_format[i]) {
|
if (dst_format[j] == src_format[i]) {
|
||||||
@ -61,9 +64,10 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
auto builder = context->builder();
|
auto builder = context->builder();
|
||||||
xla::XlaOp dst_indices =
|
xla::XlaOp dst_indices =
|
||||||
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
|
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
|
||||||
xla::XlaOp four = xla::ConstantR0<int32>(builder, 4);
|
const int dims = dst_idx_.size();
|
||||||
|
xla::XlaOp rank = xla::ConstantR0<int32>(builder, dims);
|
||||||
xla::XlaOp src_indices =
|
xla::XlaOp src_indices =
|
||||||
(xla::ConvertElementType(context->Input(0), xla::S32) + four) % four;
|
(xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank;
|
||||||
xla::XlaOp output =
|
xla::XlaOp output =
|
||||||
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
|
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
|
||||||
context->SetOutput(
|
context->SetOutput(
|
||||||
@ -71,7 +75,7 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::array<int32, 4> dst_idx_ = {{-1, -1, -1, -1}};
|
std::vector<int32> dst_idx_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user