From c14550a38308a7f516e83f5c8e21748ad76bf972 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 Sep 2017 11:23:17 -0700 Subject: [PATCH] Add an NCHW_VECT_C kernel to MaxPoolOp and MaxPoolOpV2 PiperOrigin-RevId: 168021874 --- tensorflow/core/kernels/maxpooling_op.cc | 57 ++++++++- tensorflow/core/kernels/maxpooling_op.h | 10 ++ .../core/kernels/maxpooling_op_gpu.cu.cc | 56 +++++++++ tensorflow/core/kernels/maxpooling_op_gpu.h | 9 ++ tensorflow/core/kernels/pooling_ops_common.cc | 15 ++- tensorflow/core/kernels/pooling_ops_common.h | 57 +++++++-- tensorflow/core/ops/nn_ops.cc | 12 +- tensorflow/python/framework/test_util.py | 69 ++++++++++- .../python/kernel_tests/pooling_ops_test.py | 116 ++++++++++++------ tensorflow/python/ops/nn_ops.py | 9 +- 10 files changed, 343 insertions(+), 67 deletions(-) diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 8d825c13d76..60ed1263a23 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -920,6 +920,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { public: explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context) : OpKernel(context) { + string data_format_str; + auto status = context->GetAttr("data_format", &data_format_str); + if (status.ok()) { + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " @@ -959,6 +966,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { std::vector ksize_; std::vector stride_; Padding padding_; + TensorFormat data_format_; }; template @@ -1051,17 +1059,36 @@ class MaxPoolingNoMaskOp : public OpKernel { TensorShape out_shape = ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height, params.out_width, params.depth); - if (use_dnn_ && data_format_ == FORMAT_NCHW) { + + // Assuming qint8 <--> NCHW_VECT_C (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), + errors::InvalidArgument( + "qint8 should be used with data_format NCHW_VECT_C.")); + + // These is_int8x4 checks avoid linker errors for missing qint8 kernels. + if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) { DnnPoolingOp::Compute( context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_, stride_, padding_, data_format_, tensor_in, out_shape); } else { - CHECK(data_format_ == FORMAT_NHWC) - << "Non-Cudnn MaxPool only supports NHWC format"; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - LaunchMaxPoolingNoMask::launch(context, params, tensor_in, - output); + if (is_int8x4) { + LaunchMaxPoolingNoMask_NCHW_VECT_C::launch(context, params, + tensor_in, output); + } else if (data_format_ == FORMAT_NHWC) { + LaunchMaxPoolingNoMask::launch(context, params, tensor_in, + output); + } else { + LOG(FATAL) << "MaxPool currently only supports the following (layout, " + "type) combinations: (NHWC, non-qint8), " + "(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The " + "requested combination (" + << ToString(data_format_) << ", " + << DataTypeString(DataTypeToEnum::v()) + << ") is not supported."; + } } } @@ -1346,6 +1373,26 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS); .TypeConstraint("Targmax"), \ MaxPoolingGradGradWithArgmaxOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS); + +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_GPU).TypeConstraint("T"), + MaxPoolingNoMaskOp); + +REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") + .Device(DEVICE_GPU) + .HostMemory("ksize") + .HostMemory("strides") + .TypeConstraint("T"), + MaxPoolingV2Op); + +REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") + .Device(DEVICE_GPU) + .HostMemory("ksize") + .HostMemory("strides") + .TypeConstraint("T") + .Label("eigen_tensor"), + MaxPoolingV2Op); + #undef REGISTER_GPU_ONLY_POOL_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h index 1670c1b26d8..f82e57d44c2 100644 --- a/tensorflow/core/kernels/maxpooling_op.h +++ b/tensorflow/core/kernels/maxpooling_op.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_ // Functor definition for MaxPoolingOp, must be compilable by nvcc. +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/eigen_pooling.h" #include "tensorflow/core/platform/types.h" @@ -37,6 +39,14 @@ struct SpatialMaxPooling { } }; +template +struct SpatialMaxPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) {} +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index e3a57d2f28a..26f52748045 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/maxpooling_op.h" #include "tensorflow/core/kernels/maxpooling_op_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" @@ -89,6 +90,42 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, } } +// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for +// MaxPoolForwardNCHW above, except that mask is not supported, and each +// element of the input and output contains 4 adjacent channel values for +// the same X, y coordinate. +// (so channels = outer_channels, output_size = real output size / 4). +__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C( + const int nthreads, const int32* bottom_data, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + int32* top_data) { + // TODO(pauldonnelly): Implement a better optimized version of this kernel. + const int32 kMinINT8X4 = 0x80808080; + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int32 maxval = kMinINT8X4; + const int32* bottom_data_n = bottom_data + n * channels * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = (c * height + h) * width + w; + maxval = __vmaxs4(maxval, bottom_data_n[idx]); + } + } + top_data[index] = maxval; + } +} + template __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, const int height, const int width, @@ -328,6 +365,25 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff, namespace functor { +// Note: channels is the outer channels (dim 1) which has already been +// divided by 4. +bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()( + const int32* bottom_data, const int batch, const int height, + const int width, int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + int32* top_data, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int output_size = batch * channels * pooled_height * pooled_width; + MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<< + (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, + 0, d.stream()>>>(output_size, bottom_data, height, width, channels, + pooled_height, pooled_width, kernel_h, kernel_w, + stride_h, stride_w, pad_t, pad_l, top_data); + d.synchronize(); + return d.ok(); +} + template bool MaxPoolForwardWithOptionalArgmax::operator()( const T* bottom_data, const int batch, const int height, const int width, diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h index d2029f5719a..34203797cf0 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.h +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -42,6 +42,15 @@ struct MaxPoolForwardWithOptionalArgmax { const Eigen::GpuDevice& d); }; +struct MaxPoolForwardNoMask_NCHW_VECT_C { + bool operator()(const int32* bottom_data, const int batch, const int height, + const int width, int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_t, const int pad_l, int32* top_data, + const Eigen::GpuDevice& d); +}; + template struct MaxPoolBackwardWithArgmax { bool operator()(const int output_size, const int input_size, diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 37747a31999..7dee751c4f3 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -22,7 +22,6 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/conv_2d.h" -#include "tensorflow/core/kernels/maxpooling_op_gpu.h" #include "tensorflow/core/kernels/pooling_ops_common_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -34,12 +33,18 @@ PoolParameters::PoolParameters(OpKernelContext* context, const std::vector& stride, Padding padding, TensorFormat data_format, const TensorShape& tensor_in_shape) { - // For maxpooling, tensor_in should have 4 dimensions. - OP_REQUIRES(context, tensor_in_shape.dims() == 4, - errors::InvalidArgument("tensor_in must be 4-dimensional")); + // For maxpooling, tensor_in should have 2 spatial dimensions. + // Note: the total number of dimensions could be 4 for NHWC, NCHW, + // or 5 for NCHW_VECT_C. + OP_REQUIRES(context, + GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2, + errors::InvalidArgument( + "tensor_in_shape must have 2 spatial dimensions. ", + tensor_in_shape.dims(), " ", data_format)); this->data_format = data_format; - depth = GetTensorDim(tensor_in_shape, data_format, 'C'); + depth = GetTensorDim(tensor_in_shape, data_format, 'C') * + (data_format == FORMAT_NCHW_VECT_C ? 4 : 1); tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W'); tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H'); tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N'); diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 1b59c18df79..75a6fc371b4 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -29,6 +29,10 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/work_sharder.h" +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -256,6 +260,30 @@ class MaxPoolingOp : public OpKernel { TensorFormat data_format_; }; +template +struct LaunchMaxPoolingNoMask_NCHW_VECT_C; + +#ifdef GOOGLE_CUDA +template <> +struct LaunchMaxPoolingNoMask_NCHW_VECT_C { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& input, Tensor* output) { + bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()( + reinterpret_cast(input.flat().data()), + params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, + params.depth, params.out_height, params.out_width, params.window_rows, + params.window_cols, params.row_stride, params.col_stride, + params.pad_rows, params.pad_cols, + reinterpret_cast(output->flat().data()), + context->eigen_gpu_device()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C")); + } + } +}; +#endif + template class MaxPoolingV2Op : public OpKernel { public: @@ -266,8 +294,11 @@ class MaxPoolingV2Op : public OpKernel { OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); OP_REQUIRES( - context, data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Default MaxPoolingOp only supports NHWC.")); + context, + data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C, + errors::InvalidArgument( + "MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ", + data_format)); } else { data_format_ = FORMAT_NHWC; } @@ -315,8 +346,8 @@ class MaxPoolingV2Op : public OpKernel { errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); - PoolParameters params{context, ksize, stride, - padding_, FORMAT_NHWC, tensor_in.shape()}; + PoolParameters params{context, ksize, stride, + padding_, data_format_, tensor_in.shape()}; if (!context->status().ok()) { return; } @@ -368,13 +399,21 @@ class MaxPoolingV2Op : public OpKernel { // Spatial MaxPooling implementation. // // TODO(vrv): Remove this once we no longer need it. +#ifdef GOOGLE_CUDA if (std::is_same::value) { Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); - functor::SpatialMaxPooling()( - context->eigen_device(), output->tensor(), - tensor_in.tensor(), params.window_rows, params.window_cols, - params.row_stride, params.col_stride, pt); - } else { + if (std::is_same::value) { + LaunchMaxPoolingNoMask_NCHW_VECT_C::launch( + context, params, tensor_in, output); + } else { + functor::SpatialMaxPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } + } else +#endif + { typedef Eigen::Map> ConstEigenMatrixMap; typedef Eigen::Map> diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 8a2d5e8c05a..fd0b785b8f7 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1344,11 +1344,13 @@ output: The gradients for LRN. // -------------------------------------------------------------------------- REGISTER_OP("MaxPool") - .Attr("T: realnumbertype = DT_FLOAT") + .Attr( + "T: {float, double, int32, int64, uint8, int16, int8, uint16, " + "half, qint8} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Input("input: T") .Output("output: T") .SetShapeFn(shape_inference::MaxPoolShape) @@ -1369,9 +1371,11 @@ output: The max pooled output tensor. )doc"); REGISTER_OP("MaxPoolV2") - .Attr("T: realnumbertype = DT_FLOAT") + .Attr( + "T: {float, double, int32, int64, uint8, int16, int8, uint16, " + "half, qint8} = DT_FLOAT") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Input("input: T") .Input("ksize: int32") .Input("strides: int32") diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 9cf222a63ab..c6a7d0833e5 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -208,6 +208,71 @@ def NHWCToNCHW(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] +def NHWCToNCHW_VECT_C(input_shape_or_tensor): + """Transforms the input from the NHWC layout to NCHW_VECT_C layout. + + Note: Does not include quantization or type conversion steps, which should + be applied afterwards. + + Args: + input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape + + Returns: + tensor or shape array transformed into NCHW_VECT_C + + Raises: + ValueError: if last dimension of `input_shape_or_tensor` is not evenly + divisible by 4. + """ + permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} + is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + temp_shape = (input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) + if temp_shape[-1] % 4 != 0: + raise ValueError( + "Last dimension of input must be evenly divisible by 4 to convert to " + "NCHW_VECT_C.") + temp_shape[-1] //= 4 + temp_shape.append(4) + permutation = permutations[len(temp_shape)] + if is_tensor: + t = array_ops.reshape(input_shape_or_tensor, temp_shape) + return array_ops.transpose(t, permutation) + else: + return [temp_shape[a] for a in permutation] + + +def NCHW_VECT_CToNHWC(input_shape_or_tensor): + """Transforms the input from the NCHW_VECT_C layout to NHWC layout. + + Note: Does not include de-quantization or type conversion steps, which should + be applied beforehand. + + Args: + input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape + + Returns: + tensor or shape array transformed into NHWC + + Raises: + ValueError: if last dimension of `input_shape_or_tensor` is not 4. + """ + permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} + is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + input_shape = (input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) + if input_shape[-1] != 4: + raise ValueError("Last dimension of NCHW_VECT_C must be 4.") + permutation = permutations[len(input_shape)] + nhwc_shape = [input_shape[a] for a in permutation[:-1]] + nhwc_shape[-1] *= input_shape[-1] + if is_tensor: + t = array_ops.transpose(input_shape_or_tensor, permutation) + return array_ops.reshape(t, nhwc_shape) + else: + return nhwc_shape + + def NCHWToNHWC(input_tensor): """Converts the input from the NCHW format to NHWC. @@ -392,7 +457,7 @@ class TensorFlowTestCase(googletest.TestCase): self._cached_session = None def setUp(self): - logging.info("SET UP: %s" % str(self)) + logging.info("SET UP: %s", str(self)) self._ClearCachedSession() random.seed(random_seed.DEFAULT_GRAPH_SEED) np.random.seed(random_seed.DEFAULT_GRAPH_SEED) @@ -407,7 +472,7 @@ class TensorFlowTestCase(googletest.TestCase): ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED def tearDown(self): - logging.info("TEAR DOWN: %s" % str(self)) + logging.info("TEAR DOWN: %s", str(self)) for thread in self._threads: self.assertFalse(thread.is_alive(), "A checkedThread did not terminate") diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 9eb1fea8037..c699d50c02d 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -25,25 +25,40 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops -from tensorflow.python.framework import ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging -def GetTestConfigs(): +def GetTestConfigs(include_nchw_vect_c=False): """Get all the valid tests configs to run. + Args: + include_nchw_vect_c: Whether to include NCHW_VECT_C in the test configs. + Returns: all the valid test configs as tuples of data_format and use_gpu. """ test_configs = [("NHWC", False), ("NHWC", True)] - if test.is_gpu_available(cuda_only=True): - # "NCHW" format is currently supported exclusively on CUDA GPUs. - test_configs += [("NCHW", True)] + if not test.is_gpu_available(cuda_only=True): + tf_logging.info("NCHW and NCHW_VECT_C tests skipped because not run with " + "--config=cuda or no GPUs available.") + return test_configs + # "NCHW" format is currently supported exclusively on CUDA GPUs. + test_configs += [("NCHW", True)] + if include_nchw_vect_c: + if test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=(6, 1)): + test_configs += [("NCHW_VECT_C", True)] + else: + tf_logging.info("NCHW_VECT_C test skipped because no GPUs with " + "compute capability >= 6.1 are available.") + return test_configs @@ -95,16 +110,32 @@ class PoolingTest(test.TestCase): total_size = 1 for s in input_sizes: total_size *= s + if v2 and data_format != "NHWC": + tf_logging.info("v2 not supported for %s", data_format) + return + if data_format == "NCHW_VECT_C": + if data_type != dtypes.float32: + tf_logging.info("quantization to qint8 not implemented for %r", + data_type) + return + if input_sizes[-1] % 4 != 0: + tf_logging.info("Skipping test for depth %d", input_sizes[-1]) + return + tf_logging.info("Running %s test. %r %r %d %r %r %r", data_format, v2, + input_sizes, total_size, pool_func, ksize, strides) # Initializes the input tensor with array containing incrementing - # numbers from 1. - x = [f * 1.0 for f in range(1, total_size + 1)] + # numbers from 1, wrapping round to -127 after 127 to support int8. + x = [((f + 128) % 255) - 127 for f in range(total_size)] with self.test_session(use_gpu=use_gpu): t = constant_op.constant(x, shape=input_sizes, dtype=data_type) - if data_format == "NCHW": - t = test_util.NHWCToNCHW(t) + if data_format in ("NCHW", "NCHW_VECT_C"): + if data_format == "NCHW_VECT_C": + t = test_util.NHWCToNCHW_VECT_C(t) + t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8) + else: + t = test_util.NHWCToNCHW(t) ksize = test_util.NHWCToNCHW(ksize) strides = test_util.NHWCToNCHW(strides) - v2 = v2 and data_format != "NCHW" ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) if v2: @@ -121,7 +152,10 @@ class PoolingTest(test.TestCase): strides=strides, padding=padding, data_format=data_format) - if data_format == "NCHW": + if data_format == "NCHW_VECT_C": + t = gen_array_ops.dequantize(t, -128, 127) + t = test_util.NCHW_VECT_CToNHWC(t) + elif data_format == "NCHW": t = test_util.NCHWToNHWC(t) if v2: actual = t.eval(feed_dict={ksize_placeholder: ksize, @@ -146,6 +180,13 @@ class PoolingTest(test.TestCase): expected: An array containing the expected operation outputs. use_gpu: Whether we are running on GPU. """ + if data_format == "NCHW_VECT_C": + avg_pool_func = nn_ops.avg_pool + tf_logging.info("pool_func=%s", pool_func) + if pool_func == avg_pool_func: + tf_logging.info("NCHW_VECT_C not yet implemented for avg_pool") + return + self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float32, expected, use_gpu, v2) @@ -167,7 +208,7 @@ class PoolingTest(test.TestCase): expected: An array containing the expected operation outputs. use_gpu: Whether we are running on GPU. """ - for (data_format, use_gpu_2) in GetTestConfigs(): + for (data_format, use_gpu_2) in GetTestConfigs(True): if use_gpu_2 == use_gpu: self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding, data_format, expected, use_gpu, v2) @@ -296,20 +337,20 @@ class PoolingTest(test.TestCase): def _testAvgPoolSamePaddingPacket8(self, use_gpu): expected_output = [ - 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 89.0, 90.0, 91.0, 92.0, - 93.0, 94.0, 95.0, 96.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, - 112.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 201.0, - 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 217.0, 218.0, 219.0, - 220.0, 221.0, 222.0, 223.0, 224.0, 233.0, 234.0, 235.0, 236.0, 237.0, - 238.0, 239.0, 240.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, - 252.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 345.0, - 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 352.0, 361.0, 362.0, 363.0, - 364.0, 365.0, 366.0, 367.0, 368.0, 373.0, 374.0, 375.0, 376.0, 377.0, - 378.0, 379.0, 380.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, - 432.0, 441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 457.0, - 458.0, 459.0, 460.0, 461.0, 462.0, 463.0, 464.0, 469.0, 470.0, 471.0, - 472.0, 473.0, 474.0, 475.0, 476.0 + -12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, + 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, -3.5, -54.0, -53.0, -52.0, + -51.0, -50.0, -49.0, -48.0, -47.0, -38.0, -37.0, -36.0, -35.0, -34.0, + -33.0, -32.0, -31.0, -22.0, -21.0, -20.0, -19.0, -18.0, -17.0, -16.0, + -15.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -11.0, -10.0, + -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 33.0, 34.0, 35.0, + 36.0, 37.0, 38.0, -3.5, -2.5, -85.0, -84.0, -83.0, -82.0, -81.0, -80.0, + -79.0, -78.0, -69.0, -68.0, -67.0, -66.0, -65.0, -64.0, -63.0, -62.0, + -53.0, -52.0, -51.0, -50.0, -49.0, -48.0, -47.0, -46.0, -41.0, -40.0, + -39.0, -38.0, -37.0, -36.0, -35.0, -34.0 ] + self._VerifyValues( nn_ops.avg_pool, input_sizes=[1, 8, 8, 8], @@ -468,19 +509,18 @@ class PoolingTest(test.TestCase): def _testMaxPoolSamePaddingPacket8(self, use_gpu): expected_output = [ - 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0, - 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0, - 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, - 191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, - 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0, - 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0, - 317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, - 407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, - 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, - 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0, - 469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, - 487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, - 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0 + 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 97.0, 98.0, 99.0, 100.0, + 101.0, 102.0, 103.0, 104.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, + 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 120.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 34.0, 35.0, 36.0, 37.0, + 38.0, 39.0, 40.0, 41.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, + 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 82.0, 83.0, 84.0, 85.0, + 86.0, 87.0, 88.0, 89.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, + 105.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, + 123.0, 124.0, 125.0, 126.0, 127.0, 120.0, 121.0, -45.0, -44.0, -43.0, + -42.0, -41.0, -40.0, -39.0, -38.0, -29.0, -28.0, -27.0, -26.0, -25.0, + -24.0, -23.0, -22.0, -13.0, -12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0, + -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0 ] self._VerifyValues( nn_ops.max_pool, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index d4b16635071..a2e75dd7f27 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.gen_nn_ops import * # pylint: enable=wildcard-import + # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn @@ -1750,19 +1751,19 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """Performs the max pooling on the input. Args: - value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and - type `tf.float32`. + value: A 4-D `Tensor` of the format specified by `data_format`. ksize: A 1-D int Tensor of 4 elements. The size of the window for each dimension of the input tensor. strides: A 1-D int Tensor of 4 elements. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See the @{tf.nn.convolution$comment here} - data_format: A string. 'NHWC' and 'NCHW' are supported. + data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported. name: Optional name for the operation. Returns: - A `Tensor` with type `tf.float32`. The max pooled output tensor. + A `Tensor` of format specified by `data_format`. + The max pooled output tensor. """ with ops.name_scope(name, "MaxPool", [value]) as name: value = ops.convert_to_tensor(value, name="input")