Add support for int64 InvertPermutation Op
PiperOrigin-RevId: 321185678 Change-Id: I62fe9daa942728c56bc92f730e1ed7c1adf8a511
This commit is contained in:
parent
a6207b8f2e
commit
6c0fba8590
@ -923,16 +923,22 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array([1, 0x100000003f800000], np.uint64))
|
||||
|
||||
def testInvertPermutation(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.invert_permutation,
|
||||
np.array([1, 2, 0], np.int32),
|
||||
expected=np.array([2, 0, 1], dtype=np.int32))
|
||||
for np_dtype in [np.int32, np.int64]:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.invert_permutation,
|
||||
np.array([1, 2, 0], np_dtype),
|
||||
expected=np.array([2, 0, 1], dtype=np_dtype))
|
||||
|
||||
def testInvertPermutationTwiceIsNoop(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
|
||||
np.array([1, 2, 0], np.int32),
|
||||
expected=np.array([1, 2, 0], dtype=np.int32))
|
||||
|
||||
def invert_twice(x):
|
||||
return array_ops.invert_permutation(array_ops.invert_permutation(x))
|
||||
|
||||
for np_dtype in [np.int32, np.int64]:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
invert_twice,
|
||||
np.array([1, 2, 0], np_dtype),
|
||||
expected=np.array([1, 2, 0], dtype=np_dtype))
|
||||
|
||||
def testRank(self):
|
||||
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
|
||||
|
||||
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -110,11 +111,11 @@ REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"),
|
||||
|
||||
// InvertPermutation frequently forms part of the gradient of Transpose.
|
||||
//
|
||||
// inv = InvertPermutationOp(T<int32> p) takes a permutation of
|
||||
// inv = InvertPermutationOp(p) takes a permutation of
|
||||
// integers 0, 1, ..., n - 1 and returns the inverted
|
||||
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
|
||||
//
|
||||
// REQUIRES: input is a vector of int32.
|
||||
// REQUIRES: input is a vector of int32 or int64.
|
||||
// REQUIRES: input is a permutation of 0, 1, ..., n-1.
|
||||
|
||||
class InvertPermutationOp : public XlaOpKernel {
|
||||
@ -122,11 +123,32 @@ class InvertPermutationOp : public XlaOpKernel {
|
||||
explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType dtype = ctx->expected_output_dtype(0);
|
||||
Status status;
|
||||
switch (dtype) {
|
||||
case DT_INT32:
|
||||
InvertPermutation<int32>(ctx);
|
||||
break;
|
||||
case DT_INT64:
|
||||
InvertPermutation<int64>(ctx);
|
||||
break;
|
||||
default:
|
||||
// This should never happen since we restrict this kernel to only match
|
||||
// inputs with supported Tensor datatype.
|
||||
OP_REQUIRES_OK(ctx, errors::InvalidArgument(
|
||||
"InvertPermutation expects x as either ",
|
||||
"int32 or int64, not ", DataTypeString(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InvertPermutation(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES(ctx,
|
||||
FastBoundsCheck(ctx->InputShape(0).num_elements(),
|
||||
std::numeric_limits<int32>::max()),
|
||||
errors::InvalidArgument("permutation of nonnegative int32s "
|
||||
"must have <= int32 max elements"));
|
||||
std::numeric_limits<T>::max()),
|
||||
errors::InvalidArgument(
|
||||
"permutation of nonnegative integers must have <= ",
|
||||
std::numeric_limits<T>::max(), " elements"));
|
||||
|
||||
auto e = ctx->InputExpression(0);
|
||||
auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client());
|
||||
@ -142,7 +164,7 @@ class InvertPermutationOp : public XlaOpKernel {
|
||||
|
||||
int size = perm.size();
|
||||
|
||||
std::vector<int32> output(size);
|
||||
std::vector<T> output(size);
|
||||
std::fill_n(output.data(), size, -1);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const int64 d = perm[i];
|
||||
@ -153,11 +175,13 @@ class InvertPermutationOp : public XlaOpKernel {
|
||||
output[d] = i;
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output));
|
||||
ctx->SetOutput(0, xla::ConstantR1<T>(ctx->builder(), output));
|
||||
} else {
|
||||
auto indices = ctx->Input(0);
|
||||
int size = ctx->InputShape(0).num_elements();
|
||||
auto iota = xla::Iota(ctx->builder(), xla::S32, size);
|
||||
T size = ctx->InputShape(0).num_elements();
|
||||
auto iota =
|
||||
xla::Iota(ctx->builder(),
|
||||
xla::primitive_util::NativeToPrimitiveType<T>(), size);
|
||||
auto result = XlaScatter(iota, iota, indices,
|
||||
/*indices_are_vectors=*/false, /*combiner=*/{},
|
||||
ctx->builder());
|
||||
@ -167,8 +191,9 @@ class InvertPermutationOp : public XlaOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32),
|
||||
InvertPermutationOp);
|
||||
REGISTER_XLA_OP(
|
||||
Name("InvertPermutation").TypeConstraint("T", {DT_INT32, DT_INT64}),
|
||||
InvertPermutationOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user