Fix issues where int64 crops could not be passed to batch_to_space. (#13862)
* Fix issues where int64 crops could not be passed to batch_to_space. This fix tries to address the issue where int64 `crops` could not be passed to `batch_to_space` even though both int32 and int64 are specified as supported in the docs (tf.batch_to_space.__doc__) The reason is that BatchToSpace kernel puts a constraint of int32 to crops data types. This fix removed the constraint so that int64 `crops` could be supported. NOTE: Just removing the constraint should work and it is not necessary to add specification to the kernel class template, as `SubtleMustCopyFlat` called in the class already correctly handled both int32 and int64 cases. Besides, other data types (e.g., float or double) will not be passed to the kernel as they are guarded by the specification in `array_ops.cc`. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Also remove int64/int32 type constraints for SpaceToBatch kernels Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 crops of batch_to_space and space_to_batch Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix test failures. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
9c825d32c9
commit
c77090a0ae
tensorflow
@ -249,40 +249,34 @@ class BatchToSpaceOp : public OpKernel {
|
||||
Tensor block_shape_;
|
||||
};
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tblock_shape") \
|
||||
.TypeConstraint<int32>("Tcrops") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceNDOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("crops"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceNDOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tblock_shape") \
|
||||
.TypeConstraint<int32>("Tcrops") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceNDOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("crops"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceNDOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("crops"), \
|
||||
BatchToSpaceOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
|
@ -248,40 +248,34 @@ class SpaceToBatchOp : public OpKernel {
|
||||
Tensor block_shape_;
|
||||
};
|
||||
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tblock_shape") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchNDOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchNDOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tblock_shape") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchNDOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.HostMemory("paddings"), \
|
||||
#define REGISTER(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("block_shape") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchNDOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("paddings"), \
|
||||
SpaceToBatchOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -52,14 +53,15 @@ class BatchToSpaceDepthToSpace(test.TestCase, PythonOpImpl):
|
||||
def testDepthToSpaceTranspose(self):
|
||||
x = np.arange(20 * 5 * 8 * 7, dtype=np.float32).reshape([20, 5, 8, 7])
|
||||
block_size = 2
|
||||
crops = np.zeros((2, 2), dtype=np.int32)
|
||||
y1 = self.batch_to_space(x, crops, block_size=block_size)
|
||||
y2 = array_ops.transpose(
|
||||
array_ops.depth_to_space(
|
||||
array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
|
||||
[3, 1, 2, 0])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(y1.eval(), y2.eval())
|
||||
for crops_dtype in [dtypes.int64, dtypes.int32]:
|
||||
crops = array_ops.zeros((2, 2), dtype=crops_dtype)
|
||||
y1 = self.batch_to_space(x, crops, block_size=block_size)
|
||||
y2 = array_ops.transpose(
|
||||
array_ops.depth_to_space(
|
||||
array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
|
||||
[3, 1, 2, 0])
|
||||
with self.test_session():
|
||||
self.assertAllEqual(y1.eval(), y2.eval())
|
||||
|
||||
|
||||
class BatchToSpaceDepthToSpaceCpp(BatchToSpaceDepthToSpace, CppOpImpl):
|
||||
@ -287,9 +289,10 @@ class BatchToSpaceGradientCppTest(BatchToSpaceGradientTest, CppOpImpl):
|
||||
class BatchToSpaceNDGradientTest(test.TestCase):
|
||||
|
||||
# Check the gradients.
|
||||
def _checkGrad(self, x, block_shape, crops):
|
||||
def _checkGrad(self, x, block_shape, crops, crops_dtype):
|
||||
block_shape = np.array(block_shape)
|
||||
crops = np.array(crops).reshape((len(block_shape), 2))
|
||||
crops = constant_op.constant(
|
||||
np.array(crops).reshape((len(block_shape), 2)), crops_dtype)
|
||||
with self.test_session():
|
||||
tf_x = ops.convert_to_tensor(x)
|
||||
tf_y = array_ops.batch_to_space_nd(tf_x, block_shape, crops)
|
||||
@ -304,23 +307,26 @@ class BatchToSpaceNDGradientTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=epsilon)
|
||||
|
||||
def _compare(self, input_shape, block_shape, crops):
|
||||
def _compare(self, input_shape, block_shape, crops, crops_dtype):
|
||||
input_shape = list(input_shape)
|
||||
input_shape[0] *= np.prod(block_shape)
|
||||
x = np.random.normal(
|
||||
0, 1, np.prod(input_shape)).astype(np.float32).reshape(input_shape)
|
||||
self._checkGrad(x, block_shape, crops)
|
||||
self._checkGrad(x, block_shape, crops, crops_dtype)
|
||||
|
||||
# Don't use very large numbers as dimensions here as the result is tensor
|
||||
# with cartesian product of the dimensions.
|
||||
def testSmall(self):
|
||||
self._compare([1, 2, 3, 5], [2, 2], [[0, 0], [0, 0]])
|
||||
for dtype in [dtypes.int64, dtypes.int32]:
|
||||
self._compare([1, 2, 3, 5], [2, 2], [[0, 0], [0, 0]], dtype)
|
||||
|
||||
def testSmall2(self):
|
||||
self._compare([2, 4, 3, 2], [2, 2], [[0, 0], [0, 0]])
|
||||
for dtype in [dtypes.int64, dtypes.int32]:
|
||||
self._compare([2, 4, 3, 2], [2, 2], [[0, 0], [0, 0]], dtype)
|
||||
|
||||
def testSmallCrop1x1(self):
|
||||
self._compare([1, 2, 3, 5], [2, 2], [[1, 1], [1, 1]])
|
||||
for dtype in [dtypes.int64, dtypes.int32]:
|
||||
self._compare([1, 2, 3, 5], [2, 2], [[1, 1], [1, 1]], dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user