update xla data format map ops
This commit is contained in:
parent
25567bd841
commit
7a38d3fd96
tensorflow/compiler
@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase):
|
||||
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
|
||||
[3, 2, 1, 0, 3, 2, 1, 0])
|
||||
|
||||
self._test(0, "NDHWC", "NCDHW", 0)
|
||||
self._test(1, "NDHWC", "NCDHW", 2)
|
||||
self._test(2, "NDHWC", "NCDHW", 3)
|
||||
self._test(3, "NDHWC", "NCDHW", 4)
|
||||
self._test(4, "NDHWC", "NCDHW", 1)
|
||||
self._test([1, 4], "NDHWC", "NCDHW", [2, 1])
|
||||
self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4])
|
||||
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||
self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]])
|
||||
|
||||
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC",
|
||||
[3, 0, 1, 2, 4, 3, 0, 1, 2, 4])
|
||||
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN",
|
||||
[4, 2, 1, 0, 3, 4, 2, 1, 0, 3])
|
||||
|
||||
|
||||
class XlaPermuteOpTest(xla_test.XLATestCase):
|
||||
|
||||
|
@ -35,15 +35,18 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
OP_REQUIRES(context, src_format.size() == 4,
|
||||
OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5,
|
||||
errors::InvalidArgument(absl::StrCat(
|
||||
"Source format must of length 4, received src_format = ",
|
||||
src_format)));
|
||||
"Source format must of length 4 or 5, "
|
||||
"received src_format = ", src_format)));
|
||||
OP_REQUIRES(
|
||||
context, dst_format.size() == 4,
|
||||
context, dst_format.size() == 4 or dst_format.size() == 5,
|
||||
errors::InvalidArgument(absl::StrCat(
|
||||
"Destination format must of length 4, received dst_format = ",
|
||||
"Destination format must of length 4 or 5, received dst_format = ",
|
||||
dst_format)));
|
||||
for (int i = 0; i < src_format.size(); ++i) {
|
||||
dst_idx_.push_back(-1);
|
||||
}
|
||||
for (int i = 0; i < src_format.size(); ++i) {
|
||||
for (int j = 0; j < dst_format.size(); ++j) {
|
||||
if (dst_format[j] == src_format[i]) {
|
||||
@ -61,9 +64,10 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
||||
auto builder = context->builder();
|
||||
xla::XlaOp dst_indices =
|
||||
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
|
||||
xla::XlaOp four = xla::ConstantR0<int32>(builder, 4);
|
||||
const int dims = dst_idx_.size();
|
||||
xla::XlaOp rank = xla::ConstantR0<int32>(builder, dims);
|
||||
xla::XlaOp src_indices =
|
||||
(xla::ConvertElementType(context->Input(0), xla::S32) + four) % four;
|
||||
(xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank;
|
||||
xla::XlaOp output =
|
||||
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
|
||||
context->SetOutput(
|
||||
@ -71,7 +75,7 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<int32, 4> dst_idx_ = {{-1, -1, -1, -1}};
|
||||
std::vector<int32> dst_idx_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user