Implement NCHW_VECT_C support for tf.depth_to_space on GPU.

PiperOrigin-RevId: 171904046
This commit is contained in:
A. Unique TensorFlower 2017-10-11 18:15:32 -07:00 committed by TensorFlower Gardener
parent c69b959799
commit 9b26ed77dc
6 changed files with 112 additions and 62 deletions

View File

@ -49,34 +49,33 @@ class DepthToSpaceOp : public OpKernel {
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
OP_REQUIRES(context, block_size_ > 1,
errors::InvalidArgument("Block size should be > 1, but was: ",
block_size_));
if (std::is_same<Device, CPUDevice>::value) {
OP_REQUIRES(
context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument(
"Only NHWC data_format supported on CPU. Got ", data_format_str));
}
// TODO(pauldonnelly): Implement NCHW_VECT_C kernel for the GPU.
OP_REQUIRES(
context, data_format_ != FORMAT_NCHW_VECT_C,
errors::InvalidArgument("NHWC_VECT_C kernel not yet implemented."));
OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
OP_REQUIRES(
context, block_size_ > 1,
errors::InvalidArgument("Block size should be > 1: ", block_size_));
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
// Check on the input dimensions first.
// The input is presumed to be [batch, height, width, depth]
const int dims = input.dims();
constexpr int kRequiredDims = 4;
OP_REQUIRES(context, kRequiredDims == dims,
errors::InvalidArgument("Input rank should be: ", kRequiredDims,
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
errors::InvalidArgument(
"qint8 should be used with data_format NCHW_VECT_C."));
constexpr int kVect = is_int8x4 ? 4 : 1;
constexpr int kDims = is_int8x4 ? 5 : 4;
OP_REQUIRES(context, kDims == dims,
errors::InvalidArgument("Input rank should be: ", kDims,
" instead of: ", dims));
constexpr int kNumSpatialDims = 2;
@ -87,7 +86,8 @@ class DepthToSpaceOp : public OpKernel {
const int input_width =
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
const int input_depth =
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C'));
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
kVect;
const int block_size_sq = block_size_ * block_size_;
@ -109,13 +109,30 @@ class DepthToSpaceOp : public OpKernel {
ShapeFromFormat(data_format_, batch_size, output_height,
output_width, output_depth),
&outputs_tensor));
auto Tinput = input.tensor<T, kRequiredDims>();
auto Toutput = outputs_tensor->tensor<T, kRequiredDims>();
auto Tinput = input.tensor<T, kDims>();
auto Toutput = outputs_tensor->tensor<T, kDims>();
if (std::is_same<Device, GPUDevice>::value && data_format_ == FORMAT_NCHW) {
functor::DepthToSpaceOpFunctor<Device, T, FORMAT_NCHW> functor;
functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
} else {
if (std::is_same<Device, GPUDevice>::value) {
if (is_int8x4) {
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
Toutput_v);
return;
} else if (data_format_ == FORMAT_NCHW) {
functor::DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> functor;
functor(context->eigen_device<GPUDevice>(), Tinput, block_size_,
Toutput);
return;
}
}
// NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
// (CPU && data_format_ != FORMAT_NHWC) in the constructor.
if (!is_int8x4) {
functor::DepthToSpaceOpFunctor<Device, T, FORMAT_NHWC> functor;
functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
}
@ -170,6 +187,9 @@ TF_CALL_ALL_TYPES(REGISTER);
REGISTER_KERNEL_BUILDER(
Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<float>("T"),
DepthToSpaceOp<GPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
DepthToSpaceOp<GPUDevice, qint8>);
#endif // GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -44,6 +44,10 @@ template <typename Device, typename T, TensorFormat data_format>
struct DepthToSpaceOpFunctor {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
int block_size, typename TTypes<T, 4>::Tensor output);
// This 5-D version is to support NCHW_VECT_C.
void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
int block_size, typename TTypes<T, 5>::Tensor output);
};
} // namespace functor

View File

@ -124,6 +124,10 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NHWC> {
input_height, input_width, input_depth, output_height, output_width,
output_depth, output.data());
}
void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
int block_size, typename TTypes<T, 5>::Tensor output) {
LOG(FATAL) << "5-D tensors should not be used with NHWC format";
}
};
template <typename T>
@ -143,6 +147,10 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> {
config.virtual_thread_count, input.data(), block_size, input_width,
output_depth * input_height, output.data());
}
void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
int block_size, typename TTypes<T, 5>::Tensor output) {
LOG(FATAL) << "5-D tensors should not be used with NCHW format";
}
};
} // end namespace functor
@ -150,6 +158,9 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> {
template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NCHW>;
template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NHWC>;
// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
} // end namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -4244,13 +4244,16 @@ REGISTER_OP("DepthToSpace")
TensorFormat data_format;
FormatFromString(data_format_str, &data_format);
constexpr int num_spatial_dims = 2;
const int dims =
GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
constexpr int num_spatial_dims = 2;
DimensionHandle batch_size =
c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
DimensionHandle input_height =

View File

