Add int64 Tperm type support for Transpose
(#13909)
* Add int64 Tperm type support for `Transpose` This fix adds int64 Tperm support for `Transpose`. In `array_ops.cc`, `Transpose` and `ConjugateTranspose` have been specified as accepting int32 and int64 perm types. However, only int32 kernels has been registered. This fix adds the int64 perm support by removing the constraint on Tperm, resolve the type at runtime, and copying the data type accordingly to correctly handle the int64/int32 types. Additional tests have been added as well. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 of perm in Transpose. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add namespace to hide PermutationHelper Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Enable use_gpu=True for perm type test. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * extra // namespace annotation * Adding a comment about int32 casting that should be safe. Permutations only contain values that refer to dimensions, and the maximum number of dimensions we have is 254, so an int32 is always safe here.
This commit is contained in:
parent
ac0004e711
commit
9b9cbbe2a6
tensorflow
@ -91,6 +91,26 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
|
||||
InvertPermutationOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace {
|
||||
template <typename Tperm>
|
||||
Status PermutationHelper(const Tensor& perm, const int dims,
|
||||
std::vector<int32>* permutation) {
|
||||
auto Vperm = perm.vec<Tperm>();
|
||||
if (dims != Vperm.size()) {
|
||||
return errors::InvalidArgument("transpose expects a vector of size ", dims,
|
||||
". But input(1) is a vector of size ",
|
||||
Vperm.size());
|
||||
}
|
||||
// using volatile instead of SubtleMustCopy here so that the
|
||||
// asynchrony boundary is permutation.
|
||||
const volatile Tperm* perm_begin =
|
||||
reinterpret_cast<const volatile Tperm*>(Vperm.data());
|
||||
*permutation = std::vector<int32>(perm_begin, perm_begin + dims);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
|
||||
// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
|
||||
// shuffles the dimensions of the input tensor according to permutation.
|
||||
@ -113,17 +133,16 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
|
||||
errors::InvalidArgument("perm must be a vector, not ",
|
||||
perm.shape().DebugString()));
|
||||
auto Vperm = perm.vec<int32>();
|
||||
|
||||
// Although Tperm may be an int64 type, an int32 is sufficient to hold
|
||||
// dimension range values, so the narrowing here should be safe.
|
||||
std::vector<int32> permutation;
|
||||
const int dims = input.dims();
|
||||
OP_REQUIRES(ctx, dims == Vperm.size(),
|
||||
errors::InvalidArgument(
|
||||
"transpose expects a vector of size ", input.dims(),
|
||||
". But input(1) is a vector of size ", Vperm.size()));
|
||||
// using volatile instead of SubtleMustCopy here so that the
|
||||
// asynchrony boundary is permutation.
|
||||
const volatile int32* perm_begin =
|
||||
reinterpret_cast<const volatile int32*>(Vperm.data());
|
||||
const std::vector<int32> permutation(perm_begin, perm_begin + dims);
|
||||
if (perm.dtype() == DT_INT32) {
|
||||
OP_REQUIRES_OK(ctx, PermutationHelper<int32>(perm, dims, &permutation));
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, PermutationHelper<int64>(perm, dims, &permutation));
|
||||
}
|
||||
TensorShape shape;
|
||||
|
||||
// Check whether permutation is a permutation of integers of [0 .. dims).
|
||||
@ -142,10 +161,9 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, bits[i],
|
||||
errors::InvalidArgument(i, " is missing from {",
|
||||
str_util::Join(permutation, ","), "}."));
|
||||
OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
|
||||
i, " is missing from {",
|
||||
str_util::Join(permutation, ","), "}."));
|
||||
}
|
||||
|
||||
// 0-D, 1-D, and identity transposes do nothing.
|
||||
@ -185,18 +203,16 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
MklTransposeCpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
MklTransposeCpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
MklConjugateTransposeCpuOp);
|
||||
TF_CALL_ALL_TYPES(REGISTER);
|
||||
REGISTER(bfloat16);
|
||||
@ -204,18 +220,16 @@ REGISTER(bfloat16);
|
||||
|
||||
#else // INTEL_MKL
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeCpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeCpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
ConjugateTransposeCpuOp);
|
||||
TF_CALL_ALL_TYPES(REGISTER)
|
||||
REGISTER(bfloat16);
|
||||
@ -238,18 +252,16 @@ Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
|
||||
perm, out);
|
||||
}
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeGpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeGpuOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
ConjugateTransposeGpuOp);
|
||||
TF_CALL_POD_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
@ -270,18 +282,16 @@ Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx,
|
||||
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<SYCLDevice>(), in,
|
||||
perm, out);
|
||||
}
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeSyclOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tperm") \
|
||||
.HostMemory("perm"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Transpose") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
TransposeSyclOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("perm"), \
|
||||
ConjugateTransposeSyclOp);
|
||||
TF_CALL_POD_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
@ -317,6 +317,19 @@ class TransposeTest(test.TestCase):
|
||||
np.arange(0, 8).reshape([2, 4]).astype(np.float32),
|
||||
np.array([1, 0]).astype(np.int32))
|
||||
|
||||
def testPermType(self):
|
||||
for perm_dtype in [np.int64, np.int32]:
|
||||
x = np.arange(0, 8).reshape([2, 4]).astype(np.float32)
|
||||
p = np.array([1, 0]).astype(perm_dtype)
|
||||
np_ans = np.copy(x).transpose(p)
|
||||
with self.test_session(use_gpu=True):
|
||||
inx = ops.convert_to_tensor(x)
|
||||
inp = constant_op.constant(p)
|
||||
y = array_ops.transpose(inx, inp)
|
||||
tf_ans = y.eval()
|
||||
self.assertShapeEqual(np_ans, y)
|
||||
self.assertAllEqual(np_ans, tf_ans)
|
||||
|
||||
def testHalf(self):
|
||||
self._compare(np.arange(0, 21).reshape([3, 7]).astype(np.float16))
|
||||
self._compare(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float16))
|
||||
|
Loading…
Reference in New Issue
Block a user