Enhanced CPU SpaceToDepth to take qint8.
PiperOrigin-RevId: 249579694
This commit is contained in:
parent
5786814eda
commit
bf08872ded
@ -37,6 +37,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct RawType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RawType<qint8> {
|
||||
// spacetodepth_op_gpu.cu.cc does not instantiate SpaceToDepthOpFunctor for
|
||||
// int8, so we map qint8 to uint8. Instantiating int8 could slow down
|
||||
// compilation and the code generated is almost the same as for uint8.
|
||||
using type = uint8;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
@ -66,17 +81,17 @@ class SpaceToDepthOp : public OpKernel {
|
||||
const Tensor& input = context->input(0);
|
||||
const int dims = input.dims();
|
||||
|
||||
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
|
||||
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
||||
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
|
||||
errors::InvalidArgument(
|
||||
"qint8 should be used with data_format NCHW_VECT_C."));
|
||||
|
||||
constexpr int kVect = is_int8x4 ? 4 : 1;
|
||||
constexpr int kDims = is_int8x4 ? 5 : 4;
|
||||
OP_REQUIRES(context, kDims == dims,
|
||||
errors::InvalidArgument("Input rank should be: ", kDims,
|
||||
" instead of: ", dims));
|
||||
const bool is_int8x4 = (data_format_ == FORMAT_NCHW_VECT_C);
|
||||
const int vect = is_int8x4 ? 4 : 1;
|
||||
if (is_int8x4) {
|
||||
OP_REQUIRES(
|
||||
context, dims == 5,
|
||||
errors::InvalidArgument("Input rank should be 5 instead of ", dims));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
context, dims == 4,
|
||||
errors::InvalidArgument("Input rank should be 4 instead of ", dims));
|
||||
}
|
||||
|
||||
constexpr int kNumSpatialDims = 2;
|
||||
const int batch_size =
|
||||
@ -87,7 +102,7 @@ class SpaceToDepthOp : public OpKernel {
|
||||
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
|
||||
const int input_depth =
|
||||
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
|
||||
kVect;
|
||||
vect;
|
||||
|
||||
// Both width and height must be divisible by block_size.
|
||||
OP_REQUIRES(context,
|
||||
@ -111,32 +126,32 @@ class SpaceToDepthOp : public OpKernel {
|
||||
output_width, output_depth),
|
||||
&outputs_tensor));
|
||||
|
||||
auto Tinput = input.tensor<T, kDims>();
|
||||
auto Toutput = outputs_tensor->tensor<T, kDims>();
|
||||
|
||||
if (std::is_same<Device, GPUDevice>::value) {
|
||||
if (is_int8x4) {
|
||||
using RT = typename RawType<T>::type;
|
||||
if (data_format_ == FORMAT_NCHW_VECT_C) {
|
||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||
auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
|
||||
auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
|
||||
functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
|
||||
functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
|
||||
Toutput_v);
|
||||
return;
|
||||
} else if (data_format_ == FORMAT_NCHW) {
|
||||
functor::SpaceToDepthOpFunctor<GPUDevice, T, FORMAT_NCHW> functor;
|
||||
functor(context->eigen_device<GPUDevice>(), Tinput, block_size_,
|
||||
Toutput);
|
||||
return;
|
||||
CHECK((std::is_same<T, RT>::value));
|
||||
functor::SpaceToDepthOpFunctor<GPUDevice, RT, FORMAT_NCHW> functor;
|
||||
functor(context->eigen_device<GPUDevice>(), input.tensor<RT, 4>(),
|
||||
block_size_, outputs_tensor->tensor<RT, 4>());
|
||||
} else {
|
||||
CHECK((std::is_same<T, RT>::value));
|
||||
functor::SpaceToDepthOpFunctor<GPUDevice, RT, FORMAT_NHWC> functor;
|
||||
functor(context->eigen_device<GPUDevice>(), input.tensor<RT, 4>(),
|
||||
block_size_, outputs_tensor->tensor<RT, 4>());
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
|
||||
// (CPU && data_format_ != FORMAT_NHWC) in the constructor.
|
||||
|
||||
if (!is_int8x4) {
|
||||
} else {
|
||||
// NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
|
||||
// (CPU && data_format_ != FORMAT_NHWC) in the constructor.
|
||||
functor::SpaceToDepthOpFunctor<Device, T, FORMAT_NHWC> functor;
|
||||
functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
|
||||
functor(context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
block_size_, outputs_tensor->tensor<T, 4>());
|
||||
}
|
||||
};
|
||||
|
||||
@ -181,6 +196,7 @@ struct SpaceToDepthOpFunctor<CPUDevice, T, FORMAT_NHWC> {
|
||||
SpaceToDepthOp<CPUDevice, type>);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER);
|
||||
TF_CALL_qint8(REGISTER);
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
@ -235,6 +235,11 @@ class SpaceToDepthTest(test.TestCase):
|
||||
|
||||
def spaceToDepthUsingTranspose(self, tensor, block_size, data_format):
|
||||
block_size_sq = block_size * block_size
|
||||
|
||||
dtype = tensor.dtype
|
||||
if dtype == dtypes.qint8:
|
||||
tensor = array_ops.bitcast(tensor, dtypes.int8)
|
||||
|
||||
if data_format == "NHWC":
|
||||
b, ih, iw, ic = tensor.shape.as_list()
|
||||
assert ih % block_size == 0, (ih, block_size)
|
||||
@ -253,56 +258,87 @@ class SpaceToDepthTest(test.TestCase):
|
||||
[b, ic, oh, block_size, ow, block_size])
|
||||
tensor = array_ops.transpose(tensor, [0, 3, 5, 1, 2, 4])
|
||||
tensor = array_ops.reshape(tensor, [b, oc, oh, ow])
|
||||
|
||||
if dtype == dtypes.qint8:
|
||||
tensor = array_ops.bitcast(tensor, dtype)
|
||||
return tensor
|
||||
|
||||
def compareToTranspose(self, batch_size, out_height, out_width, in_channels,
|
||||
block_size, data_format, use_gpu):
|
||||
block_size, data_format, data_type, use_gpu):
|
||||
in_height = out_height * block_size
|
||||
in_width = out_width * block_size
|
||||
nhwc_input_shape = [batch_size, in_height, in_width, in_channels]
|
||||
nchw_input_shape = [batch_size, in_channels, in_height, in_width]
|
||||
total_size = np.prod(nhwc_input_shape)
|
||||
|
||||
if data_format == "NCHW_VECT_C":
|
||||
# Initialize the input tensor with qint8 values that circle -127..127.
|
||||
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
||||
t = constant_op.constant(x, shape=nhwc_input_shape, dtype=dtypes.float32)
|
||||
expected = self.spaceToDepthUsingTranspose(t, block_size, "NHWC")
|
||||
t = test_util.NHWCToNCHW_VECT_C(t)
|
||||
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
||||
t = array_ops.space_to_depth(t, block_size, data_format="NCHW_VECT_C")
|
||||
t = gen_array_ops.dequantize(t, -128, 127)
|
||||
actual = test_util.NCHW_VECT_CToNHWC(t)
|
||||
else:
|
||||
# Initialize the input tensor with ascending whole numbers as floats.
|
||||
x = [f * 1.0 for f in range(total_size)]
|
||||
shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
|
||||
t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
|
||||
expected = self.spaceToDepthUsingTranspose(t, block_size, data_format)
|
||||
actual = array_ops.space_to_depth(t, block_size, data_format=data_format)
|
||||
# Construct the input tensor in data_type and NHWC.
|
||||
# force_cpu is needed because quantize_v2 runs on only CPU.
|
||||
with test_util.force_cpu():
|
||||
if data_type == dtypes.qint8:
|
||||
# Initialize the input tensor with qint8 values that circle -127..127.
|
||||
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
||||
t = constant_op.constant(
|
||||
x, shape=nhwc_input_shape, dtype=dtypes.float32)
|
||||
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
||||
else:
|
||||
assert data_type == dtypes.float32
|
||||
# Initialize the input tensor with ascending whole numbers as floats.
|
||||
x = [f * 1.0 for f in range(total_size)]
|
||||
shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
|
||||
t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
|
||||
|
||||
with test_util.device(use_gpu):
|
||||
if data_format == "NCHW_VECT_C":
|
||||
assert data_type == dtypes.qint8
|
||||
|
||||
# Convert to int8, then NHWCToNCHW_VECT_C, and then back to qint8.
|
||||
actual = array_ops.bitcast(t, dtypes.int8)
|
||||
actual = test_util.NHWCToNCHW_VECT_C(actual)
|
||||
actual = array_ops.bitcast(actual, dtypes.qint8)
|
||||
actual = array_ops.space_to_depth(
|
||||
actual, block_size, data_format=data_format)
|
||||
actual = array_ops.bitcast(actual, dtypes.int8)
|
||||
actual = test_util.NCHW_VECT_CToNHWC(actual)
|
||||
actual = array_ops.bitcast(actual, dtypes.qint8)
|
||||
|
||||
expected = array_ops.bitcast(t, dtypes.int8)
|
||||
expected = math_ops.cast(expected, dtypes.float32)
|
||||
expected = self.spaceToDepthUsingTranspose(expected, block_size, "NHWC")
|
||||
expected = math_ops.cast(expected, dtypes.int8)
|
||||
expected = array_ops.bitcast(expected, dtypes.qint8)
|
||||
else:
|
||||
# Initialize the input tensor with ascending whole numbers as floats.
|
||||
actual = array_ops.space_to_depth(
|
||||
t, block_size, data_format=data_format)
|
||||
expected = self.spaceToDepthUsingTranspose(t, block_size, data_format)
|
||||
|
||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||
|
||||
# TODO(jingyue): figure out why this test failed in eager mode.
|
||||
@test_util.run_deprecated_v1
|
||||
def testAgainstTranspose(self):
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", False)
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.float32, False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", dtypes.float32, False)
|
||||
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.qint8, False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.qint8, False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", dtypes.qint8, False)
|
||||
|
||||
if not test.is_gpu_available():
|
||||
tf_logging.info("skipping gpu tests since gpu not available")
|
||||
return
|
||||
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", True)
|
||||
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", True)
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", True)
|
||||
self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", True)
|
||||
self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", True)
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, True)
|
||||
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", dtypes.float32, True)
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", dtypes.float32, True)
|
||||
self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", dtypes.float32, True)
|
||||
self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", dtypes.float32, True)
|
||||
|
||||
self.compareToTranspose(3, 2, 3, 4, 2, "NCHW_VECT_C", True)
|
||||
self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", True)
|
||||
self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", True)
|
||||
self.compareToTranspose(3, 2, 3, 4, 2, "NCHW_VECT_C", dtypes.qint8, True)
|
||||
self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", dtypes.qint8, True)
|
||||
self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", dtypes.qint8, True)
|
||||
|
||||
|
||||
class SpaceToDepthGradientTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user