Add support for int64 InvertPermutation Op

PiperOrigin-RevId: 321185678
Change-Id: I62fe9daa942728c56bc92f730e1ed7c1adf8a511
This commit is contained in:
Gaurav Jain 2020-07-14 10:31:50 -07:00 committed by TensorFlower Gardener
parent a6207b8f2e
commit 6c0fba8590
2 changed files with 50 additions and 19 deletions

View File

@ -923,16 +923,22 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, 0x100000003f800000], np.uint64))
def testInvertPermutation(self):
self._assertOpOutputMatchesExpected(
array_ops.invert_permutation,
np.array([1, 2, 0], np.int32),
expected=np.array([2, 0, 1], dtype=np.int32))
for np_dtype in [np.int32, np.int64]:
self._assertOpOutputMatchesExpected(
array_ops.invert_permutation,
np.array([1, 2, 0], np_dtype),
expected=np.array([2, 0, 1], dtype=np_dtype))
def testInvertPermutationTwiceIsNoop(self):
self._assertOpOutputMatchesExpected(
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
np.array([1, 2, 0], np.int32),
expected=np.array([1, 2, 0], dtype=np.int32))
def invert_twice(x):
return array_ops.invert_permutation(array_ops.invert_permutation(x))
for np_dtype in [np.int32, np.int64]:
self._assertOpOutputMatchesExpected(
invert_twice,
np.array([1, 2, 0], np_dtype),
expected=np.array([1, 2, 0], dtype=np_dtype))
def testRank(self):
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@ -110,11 +111,11 @@ REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"),
// InvertPermutation frequently forms part of the gradient of Transpose.
//
// inv = InvertPermutationOp(T<int32> p) takes a permutation of
// inv = InvertPermutationOp(p) takes a permutation of
// integers 0, 1, ..., n - 1 and returns the inverted
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
//
// REQUIRES: input is a vector of int32.
// REQUIRES: input is a vector of int32 or int64.
// REQUIRES: input is a permutation of 0, 1, ..., n-1.
class InvertPermutationOp : public XlaOpKernel {
@ -122,11 +123,32 @@ class InvertPermutationOp : public XlaOpKernel {
explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType dtype = ctx->expected_output_dtype(0);
Status status;
switch (dtype) {
case DT_INT32:
InvertPermutation<int32>(ctx);
break;
case DT_INT64:
InvertPermutation<int64>(ctx);
break;
default:
// This should never happen since we restrict this kernel to only match
// inputs with supported Tensor datatype.
OP_REQUIRES_OK(ctx, errors::InvalidArgument(
"InvertPermutation expects x as either ",
"int32 or int64, not ", DataTypeString(dtype)));
}
}
template <typename T>
void InvertPermutation(XlaOpKernelContext* ctx) {
OP_REQUIRES(ctx,
FastBoundsCheck(ctx->InputShape(0).num_elements(),
std::numeric_limits<int32>::max()),
errors::InvalidArgument("permutation of nonnegative int32s "
"must have <= int32 max elements"));
std::numeric_limits<T>::max()),
errors::InvalidArgument(
"permutation of nonnegative integers must have <= ",
std::numeric_limits<T>::max(), " elements"));
auto e = ctx->InputExpression(0);
auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client());
@ -142,7 +164,7 @@ class InvertPermutationOp : public XlaOpKernel {
int size = perm.size();
std::vector<int32> output(size);
std::vector<T> output(size);
std::fill_n(output.data(), size, -1);
for (int i = 0; i < size; ++i) {
const int64 d = perm[i];
@ -153,11 +175,13 @@ class InvertPermutationOp : public XlaOpKernel {
output[d] = i;
}
ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output));
ctx->SetOutput(0, xla::ConstantR1<T>(ctx->builder(), output));
} else {
auto indices = ctx->Input(0);
int size = ctx->InputShape(0).num_elements();
auto iota = xla::Iota(ctx->builder(), xla::S32, size);
T size = ctx->InputShape(0).num_elements();
auto iota =
xla::Iota(ctx->builder(),
xla::primitive_util::NativeToPrimitiveType<T>(), size);
auto result = XlaScatter(iota, iota, indices,
/*indices_are_vectors=*/false, /*combiner=*/{},
ctx->builder());
@ -167,8 +191,9 @@ class InvertPermutationOp : public XlaOpKernel {
}
};
REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32),
InvertPermutationOp);
REGISTER_XLA_OP(
Name("InvertPermutation").TypeConstraint("T", {DT_INT32, DT_INT64}),
InvertPermutationOp);
} // namespace
} // namespace tensorflow