Enhanced CPU SpaceToDepth to take qint8.
PiperOrigin-RevId: 249579694
This commit is contained in:
parent
5786814eda
commit
bf08872ded
@ -37,6 +37,21 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
struct RawType {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct RawType<qint8> {
|
||||||
|
// spacetodepth_op_gpu.cu.cc does not instantiate SpaceToDepthOpFunctor for
|
||||||
|
// int8, so we map qint8 to uint8. Instantiating int8 could slow down
|
||||||
|
// compilation and the code generated is almost the same as for uint8.
|
||||||
|
using type = uint8;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
@ -66,17 +81,17 @@ class SpaceToDepthOp : public OpKernel {
|
|||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const int dims = input.dims();
|
const int dims = input.dims();
|
||||||
|
|
||||||
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
|
const bool is_int8x4 = (data_format_ == FORMAT_NCHW_VECT_C);
|
||||||
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
const int vect = is_int8x4 ? 4 : 1;
|
||||||
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
|
if (is_int8x4) {
|
||||||
errors::InvalidArgument(
|
OP_REQUIRES(
|
||||||
"qint8 should be used with data_format NCHW_VECT_C."));
|
context, dims == 5,
|
||||||
|
errors::InvalidArgument("Input rank should be 5 instead of ", dims));
|
||||||
constexpr int kVect = is_int8x4 ? 4 : 1;
|
} else {
|
||||||
constexpr int kDims = is_int8x4 ? 5 : 4;
|
OP_REQUIRES(
|
||||||
OP_REQUIRES(context, kDims == dims,
|
context, dims == 4,
|
||||||
errors::InvalidArgument("Input rank should be: ", kDims,
|
errors::InvalidArgument("Input rank should be 4 instead of ", dims));
|
||||||
" instead of: ", dims));
|
}
|
||||||
|
|
||||||
constexpr int kNumSpatialDims = 2;
|
constexpr int kNumSpatialDims = 2;
|
||||||
const int batch_size =
|
const int batch_size =
|
||||||
@ -87,7 +102,7 @@ class SpaceToDepthOp : public OpKernel {
|
|||||||
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
|
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
|
||||||
const int input_depth =
|
const int input_depth =
|
||||||
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
|
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
|
||||||
kVect;
|
vect;
|
||||||
|
|
||||||
// Both width and height must be divisible by block_size.
|
// Both width and height must be divisible by block_size.
|
||||||
OP_REQUIRES(context,
|
OP_REQUIRES(context,
|
||||||
@ -111,32 +126,32 @@ class SpaceToDepthOp : public OpKernel {
|
|||||||
output_width, output_depth),
|
output_width, output_depth),
|
||||||
&outputs_tensor));
|
&outputs_tensor));
|
||||||
|
|
||||||
auto Tinput = input.tensor<T, kDims>();
|
|
||||||
auto Toutput = outputs_tensor->tensor<T, kDims>();
|
|
||||||
|
|
||||||
if (std::is_same<Device, GPUDevice>::value) {
|
if (std::is_same<Device, GPUDevice>::value) {
|
||||||
if (is_int8x4) {
|
using RT = typename RawType<T>::type;
|
||||||
|
if (data_format_ == FORMAT_NCHW_VECT_C) {
|
||||||
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
|
||||||
auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
|
auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
|
||||||
auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
|
auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
|
||||||
functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
|
functor::SpaceToDepthOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
|
||||||
functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
|
functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
|
||||||
Toutput_v);
|
Toutput_v);
|
||||||
return;
|
|
||||||
} else if (data_format_ == FORMAT_NCHW) {
|
} else if (data_format_ == FORMAT_NCHW) {
|
||||||
functor::SpaceToDepthOpFunctor<GPUDevice, T, FORMAT_NCHW> functor;
|
CHECK((std::is_same<T, RT>::value));
|
||||||
functor(context->eigen_device<GPUDevice>(), Tinput, block_size_,
|
functor::SpaceToDepthOpFunctor<GPUDevice, RT, FORMAT_NCHW> functor;
|
||||||
Toutput);
|
functor(context->eigen_device<GPUDevice>(), input.tensor<RT, 4>(),
|
||||||
return;
|
block_size_, outputs_tensor->tensor<RT, 4>());
|
||||||
|
} else {
|
||||||
|
CHECK((std::is_same<T, RT>::value));
|
||||||
|
functor::SpaceToDepthOpFunctor<GPUDevice, RT, FORMAT_NHWC> functor;
|
||||||
|
functor(context->eigen_device<GPUDevice>(), input.tensor<RT, 4>(),
|
||||||
|
block_size_, outputs_tensor->tensor<RT, 4>());
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
// NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
|
||||||
// NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
|
// (CPU && data_format_ != FORMAT_NHWC) in the constructor.
|
||||||
// (CPU && data_format_ != FORMAT_NHWC) in the constructor.
|
|
||||||
|
|
||||||
if (!is_int8x4) {
|
|
||||||
functor::SpaceToDepthOpFunctor<Device, T, FORMAT_NHWC> functor;
|
functor::SpaceToDepthOpFunctor<Device, T, FORMAT_NHWC> functor;
|
||||||
functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
|
functor(context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||||
|
block_size_, outputs_tensor->tensor<T, 4>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -181,6 +196,7 @@ struct SpaceToDepthOpFunctor<CPUDevice, T, FORMAT_NHWC> {
|
|||||||
SpaceToDepthOp<CPUDevice, type>);
|
SpaceToDepthOp<CPUDevice, type>);
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER);
|
TF_CALL_ALL_TYPES(REGISTER);
|
||||||
|
TF_CALL_qint8(REGISTER);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
@ -235,6 +235,11 @@ class SpaceToDepthTest(test.TestCase):
|
|||||||
|
|
||||||
def spaceToDepthUsingTranspose(self, tensor, block_size, data_format):
|
def spaceToDepthUsingTranspose(self, tensor, block_size, data_format):
|
||||||
block_size_sq = block_size * block_size
|
block_size_sq = block_size * block_size
|
||||||
|
|
||||||
|
dtype = tensor.dtype
|
||||||
|
if dtype == dtypes.qint8:
|
||||||
|
tensor = array_ops.bitcast(tensor, dtypes.int8)
|
||||||
|
|
||||||
if data_format == "NHWC":
|
if data_format == "NHWC":
|
||||||
b, ih, iw, ic = tensor.shape.as_list()
|
b, ih, iw, ic = tensor.shape.as_list()
|
||||||
assert ih % block_size == 0, (ih, block_size)
|
assert ih % block_size == 0, (ih, block_size)
|
||||||
@ -253,56 +258,87 @@ class SpaceToDepthTest(test.TestCase):
|
|||||||
[b, ic, oh, block_size, ow, block_size])
|
[b, ic, oh, block_size, ow, block_size])
|
||||||
tensor = array_ops.transpose(tensor, [0, 3, 5, 1, 2, 4])
|
tensor = array_ops.transpose(tensor, [0, 3, 5, 1, 2, 4])
|
||||||
tensor = array_ops.reshape(tensor, [b, oc, oh, ow])
|
tensor = array_ops.reshape(tensor, [b, oc, oh, ow])
|
||||||
|
|
||||||
|
if dtype == dtypes.qint8:
|
||||||
|
tensor = array_ops.bitcast(tensor, dtype)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def compareToTranspose(self, batch_size, out_height, out_width, in_channels,
|
def compareToTranspose(self, batch_size, out_height, out_width, in_channels,
|
||||||
block_size, data_format, use_gpu):
|
block_size, data_format, data_type, use_gpu):
|
||||||
in_height = out_height * block_size
|
in_height = out_height * block_size
|
||||||
in_width = out_width * block_size
|
in_width = out_width * block_size
|
||||||
nhwc_input_shape = [batch_size, in_height, in_width, in_channels]
|
nhwc_input_shape = [batch_size, in_height, in_width, in_channels]
|
||||||
nchw_input_shape = [batch_size, in_channels, in_height, in_width]
|
nchw_input_shape = [batch_size, in_channels, in_height, in_width]
|
||||||
total_size = np.prod(nhwc_input_shape)
|
total_size = np.prod(nhwc_input_shape)
|
||||||
|
|
||||||
if data_format == "NCHW_VECT_C":
|
# Construct the input tensor in data_type and NHWC.
|
||||||
# Initialize the input tensor with qint8 values that circle -127..127.
|
# force_cpu is needed because quantize_v2 runs on only CPU.
|
||||||
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
with test_util.force_cpu():
|
||||||
t = constant_op.constant(x, shape=nhwc_input_shape, dtype=dtypes.float32)
|
if data_type == dtypes.qint8:
|
||||||
expected = self.spaceToDepthUsingTranspose(t, block_size, "NHWC")
|
# Initialize the input tensor with qint8 values that circle -127..127.
|
||||||
t = test_util.NHWCToNCHW_VECT_C(t)
|
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
||||||
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
t = constant_op.constant(
|
||||||
t = array_ops.space_to_depth(t, block_size, data_format="NCHW_VECT_C")
|
x, shape=nhwc_input_shape, dtype=dtypes.float32)
|
||||||
t = gen_array_ops.dequantize(t, -128, 127)
|
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
||||||
actual = test_util.NCHW_VECT_CToNHWC(t)
|
else:
|
||||||
else:
|
assert data_type == dtypes.float32
|
||||||
# Initialize the input tensor with ascending whole numbers as floats.
|
# Initialize the input tensor with ascending whole numbers as floats.
|
||||||
x = [f * 1.0 for f in range(total_size)]
|
x = [f * 1.0 for f in range(total_size)]
|
||||||
shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
|
shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
|
||||||
t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
|
t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
|
||||||
expected = self.spaceToDepthUsingTranspose(t, block_size, data_format)
|
|
||||||
actual = array_ops.space_to_depth(t, block_size, data_format=data_format)
|
with test_util.device(use_gpu):
|
||||||
|
if data_format == "NCHW_VECT_C":
|
||||||
|
assert data_type == dtypes.qint8
|
||||||
|
|
||||||
|
# Convert to int8, then NHWCToNCHW_VECT_C, and then back to qint8.
|
||||||
|
actual = array_ops.bitcast(t, dtypes.int8)
|
||||||
|
actual = test_util.NHWCToNCHW_VECT_C(actual)
|
||||||
|
actual = array_ops.bitcast(actual, dtypes.qint8)
|
||||||
|
actual = array_ops.space_to_depth(
|
||||||
|
actual, block_size, data_format=data_format)
|
||||||
|
actual = array_ops.bitcast(actual, dtypes.int8)
|
||||||
|
actual = test_util.NCHW_VECT_CToNHWC(actual)
|
||||||
|
actual = array_ops.bitcast(actual, dtypes.qint8)
|
||||||
|
|
||||||
|
expected = array_ops.bitcast(t, dtypes.int8)
|
||||||
|
expected = math_ops.cast(expected, dtypes.float32)
|
||||||
|
expected = self.spaceToDepthUsingTranspose(expected, block_size, "NHWC")
|
||||||
|
expected = math_ops.cast(expected, dtypes.int8)
|
||||||
|
expected = array_ops.bitcast(expected, dtypes.qint8)
|
||||||
|
else:
|
||||||
|
# Initialize the input tensor with ascending whole numbers as floats.
|
||||||
|
actual = array_ops.space_to_depth(
|
||||||
|
t, block_size, data_format=data_format)
|
||||||
|
expected = self.spaceToDepthUsingTranspose(t, block_size, data_format)
|
||||||
|
|
||||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
|
||||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||||
|
|
||||||
|
# TODO(jingyue): figure out why this test failed in eager mode.
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testAgainstTranspose(self):
|
def testAgainstTranspose(self):
|
||||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, False)
|
||||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
|
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.float32, False)
|
||||||
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", False)
|
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", dtypes.float32, False)
|
||||||
|
|
||||||
|
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.qint8, False)
|
||||||
|
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.qint8, False)
|
||||||
|
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", dtypes.qint8, False)
|
||||||
|
|
||||||
if not test.is_gpu_available():
|
if not test.is_gpu_available():
|
||||||
tf_logging.info("skipping gpu tests since gpu not available")
|
tf_logging.info("skipping gpu tests since gpu not available")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", True)
|
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, True)
|
||||||
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", True)
|
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", dtypes.float32, True)
|
||||||
self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", True)
|
self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", dtypes.float32, True)
|
||||||
self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", True)
|
self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", dtypes.float32, True)
|
||||||
self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", True)
|
self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", dtypes.float32, True)
|
||||||
|
|
||||||
self.compareToTranspose(3, 2, 3, 4, 2, "NCHW_VECT_C", True)
|
self.compareToTranspose(3, 2, 3, 4, 2, "NCHW_VECT_C", dtypes.qint8, True)
|
||||||
self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", True)
|
self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", dtypes.qint8, True)
|
||||||
self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", True)
|
self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", dtypes.qint8, True)
|
||||||
|
|
||||||
|
|
||||||
class SpaceToDepthGradientTest(test.TestCase):
|
class SpaceToDepthGradientTest(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user