Remove CanUseCudnn and CudnnConvComputeMode functions.
CanUseCudnn always returned true. Any code that was only used when it returned false has been removed. CudnnConvComputeMode was never used. PiperOrigin-RevId: 327276464 Change-Id: I26d5d700a9ac37a0cec451fb04e5169cd46a8547
This commit is contained in:
parent
306f6612fc
commit
82f46cb5fa
@ -301,7 +301,6 @@ class Conv2DBackpropFilterOp : public OpKernel {
|
||||
/*num_dims=*/4, data_format_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
|
||||
if (std::is_same<Device, CPUDevice>::value) {
|
||||
|
@ -426,7 +426,6 @@ class Conv2DBackpropInputOp : public OpKernel {
|
||||
/*num_dims=*/4, data_format_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
|
||||
if (std::is_same<Device, CPUDevice>::value ||
|
||||
|
@ -512,7 +512,6 @@ class Conv2DOp : public BinaryOp<T> {
|
||||
OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
}
|
||||
|
||||
|
@ -670,7 +670,6 @@ class FusedConv2DOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
|
||||
using FCT = FusedComputationType;
|
||||
|
@ -579,8 +579,6 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
dtype_ = DataTypeToEnum<T>::value;
|
||||
#if CUDNN_VERSION >= 8000
|
||||
@ -638,13 +636,13 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
|
||||
|
||||
// If in_depth==1, this operation is just a standard convolution.
|
||||
// Depthwise convolution is a special case of cuDNN's grouped convolution.
|
||||
bool use_cudnn =
|
||||
use_cudnn_ && (in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
bool use_cudnn = std::is_same<Device, GPUDevice>::value &&
|
||||
(in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
|
||||
VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
|
||||
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
|
||||
@ -674,7 +672,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
|
||||
"Failed to reshape filter tensor for grouped convolution."));
|
||||
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
|
||||
// conv is supported.
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
|
||||
launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, out_backprop,
|
||||
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
|
||||
stride_, stride_, padding_, explicit_paddings_, in_backprop,
|
||||
data_format_);
|
||||
@ -701,7 +699,6 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
LaunchConv2DBackpropInputOp<Device, T> launcher_;
|
||||
bool use_cudnn_;
|
||||
bool cudnn_use_autotune_;
|
||||
DataType dtype_;
|
||||
|
||||
@ -1085,8 +1082,6 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
|
||||
if (std::is_same<T, Eigen::half>::value) {
|
||||
@ -1138,13 +1133,13 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
|
||||
|
||||
// If in_depth==1, this operation is just a standard convolution.
|
||||
// Depthwise convolution is a special case of cuDNN's grouped convolution.
|
||||
bool use_cudnn =
|
||||
use_cudnn_ && (in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
bool use_cudnn = std::is_same<Device, GPUDevice>::value &&
|
||||
(in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
|
||||
VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
|
||||
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
|
||||
@ -1175,7 +1170,8 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
|
||||
|
||||
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
|
||||
// conv is supported.
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
|
||||
launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, out_backprop,
|
||||
input,
|
||||
/*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
|
||||
padding_, explicit_paddings_, &reshaped_filter, data_format_);
|
||||
return;
|
||||
@ -1234,7 +1230,6 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
LaunchConv2DBackpropFilterOp<Device, T> launcher_;
|
||||
bool use_cudnn_;
|
||||
bool cudnn_use_autotune_;
|
||||
DataType dtype_;
|
||||
|
||||
|
@ -298,8 +298,6 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
|
||||
/*num_dims=*/4, data_format_));
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
dtype_ = DataTypeToEnum<T>::value;
|
||||
#if CUDNN_VERSION >= 8000
|
||||
@ -407,13 +405,13 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
// TODO(csigg): Have autotune decide if native is faster than cuDNN.
|
||||
// If in_depth==1, this operation is just a standard convolution.
|
||||
// Depthwise convolution is a special case of cuDNN's grouped convolution.
|
||||
bool use_cudnn =
|
||||
use_cudnn_ && (in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
bool use_cudnn = std::is_same<Device, GPUDevice>::value &&
|
||||
(in_depth == 1 ||
|
||||
(use_cudnn_grouped_conv_ &&
|
||||
IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
|
||||
/*filter_cols=*/filter_cols,
|
||||
/*in_depth=*/in_depth,
|
||||
/*out_depth=*/out_depth)));
|
||||
|
||||
VLOG(2) << "DepthwiseConv2dNative: "
|
||||
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
|
||||
@ -443,7 +441,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
"Failed to reshape filter tensor for grouped convolution."));
|
||||
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
|
||||
// conv is supported.
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
|
||||
launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, input,
|
||||
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
|
||||
stride_, stride_, padding_, explicit_paddings_, output,
|
||||
data_format_);
|
||||
@ -485,7 +483,6 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
|
||||
// For in_depth == 1 and grouped convolutions.
|
||||
LaunchConv2DOp<Device, T> launcher_;
|
||||
bool use_cudnn_;
|
||||
bool cudnn_use_autotune_;
|
||||
DataType dtype_;
|
||||
|
||||
|
@ -321,30 +321,6 @@ class MaxPoolingGradOp : public OpKernel {
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
static void MaxPoolingBackwardCustomKernel(
|
||||
OpKernelContext* context, const std::vector<int32>& size,
|
||||
const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
|
||||
const Tensor& out_backprop, const TensorShape& tensor_in_shape) {
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, tensor_in_shape, &output));
|
||||
|
||||
PoolParameters params{context, size, stride,
|
||||
padding, FORMAT_NHWC, tensor_in_shape};
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
functor::MaxPoolBackwardNoMask<T>()(
|
||||
tensor_in->flat<T>().data(), params.tensor_in_batch,
|
||||
params.tensor_in_rows, params.tensor_in_cols, params.depth,
|
||||
params.out_height, params.out_width, params.window_rows,
|
||||
params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
|
||||
params.pad_cols, out_backprop.flat<T>().data(), output->flat<T>().data(),
|
||||
context->eigen_device<Eigen::GpuDevice>());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
public:
|
||||
@ -372,7 +348,6 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
|
||||
use_dnn_ = CanUseCudnn();
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
@ -417,18 +392,10 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
|
||||
if (use_dnn_) {
|
||||
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum,
|
||||
ksize, stride, padding_, data_format_,
|
||||
&tensor_in, &tensor_out, out_backprop,
|
||||
output_shape, propagate_nans_);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
|
||||
MaxPoolingBackwardCustomKernel<T>(context, ksize, stride, padding_,
|
||||
&tensor_in, out_backprop, output_shape);
|
||||
}
|
||||
DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, data_format_, &tensor_in,
|
||||
&tensor_out, out_backprop, output_shape,
|
||||
propagate_nans_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -436,7 +403,6 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
bool use_dnn_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
|
||||
@ -1139,7 +1105,6 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
OP_REQUIRES(context, ksize_n == 1 && stride_n == 1,
|
||||
errors::Unimplemented(
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
use_dnn_ = CanUseCudnn();
|
||||
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
@ -1165,17 +1130,15 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
"qint8 should be used with data_format NCHW_VECT_C."));
|
||||
|
||||
#if CUDNN_VERSION >= 7300
|
||||
if (use_dnn_) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
#else
|
||||
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
|
||||
if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||
if (!is_int8x4 && data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
|
||||
stride_, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
#endif
|
||||
} else {
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
@ -1195,6 +1158,7 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
<< ") is not supported.";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1202,7 +1166,6 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
bool use_dnn_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
|
||||
@ -1232,7 +1195,6 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
use_dnn_ = CanUseCudnn();
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
|
||||
&propagate_nans_));
|
||||
}
|
||||
@ -1275,13 +1237,13 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
|
||||
params.out_width, params.depth);
|
||||
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
|
||||
stride, padding_, data_format_, tensor_in,
|
||||
out_shape, propagate_nans_);
|
||||
} else {
|
||||
CHECK(data_format_ == FORMAT_NHWC)
|
||||
<< "Non-Cudnn MaxPool only supports NHWC format";
|
||||
<< "MaxPool only supports NCHW or NHWC format";
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||
@ -1294,7 +1256,6 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
|
||||
std::vector<int32> stride_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
bool use_dnn_;
|
||||
bool propagate_nans_;
|
||||
};
|
||||
|
||||
|
@ -180,48 +180,6 @@ __global__ void MaxPoolForwardNHWC(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dtype>
|
||||
__global__ void MaxPoolBackwardNoMaskNHWC(
|
||||
const int nthreads, const dtype* __restrict__ bottom_data, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||
const dtype* __restrict__ top_diff, dtype* __restrict__ bottom_diff) {
|
||||
GPU_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// First find out the index to the maximum, since we have no mask.
|
||||
int n = index;
|
||||
int c = n % channels;
|
||||
n /= channels;
|
||||
int wstart = (n % pooled_width) * stride_w - pad_l;
|
||||
n /= pooled_width;
|
||||
int hstart = (n % pooled_height) * stride_h - pad_t;
|
||||
n /= pooled_height;
|
||||
int hend = min(hstart + kernel_h, height);
|
||||
int wend = min(wstart + kernel_w, width);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
dtype maxval = Eigen::NumTraits<dtype>::lowest();
|
||||
int maxidx = -1;
|
||||
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
|
||||
for (int h = hstart; h < hend; ++h) {
|
||||
for (int w = wstart; w < wend; ++w) {
|
||||
int idx = (h * width + w) * channels + c;
|
||||
if (bottom_data_n[idx] > maxval) {
|
||||
maxidx = idx;
|
||||
maxval = bottom_data_n[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically accumulate the bottom diff. The index could still be
|
||||
// uninitialized, if all the bottom_data are NaN.
|
||||
if (maxidx != -1) {
|
||||
GpuAtomicAdd(bottom_diff + n * height * width * channels + maxidx,
|
||||
top_diff[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The parameters to the kernels in the backward function is as follows:
|
||||
// nthreads: the number of threads, which is equal to the output size.
|
||||
// top_diff: the gradient of the output data, of size N*Hout*Wout*C (or
|
||||
@ -445,31 +403,6 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||
return d.ok();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolBackwardNoMask<T>::operator()(
|
||||
const T* bottom_data, const int batch, const int height, const int width,
|
||||
const int channels, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d) {
|
||||
const int kThreadsPerBlock = 1024;
|
||||
|
||||
const int bottom_size = batch * channels * height * width;
|
||||
if (bottom_size == 0) return true;
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
SetZero<T>, (bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
|
||||
kThreadsPerBlock, 0, d.stream(), bottom_size, bottom_diff));
|
||||
|
||||
const int top_size = batch * channels * pooled_height * pooled_width;
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
MaxPoolBackwardNoMaskNHWC<T>,
|
||||
(top_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0,
|
||||
d.stream(), top_size, bottom_data, height, width, channels, pooled_height,
|
||||
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
|
||||
top_diff, bottom_diff));
|
||||
return d.ok();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolBackwardWithArgmax<T>::operator()(
|
||||
const int output_size, const int input_size, const T* top_diff,
|
||||
@ -540,7 +473,6 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct SpatialMaxPooling<GPUDevice, T>; \
|
||||
template struct MaxPoolForwardWithOptionalArgmax<T>; \
|
||||
template struct MaxPoolBackwardWithArgmax<T>; \
|
||||
template struct MaxPoolBackwardNoMask<T>; \
|
||||
template struct MaxPoolGradBackwardWithArgmax<T>; \
|
||||
template struct MaxPoolGradBackwardNoMask<T>;
|
||||
|
||||
|
@ -60,16 +60,6 @@ struct MaxPoolBackwardWithArgmax {
|
||||
const Eigen::GpuDevice& d, const bool include_batch_in_index);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolBackwardNoMask {
|
||||
bool operator()(const T* bottom_data, const int batch, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_t, const int pad_l, const T* top_diff,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolGradBackwardWithArgmax {
|
||||
bool operator()(const int output_size, const int input_size,
|
||||
|
@ -22,9 +22,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(b/155239286): Remove this function
|
||||
bool CanUseCudnn() { return true; }
|
||||
|
||||
#define ADD_BOOL_CUDNN_FLAG(func_name, flag_name, default_value) \
|
||||
bool func_name() { \
|
||||
bool value = default_value; \
|
||||
@ -73,25 +70,6 @@ ADD_BOOL_CUDNN_FLAG(DebugCudnnRnnUseTensorOps,
|
||||
ADD_INT64_CUDNN_FLAG(DebugCudnnRnnAlgo, TF_DEBUG_CUDNN_RNN_ALGO, -1);
|
||||
#undef ADD_INT64_CUDNN_FLAG
|
||||
|
||||
FP16ConvMode CudnnConvComputeMode() {
|
||||
string value;
|
||||
Status status = ReadStringFromEnvVar("TF_FP16_CONV_MODE", "accurate", &value);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
string lowercase_value = absl::AsciiStrToLower(value);
|
||||
if (lowercase_value == "accurate") {
|
||||
return FP16ConvMode::kAccurate;
|
||||
} else if (lowercase_value == "fast") {
|
||||
return FP16ConvMode::kFast;
|
||||
} else {
|
||||
LOG(ERROR) << "FP16ConvMode only supports two modes, ACCURATE and FAST. "
|
||||
"Got unknown mode: "
|
||||
<< value;
|
||||
}
|
||||
return FP16ConvMode::kAccurate;
|
||||
}
|
||||
|
||||
bool IsCudnnSupportedFilterSize(const int32 filter_rows,
|
||||
const int32 filter_cols, const int32 in_depth,
|
||||
const int32 out_depth) {
|
||||
|
@ -22,20 +22,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FP16ConvMode: The mode to set the internal compute type for cudnn convolution
|
||||
// when the input data type is float16. Two types of modes are supported:
|
||||
// kAccurate: Always use float32 as the internal compute type.
|
||||
// kFast: Include both float32 and float16 compute type in the autotune.
|
||||
enum class FP16ConvMode {
|
||||
kAccurate = 1,
|
||||
kFast = 2,
|
||||
};
|
||||
|
||||
bool CanUseCudnn();
|
||||
bool CudnnUseAutotune();
|
||||
bool CudnnRnnUseAutotune();
|
||||
bool CudnnDisableConv1x1Optimization();
|
||||
FP16ConvMode CudnnConvComputeMode();
|
||||
bool DebugCudnnRnn();
|
||||
bool DebugCudnnRnnUseTensorOps();
|
||||
int64 DebugCudnnRnnAlgo();
|
||||
|
Loading…
Reference in New Issue
Block a user