From 9b9cbbe2a69b7fcec72d82f271cb90839c3035b7 Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Sun, 22 Oct 2017 23:02:28 -0700
Subject: [PATCH] Add int64 Tperm type support for `Transpose` (#13909)

* Add int64 Tperm type support for `Transpose`

This fix adds int64 Tperm support for `Transpose`. In
`array_ops.cc`, `Transpose` and `ConjugateTranspose`
have been specified as accepting int32 and int64 perm
types. However, only int32 kernels has been registered.

This fix adds the int64 perm support by removing
the constraint on Tperm, resolve the type at runtime,
and copying the data type accordingly to correctly handle
the int64/int32 types.

Additional tests have been added as well.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for int64 of perm in Transpose.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add namespace to hide PermutationHelper

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Enable use_gpu=True for perm type test.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* extra // namespace annotation

* Adding a comment about int32 casting that should be safe.

Permutations only contain values that refer to dimensions, and the maximum number of dimensions we have is 254, so an int32 is always safe here.
---
 tensorflow/core/kernels/transpose_op.cc       | 134 ++++++++++--------
 .../python/kernel_tests/transpose_op_test.py  |  13 ++
 2 files changed, 85 insertions(+), 62 deletions(-)

diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index e151b38d90a..20f0edf309a 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -91,6 +91,26 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
                         InvertPermutationOp);
 #endif  // TENSORFLOW_USE_SYCL
 
+namespace {
+template <typename Tperm>
+Status PermutationHelper(const Tensor& perm, const int dims,
+                         std::vector<int32>* permutation) {
+  auto Vperm = perm.vec<Tperm>();
+  if (dims != Vperm.size()) {
+    return errors::InvalidArgument("transpose expects a vector of size ", dims,
+                                   ". But input(1) is a vector of size ",
+                                   Vperm.size());
+  }
+  // using volatile instead of SubtleMustCopy here so that the
+  // asynchrony boundary is permutation.
+  const volatile Tperm* perm_begin =
+      reinterpret_cast<const volatile Tperm*>(Vperm.data());
+  *permutation = std::vector<int32>(perm_begin, perm_begin + dims);
+
+  return Status::OK();
+}
+}  // namespace
+
 // output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
 // of type T and rank N, and a permutation of 0, 1, ..., N-1. It
 // shuffles the dimensions of the input tensor according to permutation.
@@ -113,17 +133,16 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
   OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
               errors::InvalidArgument("perm must be a vector, not ",
                                       perm.shape().DebugString()));
-  auto Vperm = perm.vec<int32>();
+
+  // Although Tperm may be an int64 type, an int32 is sufficient to hold
+  // dimension range values, so the narrowing here should be safe.
+  std::vector<int32> permutation;
   const int dims = input.dims();
-  OP_REQUIRES(ctx, dims == Vperm.size(),
-              errors::InvalidArgument(
-                  "transpose expects a vector of size ", input.dims(),
-                  ". But input(1) is a vector of size ", Vperm.size()));
-  // using volatile instead of SubtleMustCopy here so that the
-  // asynchrony boundary is permutation.
-  const volatile int32* perm_begin =
-      reinterpret_cast<const volatile int32*>(Vperm.data());
-  const std::vector<int32> permutation(perm_begin, perm_begin + dims);
+  if (perm.dtype() == DT_INT32) {
+    OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation));
+  } else {
+    OP_REQUIRES_OK(ctx, PermutationHelper<int64>(perm, dims, &permutation));
+  }
   TensorShape shape;
 
   // Check whether permutation is a permutation of integers of [0 .. dims).
@@ -142,10 +161,9 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
     }
   }
   for (int i = 0; i < dims; ++i) {
-    OP_REQUIRES(
-        ctx, bits[i],
-        errors::InvalidArgument(i, " is missing from {",
-                                str_util::Join(permutation, ","), "}."));
+    OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
+                                  i, " is missing from {",
+                                  str_util::Join(permutation, ","), "}."));
   }
 
   // 0-D, 1-D, and identity transposes do nothing.
@@ -185,18 +203,16 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
 }
 
 #ifdef INTEL_MKL
