diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index e151b38d90a..20f0edf309a 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -91,6 +91,26 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") InvertPermutationOp); #endif // TENSORFLOW_USE_SYCL +namespace { +template +Status PermutationHelper(const Tensor& perm, const int dims, + std::vector* permutation) { + auto Vperm = perm.vec(); + if (dims != Vperm.size()) { + return errors::InvalidArgument("transpose expects a vector of size ", dims, + ". But input(1) is a vector of size ", + Vperm.size()); + } + // using volatile instead of SubtleMustCopy here so that the + // asynchrony boundary is permutation. + const volatile Tperm* perm_begin = + reinterpret_cast(Vperm.data()); + *permutation = std::vector(perm_begin, perm_begin + dims); + + return Status::OK(); +} +} // namespace + // output = TransposeOp(T input, T perm) takes a tensor // of type T and rank N, and a permutation of 0, 1, ..., N-1. It // shuffles the dimensions of the input tensor according to permutation. @@ -113,17 +133,16 @@ void TransposeOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()), errors::InvalidArgument("perm must be a vector, not ", perm.shape().DebugString())); - auto Vperm = perm.vec(); + + // Although Tperm may be an int64 type, an int32 is sufficient to hold + // dimension range values, so the narrowing here should be safe. + std::vector permutation; const int dims = input.dims(); - OP_REQUIRES(ctx, dims == Vperm.size(), - errors::InvalidArgument( - "transpose expects a vector of size ", input.dims(), - ". But input(1) is a vector of size ", Vperm.size())); - // using volatile instead of SubtleMustCopy here so that the - // asynchrony boundary is permutation. - const volatile int32* perm_begin = - reinterpret_cast(Vperm.data()); - const std::vector permutation(perm_begin, perm_begin + dims); + if (perm.dtype() == DT_INT32) { + OP_REQUIRES_OK(ctx, PermutationHelper(perm, dims, &permutation)); + } else { + OP_REQUIRES_OK(ctx, PermutationHelper(perm, dims, &permutation)); + } TensorShape shape; // Check whether permutation is a permutation of integers of [0 .. dims). @@ -142,10 +161,9 @@ void TransposeOp::Compute(OpKernelContext* ctx) { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES( - ctx, bits[i], - errors::InvalidArgument(i, " is missing from {", - str_util::Join(permutation, ","), "}.")); + OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( + i, " is missing from {", + str_util::Join(permutation, ","), "}.")); } // 0-D, 1-D, and identity transposes do nothing. @@ -185,18 +203,16 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, } #ifdef INTEL_MKL -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ - MklTransposeCpuOp); \ - REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ + MklTransposeCpuOp); \ + REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ MklConjugateTransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER); REGISTER(bfloat16); @@ -204,18 +220,16 @@ REGISTER(bfloat16); #else // INTEL_MKL -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ - TransposeCpuOp); \ - REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ + TransposeCpuOp); \ + REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ ConjugateTransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER) REGISTER(bfloat16); @@ -238,18 +252,16 @@ Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx, perm, out); } -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ - TransposeGpuOp); \ - REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ + TransposeGpuOp); \ + REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ ConjugateTransposeGpuOp); TF_CALL_POD_TYPES(REGISTER); #undef REGISTER @@ -270,18 +282,16 @@ Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx, return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), in, perm, out); } -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ - TransposeSyclOp); \ - REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .TypeConstraint("Tperm") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ + TransposeSyclOp); \ + REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("perm"), \ ConjugateTransposeSyclOp); TF_CALL_POD_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 3b352937c82..c551d9c3d05 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -317,6 +317,19 @@ class TransposeTest(test.TestCase): np.arange(0, 8).reshape([2, 4]).astype(np.float32), np.array([1, 0]).astype(np.int32)) + def testPermType(self): + for perm_dtype in [np.int64, np.int32]: + x = np.arange(0, 8).reshape([2, 4]).astype(np.float32) + p = np.array([1, 0]).astype(perm_dtype) + np_ans = np.copy(x).transpose(p) + with self.test_session(use_gpu=True): + inx = ops.convert_to_tensor(x) + inp = constant_op.constant(p) + y = array_ops.transpose(inx, inp) + tf_ans = y.eval() + self.assertShapeEqual(np_ans, y) + self.assertAllEqual(np_ans, tf_ans) + def testHalf(self): self._compare(np.arange(0, 21).reshape([3, 7]).astype(np.float16)) self._compare(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float16))