Add an NCHW_VECT_C kernel to MaxPoolOp and MaxPoolOpV2
PiperOrigin-RevId: 168021874
This commit is contained in:
parent
27542d1e5c
commit
c14550a383
@ -920,6 +920,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||
public:
|
||||
explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format_str;
|
||||
auto status = context->GetAttr("data_format", &data_format_str);
|
||||
if (status.ok()) {
|
||||
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(context, ksize_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
@ -959,6 +966,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
||||
std::vector<int32> ksize_;
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -1051,17 +1059,36 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
|
||||
params.out_width, params.depth);
|
||||
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||
|
||||
// Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
|
||||
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
||||
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
|
||||
errors::InvalidArgument(
|
||||
"qint8 should be used with data_format NCHW_VECT_C."));
|
||||
|
||||
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
|
||||
if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(
|
||||
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, data_format_, tensor_in, out_shape);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "Non-Cudnn MaxPool only supports NHWC format";
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||
output);
|
||||
if (is_int8x4) {
|
||||
LaunchMaxPoolingNoMask_NCHW_VECT_C<Device>::launch(context, params,
|
||||
tensor_in, output);
|
||||
} else if (data_format_ == FORMAT_NHWC) {
|
||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||
output);
|
||||
} else {
|
||||
LOG(FATAL) << "MaxPool currently only supports the following (layout, "
|
||||
"type) combinations: (NHWC, non-qint8), "
|
||||
"(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The "
|
||||
"requested combination ("
|
||||
<< ToString(data_format_) << ", "
|
||||
<< DataTypeString(DataTypeToEnum<T>::v())
|
||||
<< ") is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1346,6 +1373,26 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
|
||||
MaxPoolingNoMaskOp<GPUDevice, qint8>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("ksize")
|
||||
.HostMemory("strides")
|
||||
.TypeConstraint<qint8>("T"),
|
||||
MaxPoolingV2Op<GPUDevice, qint8>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("ksize")
|
||||
.HostMemory("strides")
|
||||
.TypeConstraint<qint8>("T")
|
||||
.Label("eigen_tensor"),
|
||||
MaxPoolingV2Op<GPUDevice, qint8>);
|
||||
|
||||
#undef REGISTER_GPU_ONLY_POOL_KERNELS
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -17,7 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
|
||||
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
|
||||
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/kernels/eigen_pooling.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -37,6 +39,14 @@ struct SpatialMaxPooling {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct SpatialMaxPooling<Device, qint8> {
|
||||
void operator()(const Device& d, typename TTypes<qint8, 4>::Tensor output,
|
||||
typename TTypes<qint8, 4>::ConstTensor input, int window_rows,
|
||||
int window_cols, int row_stride, int col_stride,
|
||||
const Eigen::PaddingType& padding) {}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/kernels/maxpooling_op.h"
|
||||
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
@ -89,6 +90,42 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
|
||||
}
|
||||
}
|
||||
|
||||
// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for
|
||||
// MaxPoolForwardNCHW above, except that mask is not supported, and each
|
||||
// element of the input and output contains 4 adjacent channel values for
|
||||
// the same X, y coordinate.
|
||||
// (so channels = outer_channels, output_size = real output size / 4).
|
||||
__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
|
||||
const int nthreads, const int32* bottom_data, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||
int32* top_data) {
|
||||
// TODO(pauldonnelly): Implement a better optimized version of this kernel.
|
||||
const int32 kMinINT8X4 = 0x80808080;
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
int hstart = ph * stride_h - pad_t;
|
||||
int wstart = pw * stride_w - pad_l;
|
||||
int hend = min(hstart + kernel_h, height);
|
||||
int wend = min(wstart + kernel_w, width);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
int32 maxval = kMinINT8X4;
|
||||
const int32* bottom_data_n = bottom_data + n * channels * height * width;
|
||||
for (int h = hstart; h < hend; ++h) {
|
||||
for (int w = wstart; w < wend; ++w) {
|
||||
int idx = (c * height + h) * width + w;
|
||||
maxval = __vmaxs4(maxval, bottom_data_n[idx]);
|
||||
}
|
||||
}
|
||||
top_data[index] = maxval;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dtype>
|
||||
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
|
||||
const int height, const int width,
|
||||
@ -328,6 +365,25 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Note: channels is the outer channels (dim 1) which has already been
|
||||
// divided by 4.
|
||||
bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()(
|
||||
const int32* bottom_data, const int batch, const int height,
|
||||
const int width, int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||
int32* top_data, const Eigen::GpuDevice& d) {
|
||||
const int kThreadsPerBlock = 1024;
|
||||
const int output_size = batch * channels * pooled_height * pooled_width;
|
||||
MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<<
|
||||
(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock,
|
||||
0, d.stream()>>>(output_size, bottom_data, height, width, channels,
|
||||
pooled_height, pooled_width, kernel_h, kernel_w,
|
||||
stride_h, stride_w, pad_t, pad_l, top_data);
|
||||
d.synchronize();
|
||||
return d.ok();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||
const T* bottom_data, const int batch, const int height, const int width,
|
||||
|
@ -42,6 +42,15 @@ struct MaxPoolForwardWithOptionalArgmax {
|
||||
const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
struct MaxPoolForwardNoMask_NCHW_VECT_C {
|
||||
bool operator()(const int32* bottom_data, const int batch, const int height,
|
||||
const int width, int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_t, const int pad_l, int32* top_data,
|
||||
const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolBackwardWithArgmax {
|
||||
bool operator()(const int output_size, const int input_size,
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
||||
#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -34,12 +33,18 @@ PoolParameters::PoolParameters(OpKernelContext* context,
|
||||
const std::vector<int32>& stride,
|
||||
Padding padding, TensorFormat data_format,
|
||||
const TensorShape& tensor_in_shape) {
|
||||
// For maxpooling, tensor_in should have 4 dimensions.
|
||||
OP_REQUIRES(context, tensor_in_shape.dims() == 4,
|
||||
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
||||
// For maxpooling, tensor_in should have 2 spatial dimensions.
|
||||
// Note: the total number of dimensions could be 4 for NHWC, NCHW,
|
||||
// or 5 for NCHW_VECT_C.
|
||||
OP_REQUIRES(context,
|
||||
GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2,
|
||||
errors::InvalidArgument(
|
||||
"tensor_in_shape must have 2 spatial dimensions. ",
|
||||
tensor_in_shape.dims(), " ", data_format));
|
||||
|
||||
this->data_format = data_format;
|
||||
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
|
||||
depth = GetTensorDim(tensor_in_shape, data_format, 'C') *
|
||||
(data_format == FORMAT_NCHW_VECT_C ? 4 : 1);
|
||||
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
|
||||
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
|
||||
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
|
||||
|
@ -29,6 +29,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -256,6 +260,30 @@ class MaxPoolingOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct LaunchMaxPoolingNoMask_NCHW_VECT_C;
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <>
|
||||
struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
|
||||
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||
const Tensor& input, Tensor* output) {
|
||||
bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()(
|
||||
reinterpret_cast<const int32*>(input.flat<qint8>().data()),
|
||||
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
|
||||
params.depth, params.out_height, params.out_width, params.window_rows,
|
||||
params.window_cols, params.row_stride, params.col_stride,
|
||||
params.pad_rows, params.pad_cols,
|
||||
reinterpret_cast<int32*>(output->flat<qint8>().data()),
|
||||
context->eigen_gpu_device());
|
||||
if (!status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MaxPoolingV2Op : public OpKernel {
|
||||
public:
|
||||
@ -266,8 +294,11 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(
|
||||
context, data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Default MaxPoolingOp only supports NHWC."));
|
||||
context,
|
||||
data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C,
|
||||
errors::InvalidArgument(
|
||||
"MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ",
|
||||
data_format));
|
||||
} else {
|
||||
data_format_ = FORMAT_NHWC;
|
||||
}
|
||||
@ -315,8 +346,8 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
||||
PoolParameters params{context, ksize, stride,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -368,13 +399,21 @@ class MaxPoolingV2Op : public OpKernel {
|
||||
// Spatial MaxPooling implementation.
|
||||
//
|
||||
// TODO(vrv): Remove this once we no longer need it.
|
||||
#ifdef GOOGLE_CUDA
|
||||
if (std::is_same<Device, GPUDevice>::value) {
|
||||
Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
|
||||
functor::SpatialMaxPooling<Device, T>()(
|
||||
context->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
|
||||
params.row_stride, params.col_stride, pt);
|
||||
} else {
|
||||
if (std::is_same<T, qint8>::value) {
|
||||
LaunchMaxPoolingNoMask_NCHW_VECT_C<GPUDevice>::launch(
|
||||
context, params, tensor_in, output);
|
||||
} else {
|
||||
functor::SpatialMaxPooling<Device, T>()(
|
||||
context->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
|
||||
params.row_stride, params.col_stride, pt);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
ConstEigenMatrixMap;
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
|
@ -1344,11 +1344,13 @@ output: The gradients for LRN.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("MaxPool")
|
||||
.Attr("T: realnumbertype = DT_FLOAT")
|
||||
.Attr(
|
||||
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
|
||||
"half, qint8} = DT_FLOAT")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::MaxPoolShape)
|
||||
@ -1369,9 +1371,11 @@ output: The max pooled output tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MaxPoolV2")
|
||||
.Attr("T: realnumbertype = DT_FLOAT")
|
||||
.Attr(
|
||||
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
|
||||
"half, qint8} = DT_FLOAT")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
.Input("input: T")
|
||||
.Input("ksize: int32")
|
||||
.Input("strides: int32")
|
||||
|
@ -208,6 +208,71 @@ def NHWCToNCHW(input_tensor):
|
||||
return [input_tensor[a] for a in new_axes[ndims]]
|
||||
|
||||
|
||||
def NHWCToNCHW_VECT_C(input_shape_or_tensor):
|
||||
"""Transforms the input from the NHWC layout to NCHW_VECT_C layout.
|
||||
|
||||
Note: Does not include quantization or type conversion steps, which should
|
||||
be applied afterwards.
|
||||
|
||||
Args:
|
||||
input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
|
||||
|
||||
Returns:
|
||||
tensor or shape array transformed into NCHW_VECT_C
|
||||
|
||||
Raises:
|
||||
ValueError: if last dimension of `input_shape_or_tensor` is not evenly
|
||||
divisible by 4.
|
||||
"""
|
||||
permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
|
||||
is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
|
||||
temp_shape = (input_shape_or_tensor.shape.as_list()
|
||||
if is_tensor else input_shape_or_tensor)
|
||||
if temp_shape[-1] % 4 != 0:
|
||||
raise ValueError(
|
||||
"Last dimension of input must be evenly divisible by 4 to convert to "
|
||||
"NCHW_VECT_C.")
|
||||
temp_shape[-1] //= 4
|
||||
temp_shape.append(4)
|
||||
permutation = permutations[len(temp_shape)]
|
||||
if is_tensor:
|
||||
t = array_ops.reshape(input_shape_or_tensor, temp_shape)
|
||||
return array_ops.transpose(t, permutation)
|
||||
else:
|
||||
return [temp_shape[a] for a in permutation]
|
||||
|
||||
|
||||
def NCHW_VECT_CToNHWC(input_shape_or_tensor):
|
||||
"""Transforms the input from the NCHW_VECT_C layout to NHWC layout.
|
||||
|
||||
Note: Does not include de-quantization or type conversion steps, which should
|
||||
be applied beforehand.
|
||||
|
||||
Args:
|
||||
input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
|
||||
|
||||
Returns:
|
||||
tensor or shape array transformed into NHWC
|
||||
|
||||
Raises:
|
||||
ValueError: if last dimension of `input_shape_or_tensor` is not 4.
|
||||
"""
|
||||
permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
|
||||
is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
|
||||
input_shape = (input_shape_or_tensor.shape.as_list()
|
||||
if is_tensor else input_shape_or_tensor)
|
||||
if input_shape[-1] != 4:
|
||||
raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
|
||||
permutation = permutations[len(input_shape)]
|
||||
nhwc_shape = [input_shape[a] for a in permutation[:-1]]
|
||||
nhwc_shape[-1] *= input_shape[-1]
|
||||
if is_tensor:
|
||||
t = array_ops.transpose(input_shape_or_tensor, permutation)
|
||||
return array_ops.reshape(t, nhwc_shape)
|
||||
else:
|
||||
return nhwc_shape
|
||||
|
||||
|
||||
def NCHWToNHWC(input_tensor):
|
||||
"""Converts the input from the NCHW format to NHWC.
|
||||
|
||||
@ -392,7 +457,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self._cached_session = None
|
||||
|
||||
def setUp(self):
|
||||
logging.info("SET UP: %s" % str(self))
|
||||
logging.info("SET UP: %s", str(self))
|
||||
self._ClearCachedSession()
|
||||
random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
@ -407,7 +472,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
|
||||
def tearDown(self):
|
||||
logging.info("TEAR DOWN: %s" % str(self))
|
||||
logging.info("TEAR DOWN: %s", str(self))
|
||||
for thread in self._threads:
|
||||
self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
|
||||
|
||||
|
@ -25,25 +25,40 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.framework import ops
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
def GetTestConfigs():
|
||||
def GetTestConfigs(include_nchw_vect_c=False):
|
||||
"""Get all the valid tests configs to run.
|
||||
|
||||
Args:
|
||||
include_nchw_vect_c: Whether to include NCHW_VECT_C in the test configs.
|
||||
|
||||
Returns:
|
||||
all the valid test configs as tuples of data_format and use_gpu.
|
||||
"""
|
||||
test_configs = [("NHWC", False), ("NHWC", True)]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
# "NCHW" format is currently supported exclusively on CUDA GPUs.
|
||||
test_configs += [("NCHW", True)]
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
tf_logging.info("NCHW and NCHW_VECT_C tests skipped because not run with "
|
||||
"--config=cuda or no GPUs available.")
|
||||
return test_configs
|
||||
# "NCHW" format is currently supported exclusively on CUDA GPUs.
|
||||
test_configs += [("NCHW", True)]
|
||||
if include_nchw_vect_c:
|
||||
if test.is_gpu_available(
|
||||
cuda_only=True, min_cuda_compute_capability=(6, 1)):
|
||||
test_configs += [("NCHW_VECT_C", True)]
|
||||
else:
|
||||
tf_logging.info("NCHW_VECT_C test skipped because no GPUs with "
|
||||
"compute capability >= 6.1 are available.")
|
||||
|
||||
return test_configs
|
||||
|
||||
|
||||
@ -95,16 +110,32 @@ class PoolingTest(test.TestCase):
|
||||
total_size = 1
|
||||
for s in input_sizes:
|
||||
total_size *= s
|
||||
if v2 and data_format != "NHWC":
|
||||
tf_logging.info("v2 not supported for %s", data_format)
|
||||
return
|
||||
if data_format == "NCHW_VECT_C":
|
||||
if data_type != dtypes.float32:
|
||||
tf_logging.info("quantization to qint8 not implemented for %r",
|
||||
data_type)
|
||||
return
|
||||
if input_sizes[-1] % 4 != 0:
|
||||
tf_logging.info("Skipping test for depth %d", input_sizes[-1])
|
||||
return
|
||||
tf_logging.info("Running %s test. %r %r %d %r %r %r", data_format, v2,
|
||||
input_sizes, total_size, pool_func, ksize, strides)
|
||||
# Initializes the input tensor with array containing incrementing
|
||||
# numbers from 1.
|
||||
x = [f * 1.0 for f in range(1, total_size + 1)]
|
||||
# numbers from 1, wrapping round to -127 after 127 to support int8.
|
||||
x = [((f + 128) % 255) - 127 for f in range(total_size)]
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
t = constant_op.constant(x, shape=input_sizes, dtype=data_type)
|
||||
if data_format == "NCHW":
|
||||
t = test_util.NHWCToNCHW(t)
|
||||
if data_format in ("NCHW", "NCHW_VECT_C"):
|
||||
if data_format == "NCHW_VECT_C":
|
||||
t = test_util.NHWCToNCHW_VECT_C(t)
|
||||
t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
|
||||
else:
|
||||
t = test_util.NHWCToNCHW(t)
|
||||
ksize = test_util.NHWCToNCHW(ksize)
|
||||
strides = test_util.NHWCToNCHW(strides)
|
||||
v2 = v2 and data_format != "NCHW"
|
||||
ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4])
|
||||
if v2:
|
||||
@ -121,7 +152,10 @@ class PoolingTest(test.TestCase):
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format)
|
||||
if data_format == "NCHW":
|
||||
if data_format == "NCHW_VECT_C":
|
||||
t = gen_array_ops.dequantize(t, -128, 127)
|
||||
t = test_util.NCHW_VECT_CToNHWC(t)
|
||||
elif data_format == "NCHW":
|
||||
t = test_util.NCHWToNHWC(t)
|
||||
if v2:
|
||||
actual = t.eval(feed_dict={ksize_placeholder: ksize,
|
||||
@ -146,6 +180,13 @@ class PoolingTest(test.TestCase):
|
||||
expected: An array containing the expected operation outputs.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
if data_format == "NCHW_VECT_C":
|
||||
avg_pool_func = nn_ops.avg_pool
|
||||
tf_logging.info("pool_func=%s", pool_func)
|
||||
if pool_func == avg_pool_func:
|
||||
tf_logging.info("NCHW_VECT_C not yet implemented for avg_pool")
|
||||
return
|
||||
|
||||
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, dtypes.float32, expected, use_gpu, v2)
|
||||
|
||||
@ -167,7 +208,7 @@ class PoolingTest(test.TestCase):
|
||||
expected: An array containing the expected operation outputs.
|
||||
use_gpu: Whether we are running on GPU.
|
||||
"""
|
||||
for (data_format, use_gpu_2) in GetTestConfigs():
|
||||
for (data_format, use_gpu_2) in GetTestConfigs(True):
|
||||
if use_gpu_2 == use_gpu:
|
||||
self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected, use_gpu, v2)
|
||||
@ -296,20 +337,20 @@ class PoolingTest(test.TestCase):
|
||||
|
||||
def _testAvgPoolSamePaddingPacket8(self, use_gpu):
|
||||
expected_output = [
|
||||
73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 89.0, 90.0, 91.0, 92.0,
|
||||
93.0, 94.0, 95.0, 96.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0,
|
||||
112.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 201.0,
|
||||
202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 217.0, 218.0, 219.0,
|
||||
220.0, 221.0, 222.0, 223.0, 224.0, 233.0, 234.0, 235.0, 236.0, 237.0,
|
||||
238.0, 239.0, 240.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0,
|
||||
252.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 345.0,
|
||||
346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 352.0, 361.0, 362.0, 363.0,
|
||||
364.0, 365.0, 366.0, 367.0, 368.0, 373.0, 374.0, 375.0, 376.0, 377.0,
|
||||
378.0, 379.0, 380.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0,
|
||||
432.0, 441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 457.0,
|
||||
458.0, 459.0, 460.0, 461.0, 462.0, 463.0, 464.0, 469.0, 470.0, 471.0,
|
||||
472.0, 473.0, 474.0, 475.0, 476.0
|
||||
-12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0,
|
||||
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, -3.5, -54.0, -53.0, -52.0,
|
||||
-51.0, -50.0, -49.0, -48.0, -47.0, -38.0, -37.0, -36.0, -35.0, -34.0,
|
||||
-33.0, -32.0, -31.0, -22.0, -21.0, -20.0, -19.0, -18.0, -17.0, -16.0,
|
||||
-15.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -11.0, -10.0,
|
||||
-9.0, -8.0, -7.0, -6.0, -5.0, -4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
|
||||
12.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 33.0, 34.0, 35.0,
|
||||
36.0, 37.0, 38.0, -3.5, -2.5, -85.0, -84.0, -83.0, -82.0, -81.0, -80.0,
|
||||
-79.0, -78.0, -69.0, -68.0, -67.0, -66.0, -65.0, -64.0, -63.0, -62.0,
|
||||
-53.0, -52.0, -51.0, -50.0, -49.0, -48.0, -47.0, -46.0, -41.0, -40.0,
|
||||
-39.0, -38.0, -37.0, -36.0, -35.0, -34.0
|
||||
]
|
||||
|
||||
self._VerifyValues(
|
||||
nn_ops.avg_pool,
|
||||
input_sizes=[1, 8, 8, 8],
|
||||
@ -468,19 +509,18 @@ class PoolingTest(test.TestCase):
|
||||
|
||||
def _testMaxPoolSamePaddingPacket8(self, use_gpu):
|
||||
expected_output = [
|
||||
145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
|
||||
163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
|
||||
181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
|
||||
191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
|
||||
289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
|
||||
307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
|
||||
317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
|
||||
407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
|
||||
433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
|
||||
443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
|
||||
469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
|
||||
487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
|
||||
505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
|
||||
81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 97.0, 98.0, 99.0, 100.0,
|
||||
101.0, 102.0, 103.0, 104.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0,
|
||||
119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 120.0,
|
||||
18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 34.0, 35.0, 36.0, 37.0,
|
||||
38.0, 39.0, 40.0, 41.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0,
|
||||
58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 82.0, 83.0, 84.0, 85.0,
|
||||
86.0, 87.0, 88.0, 89.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0,
|
||||
105.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0,
|
||||
123.0, 124.0, 125.0, 126.0, 127.0, 120.0, 121.0, -45.0, -44.0, -43.0,
|
||||
-42.0, -41.0, -40.0, -39.0, -38.0, -29.0, -28.0, -27.0, -26.0, -25.0,
|
||||
-24.0, -23.0, -22.0, -13.0, -12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0,
|
||||
-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0
|
||||
]
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool,
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.gen_nn_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
|
||||
# Aliases for some automatically-generated names.
|
||||
local_response_normalization = gen_nn_ops.lrn
|
||||
|
||||
@ -1750,19 +1751,19 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
|
||||
"""Performs the max pooling on the input.
|
||||
|
||||
Args:
|
||||
value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and
|
||||
type `tf.float32`.
|
||||
value: A 4-D `Tensor` of the format specified by `data_format`.
|
||||
ksize: A 1-D int Tensor of 4 elements. The size of the window for
|
||||
each dimension of the input tensor.
|
||||
strides: A 1-D int Tensor of 4 elements. The stride of the sliding
|
||||
window for each dimension of the input tensor.
|
||||
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
|
||||
See the @{tf.nn.convolution$comment here}
|
||||
data_format: A string. 'NHWC' and 'NCHW' are supported.
|
||||
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with type `tf.float32`. The max pooled output tensor.
|
||||
A `Tensor` of format specified by `data_format`.
|
||||
The max pooled output tensor.
|
||||
"""
|
||||
with ops.name_scope(name, "MaxPool", [value]) as name:
|
||||
value = ops.convert_to_tensor(value, name="input")
|
||||
|
Loading…
Reference in New Issue
Block a user