-#define REGISTER(T)                                           \
-  REGISTER_KERNEL_BUILDER(Name("Transpose")                   \
-                              .Device(DEVICE_CPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
-                          MklTransposeCpuOp);                 \
-  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")          \
-                              .Device(DEVICE_CPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
+#define REGISTER(T)                                   \
+  REGISTER_KERNEL_BUILDER(Name("Transpose")           \
+                              .Device(DEVICE_CPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
+                          MklTransposeCpuOp);         \
+  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
+                              .Device(DEVICE_CPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
                           MklConjugateTransposeCpuOp);
 TF_CALL_ALL_TYPES(REGISTER);
 REGISTER(bfloat16);
@@ -204,18 +220,16 @@ REGISTER(bfloat16);
 
 #else  // INTEL_MKL
 
-#define REGISTER(T)                                           \
-  REGISTER_KERNEL_BUILDER(Name("Transpose")                   \
-                              .Device(DEVICE_CPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
-                          TransposeCpuOp);                    \
-  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")          \
-                              .Device(DEVICE_CPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
+#define REGISTER(T)                                   \
+  REGISTER_KERNEL_BUILDER(Name("Transpose")           \
+                              .Device(DEVICE_CPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
+                          TransposeCpuOp);            \
+  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
+                              .Device(DEVICE_CPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
                           ConjugateTransposeCpuOp);
 TF_CALL_ALL_TYPES(REGISTER)
 REGISTER(bfloat16);
@@ -238,18 +252,16 @@ Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
                                             perm, out);
 }
 
-#define REGISTER(T)                                           \
-  REGISTER_KERNEL_BUILDER(Name("Transpose")                   \
-                              .Device(DEVICE_GPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
-                          TransposeGpuOp);                    \
-  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")          \
-                              .Device(DEVICE_GPU)             \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
+#define REGISTER(T)                                   \
+  REGISTER_KERNEL_BUILDER(Name("Transpose")           \
+                              .Device(DEVICE_GPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
+                          TransposeGpuOp);            \
+  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
+                              .Device(DEVICE_GPU)     \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
                           ConjugateTransposeGpuOp);
 TF_CALL_POD_TYPES(REGISTER);
 #undef REGISTER
@@ -270,18 +282,16 @@ Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx,
   return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<SYCLDevice>(), in,
                                             perm, out);
 }
-#define REGISTER(T)                                           \
-  REGISTER_KERNEL_BUILDER(Name("Transpose")                   \
-                              .Device(DEVICE_SYCL)            \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
-                          TransposeSyclOp);                   \
-  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")          \
-                              .Device(DEVICE_SYCL)            \
-                              .TypeConstraint<T>("T")         \
-                              .TypeConstraint<int32>("Tperm") \
-                              .HostMemory("perm"),            \
+#define REGISTER(T)                                   \
+  REGISTER_KERNEL_BUILDER(Name("Transpose")           \
+                              .Device(DEVICE_SYCL)    \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
+                          TransposeSyclOp);           \
+  REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose")  \
+                              .Device(DEVICE_SYCL)    \
+                              .TypeConstraint<T>("T") \
+                              .HostMemory("perm"),    \
                           ConjugateTransposeSyclOp);
 TF_CALL_POD_TYPES(REGISTER);
 #undef REGISTER
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 3b352937c82..c551d9c3d05 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -317,6 +317,19 @@ class TransposeTest(test.TestCase):
         np.arange(0, 8).reshape([2, 4]).astype(np.float32),
         np.array([1, 0]).astype(np.int32))
 
+  def testPermType(self):
+    for perm_dtype in [np.int64, np.int32]:
+      x = np.arange(0, 8).reshape([2, 4]).astype(np.float32)
+      p = np.array([1, 0]).astype(perm_dtype)
+      np_ans = np.copy(x).transpose(p)
+      with self.test_session(use_gpu=True):
+        inx = ops.convert_to_tensor(x)
+        inp = constant_op.constant(p)
+        y = array_ops.transpose(inx, inp)
+        tf_ans = y.eval()
+        self.assertShapeEqual(np_ans, y)
+        self.assertAllEqual(np_ans, tf_ans)
+
   def testHalf(self):
     self._compare(np.arange(0, 21).reshape([3, 7]).astype(np.float16))
     self._compare(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float16))