@ -1330,7 +1330,7 @@ cuda_py_test(
cuda_py_test(
name = "depthtospace_op_test",
size = "small",
size = "medium",
srcs = ["depthtospace_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@ -1898,7 +1898,7 @@ cuda_py_test(
cuda_py_test(
name = "spacetodepth_op_test",
size = "small",
size = "medium",
srcs = ["spacetodepth_op_test.py"],
additional_deps = [
"//third_party/py/numpy",

View File

@ -26,9 +26,11 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
class DepthToSpaceTest(test.TestCase):
@ -201,7 +203,8 @@ class DepthToSpaceTest(test.TestCase):
_ = array_ops.space_to_depth(x_np, block_size)
def testUnknownShape(self):
t = array_ops.depth_to_space(array_ops.placeholder(dtypes.float32), block_size=4)
t = array_ops.depth_to_space(
array_ops.placeholder(dtypes.float32), block_size=4)
self.assertEqual(4, t.get_shape().ndims)
def depthToSpaceUsingTranspose(self, tensor, block_size, data_format):
@ -224,49 +227,58 @@ class DepthToSpaceTest(test.TestCase):
tensor = array_ops.reshape(tensor, [b, oc, oh, ow])
return tensor
def compareToTranspose(self, data_format, batch_size, in_height, in_width,
out_channels, block_size, use_gpu):
if use_gpu and not test.is_gpu_available():
print("gpu not available")
return
dtype = dtypes.float32
def compareToTranspose(self, batch_size, in_height, in_width, out_channels,
block_size, data_format, use_gpu):
in_channels = out_channels * block_size * block_size
nhwc_input_shape = [batch_size, in_height, in_width, in_channels]
nchw_input_shape = [batch_size, in_channels, in_height, in_width]
total_size = np.prod(nhwc_input_shape)
if data_format == "NHWC":
input_shape = [batch_size, in_height, in_width, in_channels]
elif data_format == "NCHW":
input_shape = [batch_size, in_channels, in_height, in_width]
if data_format == "NCHW_VECT_C":
# Initialize the input tensor with qint8 values that circle -127..127.
x = [((f + 128) % 255) - 127 for f in range(total_size)]
t = constant_op.constant(x, shape=nhwc_input_shape, dtype=dtypes.float32)
expected = self.depthToSpaceUsingTranspose(t, block_size, "NHWC")
t = test_util.NHWCToNCHW_VECT_C(t)
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
t = array_ops.depth_to_space(t, block_size, data_format="NCHW_VECT_C")
t = gen_array_ops.dequantize(t, -128, 127)
actual = test_util.NCHW_VECT_CToNHWC(t)
else:
assert False, "unsupported format"
# Initialize the input tensor with ascending whole numbers.
total_size = 1
for dim_size in input_shape:
total_size *= dim_size
x = [f for f in range(total_size)]
inputs = constant_op.constant(x, shape=input_shape, dtype=dtype)
expected = self.depthToSpaceUsingTranspose(inputs, block_size, data_format)
actual = array_ops.depth_to_space(
inputs, block_size, data_format=data_format)
# Initialize the input tensor with ascending whole numbers as floats.
x = [f * 1.0 for f in range(total_size)]
shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
expected = self.depthToSpaceUsingTranspose(t, block_size, data_format)
actual = array_ops.depth_to_space(t, block_size, data_format=data_format)
with self.test_session(use_gpu=use_gpu) as sess:
actual_vals, expected_vals = sess.run([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
def testAgainstTranspose(self):
self.compareToTranspose("NHWC", 3, 2, 3, 1, 2, False)
self.compareToTranspose("NHWC", 3, 2, 3, 2, 2, False)
self.compareToTranspose("NHWC", 3, 2, 3, 1, 2, True)
self.compareToTranspose("NHWC", 3, 2, 3, 2, 2, True)
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", False)
self.compareToTranspose("NCHW", 3, 2, 3, 1, 2, True)
self.compareToTranspose("NCHW", 3, 2, 3, 2, 2, True)
self.compareToTranspose("NCHW", 3, 2, 3, 1, 3, True)
self.compareToTranspose("NCHW", 3, 2, 3, 2, 3, True)
self.compareToTranspose("NCHW", 5, 7, 11, 3, 2, True)
self.compareToTranspose("NCHW", 3, 200, 300, 32, 2, True)
if not test.is_gpu_available():
tf_logging.info("skipping gpu tests since gpu not available")
return
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", True)
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", True)
self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", True)
self.compareToTranspose(3, 2, 3, 2, 2, "NCHW", True)
self.compareToTranspose(3, 2, 3, 1, 3, "NCHW", True)
self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", True)
self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", True)
self.compareToTranspose(3, 200, 300, 32, 2, "NCHW", True)
self.compareToTranspose(3, 2, 3, 8, 2, "NCHW_VECT_C", True)
self.compareToTranspose(3, 2, 3, 4, 3, "NCHW_VECT_C", True)
self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", True)
self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", True)
self.compareToTranspose(3, 200, 300, 32, 2, "NCHW_VECT_C", True)
class DepthToSpaceGradientTest(test.TestCase):