Add int8 version of fused_conv2d_bias_activation operator for the forward phase,
and support side_input and scaling parameters in float and int8 versions. PiperOrigin-RevId: 167763219
This commit is contained in:
parent
ca65468a02
commit
2b15badd96
@ -60,12 +60,14 @@ tf_kernel_library(
|
||||
srcs = [
|
||||
"kernels/fused_conv2d_bias_activation_op.cc",
|
||||
"kernels/fused_conv2d_bias_activation_op.h",
|
||||
"kernels/fused_conv_ops_gpu.h",
|
||||
],
|
||||
prefix = "fused_conv2d_bias_activation_op",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:conv_2d_hdrs",
|
||||
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
|
||||
@ -81,6 +83,7 @@ tf_custom_op_library(
|
||||
srcs = [
|
||||
"kernels/fused_conv2d_bias_activation_op.cc",
|
||||
"kernels/fused_conv2d_bias_activation_op.h",
|
||||
"kernels/fused_conv_ops_gpu.h",
|
||||
"ops/fused_conv2d_bias_activation_op.cc",
|
||||
],
|
||||
deps = [
|
||||
@ -94,12 +97,8 @@ tf_custom_op_library(
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"fused_conv2d_bias_activation_op",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
op_lib_names = ["fused_conv2d_bias_activation_op"],
|
||||
deps = ["//tensorflow/core:lib_proto_parsing"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
@ -109,7 +108,7 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
cuda_py_test(
|
||||
name = "fused_conv2d_bias_activation_op_test",
|
||||
size = "small",
|
||||
size = "large",
|
||||
srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
|
||||
additional_deps = [
|
||||
":fused_conv_py",
|
||||
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -31,8 +29,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -40,38 +38,84 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/activation_mode.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchConvOp;
|
||||
template <typename T>
|
||||
struct RawType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <>
|
||||
struct RawType<qint8> {
|
||||
using type = int8;
|
||||
};
|
||||
|
||||
// Template struct to convert int8x4 to int32.
|
||||
// (for NCHW_VECT_C with element type int8, we can consider it to be
|
||||
// an NCHW layout with element type int32 for operations like padding).
|
||||
template <typename T>
|
||||
struct Int8x4ToInt32 {
|
||||
// By default, do not change T.
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Int8x4ToInt32<int8> {
|
||||
using type = int32;
|
||||
};
|
||||
|
||||
// T is the element type of the conv_input, filter and side_input tensors.
|
||||
// BiasType is the element type of the bias tensor, which can be different.
|
||||
// ScaleType is the type used for conv_input_scale, side_input_scale.
|
||||
template <typename Device, typename T, typename BiasType, typename ScaleType>
|
||||
class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
public:
|
||||
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
string data_format_str, filter_format_str;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("filter_format", &filter_format_str));
|
||||
OP_REQUIRES(context,
|
||||
(data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
|
||||
errors::InvalidArgument("Current implementation only supports "
|
||||
"NHWC and NCHW data formats."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
OP_REQUIRES(context, strides_.size() == 4,
|
||||
FilterFormatFromString(filter_format_str, &filter_format_),
|
||||
errors::InvalidArgument("Invalid filter format"));
|
||||
|
||||
std::vector<int32> strides;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
|
||||
OP_REQUIRES(context, strides.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
|
||||
stride_rows_ = GetTensorDim(strides, data_format_, 'H');
|
||||
stride_cols_ = GetTensorDim(strides, data_format_, 'W');
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(GetTensorDim(strides_, data_format_, 'N') == 1 &&
|
||||
GetTensorDim(strides_, data_format_, 'C') == 1),
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
(GetTensorDim(strides, data_format_, 'N') == 1 &&
|
||||
GetTensorDim(strides, data_format_, 'C') == 1),
|
||||
errors::InvalidArgument("Convolutional strides are not supported in "
|
||||
"the batch or depth dimensions."));
|
||||
|
||||
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
|
||||
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
||||
|
||||
// Note: Only NCHW_VECT_C format is supported for int8.
|
||||
// This is because it is expected to be the fastest, and our previous tests
|
||||
// found cudnn 6 does not fully support the other formats for int8 mode.
|
||||
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
|
||||
errors::InvalidArgument(
|
||||
"qint8 should be used with data_format NCHW_VECT_C."));
|
||||
|
||||
OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)),
|
||||
errors::InvalidArgument(
|
||||
"qint8 should be used with filter_format OIHW_VECT_I."));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_));
|
||||
eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_);
|
||||
string activation_mode_str;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("activation_mode", &activation_mode_str));
|
||||
@ -79,130 +123,111 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
&activation_mode_));
|
||||
OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
|
||||
errors::InvalidArgument("Current implementation only supports "
|
||||
"relu as the activation mode."));
|
||||
"RELU as the activation function."));
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
float conv_input_scale_flt, side_input_scale_flt;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("conv_input_scale", &conv_input_scale_flt));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("side_input_scale", &side_input_scale_flt));
|
||||
conv_input_scale_ = conv_input_scale_flt;
|
||||
side_input_scale_ = side_input_scale_flt;
|
||||
}
|
||||
|
||||
Status CheckShape(const Tensor& tensor, const string& tensor_name) {
|
||||
const int num_dims = tensor.dims();
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
if (!FastBoundsCheck(tensor.dim_size(i),
|
||||
std::numeric_limits<int32>::max())) {
|
||||
return errors::InvalidArgument(tensor_name, " dimension ", i,
|
||||
" too large");
|
||||
}
|
||||
}
|
||||
// If there is a 5th dimension it is the VECT_C or VECT_I dimension.
|
||||
if (num_dims == 5 && tensor.dim_size(4) != 4) {
|
||||
return errors::InvalidArgument("The last dimension of ", tensor_name,
|
||||
" must be of size 4 for qint8.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Input tensor is one of the following shapes:
|
||||
// [ batch, in_rows, in_cols, in_depth ] (for NHWC data format)
|
||||
// [ batch, in_depth, in_rows, in_cols ] (for NCHW data format)
|
||||
const Tensor& input = context->input(0);
|
||||
// The conv_input tensor is one of the following formats:
|
||||
// NHWC, NCHW, NCHW_VECT_C.
|
||||
const Tensor& conv_input = context->input(0);
|
||||
OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
|
||||
|
||||
// Input filter is of the following dimensions:
|
||||
// [ filter_rows, filter_cols, in_depth, out_depth ]
|
||||
// The filter tensor is one of the following formats:
|
||||
// HWIO, OIHW, OIHW_VECT_I.
|
||||
const Tensor& filter = context->input(1);
|
||||
OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
|
||||
|
||||
// Input bias is a 1-D tensor the size of the last
|
||||
// dimension of Output tensor
|
||||
// Input bias is a 1-D tensor, with size matching output depth.
|
||||
const Tensor& bias = context->input(2);
|
||||
OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
|
||||
|
||||
// For 2D convolution, there should be 4 dimensions.
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
OP_REQUIRES(context, filter.dims() == 4,
|
||||
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||
filter.shape().DebugString()));
|
||||
|
||||
// Bias should be a 1-D tensor.
|
||||
OP_REQUIRES(context, bias.dims() == 1,
|
||||
errors::InvalidArgument("bias must be 1-dimensional: ",
|
||||
bias.shape().DebugString()));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
OP_REQUIRES(context,
|
||||
FastBoundsCheck(filter.dim_size(i),
|
||||
std::numeric_limits<int32>::max()),
|
||||
errors::InvalidArgument("filter dimension too large"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input.dim_size(i), std::numeric_limits<int32>::max()),
|
||||
errors::InvalidArgument("input dimension too large"));
|
||||
// If side_input_scale != 0, then side_input is not ignored and
|
||||
// has the same type and dimensions as the output.
|
||||
const Tensor& side_input = context->input(3);
|
||||
if (side_input_scale_ != 0) {
|
||||
OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
|
||||
}
|
||||
|
||||
// The last dimension for input is in_depth. It must be the same as the
|
||||
// filter's in_depth.
|
||||
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
|
||||
OP_REQUIRES(context, in_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument(
|
||||
"input and filter must have the same depth: ", in_depth,
|
||||
" vs ", filter.dim_size(2)));
|
||||
// TODO(pauldonnelly): Switch to a more efficient mechanism to access
|
||||
// dimension indexes and per-dimension attributes.
|
||||
const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H');
|
||||
const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W');
|
||||
const int32 output_depth = GetFilterDim(filter, filter_format_, 'O');
|
||||
|
||||
// The last dimension for filter is out_depth.
|
||||
const int32 out_depth = static_cast<int32>(filter.dim_size(3));
|
||||
const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N');
|
||||
const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H');
|
||||
const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W');
|
||||
|
||||
// The second dimension for input is rows/height.
|
||||
// The first dimension for filter is rows/height.
|
||||
const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
|
||||
const int32 input_rows = static_cast<int32>(input_rows_raw);
|
||||
const int32 filter_rows = static_cast<int32>(filter.dim_size(0));
|
||||
|
||||
// The third dimension for input is columns/width.
|
||||
// The second dimension for filter is columns/width.
|
||||
const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
|
||||
const int32 input_cols = static_cast<int32>(input_cols_raw);
|
||||
const int32 filter_cols = static_cast<int32>(filter.dim_size(1));
|
||||
|
||||
// The first dimension for input is batch.
|
||||
const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
|
||||
const int32 batch = static_cast<int32>(batch_raw);
|
||||
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
const int32 stride_rows =
|
||||
static_cast<int32>(GetTensorDim(strides_, data_format_, 'H'));
|
||||
const int32 stride_cols =
|
||||
static_cast<int32>(GetTensorDim(strides_, data_format_, 'W'));
|
||||
const int32 bias_size = static_cast<int32>(bias.dim_size(0));
|
||||
|
||||
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
|
||||
padding_, &out_rows, &pad_rows));
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
|
||||
padding_, &out_cols, &pad_cols));
|
||||
// Output tensor is of the following dimensions:
|
||||
// [ in_batch, out_rows, out_cols, out_depth ]
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
|
||||
int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0;
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows,
|
||||
stride_rows_, padding_type_,
|
||||
&output_rows, &pad_rows));
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols,
|
||||
stride_cols_, padding_type_,
|
||||
&output_cols, &pad_cols));
|
||||
// Initialize the output tensor shape according to data_format_
|
||||
TensorShape output_shape = ShapeFromFormat(
|
||||
data_format_, batch_size, output_rows, output_cols, output_depth);
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
|
||||
|
||||
// Bias size should be the same as the size of the channel dimension of
|
||||
// output.
|
||||
OP_REQUIRES(context, bias_size == out_depth,
|
||||
errors::InvalidArgument(
|
||||
"bias size should equal the channel "
|
||||
"dimension size of output. bias shape: ",
|
||||
bias.shape().DebugString() +
|
||||
", output shape: " + output->shape().DebugString()));
|
||||
|
||||
VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth
|
||||
<< ", input_cols = " << input_cols
|
||||
VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = "
|
||||
<< conv_input_cols << ", conv_input_rows = " << conv_input_rows
|
||||
<< ", filter_cols = " << filter_cols
|
||||
<< ", input_rows = " << input_rows
|
||||
<< ", filter_rows = " << filter_rows
|
||||
<< ", stride_rows = " << stride_rows
|
||||
<< ", stride_cols = " << stride_cols
|
||||
<< ", bias_size = " << bias_size << ", out_depth = " << out_depth;
|
||||
<< ", stride_cols = " << stride_cols_
|
||||
<< ", stride_rows = " << stride_rows_
|
||||
<< ", output_depth = " << output_depth
|
||||
<< ", output_cols = " << output_cols
|
||||
<< ", output_rows = " << output_rows
|
||||
<< ", output_shape.num_elements = " << output_shape.num_elements();
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (out_shape.num_elements() == 0) {
|
||||
if (output_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows,
|
||||
stride_cols, bias, activation_mode_,
|
||||
BrainPadding2EigenPadding(padding_), data_format_, output);
|
||||
|
||||
launcher_.launch(context, cudnn_use_autotune_, conv_input,
|
||||
conv_input_scale_, filter, stride_rows_, stride_cols_,
|
||||
eigen_padding_type_, side_input, side_input_scale_, bias,
|
||||
activation_mode_, data_format_, filter_format_, output);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
int32 stride_rows_, stride_cols_;
|
||||
Padding padding_type_;
|
||||
Eigen::PaddingType eigen_padding_type_;
|
||||
ActivationMode activation_mode_;
|
||||
TensorFormat data_format_;
|
||||
LaunchFusedConv2DBiasActivationOp<Device, T> launcher_;
|
||||
FilterTensorFormat filter_format_;
|
||||
ScaleType conv_input_scale_;
|
||||
ScaleType side_input_scale_;
|
||||
LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
|
||||
bool cudnn_use_autotune_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
|
||||
@ -211,67 +236,72 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
#if GOOGLE_CUDA
|
||||
namespace dnn = ::perftools::gputools::dnn;
|
||||
|
||||
dnn::ActivationMode BrainActivationMode2CudnnActivationMode(
|
||||
ActivationMode activation_mode) {
|
||||
switch (activation_mode) {
|
||||
case ActivationMode::SIGMOID:
|
||||
return dnn::ActivationMode::kSigmoid;
|
||||
case ActivationMode::RELU:
|
||||
return dnn::ActivationMode::kRelu;
|
||||
case ActivationMode::RELUX:
|
||||
return dnn::ActivationMode::kReluX;
|
||||
case ActivationMode::RELU6:
|
||||
return dnn::ActivationMode::kRelu6;
|
||||
case ActivationMode::TANH:
|
||||
return dnn::ActivationMode::kTanh;
|
||||
case ActivationMode::BANDPASS:
|
||||
return dnn::ActivationMode::kBandPass;
|
||||
}
|
||||
// Prevent compiler warning about missing return
|
||||
return dnn::ActivationMode::kRelu;
|
||||
}
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
struct ConvBiasActivationAutoTuneGroup {
|
||||
static string name() { return "ConvBiasActivation"; }
|
||||
};
|
||||
typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, ConvParameters,
|
||||
perftools::gputools::dnn::AlgorithmConfig>
|
||||
typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters,
|
||||
dnn::AlgorithmConfig>
|
||||
AutoTuneConvBiasActivation;
|
||||
|
||||
template <typename T>
|
||||
void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
|
||||
OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param,
|
||||
const Tensor& filter, int32 row_stride, int32 col_stride,
|
||||
const Tensor& bias, const ActivationMode& activation_mode,
|
||||
const Eigen::PaddingType& padding, TensorFormat data_format,
|
||||
Tensor* output) {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
using perftools::gputools::dnn::ProfileResult;
|
||||
using perftools::gputools::dnn::kDefaultAlgorithm;
|
||||
// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it
|
||||
// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions.
|
||||
template <typename T, size_t NDIMS>
|
||||
Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor,
|
||||
int batch_size, int rows, int cols, int depth,
|
||||
Tensor* transformed_tensor, const Tensor** result) {
|
||||
TensorShape nchw_shape =
|
||||
ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth);
|
||||
if (depth > 1) {
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
|
||||
transformed_tensor));
|
||||
functor::NHWCToNCHW<GPUDevice, T, NDIMS>()(
|
||||
ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(),
|
||||
transformed_tensor->tensor<T, NDIMS>());
|
||||
} else {
|
||||
// If depth <= 1, then just reshape.
|
||||
CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape));
|
||||
}
|
||||
*result = transformed_tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T, typename BiasType, typename ScaleType>
|
||||
void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
|
||||
launch(OpKernelContext* ctx, bool cudnn_use_autotune,
|
||||
const Tensor& conv_input_param, ScaleType conv_input_scale,
|
||||
const Tensor& filter_param, int32 row_stride, int32 col_stride,
|
||||
const Eigen::PaddingType& padding, const Tensor& side_input_param,
|
||||
ScaleType side_input_scale, const Tensor& bias,
|
||||
ActivationMode activation_mode, TensorFormat data_format,
|
||||
FilterTensorFormat filter_format, Tensor* output_param) {
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
Tensor input = input_param;
|
||||
|
||||
perftools::gputools::dnn::ActivationMode cudnn_activation_mode =
|
||||
BrainActivationMode2CudnnActivationMode(activation_mode);
|
||||
|
||||
// TODO(yangzihao): refactor all the complicated/duplicated code in regular
|
||||
// conv ops to a shared conv utility.
|
||||
int32 padding_rows = 0;
|
||||
int32 padding_cols = 0;
|
||||
const int64 in_batch = GetTensorDim(input, data_format, 'N');
|
||||
int64 in_rows = GetTensorDim(input, data_format, 'H');
|
||||
int64 in_cols = GetTensorDim(input, data_format, 'W');
|
||||
const int64 in_depths = GetTensorDim(input, data_format, 'C');
|
||||
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
|
||||
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
|
||||
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
|
||||
const int64 out_depths = GetTensorDim(*output, data_format, 'C');
|
||||
const int64 patch_rows = filter.dim_size(0);
|
||||
const int64 patch_cols = filter.dim_size(1);
|
||||
|
||||
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
|
||||
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
||||
constexpr int rank = is_int8x4 ? 5 : 4;
|
||||
constexpr int vect = is_int8x4 ? 4 : 1;
|
||||
|
||||
const int batch_size = GetTensorDim(conv_input_param, data_format, 'N');
|
||||
int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H');
|
||||
int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W');
|
||||
|
||||
const int conv_input_depth =
|
||||
GetTensorDim(conv_input_param, data_format, 'C') * vect;
|
||||
const int output_rows = GetTensorDim(*output_param, data_format, 'H');
|
||||
const int output_cols = GetTensorDim(*output_param, data_format, 'W');
|
||||
const int output_depth = GetFilterDim(filter_param, filter_format, 'O');
|
||||
const int filter_rows = GetFilterDim(filter_param, filter_format, 'H');
|
||||
const int filter_cols = GetFilterDim(filter_param, filter_format, 'W');
|
||||
int padding_rows = 0;
|
||||
int padding_cols = 0;
|
||||
const Tensor* conv_input = &conv_input_param;
|
||||
|
||||
Tensor maybe_padded_conv_input;
|
||||
if (padding == Eigen::PADDING_SAME) {
|
||||
// Total padding on rows and cols is
|
||||
// Pr = (R' - 1) * S + Kr - R
|
||||
@ -281,114 +311,152 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
|
||||
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
|
||||
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
|
||||
// we pad more on the right and bottom than on the top and left.
|
||||
padding_rows =
|
||||
std::max<int32>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
|
||||
padding_cols =
|
||||
std::max<int32>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
|
||||
const int rows_parity = padding_rows & 1;
|
||||
const int cols_parity = padding_cols & 1;
|
||||
if ((rows_parity | cols_parity) != 0) {
|
||||
padding_rows = std::max<int>(
|
||||
0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows);
|
||||
padding_cols = std::max<int>(
|
||||
0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols);
|
||||
const int padding_rows_parity = padding_rows & 1;
|
||||
const int padding_cols_parity = padding_cols & 1;
|
||||
if ((padding_rows_parity | padding_cols_parity) != 0) {
|
||||
Tensor transformed_input;
|
||||
int64 new_in_rows = in_rows + rows_parity;
|
||||
int64 new_in_cols = in_cols + cols_parity;
|
||||
const int new_conv_input_rows = conv_input_rows + padding_rows_parity;
|
||||
const int new_conv_input_cols = conv_input_cols + padding_cols_parity;
|
||||
|
||||
using VectT = typename Int8x4ToInt32<typename RawType<T>::type>::type;
|
||||
auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, in_batch, new_in_rows,
|
||||
new_in_cols, in_depths),
|
||||
&transformed_input));
|
||||
ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, batch_size, new_conv_input_rows,
|
||||
new_conv_input_cols, conv_input_depth),
|
||||
&maybe_padded_conv_input));
|
||||
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
|
||||
{{0, 0}}, {{rows_parity, cols_parity}},
|
||||
To32Bit(transformed_input.tensor<T, 4>()), data_format);
|
||||
auto conv_input_eigen_tensor =
|
||||
To32Bit(conv_input_param.reinterpret_last_dimension<VectT, 4>());
|
||||
auto padded_conv_input_eigen_tensor = To32Bit(
|
||||
maybe_padded_conv_input.reinterpret_last_dimension<VectT, 4>());
|
||||
|
||||
input = transformed_input;
|
||||
in_rows = new_in_rows;
|
||||
in_cols = new_in_cols;
|
||||
functor::PadInput<GPUDevice, VectT, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), conv_input_eigen_tensor, {{0, 0}},
|
||||
{{padding_rows_parity, padding_cols_parity}},
|
||||
padded_conv_input_eigen_tensor, pad_data_format);
|
||||
|
||||
conv_input = &maybe_padded_conv_input;
|
||||
conv_input_rows = new_conv_input_rows;
|
||||
conv_input_cols = new_conv_input_cols;
|
||||
}
|
||||
}
|
||||
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
// Convert the input tensor from NHWC to NCHW.
|
||||
TensorShape nchw_shape =
|
||||
ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
|
||||
if (in_depths > 1) {
|
||||
Tensor transformed_input;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
nchw_shape, &transformed_input));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
const_cast<const Tensor&>(input).tensor<T, 4>(),
|
||||
transformed_input.tensor<T, 4>());
|
||||
input = transformed_input;
|
||||
} else {
|
||||
// If depth <= 1, then just reshape.
|
||||
CHECK(input.CopyFrom(input, nchw_shape));
|
||||
Tensor maybe_transformed_conv_input, maybe_transformed_side_input;
|
||||
Tensor maybe_transformed_output;
|
||||
const Tensor* side_input = &side_input_param;
|
||||
Tensor* output = output_param;
|
||||
|
||||
// NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary
|
||||
// and inefficient, but it is actually both a time and code size optimization,
|
||||
// since 'is_int8x4' is a constexpr determined by the template parameter.
|
||||
if (!is_int8x4 && data_format == FORMAT_NHWC) {
|
||||
OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>(
|
||||
ctx, *conv_input, batch_size, conv_input_rows,
|
||||
conv_input_cols, conv_input_depth,
|
||||
&maybe_transformed_conv_input, &conv_input)));
|
||||
if (side_input_scale != 0) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, (TransformNHWCToNCHW<T, rank>(
|
||||
ctx, side_input_param, batch_size, output_rows, output_cols,
|
||||
output_depth, &maybe_transformed_side_input, &side_input)));
|
||||
}
|
||||
if (output_depth > 1) {
|
||||
// Allocate a tensor for the NCHW output of the kernel and point output
|
||||
// to it. Afterwards, we will transform it to NHWC while copying back to
|
||||
// 'output_param'.
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth);
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
|
||||
&maybe_transformed_output));
|
||||
output = &maybe_transformed_output;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(padding_rows >= 0 && padding_cols >= 0)
|
||||
<< "Negative row or col paddings: (" << padding_rows << ", "
|
||||
<< padding_cols << ")";
|
||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(in_batch)
|
||||
.set_feature_map_count(in_depths)
|
||||
.set_height(in_rows)
|
||||
.set_width(in_cols)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(out_batch)
|
||||
.set_height(out_rows)
|
||||
.set_width(out_cols)
|
||||
.set_feature_map_count(out_depths)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(filter.dim_size(0))
|
||||
.set_input_filter_width(filter.dim_size(1))
|
||||
.set_input_feature_map_count(filter.dim_size(2))
|
||||
.set_output_feature_map_count(filter.dim_size(3));
|
||||
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
|
||||
constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4
|
||||
: dnn::DataLayout::kBatchDepthYX;
|
||||
constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
|
||||
: dnn::FilterLayout::kOutputInputYX;
|
||||
|
||||
dnn::BatchDescriptor conv_input_desc;
|
||||
conv_input_desc.set_count(batch_size)
|
||||
.set_feature_map_count(conv_input_depth)
|
||||
.set_height(conv_input_rows)
|
||||
.set_width(conv_input_cols)
|
||||
.set_layout(data_layout);
|
||||
dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(filter_rows)
|
||||
.set_input_filter_width(filter_cols)
|
||||
.set_input_feature_map_count(conv_input_depth)
|
||||
.set_output_feature_map_count(output_depth)
|
||||
.set_layout(filter_layout);
|
||||
dnn::BatchDescriptor side_input_desc;
|
||||
side_input_desc.set_count(batch_size)
|
||||
.set_height(output_rows)
|
||||
.set_width(output_cols)
|
||||
.set_feature_map_count(output_depth)
|
||||
.set_layout(data_layout);
|
||||
dnn::BatchDescriptor bias_desc;
|
||||
bias_desc.set_count(1)
|
||||
.set_height(1)
|
||||
.set_width(1)
|
||||
.set_feature_map_count(output_depth)
|
||||
.set_layout(dnn::DataLayout::kBatchDepthYX);
|
||||
dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(batch_size)
|
||||
.set_height(output_rows)
|
||||
.set_width(output_cols)
|
||||
.set_feature_map_count(output_depth)
|
||||
.set_layout(data_layout);
|
||||
dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_filter_stride(row_stride)
|
||||
.set_horizontal_filter_stride(col_stride)
|
||||
.set_zero_padding_height(padding_rows / 2)
|
||||
.set_zero_padding_width(padding_cols / 2);
|
||||
|
||||
// Shuffles a filter tensor from:
|
||||
// [<spatial_dims>, in, out]
|
||||
// to:
|
||||
// [out, in, <spatial_dims>]
|
||||
// TODO(yangzihao): Support a data layout tag for the filter weights, and only
|
||||
// do the transform if the weights are not already in the correct layout.
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({filter.dim_size(3), filter.dim_size(2),
|
||||
filter.dim_size(0), filter.dim_size(1)}),
|
||||
&transformed_filter));
|
||||
Tensor maybe_transformed_filter;
|
||||
const Tensor* filter;
|
||||
if (is_int8x4) {
|
||||
// We have already checked filter is OIHW_VECT_I in the constructor.
|
||||
filter = &filter_param;
|
||||
} else if (filter_format == FORMAT_HWIO) {
|
||||
// Shuffle filter tensor from HWIO to OIHW:
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFilterFormat(
|
||||
FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
|
||||
&maybe_transformed_filter));
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
|
||||
To32Bit(maybe_transformed_filter.tensor<T, 4>()));
|
||||
filter = &maybe_transformed_filter;
|
||||
}
|
||||
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
Tensor transformed_output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
|
||||
out_cols, out_depths),
|
||||
&transformed_output));
|
||||
|
||||
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
auto conv_input_ptr =
|
||||
AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
|
||||
conv_input->template flat<T>().data()),
|
||||
conv_input->template flat<T>().size());
|
||||
auto filter_ptr =
|
||||
AsDeviceMemory(transformed_filter.template flat<T>().data(),
|
||||
transformed_filter.template flat<T>().size());
|
||||
AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
|
||||
filter->template flat<T>().data()),
|
||||
filter->template flat<T>().size());
|
||||
auto side_input_ptr =
|
||||
AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
|
||||
side_input->template flat<T>().data()),
|
||||
side_input->template flat<T>().size());
|
||||
auto output_ptr =
|
||||
AsDeviceMemory(transformed_output.template flat<T>().data(),
|
||||
transformed_output.template flat<T>().size());
|
||||
|
||||
auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
|
||||
bias.template flat<T>().size());
|
||||
AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
|
||||
output->template flat<T>().data()),
|
||||
output->template flat<T>().size());
|
||||
auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
|
||||
bias.template flat<BiasType>().size());
|
||||
|
||||
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
|
||||
// default value is in bytes despite the name of the environment variable
|
||||
@ -396,38 +464,42 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
|
||||
);
|
||||
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = input.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
in_batch,
|
||||
in_depths,
|
||||
{{in_rows, in_cols}},
|
||||
out_depths,
|
||||
{{patch_rows, patch_cols}},
|
||||
FusedConvParameters fused_conv_parameters = {
|
||||
batch_size,
|
||||
conv_input_depth,
|
||||
{{conv_input_rows, conv_input_cols}},
|
||||
output_depth,
|
||||
{{filter_rows, filter_cols}},
|
||||
{{row_stride, col_stride}},
|
||||
{{padding_rows, padding_cols}},
|
||||
dtype,
|
||||
conv_input->dtype(),
|
||||
device_id,
|
||||
(side_input_scale != 0),
|
||||
activation_mode,
|
||||
};
|
||||
|
||||
AlgorithmConfig algorithm_config;
|
||||
dnn::AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
fused_conv_parameters, &algorithm_config)) {
|
||||
std::vector<dnn::AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
|
||||
&algorithms));
|
||||
dnn::ProfileResult best_result;
|
||||
dnn::ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
ProfileResult profile_result;
|
||||
dnn::ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(
|
||||
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
||||
bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(profile_algorithm),
|
||||
->ThenFusedConvolveWithAlgorithm(
|
||||
conv_input_desc, conv_input_ptr, conv_input_scale,
|
||||
filter_desc, filter_ptr, conv_desc, side_input_ptr,
|
||||
side_input_scale, bias_desc, bias_ptr,
|
||||
dnn::ActivationMode::kRelu, output_desc, &output_ptr,
|
||||
&scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
@ -454,42 +526,68 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T>::launch(
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters,
|
||||
AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(
|
||||
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
||||
bias_ptr, cudnn_activation_mode, output_desc, &output_ptr,
|
||||
&scratch_allocator, algorithm_config,
|
||||
->ThenFusedConvolveWithAlgorithm(
|
||||
conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
|
||||
filter_ptr, conv_desc, side_input_ptr, side_input_scale,
|
||||
bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
|
||||
&output_ptr, &scratch_allocator, algorithm_config,
|
||||
/*output_profile_result=*/nullptr)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"cuDNN launch failure : input shape(", input.shape().DebugString(),
|
||||
") filter shape(", filter.shape().DebugString(), ")"));
|
||||
ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(",
|
||||
conv_input->shape().DebugString(),
|
||||
") filter shape(",
|
||||
filter->shape().DebugString(), ")"));
|
||||
}
|
||||
|
||||
// Convert the output tensor back from NCHW to NHWC.
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
// Convert the output tensor back from NCHW to NHWC if necessary.
|
||||
if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) {
|
||||
functor::NCHWToNHWC<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
|
||||
output->tensor<T, 4>());
|
||||
} else {
|
||||
*output = transformed_output;
|
||||
const_cast<const Tensor*>(output)->tensor<T, 4>(),
|
||||
output_param->tensor<T, 4>());
|
||||
}
|
||||
}
|
||||
|
||||
// Forward declarations of the functor specializations for GPU used above.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T, int, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const std::array<int, 2>& padding_left, \
|
||||
const std::array<int, 2>& padding_right, \
|
||||
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
|
||||
extern template struct PadInput<GPUDevice, T, int, 4>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<float>("T"),
|
||||
FusedConv2DBiasActivationOp<GPUDevice, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FusedConv2DBiasActivation")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<float>("Tbias"),
|
||||
FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FusedConv2DBiasActivation")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.TypeConstraint<float>("Tbias"),
|
||||
FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -33,27 +33,30 @@ namespace tensorflow {
|
||||
// Forward declaration.
|
||||
class OpKernelContext;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename BiasType, typename ScaleType>
|
||||
class LaunchFusedConv2DBiasActivationOp {
|
||||
public:
|
||||
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Tensor& bias,
|
||||
const ActivationMode& activation_mode,
|
||||
const Eigen::PaddingType& padding, TensorFormat data_format,
|
||||
Tensor* output);
|
||||
const Tensor& conv_input, ScaleType conv_input_scale,
|
||||
const Tensor& filter, int32 row_stride, int32 col_stride,
|
||||
const Eigen::PaddingType& padding, const Tensor& side_input,
|
||||
ScaleType side_input_scale, const Tensor& bias,
|
||||
ActivationMode activation_mode, TensorFormat data_format,
|
||||
FilterTensorFormat filter_format, Tensor* output);
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T> {
|
||||
template <typename T, typename BiasType, typename ScaleType>
|
||||
class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T, BiasType,
|
||||
ScaleType> {
|
||||
public:
|
||||
void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int32 row_stride,
|
||||
int32 col_stride, const Tensor& bias,
|
||||
const ActivationMode& activation_mode,
|
||||
const Eigen::PaddingType& padding, TensorFormat data_format,
|
||||
Tensor* output);
|
||||
const Tensor& conv_input, ScaleType conv_input_scale,
|
||||
const Tensor& filter, int32 row_stride, int32 col_stride,
|
||||
const Eigen::PaddingType& padding, const Tensor& side_input,
|
||||
ScaleType side_input_scale, const Tensor& bias,
|
||||
ActivationMode activation_mode, TensorFormat data_format,
|
||||
FilterTensorFormat filter_format, Tensor* output);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
74
tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
Normal file
74
tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/util/activation_mode.h"
|
||||
|
||||
// TODO(pauldonnelly): Merge this file into core/kernels/conv_ops_gpu.h.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Add additional parameters specific to fused convolutions.
|
||||
class FusedConvParameters : public ConvParameters {
|
||||
public:
|
||||
FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
|
||||
int64 out_depths, const SpatialArray& filter,
|
||||
const SpatialArray& stride, const SpatialArray& padding,
|
||||
DataType dtype, int device_id, bool has_side_input,
|
||||
ActivationMode activation_mode)
|
||||
: ConvParameters(batch, in_depths, in, out_depths, filter, stride,
|
||||
padding, dtype, device_id),
|
||||
activation_mode_(activation_mode),
|
||||
has_side_input_(has_side_input) {
|
||||
hash_code_ = Hash64Combine(hash_code_, has_side_input);
|
||||
hash_code_ = Hash64Combine(hash_code_, activation_mode);
|
||||
}
|
||||
|
||||
bool operator==(const FusedConvParameters& other) const {
|
||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||
}
|
||||
|
||||
bool operator!=(const FusedConvParameters& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
string ToString() const {
|
||||
return strings::StrCat(ConvParameters::ToString(), ", ", has_side_input_,
|
||||
", ", activation_mode_, ", ");
|
||||
}
|
||||
|
||||
private:
|
||||
using ParameterDataType =
|
||||
std::tuple<ConvParameters::ParameterDataType, bool, ActivationMode>;
|
||||
|
||||
ParameterDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(ConvParameters::get_data_as_tuple(), has_side_input_,
|
||||
activation_mode_);
|
||||
}
|
||||
|
||||
ActivationMode activation_mode_;
|
||||
bool has_side_input_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
|
@ -33,40 +33,73 @@ string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
|
||||
} // namespace
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TODO(pauldonnelly): Add support for double inputs and scales to this Op,
|
||||
// (currently Attr does not support double).
|
||||
|
||||
REGISTER_OP("FusedConv2DBiasActivation")
|
||||
.Input("input: T")
|
||||
.Input("conv_input: T")
|
||||
.Input("filter: T")
|
||||
.Input("bias: T")
|
||||
.Input("bias: Tbias")
|
||||
.Input("side_input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float}")
|
||||
.Attr("T: {float, half, qint8}")
|
||||
.Attr("Tbias: {float, half}")
|
||||
.Attr("conv_input_scale: float = 1.0")
|
||||
.Attr("side_input_scale: float = 0.0")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr(GetAllActivationModeAttrString())
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
|
||||
.Attr("activation_mode: {'Relu'} = 'Relu'")
|
||||
.SetShapeFn(shape_inference::FusedConvBiasActivationShape)
|
||||
.Doc(R"doc(
|
||||
Computes a fused 2-D convolution, adds bias, and applies an activation function
|
||||
on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode.
|
||||
Computes a fused kernel which implements: 2-D convolution, adds side input,
|
||||
with separate scaling on convolution and side inputs, then adds bias and
|
||||
applies the RELU activation function to the result. Supports both float and
|
||||
qint8 data formats. In the case of qint8, the output is clipped to [0..127].
|
||||
|
||||
input: A 4-D tensor. The dimension order is interpreted according to the value
|
||||
of `data_format`, see below for details.
|
||||
filter: A 4-D tensor of shape
|
||||
`[filter_height, filter_width, in_channels, out_channels]`
|
||||
bias: 1-D with size of the `out_channels` dimension in filter.
|
||||
output: A 4-D tensor. The dimension order is determined by the value of
|
||||
`data_format`, see below for details.
|
||||
T: The data type for the elements of input, filter, bias, and output Tensors.
|
||||
conv_input: A tensor with format as specified by `data_format` (see below).
|
||||
filter: A tensor with format depending on `data_format` as follows:
|
||||
"NHWC", "NCHW":
|
||||
`float [ filter_height, filter_width, in_channels, out_channels ]`
|
||||
"NCHW_VECT_C":
|
||||
`qint8 [ out_channels, in_channels, filter_height, filter_width ]`
|
||||
bias: 1-D float tensor with size matching the `out_channels` dimension of
|
||||
`filter`.
|
||||
Note: this tensor is still float, even if other inputs are qint8.
|
||||
side_input: A tensor with format as specified by `data_format` (see below).
|
||||
This tensor will be ignored and can be [] if side_input_scale == 0.
|
||||
Otherwise, the size of each dimension must match the `output` tensor.
|
||||
output: A tensor with format as specified by `data_format` (see below).
|
||||
The dimension sizes are determined automatically based on other inputs
|
||||
and attributes.
|
||||
T: The element data type of `conv_input`, `side_input` and `output` tensors.
|
||||
Note: must match with the `data_format`.
|
||||
Tbias: The element data type of `bias`.
|
||||
conv_input_scale: scalar float value to be multiplied by `conv_input`.
|
||||
(conceptually.. in reality it is applied after convolution).
|
||||
side_input_scale: scalar float value to be multiplied by `side_input`.
|
||||
strides: 1-D tensor of length 4. The stride of the sliding window for each
|
||||
dimension of `input`. The dimension order is determined by the value of
|
||||
`data_format`, see below for details.
|
||||
Note: the stride for batch and channel dimensions must be 1.
|
||||
padding: The type of padding algorithm to use.
|
||||
data_format: Specify the data format of the input and output data. With the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, height, width, channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, channels, height, width].
|
||||
activation_mode: Specify the activation function to apply to the output tensor
|
||||
of bias add. Currently only supports "Relu".
|
||||
data_format: A string specifying the data format of `conv_input`,
|
||||
`side_input` and `output` tensors with the following options:
|
||||
"NHWC": `float [ batch, height, width, channels ]`
|
||||
"NCHW": `float [ batch, channels, height, width ]`
|
||||
"NCHW_VECT_C":
|
||||
`qint8 [ batch, channels / 4, height, width, channels % 4 ]`
|
||||
Note: for "NCHW_VECT_C", `channels` must be a multiple of 4.
|
||||
filter_format: A string specifying the data format of `filter`,
|
||||
"HWIO": `float [ kernel_height, kernel_width, input_channels,
|
||||
output_channels ]`
|
||||
"OIHW_VECT_I":
|
||||
`qint8 [ output_channels, input_channels / 4,
|
||||
kernel_height, kernel_width, input_channels % 4 ]`
|
||||
activation_mode: The activation applied to the output.
|
||||
Currently must be "Relu".
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -26,62 +26,83 @@ _fused_conv2d_bias_activation_op_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so"))
|
||||
|
||||
|
||||
def fused_conv2d_bias_activation(input_tensor,
|
||||
filter_tensor,
|
||||
# pylint: disable=redefined-builtin
|
||||
def fused_conv2d_bias_activation(conv_input,
|
||||
filter,
|
||||
bias,
|
||||
strides,
|
||||
padding,
|
||||
activation_mode,
|
||||
strides=None,
|
||||
padding=None,
|
||||
conv_input_scale=1.0,
|
||||
side_input_scale=0.0,
|
||||
side_input=None,
|
||||
activation_mode="Relu",
|
||||
data_format=None,
|
||||
filter_format=None,
|
||||
name=None):
|
||||
"""Computes a fused 2-D convolution, adds bias, and applies relu.
|
||||
"""Fused 2D conv, bias and activation with optional side input.
|
||||
|
||||
input_tensor: A 4-D tensor. The dimension order is interpreted
|
||||
according to the value of `data_format`, see below for details.
|
||||
filter_tensor: A 4-D tensor of shape
|
||||
`[filter_height, filter_width, in_channels, out_channels]`
|
||||
bias: 1-D with size of the `out_channels` dimension in filter.
|
||||
output: A 4-D tensor. The dimension order is determined by the value of
|
||||
`data_format`, see below for details.
|
||||
T: The data type for the elements of input, filter, bias, and output
|
||||
Tensors.
|
||||
strides: 1-D tensor of length 4. The stride of the sliding window for
|
||||
each
|
||||
dimension of `input`. The dimension order is determined by the value
|
||||
of
|
||||
`data_format`, see below for details.
|
||||
padding: The type of padding algorithm to use.
|
||||
data_format: Specify the data format of the input and output data. With
|
||||
the
|
||||
default format "NHWC", the data is stored in the order of:
|
||||
[batch, height, width, channels].
|
||||
Alternatively, the format could be "NCHW", the data storage order of:
|
||||
[batch, channels, height, width].
|
||||
activation_mode: Specify the activation function to apply to the output
|
||||
tensor
|
||||
of bias add. Currently only supports "Relu".
|
||||
Computes a fused 2-D convolution scaled by conv_input_scale,
|
||||
adds an optional side input scaled by side_input_scale, adds biases,
|
||||
and applies ReLU. As an equation:
|
||||
output = ReLU(conv_input_scale * Conv(conv_input, filter) +
|
||||
side_input_scale * side_input + bias)
|
||||
Note: In int8 mode, The ReLU will clip the output to the range [0..127].
|
||||
|
||||
Args:
|
||||
input_tensor: A `Tensor`. Must be one of the following types: `float32`.
|
||||
filter_tensor: A `Tensor`. Must have the same type as `input`.
|
||||
bias: A `Tensor`. Must have the same type as `input`.
|
||||
strides: A list of `ints`.
|
||||
conv_input: A `Tensor` of the format specified by `data_format`.
|
||||
filter: A `Tensor` whose format depends on `data_format`:
|
||||
if `data_format` is "NCHW_VECT_C", filter should be "OIHW_VECT_I"
|
||||
otherwise, it should be "HWIO" format.
|
||||
bias: A 1-D `Tensor` of type `float32`, and dimensions equal to the
|
||||
number of output channels.
|
||||
strides: A list of 4 `ints` specifying convolution strides.
|
||||
if `data_format` is "NCHW" or "NCHW_VECT_C", the order should be NCHW.
|
||||
if `data_format` is "NHWC", the order should be NHWC.
|
||||
padding: A `string` from: `"SAME", "VALID"`.
|
||||
activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX",
|
||||
"Tanh", "BandPass"`.
|
||||
data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
|
||||
`"NHWC"`.
|
||||
conv_input_scale: A scalar `float32` that will be multiplied by conv_input.
|
||||
This is optional and defaults to 1. However it should be set to
|
||||
specify the quantization scale when `data_format` is "NCHW_VECT_C".
|
||||
side_input_scale: A scalar `float32` that will be multiplied by side_input.
|
||||
This is optional and defaults to 0.
|
||||
side_input: A `Tensor` of the format specified by `data_format`.
|
||||
This is useful for imlementing ResNet blocks.
|
||||
activation_mode: (optional) currently must be the default "Relu".
|
||||
Note that in qint8 mode, it also clips to 127, so acts like ReluX.
|
||||
data_format: Specifies the data format.
|
||||
Possible values are:
|
||||
"NHWC" float [batch, height, width, channels]
|
||||
"NCHW" float [batch, channels, height, width]
|
||||
"NCHW_VECT_C" qint8 [batch, channels / 4, height, width, channels % 4]
|
||||
Defaults to `"NHWC"`.
|
||||
Performance is worst for `"NHWC"` and best for `"NCHW_VECT_C"`.
|
||||
filter_format: Specifies the filter format.
|
||||
Possible values are:
|
||||
"HWIO" float [kernel_height, kernel_width, input_channels,
|
||||
output_channels ]
|
||||
"OIHW" float [output_channels, input_channels, kernel_height,
|
||||
kernel_width ]
|
||||
"OIHW_VECT_I" qint8 [ output_channels, input_channels / 4,
|
||||
kernel_height, kernel_width, input_channels % 4 ]
|
||||
Defaults to `"HWIO"`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `input`.
|
||||
A `Tensor` of the format specified by `data_format`.
|
||||
"""
|
||||
if strides is None:
|
||||
strides = [1, 1, 1, 1]
|
||||
if side_input is None:
|
||||
side_input = []
|
||||
return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
|
||||
input=input_tensor,
|
||||
filter=filter_tensor,
|
||||
bias=bias,
|
||||
strides=strides,
|
||||
conv_input,
|
||||
filter,
|
||||
bias,
|
||||
padding=padding,
|
||||
strides=strides,
|
||||
conv_input_scale=conv_input_scale,
|
||||
side_input_scale=side_input_scale,
|
||||
side_input=side_input,
|
||||
activation_mode=activation_mode,
|
||||
data_format=data_format,
|
||||
filter_format=filter_format,
|
||||
name=name)
|
||||
|
@ -19,13 +19,16 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
@ -484,7 +487,8 @@ class FusedConv2DBiasActivationTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
# Illegal strides.
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"strides in the batch and depth"):
|
||||
"Convolutional strides are not supported in "
|
||||
"the batch or depth dimensions."):
|
||||
sess.run(
|
||||
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
|
||||
array_ops.placeholder(dtypes.float32),
|
||||
@ -494,7 +498,8 @@ class FusedConv2DBiasActivationTest(test.TestCase):
|
||||
padding="SAME",
|
||||
activation_mode="Relu"))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"strides in the batch and depth"):
|
||||
"Convolutional strides are not supported in "
|
||||
"the batch or depth dimensions."):
|
||||
sess.run(
|
||||
fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
|
||||
array_ops.placeholder(dtypes.float32),
|
||||
@ -552,6 +557,286 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
|
||||
return Test
|
||||
|
||||
|
||||
def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type):
|
||||
"""Calculates the size of an output dimension of a strided convolution.
|
||||
|
||||
Given the sizes of the corresponding dimension of the input and filter shapes,
|
||||
and the stride and padding_types, calculates the size of the output dimension.
|
||||
This function can be called separately for each input dimension.
|
||||
|
||||
Args:
|
||||
input_dim: An `int` specifying the size of the input dimension.
|
||||
filter_dim: An `int` specifying the size of the filter dimension.
|
||||
stride: An `int` specifying the step size of the convolution along the
|
||||
input dimension.
|
||||
padding_type: either 'VALID' or 'SAME'.
|
||||
|
||||
Returns:
|
||||
The size of the output dimension.
|
||||
"""
|
||||
if padding_type == "VALID":
|
||||
return (input_dim - filter_dim + stride) // stride
|
||||
else: # padding_type == 'SAME'
|
||||
return (input_dim + stride - 1) // stride
|
||||
|
||||
|
||||
def NchwVectCToNchw(in_tensor):
|
||||
# [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
|
||||
t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
|
||||
n = in_tensor.shape.dims[0].value
|
||||
c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
|
||||
h = in_tensor.shape.dims[2].value
|
||||
w = in_tensor.shape.dims[3].value
|
||||
return array_ops.reshape(t, [n, c, h, w])
|
||||
|
||||
|
||||
def OihwVectIToHwio(in_tensor):
|
||||
# [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
|
||||
t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
|
||||
o = in_tensor.shape.dims[0].value
|
||||
i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
|
||||
h = in_tensor.shape.dims[2].value
|
||||
w = in_tensor.shape.dims[3].value
|
||||
return array_ops.reshape(t, [h, w, i, o])
|
||||
|
||||
|
||||
def NchwToNchwVectC(in_tensor):
|
||||
n, c, h, w = in_tensor.shape.as_list()
|
||||
assert c % 4 == 0
|
||||
t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
|
||||
return array_ops.transpose(t, [0, 1, 3, 4, 2])
|
||||
|
||||
|
||||
def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
|
||||
padding, strides, side_input_scale,
|
||||
side_input, biases):
|
||||
"""Simulates the int8 fused 2-D convolution op using separate float ops.
|
||||
|
||||
The arguments and return values have the same format, meanings and
|
||||
restrictions as the actual op.
|
||||
Args:
|
||||
conv_input_scale: A scalar 'float'.
|
||||
conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
|
||||
kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
|
||||
padding: A `string` from: `"SAME", "VALID"`.
|
||||
strides: A list of `ints`.
|
||||
side_input_scale: A scalar 'float'.
|
||||
side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
|
||||
biases: A `Tensor` of type `float32` in NCHW layout.
|
||||
Returns:
|
||||
A `Tensor` of type `qint8` in NCHW_VECT_C layout.
|
||||
"""
|
||||
conv_result = nn_ops.conv2d(
|
||||
NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
|
||||
OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format="NCHW") * conv_input_scale
|
||||
|
||||
conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
|
||||
gen_array_ops.dequantize(side_input, -128, 127))
|
||||
|
||||
logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
|
||||
|
||||
result, _, _ = gen_array_ops.quantize_v2(
|
||||
NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8)
|
||||
return result
|
||||
|
||||
|
||||
class FusedConvInt8Tests(test.TestCase):
|
||||
_test_params = [
|
||||
{
|
||||
"batch_size": 2,
|
||||
"input_channels": 8,
|
||||
"output_channels": 16,
|
||||
"input_height": 8,
|
||||
"input_width": 8,
|
||||
"filter_height": 3,
|
||||
"filter_width": 3,
|
||||
"vertical_stride": 2,
|
||||
"horizontal_stride": 2,
|
||||
"conv_input_scale": 0.002,
|
||||
"side_input_scale": 0.0,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "VALID"
|
||||
},
|
||||
{
|
||||
"batch_size": 2,
|
||||
"input_channels": 8,
|
||||
"output_channels": 16,
|
||||
"input_height": 8,
|
||||
"input_width": 8,
|
||||
"filter_height": 3,
|
||||
"filter_width": 3,
|
||||
"vertical_stride": 2,
|
||||
"horizontal_stride": 2,
|
||||
"conv_input_scale": 0.002,
|
||||
"side_input_scale": 0.0,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "SAME"
|
||||
},
|
||||
{
|
||||
"batch_size": 2,
|
||||
"input_channels": 8,
|
||||
"output_channels": 16,
|
||||
"input_height": 8,
|
||||
"input_width": 8,
|
||||
"filter_height": 3,
|
||||
"filter_width": 3,
|
||||
"vertical_stride": 2,
|
||||
"horizontal_stride": 2,
|
||||
"conv_input_scale": 0.002,
|
||||
"side_input_scale": 0.5,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "VALID"
|
||||
},
|
||||
{
|
||||
"batch_size": 2,
|
||||
"input_channels": 16,
|
||||
"output_channels": 16,
|
||||
"input_height": 9,
|
||||
"input_width": 9,
|
||||
"filter_height": 3,
|
||||
"filter_width": 3,
|
||||
"vertical_stride": 1,
|
||||
"horizontal_stride": 1,
|
||||
"conv_input_scale": 0.001,
|
||||
"side_input_scale": 0.5,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "SAME"
|
||||
},
|
||||
{
|
||||
"batch_size": 3,
|
||||
"input_channels": 8,
|
||||
"output_channels": 8,
|
||||
"input_height": 9,
|
||||
"input_width": 9,
|
||||
"filter_height": 5,
|
||||
"filter_width": 5,
|
||||
"vertical_stride": 1,
|
||||
"horizontal_stride": 1,
|
||||
"conv_input_scale": 0.001,
|
||||
"side_input_scale": 0.5,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "SAME"
|
||||
},
|
||||
{
|
||||
"batch_size": 3,
|
||||
"input_channels": 8,
|
||||
"output_channels": 8,
|
||||
"input_height": 9,
|
||||
"input_width": 9,
|
||||
"filter_height": 7,
|
||||
"filter_width": 1,
|
||||
"vertical_stride": 2,
|
||||
"horizontal_stride": 1,
|
||||
"conv_input_scale": 0.002,
|
||||
"side_input_scale": 0.5,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "SAME"
|
||||
},
|
||||
{
|
||||
"batch_size": 3,
|
||||
"input_channels": 8,
|
||||
"output_channels": 8,
|
||||
"input_height": 9,
|
||||
"input_width": 9,
|
||||
"filter_height": 1,
|
||||
"filter_width": 7,
|
||||
"vertical_stride": 1,
|
||||
"horizontal_stride": 1,
|
||||
"conv_input_scale": 0.002,
|
||||
"side_input_scale": 0.5,
|
||||
"bias_scale": 1,
|
||||
"padding_type": "SAME"
|
||||
},
|
||||
]
|
||||
|
||||
def runTest(self, test_param):
|
||||
batch_size = test_param["batch_size"]
|
||||
input_channels = test_param["input_channels"]
|
||||
output_channels = test_param["output_channels"]
|
||||
input_height = test_param["input_height"]
|
||||
input_width = test_param["input_width"]
|
||||
filter_height = test_param["filter_height"]
|
||||
filter_width = test_param["filter_width"]
|
||||
vertical_stride = test_param["vertical_stride"]
|
||||
horizontal_stride = test_param["horizontal_stride"]
|
||||
conv_input_scale = test_param["conv_input_scale"]
|
||||
side_input_scale = test_param["side_input_scale"]
|
||||
bias_scale = test_param["bias_scale"]
|
||||
padding_type = test_param["padding_type"]
|
||||
|
||||
conv_input, _, _ = gen_array_ops.quantize_v2(
|
||||
random_ops.random_uniform(
|
||||
[batch_size, input_channels // 4, input_height, input_width, 4],
|
||||
minval=-0.0,
|
||||
maxval=1.0,
|
||||
dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
|
||||
|
||||
kernel, _, _ = gen_array_ops.quantize_v2(
|
||||
random_ops.random_uniform(
|
||||
[
|
||||
output_channels, input_channels // 4, filter_height,
|
||||
filter_width, 4
|
||||
],
|
||||
minval=-1.0,
|
||||
maxval=1.0,
|
||||
dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
|
||||
|
||||
output_height = CalculateCovolvedOutputDim(input_height, filter_height,
|
||||
vertical_stride, padding_type)
|
||||
output_width = CalculateCovolvedOutputDim(input_width, filter_width,
|
||||
horizontal_stride, padding_type)
|
||||
print("output_height=", output_height, ", output_width=", output_width)
|
||||
|
||||
side_input, _, _ = gen_array_ops.quantize_v2(
|
||||
random_ops.random_uniform(
|
||||
[batch_size, output_channels // 4, output_height, output_width, 4],
|
||||
minval=0.0,
|
||||
maxval=1.0,
|
||||
dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
|
||||
|
||||
biases = random_ops.random_uniform(
|
||||
[output_channels],
|
||||
minval=-10 * bias_scale,
|
||||
maxval=20 * bias_scale,
|
||||
dtype=dtypes.float32)
|
||||
|
||||
strides = [1, 1, vertical_stride, horizontal_stride]
|
||||
|
||||
actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
|
||||
conv_input,
|
||||
kernel,
|
||||
biases,
|
||||
strides=strides,
|
||||
padding=padding_type,
|
||||
conv_input_scale=conv_input_scale,
|
||||
side_input_scale=side_input_scale,
|
||||
side_input=side_input,
|
||||
data_format="NCHW_VECT_C",
|
||||
filter_format="OIHW_VECT_I")
|
||||
|
||||
expected = SimulateFusedConv2dBiasActivationInt8(
|
||||
conv_input_scale, conv_input, kernel, padding_type, strides,
|
||||
side_input_scale, side_input, biases)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
actual_y, expected_y = sess.run([actual, expected])
|
||||
print("actual_y = ", actual_y)
|
||||
print("expected_y = ", expected_y)
|
||||
self.assertTrue(np.array_equal(actual_y, expected_y))
|
||||
|
||||
def testFusedConvInt8(self):
|
||||
if not test.is_gpu_available(
|
||||
cuda_only=True, min_cuda_compute_capability=(6, 1)):
|
||||
tf_logging.info("int8 test skipped because not run with --config=cuda or "
|
||||
"no GPUs with compute capability >= 6.1 are available.")
|
||||
return
|
||||
for test_param in self._test_params:
|
||||
self.runTest(test_param)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for index, (input_size_, filter_size_, output_size_, stride_,
|
||||
padding_) in enumerate(GetShrunkInceptionShapes()):
|
||||
|
@ -2382,6 +2382,7 @@ tf_cc_tests(
|
||||
"util/semver_test.cc",
|
||||
"util/sparse/sparse_tensor_test.cc",
|
||||
"util/stat_summarizer_test.cc",
|
||||
"util/tensor_format_test.cc",
|
||||
"util/tensor_slice_reader_test.cc",
|
||||
"util/tensor_slice_set_test.cc",
|
||||
"util/tensor_slice_util_test.cc",
|
||||
|
@ -206,15 +206,28 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
|
||||
Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(Conv2DShape(c));
|
||||
|
||||
ShapeHandle bias_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &bias_shape));
|
||||
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
|
||||
string data_format_str, filter_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
|
||||
|
||||
TensorFormat data_format;
|
||||
FormatFromString(data_format_str, &data_format);
|
||||
FilterTensorFormat filter_format;
|
||||
FilterFormatFromString(filter_format_str, &filter_format);
|
||||
|
||||
constexpr int num_spatial_dims = 2;
|
||||
const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
|
||||
ShapeHandle filter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
|
||||
DimensionHandle output_depth_dim = c->Dim(filter_shape, 3);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
||||
|
||||
DimensionHandle output_depth_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
|
||||
int64 output_depth_dim_val = c->Value(output_depth_dim);
|
||||
|
||||
ShapeHandle bias_shape;
|
||||
// Bias should be a 1-D tensor.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
|
||||
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
|
||||
int64 bias_dim_val = c->Value(bias_dim);
|
||||
|
||||
if (output_depth_dim_val != bias_dim_val) {
|
||||
@ -223,6 +236,14 @@ Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) {
|
||||
") and bias dimension (", bias_dim_val, ") do not match.");
|
||||
}
|
||||
|
||||
// Check side input shape matches the output shape.
|
||||
ShapeHandle side_input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
|
||||
if (c->Rank(side_input_shape) > 1) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -323,24 +344,38 @@ Status ShapeFromDimensions(DimensionHandle batch_dim,
|
||||
}
|
||||
|
||||
Status Conv2DShape(shape_inference::InferenceContext* c) {
|
||||
string data_format_str;
|
||||
Status s = c->GetAttr("data_format", &data_format_str);
|
||||
if (!s.ok()) {
|
||||
string data_format_str, filter_format_str;
|
||||
if (!c->GetAttr("data_format", &data_format_str).ok()) {
|
||||
data_format_str = "NHWC";
|
||||
}
|
||||
if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
|
||||
filter_format_str = "HWIO";
|
||||
}
|
||||
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
FilterTensorFormat filter_format;
|
||||
if (!FilterFormatFromString(filter_format_str, &filter_format)) {
|
||||
return errors::InvalidArgument("Invalid filter format string: ",
|
||||
filter_format_str);
|
||||
}
|
||||
|
||||
constexpr int num_spatial_dims = 2;
|
||||
const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
|
||||
ShapeHandle conv_input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
|
||||
TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
|
||||
data_format, conv_input_shape, "conv_input", c));
|
||||
|
||||
const int rank = GetTensorDimsFromSpatialDims(2, data_format);
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
|
||||
// The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
|
||||
ShapeHandle filter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
|
||||
|
||||
std::vector<int32> strides;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
||||
|
||||
@ -352,38 +387,33 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
|
||||
strides.size());
|
||||
}
|
||||
|
||||
int32 stride_rows, stride_cols;
|
||||
if (data_format == FORMAT_NCHW || data_format == FORMAT_NCHW_VECT_C) {
|
||||
stride_rows = strides[2];
|
||||
stride_cols = strides[3];
|
||||
} else {
|
||||
stride_rows = strides[1];
|
||||
stride_cols = strides[2];
|
||||
}
|
||||
const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
|
||||
const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
|
||||
|
||||
DimensionHandle batch_size_dim;
|
||||
DimensionHandle input_depth_dim;
|
||||
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
|
||||
TF_RETURN_IF_ERROR(DimensionsFromShape(input_shape, data_format,
|
||||
TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
|
||||
&batch_size_dim, &input_spatial_dims,
|
||||
&input_depth_dim, c));
|
||||
|
||||
DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim,
|
||||
filter_input_depth_dim;
|
||||
// If the input format is NCHW_VECT_C, the filter format is assumed to be
|
||||
// OIHW_VECT_I, otherwise it is assumed to be HWIO.
|
||||
if (data_format == FORMAT_NCHW_VECT_C) {
|
||||
output_depth_dim = c->Dim(filter_shape, 0);
|
||||
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1),
|
||||
c->Dim(filter_shape, 4),
|
||||
&filter_input_depth_dim));
|
||||
filter_rows_dim = c->Dim(filter_shape, 2);
|
||||
filter_cols_dim = c->Dim(filter_shape, 3);
|
||||
DimensionHandle output_depth_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
|
||||
DimensionHandle filter_rows_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
|
||||
DimensionHandle filter_cols_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
|
||||
DimensionHandle filter_input_depth_dim;
|
||||
if (filter_format == FORMAT_OIHW_VECT_I) {
|
||||
TF_RETURN_IF_ERROR(c->Multiply(
|
||||
c->Dim(filter_shape,
|
||||
GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
|
||||
c->Dim(filter_shape,
|
||||
GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
|
||||
&filter_input_depth_dim));
|
||||
} else {
|
||||
filter_rows_dim = c->Dim(filter_shape, 0);
|
||||
filter_cols_dim = c->Dim(filter_shape, 1);
|
||||
filter_input_depth_dim = c->Dim(filter_shape, 2);
|
||||
output_depth_dim = c->Dim(filter_shape, 3);
|
||||
filter_input_depth_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
|
||||
}
|
||||
|
||||
// Check that the input tensor and the filter tensor agree on the input
|
||||
|
@ -412,34 +412,35 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
TEST(CommonShapeFnsTest, Conv2DShapeTest) {
|
||||
ShapeInferenceTestOp op("Conv2D");
|
||||
auto set_op = [&op](const std::vector<int32>& strides, const string& padding,
|
||||
const string& data_format) {
|
||||
const string& data_format, const string& filter_format) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
|
||||
.Input("input", 0, DT_FLOAT)
|
||||
.Input("filter", 0, DT_FLOAT)
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", data_format)
|
||||
.Attr("filter_format", filter_format)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
// 1x1 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
||||
|
||||
// 2x2 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x2 stride
|
||||
set_op({{1, 2, 2, 1}}, "VALID", "NHWC");
|
||||
set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x1 stride
|
||||
set_op({{1, 2, 1, 1}}, "VALID", "NHWC");
|
||||
set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]");
|
||||
|
||||
// 4x4 input, 2x1 filter, 1x2 stride
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NHWC");
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
|
||||
|
||||
// Invalid rank for input
|
||||
@ -461,77 +462,76 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
|
||||
|
||||
// Tests for NCHW
|
||||
// 1x1 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
|
||||
INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
||||
|
||||
// 2x2 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO");
|
||||
INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x2 stride
|
||||
set_op({{1, 1, 2, 2}}, "VALID", "NCHW");
|
||||
set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO");
|
||||
INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x1 stride
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NCHW");
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO");
|
||||
INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]");
|
||||
|
||||
// 4x4 input, 2x1 filter, 1x2 stride
|
||||
set_op({{1, 1, 1, 2}}, "VALID", "NCHW");
|
||||
set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO");
|
||||
INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]");
|
||||
|
||||
// Tests for NCHW_VECT_C
|
||||
// 1x1 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
||||
INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
|
||||
|
||||
// 2x2 filter
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
||||
INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x2 stride
|
||||
set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C");
|
||||
set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
||||
INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x1 stride
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
||||
INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
|
||||
|
||||
// 4x4 input, 2x1 filter, 1x2 stride
|
||||
set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C");
|
||||
set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I");
|
||||
INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
|
||||
|
||||
// Some tests for "SAME" padding
|
||||
|
||||
// 4x4 input, 1x1 filter, 1x1 stride
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
|
||||
// 3x3 input, 2x2 filter, 1x1 stride
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
|
||||
// 4x4 input, 2x2 filter, 2x2 stride
|
||||
set_op({{1, 2, 2, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
|
||||
|
||||
// 4x4 input, 2x2 filter, 1x1 stride
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
|
||||
// With stride 1x1 and SAME, unknown dims don't matter - filter dims except
|
||||
// for output channels are ignored for output, so all inputs are carried
|
||||
// through to output.
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
|
||||
|
||||
// With stride != 1, the input HW dims are divided to produce output dims.
|
||||
set_op({{1, 2, 2, 1}}, "SAME", "NHWC");
|
||||
set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO");
|
||||
INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]");
|
||||
INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]");
|
||||
INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]");
|
||||
|
@ -92,11 +92,11 @@ class ConvParameters {
|
||||
ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
|
||||
int64 out_depths, const SpatialArray& filter,
|
||||
const SpatialArray& stride, const SpatialArray& padding,
|
||||
const DataType& dtype, int device_id)
|
||||
DataType dtype, int device_id)
|
||||
: batch_(batch),
|
||||
in_depths_(in_depths),
|
||||
in_(in),
|
||||
out_depths_(out_depths),
|
||||
in_(in),
|
||||
filter_(filter),
|
||||
stride_(stride),
|
||||
padding_(padding),
|
||||
@ -130,7 +130,8 @@ class ConvParameters {
|
||||
"(", str_util::Join(filter_, ", "), "), ",
|
||||
"(", str_util::Join(stride_, ", "), "), ",
|
||||
"(", str_util::Join(padding_, ", "), "), ",
|
||||
dtype_, ", ", device_id_);
|
||||
dtype_, ", ",
|
||||
device_id_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -150,26 +151,28 @@ class ConvParameters {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
typedef std::tuple<int64, int64, SpatialArray, int64, SpatialArray,
|
||||
SpatialArray, SpatialArray, DataType, int>
|
||||
ParameterDataType;
|
||||
protected:
|
||||
using ParameterDataType =
|
||||
std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray,
|
||||
SpatialArray, DataType, int>;
|
||||
|
||||
ParameterDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_,
|
||||
stride_, padding_, dtype_, device_id_);
|
||||
}
|
||||
|
||||
uint64 hash_code_;
|
||||
|
||||
private:
|
||||
int64 batch_;
|
||||
int64 in_depths_;
|
||||
SpatialArray in_;
|
||||
int64 out_depths_;
|
||||
SpatialArray in_;
|
||||
SpatialArray filter_;
|
||||
SpatialArray stride_;
|
||||
SpatialArray padding_;
|
||||
DataType dtype_;
|
||||
int device_id_;
|
||||
uint64 hash_code_;
|
||||
};
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -556,6 +556,7 @@ template struct functor::NCHWToNHWC<GPUDevice, double, 4>;
|
||||
template struct functor::NCHWToNHWC<GPUDevice, float, 4>;
|
||||
template struct functor::NCHWToNHWC<GPUDevice, Eigen::half, 4>;
|
||||
|
||||
template struct functor::PadInput<GPUDevice, int, int, 4>;
|
||||
template struct functor::PadInput<GPUDevice, float, int, 4>;
|
||||
template struct functor::PadInput<GPUDevice, Eigen::half, int, 4>;
|
||||
|
||||
|
@ -22,7 +22,9 @@ namespace tensorflow {
|
||||
|
||||
Status GetActivationModeFromString(const string& str_value,
|
||||
ActivationMode* value) {
|
||||
if (str_value == "Sigmoid") {
|
||||
if (str_value == "None") {
|
||||
*value = NONE;
|
||||
} else if (str_value == "Sigmoid") {
|
||||
*value = SIGMOID;
|
||||
} else if (str_value == "Relu") {
|
||||
*value = RELU;
|
||||
|
@ -28,6 +28,7 @@ namespace tensorflow {
|
||||
|
||||
// ActivationMode: the activation function we apply to the input tensor:
|
||||
enum ActivationMode {
|
||||
NONE = 0,
|
||||
SIGMOID = 1,
|
||||
RELU = 2,
|
||||
RELU6 = 3,
|
||||
|
@ -1913,6 +1913,106 @@ bool CudnnSupport::DoRnnBackward(
|
||||
#endif // CUDNN_VERSION
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
|
||||
Stream* stream, CUDAExecutor* parent, void* dnn_handle,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
|
||||
ScratchAllocator* scratch_allocator) {
|
||||
cudnnConvolutionFwdPreference_t preference =
|
||||
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
|
||||
: CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
|
||||
auto memory_limit_bytes =
|
||||
scratch_allocator == nullptr
|
||||
? 0
|
||||
: scratch_allocator->GetMemoryLimitInBytes(stream);
|
||||
if (memory_limit_bytes < 0) {
|
||||
memory_limit_bytes = 0;
|
||||
}
|
||||
|
||||
cudnnConvolutionFwdAlgo_t algo_to_use;
|
||||
auto status = wrap::cudnnGetConvolutionForwardAlgorithm(
|
||||
parent, ToHandle(dnn_handle), input_nd.handle(), filter.handle(),
|
||||
conv.handle(), output_nd.handle(), preference, memory_limit_bytes,
|
||||
&algo_to_use);
|
||||
CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
|
||||
<< "Unable to find a suitable algorithm for doing forward convolution";
|
||||
return algo_to_use;
|
||||
}
|
||||
|
||||
dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
|
||||
Stream* stream, CUDAExecutor* parent, void* dnn_handle,
|
||||
int cudnn_type, // Actually cudnnDataType_t.
|
||||
const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
||||
cudnnConvolutionFwdAlgo_t algo =
|
||||
(algorithm_config.algorithm() == dnn::kDefaultAlgorithm)
|
||||
? GetCudnnConvolutionForwardAlgo(
|
||||
stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
|
||||
/*specify_workspace_limit=*/scratch_allocator != nullptr,
|
||||
scratch_allocator)
|
||||
: ToConvForwardAlgo(algorithm_config.algorithm());
|
||||
size_t size_in_bytes;
|
||||
auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
|
||||
parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(),
|
||||
/*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
|
||||
/*destDesc=*/output_nd.handle(), /*algo=*/algo,
|
||||
/*sizeInBytes=*/&size_in_bytes);
|
||||
int64 size_in_bytes_int64 = size_in_bytes;
|
||||
if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
|
||||
CHECK(is_profiling) << "Cannot query the size of workspace needed "
|
||||
"for the specified algorithm: "
|
||||
<< algorithm_config.algorithm() << " "
|
||||
<< ToString(status);
|
||||
// Silently return when we are profiling.
|
||||
return dnn::kNoSuitableAlgorithmFound;
|
||||
}
|
||||
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
|
||||
LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
|
||||
"negative sizeInBytes value. This could be a cudnn bug.";
|
||||
if (TF_PREDICT_TRUE(is_profiling)) {
|
||||
return dnn::kNoSuitableAlgorithmFound;
|
||||
}
|
||||
} else if (size_in_bytes_int64 > 0) {
|
||||
port::StatusOr<DeviceMemory<uint8>> allocated;
|
||||
if (TF_PREDICT_TRUE(scratch_allocator)) {
|
||||
allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
|
||||
if (TF_PREDICT_TRUE(allocated.ok())) {
|
||||
*scratch = allocated.ValueOrDie();
|
||||
} else {
|
||||
if (TF_PREDICT_TRUE(is_profiling)) {
|
||||
// Silently return when we are profiling.
|
||||
return dnn::kNoSuitableAlgorithmFound;
|
||||
}
|
||||
LOG(WARNING) << allocated.status().error_message();
|
||||
// For the int8 case, we fail at this point since the no_scratch
|
||||
// algorithm should be set to dnn::kDefaultAlgorithm.
|
||||
CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
|
||||
<< "The primary convolution algorithm failed memory allocation, "
|
||||
"while a secondary algorithm is not provided.";
|
||||
}
|
||||
}
|
||||
if (TF_PREDICT_FALSE(!allocated.ok())) {
|
||||
algo = (algorithm_config.algorithm_no_scratch() == dnn::kDefaultAlgorithm)
|
||||
? GetCudnnConvolutionForwardAlgo(
|
||||
stream, parent, dnn_handle, input_nd, filter, conv,
|
||||
output_nd, /*specify_workspace_limit=*/false, nullptr)
|
||||
: ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
|
||||
}
|
||||
}
|
||||
|
||||
return algo;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class T>
|
||||
bool CudnnSupport::DoConvolveImpl(
|
||||
Stream* stream, int cudnn_type, // Actually cudnnDataType_t.
|
||||
@ -1920,7 +2020,6 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<T>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<T>& biases, dnn::ActivationMode activation_mode,
|
||||
const BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
@ -1953,6 +2052,8 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
DeviceMemory<uint8> scratch;
|
||||
|
||||
// TODO(pauldonnelly): Replace the following code with a call to
|
||||
// GetCudnnConvolutionForwardAlgorithm().
|
||||
if (algorithm_config.algorithm() == dnn::kDefaultAlgorithm) {
|
||||
// With the default algorithm, use Cudnn's heuristics.
|
||||
auto get_algorithm =
|
||||
@ -2059,27 +2160,9 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
"negative sizeInBytes value. This could be a cudnn bug.";
|
||||
}
|
||||
}
|
||||
const bool has_biases = (biases != nullptr);
|
||||
const bool supported_activation_mode =
|
||||
(activation_mode == dnn::ActivationMode::kRelu);
|
||||
|
||||
if (has_biases && !supported_activation_mode) {
|
||||
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only "
|
||||
"support relu activation.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_biases && activation_mode == dnn::ActivationMode::kNone) {
|
||||
LOG(ERROR) << "To use cudnnConvolutionBiasActivationForward() "
|
||||
"with a valid biases tensor, need to also provide "
|
||||
"a valid activation mode (currently only supports "
|
||||
"kRelu).";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<CUDATimer> timer;
|
||||
if (is_profiling) {
|
||||
timer.reset(new CUDATimer(parent_));
|
||||
timer.reset(new CUDATimer(parent_)); // NOLINT
|
||||
if (!timer->Init()) {
|
||||
return false;
|
||||
}
|
||||
@ -2091,50 +2174,15 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (has_biases) {
|
||||
CHECK(supported_activation_mode);
|
||||
#if CUDNN_VERSION < 6000
|
||||
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
|
||||
"supported for cuDNN version >= 6.";
|
||||
return false;
|
||||
#else
|
||||
BatchDescriptor bias_dimensions;
|
||||
bias_dimensions.set_count(1)
|
||||
.set_feature_map_count(output_descriptor.feature_map_count())
|
||||
.set_height(1)
|
||||
.set_width(1)
|
||||
.set_layout(dnn::DataLayout::kBatchYXDepth);
|
||||
ScopedTensorDescriptor bias_descriptor{
|
||||
parent_, bias_dimensions, static_cast<cudnnDataType_t>(cudnn_type)};
|
||||
// CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
|
||||
// activation descriptor. Note that this will change the nan propagation
|
||||
// behavior from separate conv, bias, and relu (which by default is
|
||||
// CUDNN_PROPAGATE_NAN.
|
||||
ScopedActivationDescriptor activation_desc{parent_, activation_mode,
|
||||
CUDNN_NOT_PROPAGATE_NAN,
|
||||
output_descriptor.value_max()};
|
||||
status = wrap::cudnnConvolutionBiasActivationForward(
|
||||
parent_, ToHandle(dnn_handle_),
|
||||
/*alpha1=*/&alpha, /*srcDesc=*/input_nd.handle(),
|
||||
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
|
||||
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
|
||||
/*algo=*/algo, /*workSpace=*/scratch.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&beta,
|
||||
/*zDesc=*/output_nd.handle(), /*z=*/input_data.opaque(),
|
||||
/*biasDesc=*/bias_descriptor.handle(),
|
||||
/*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(),
|
||||
/*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
|
||||
#endif // CUDNN_VERSION < 6000
|
||||
} else {
|
||||
status = wrap::cudnnConvolutionForward(
|
||||
parent_, ToHandle(dnn_handle_),
|
||||
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
|
||||
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
|
||||
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
|
||||
/*algo=*/algo, /*workSpace=*/scratch.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
|
||||
/*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
|
||||
}
|
||||
status = wrap::cudnnConvolutionForward(
|
||||
parent_, ToHandle(dnn_handle_),
|
||||
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
|
||||
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
|
||||
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
|
||||
/*algo=*/algo, /*workSpace=*/scratch.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
|
||||
/*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
|
||||
|
||||
if (is_profiling) {
|
||||
if (!timer->Stop(AsCUDAStream(stream))) {
|
||||
timer->Destroy();
|
||||
@ -2160,6 +2208,144 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Type, typename BiasType, typename ScaleType,
|
||||
int cudnn_data_type, int cudnn_compute_type>
|
||||
bool CudnnSupport::DoFusedConvolveImpl(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Type>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
#if CUDNN_VERSION < 6000
|
||||
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
|
||||
"supported for cuDNN version >= 6";
|
||||
return false;
|
||||
#else
|
||||
ScopedTensorDescriptor conv_input_nd{
|
||||
parent_, conv_input_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_data_type)};
|
||||
ScopedTensorDescriptor output_nd{
|
||||
parent_, output_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_data_type)};
|
||||
ScopedFilterDescriptor filter{parent_, filter_descriptor,
|
||||
conv_input_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_data_type)};
|
||||
ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, CUDNN_DATA_FLOAT};
|
||||
ScopedConvolutionDescriptor conv{
|
||||
parent_, convolution_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_compute_type)};
|
||||
|
||||
mutex_lock lock{dnn_handle_mutex_};
|
||||
auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
|
||||
AsCUDAStreamValue(stream));
|
||||
CHECK(status == CUDNN_STATUS_SUCCESS)
|
||||
<< "failed to set stream for cudnn handle: " << ToString(status);
|
||||
|
||||
const bool is_profiling = output_profile_result != nullptr;
|
||||
DeviceMemory<uint8> scratch;
|
||||
dnn::AlgorithmType algorithm_type = GetCudnnConvolutionForwardAlgorithm(
|
||||
stream, parent_, dnn_handle_, cudnn_data_type, algorithm_config,
|
||||
is_profiling, conv_input_nd, filter, conv, output_nd, scratch_allocator,
|
||||
&scratch);
|
||||
if (algorithm_type == dnn::kNoSuitableAlgorithmFound) {
|
||||
if (!is_profiling) {
|
||||
LOG(ERROR) << "No suitable algorithm found";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algorithm_type);
|
||||
|
||||
if (activation_mode != dnn::ActivationMode::kRelu) {
|
||||
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
|
||||
"activation.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<CUDATimer> timer;
|
||||
if (is_profiling) {
|
||||
timer.reset(new CUDATimer(parent_)); // NOLINT
|
||||
if (!timer->Init()) {
|
||||
return false;
|
||||
}
|
||||
// The start and stop of the timer should be as close to the Cudnn call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
// to this stream. So it could take multiple profiling measurements.
|
||||
if (!timer->Start(AsCUDAStream(stream))) {
|
||||
timer->Destroy();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
|
||||
// activation descriptor. Note that this will change the nan propagation
|
||||
// behavior from separate conv, bias, and relu (which by default is
|
||||
// CUDNN_PROPAGATE_NAN.
|
||||
ScopedActivationDescriptor activation_desc{parent_, activation_mode,
|
||||
CUDNN_NOT_PROPAGATE_NAN,
|
||||
output_descriptor.value_max()};
|
||||
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
|
||||
: side_input_data.opaque();
|
||||
|
||||
VLOG(2) << "\nconv_input_scale = " << conv_input_scale
|
||||
<< "\nconv_input_nd.handle() = " << conv_input_nd.handle()
|
||||
<< "\nconv_input_data.opaque() = " << conv_input_data.opaque()
|
||||
<< "\nfilter.handle() = " << filter.handle()
|
||||
<< "\nfilter_data.opaque() = " << filter_data.opaque()
|
||||
<< "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo
|
||||
<< "\nscratch.opaque() = " << scratch.opaque()
|
||||
<< "\nscratch.size() = " << scratch.size()
|
||||
<< "\nside_input_scale = " << side_input_scale
|
||||
<< "\noutput_nd.handle() = " << output_nd.handle()
|
||||
<< "\nside_input_data_ptr = " << side_input_data_ptr
|
||||
<< "\nbias_nd.handle() = " << bias_nd.handle()
|
||||
<< "\nbiases.opaque() = " << biases.opaque()
|
||||
<< "\nactivation_desc.handle() = " << activation_desc.handle()
|
||||
<< "\noutput_nd.handle() = " << output_nd.handle()
|
||||
<< "\noutput_data->opaque() = " << output_data->opaque();
|
||||
|
||||
status = wrap::cudnnConvolutionBiasActivationForward(
|
||||
parent_, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
|
||||
/*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
|
||||
/*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
|
||||
/*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
|
||||
/*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
|
||||
/*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
|
||||
/*activationDesc=*/activation_desc.handle(),
|
||||
/*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
|
||||
|
||||
if (is_profiling) {
|
||||
if (!timer->Stop(AsCUDAStream(stream))) {
|
||||
timer->Destroy();
|
||||
return false;
|
||||
}
|
||||
if (status == CUDNN_STATUS_SUCCESS) {
|
||||
output_profile_result->set_algorithm(algo);
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
}
|
||||
timer->Destroy();
|
||||
}
|
||||
|
||||
if (status != CUDNN_STATUS_SUCCESS) {
|
||||
// Silently return when we are profiling.
|
||||
if (!is_profiling) {
|
||||
LOG(ERROR) << "failed to enqueue convolution on stream: "
|
||||
<< ToString(status);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
#endif // CUDNN_VERSION < 6000
|
||||
}
|
||||
|
||||
// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
|
||||
// By default it is turned on, users can explicitly disable them through an
|
||||
// env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
|
||||
@ -2407,48 +2593,16 @@ bool CudnnSupport::DoConvolve(
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoConvolveImpl<float>(
|
||||
stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor,
|
||||
filter_data, convolution_descriptor, biases, activation_mode,
|
||||
output_descriptor, output_data, scratch_allocator, algorithm_config,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoConvolve(
|
||||
Stream* stream, const BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoConvolveImpl<float>(
|
||||
stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor,
|
||||
filter_data, convolution_descriptor, /*biases=*/nullptr,
|
||||
dnn::ActivationMode::kNone, output_descriptor, output_data,
|
||||
filter_data, convolution_descriptor, output_descriptor, output_data,
|
||||
scratch_allocator, algorithm_config, output_profile_result);
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoConvolve(
|
||||
Stream* stream, const BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
|
||||
const BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data) {
|
||||
LOG(ERROR) << "double-based DNN not yet implemented";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoConvolve(
|
||||
Stream* stream, const BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<double>& input_data,
|
||||
@ -2467,34 +2621,113 @@ bool CudnnSupport::DoConvolve(
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoConvolveImpl<Eigen::half>(
|
||||
stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor,
|
||||
filter_data, convolution_descriptor, output_descriptor, output_data,
|
||||
scratch_allocator, algorithm_config, output_profile_result);
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<double>& conv_input_data, double conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& side_input_data, double side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
|
||||
CUDNN_DATA_DOUBLE>(
|
||||
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor, side_input_data,
|
||||
side_input_scale, bias_descriptor, biases, activation_mode,
|
||||
output_descriptor, output_data, scratch_allocator, algorithm_config,
|
||||
output_profile_result);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<float>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
|
||||
CUDNN_DATA_FLOAT>(
|
||||
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor, side_input_data,
|
||||
side_input_scale, bias_descriptor, biases, activation_mode,
|
||||
output_descriptor, output_data, scratch_allocator, algorithm_config,
|
||||
output_profile_result);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<Eigen::half>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const BatchDescriptor& output_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoConvolveImpl<Eigen::half>(
|
||||
stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor,
|
||||
filter_data, convolution_descriptor, biases, activation_mode,
|
||||
return DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
|
||||
CUDNN_DATA_FLOAT>(
|
||||
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor, side_input_data,
|
||||
side_input_scale, bias_descriptor, biases, activation_mode,
|
||||
output_descriptor, output_data, scratch_allocator, algorithm_config,
|
||||
output_profile_result);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoConvolve(
|
||||
Stream* stream, const BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const ConvolutionDescriptor& convolution_descriptor,
|
||||
const BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
|
||||
bool CudnnSupport::DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<int8>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<int8>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return DoConvolveImpl<Eigen::half>(
|
||||
stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor,
|
||||
filter_data, convolution_descriptor, /*biases=*/nullptr,
|
||||
dnn::ActivationMode::kNone, output_descriptor, output_data,
|
||||
scratch_allocator, algorithm_config, output_profile_result);
|
||||
#if CUDNN_VERSION < 6000
|
||||
LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
|
||||
"supported for cuDNN version >= 6";
|
||||
return false;
|
||||
#else
|
||||
return DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
|
||||
CUDNN_DATA_INT32>(
|
||||
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor, side_input_data,
|
||||
side_input_scale, bias_descriptor, biases, activation_mode,
|
||||
output_descriptor, output_data, scratch_allocator, algorithm_config,
|
||||
output_profile_result);
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@ -2730,7 +2963,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
|
||||
|
||||
std::unique_ptr<CUDATimer> timer;
|
||||
if (is_profiling) {
|
||||
timer.reset(new CUDATimer(parent_));
|
||||
timer.reset(new CUDATimer(parent_)); // NOLINT
|
||||
timer->Init();
|
||||
// The start and stop of the timer should be as close to the Cudnn call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
@ -2981,7 +3214,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
|
||||
|
||||
std::unique_ptr<CUDATimer> timer;
|
||||
if (is_profiling) {
|
||||
timer.reset(new CUDATimer(parent_));
|
||||
timer.reset(new CUDATimer(parent_)); // NOLINT
|
||||
timer->Init();
|
||||
// The start and stop of the timer should be as close to the Cudnn call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
|
@ -178,42 +178,6 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
|
||||
DeviceMemory<float>* offset_backprop) override;
|
||||
|
||||
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<float>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data) override;
|
||||
|
||||
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
@ -244,6 +208,66 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<double>& conv_input_data, double conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& side_input_data, double side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<float>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoFusedConvolve(Stream* stream,
|
||||
const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half>& conv_input_data,
|
||||
float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half>& side_input_data,
|
||||
float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<Eigen::half>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<int8>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<int8>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) override;
|
||||
|
||||
bool DoConvolveQuantized(
|
||||
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
@ -561,14 +585,28 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<T>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<T>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<T>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result);
|
||||
|
||||
template <typename Type, typename BiasType, typename ScaleType,
|
||||
int cudnn_data_type, int cudnn_compute_type>
|
||||
bool DoFusedConvolveImpl(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Type>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result);
|
||||
|
||||
template <class T>
|
||||
bool DoConvolveBackwardDataImpl(
|
||||
Stream* stream,
|
||||
|
@ -669,6 +669,7 @@ class PoolingDescriptor {
|
||||
|
||||
typedef int64 AlgorithmType;
|
||||
constexpr AlgorithmType kDefaultAlgorithm = -1;
|
||||
constexpr AlgorithmType kNoSuitableAlgorithmFound = -2;
|
||||
|
||||
// Describes the result from a perf experiment.
|
||||
//
|
||||
@ -912,6 +913,129 @@ class DnnSupport {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enqueues a fused convolution operation onto the stream.
|
||||
// We provide several variants with different types for inputs, biases and
|
||||
// scaling parameters.
|
||||
//
|
||||
// Arguments (all borrowed):
|
||||
// stream: borrowed pointer to the stream that the 'convolve' operation
|
||||
// should be enqueued onto.
|
||||
// conv_input_descriptor: dimensions of the convolution input layer.
|
||||
// conv_input_data: un-owned device memory region which contains the
|
||||
// convolution input.
|
||||
// conv_input_scale: a floating point scale to multiply with each element
|
||||
// of conv_input_data.
|
||||
// filter_descriptor: dimensions of the convolution filter.
|
||||
// filter_data: un-owned device memory region which contains the
|
||||
// convolution filter weights.
|
||||
// convolution_descriptor: stride of the convolution filter.
|
||||
// biases: un-owned device memory region containing biases to add to the
|
||||
// input.
|
||||
// activation_mode: Type of activation to perform.
|
||||
// side_input_data: un-owned device memory region which contains optional
|
||||
// side input data. If 'side_input_scale' is non-zero, then this must
|
||||
// point to data in the tensor shape specified by output_shape.
|
||||
// It will be scaled by 'side_input_scale' and added to the convolution
|
||||
// result and bias prior to applying the activation function.
|
||||
// side_input_scale: a floating point scale to multiply with each element
|
||||
// of side_input_data.
|
||||
// output_descriptor: dimensions of the output layer.
|
||||
// output_data: un-owned device memory region in which to place the
|
||||
// convolution result.
|
||||
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
|
||||
// space in order to speed up the convolution operation.
|
||||
// algorithm: an integer to specify which algorithm should be used for the
|
||||
// operation. kDefaultAlgorithm means the system will pick an algorithm
|
||||
// by default. The coding of the algorithm is be interpretted by the
|
||||
// underlying implementation.
|
||||
// output_profile_result: the output profile result for this call. The
|
||||
// profiling is only enabled when this is not nullptr.
|
||||
//
|
||||
// conv_input_descriptor, filter_descriptor, convolution_descriptor and
|
||||
// output_descriptor together specify exactly how the convolution is aligned
|
||||
// with the input data:
|
||||
//
|
||||
// * (input dimensions - filter size + 1) / filter stride == output dimensions
|
||||
// corresponds to dist_belief padding = VALID, i.e. the input is not padded.
|
||||
// * input dimensions / filter stride == output dimensions
|
||||
// corresponds to dist_belief padding = SAME, i.e. input and output are the
|
||||
// same size - this requires padding the input.
|
||||
// * (input dimensions + filter size - 1) / filter stride == output dimensions
|
||||
// corresponds to dist_belief padding = FULL, i.e. the output is sized so
|
||||
// that if the inverse of the filter is applied to the output in VALID mode
|
||||
// the result is the same size as the input - this requires even more
|
||||
// padding of the input.
|
||||
virtual bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<double>& conv_input_data, double conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& side_input_data, double side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// This is the float version of DoFusedConvolve.
|
||||
virtual bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<float>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// This is the Eigen::half version of DoFusedConvolve.
|
||||
// The scaling parameters are still floats.
|
||||
virtual bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<Eigen::half>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// This is the int8 version of DoFusedConvolve.
|
||||
// The bias input and scaling parameters are floats.
|
||||
virtual bool DoFusedConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
|
||||
const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<int8>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<int8>& side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enqueues a single-precision convolution operation onto the stream.
|
||||
//
|
||||
// Arguments (all borrowed):
|
||||
@ -922,10 +1046,8 @@ class DnnSupport {
|
||||
// convolution input.
|
||||
// filter_descriptor: dimensions of the convolution filter.
|
||||
// convolution_descriptor: stride of the convolution filter.
|
||||
// biases: un-owned device memory region containing biases to add to the
|
||||
// input. This can be DeviceMemory pointing to NULL only when activation_mode
|
||||
// is kNone.
|
||||
// activation_mode: Type of activation to perform.
|
||||
// output_descriptor: dimensions of the output layer.
|
||||
// output_data: un-owned device memory region in which to place the
|
||||
// convolution result.
|
||||
@ -952,55 +1074,6 @@ class DnnSupport {
|
||||
// that if the inverse of the filter is applied to the output in VALID mode
|
||||
// the result is the same size as the input - this requires even more
|
||||
// padding of the input.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<float>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enqueues a double-precision fused convolution, bias add, and activation
|
||||
// operation onto the stream. See DoConvolve above for argument details.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<double>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enqueues a half-precision fused convolution, bias add, and activation
|
||||
// operation onto the stream. See DoConvolve above for argument details.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const DeviceMemory<Eigen::half>& filter_data,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half>& biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<Eigen::half>* output_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
ProfileResult* output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enqueues a single-precision convolution operation (without bias add
|
||||
// or activation) onto the stream.
|
||||
// See DoConvolve above for argument details.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
|
||||
const DeviceMemory<float>& input_data,
|
||||
@ -1012,8 +1085,7 @@ class DnnSupport {
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
ProfileResult* output_profile_result) = 0;
|
||||
|
||||
// Enqueues a double-precision convolution operation (without bias add
|
||||
// or activation) onto the stream.
|
||||
// Enqueues a double-precision convolution operation onto the stream.
|
||||
// See DoConvolve above for argument details.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
@ -1024,8 +1096,7 @@ class DnnSupport {
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
DeviceMemory<double>* output_data) = 0;
|
||||
|
||||
// Enqueues a half-precision convolution operation (without bias add
|
||||
// or activation) onto the stream.
|
||||
// Enqueues a half-precision convolution operation onto the stream.
|
||||
// See DoConvolve above for argument details.
|
||||
virtual bool DoConvolve(
|
||||
Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
|
||||
|
@ -361,28 +361,66 @@ Stream &Stream::ThenBatchNormalizationBackward(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
Stream &Stream::ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
|
||||
PARAM(output_descriptor), PARAM(output));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
CheckError(dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<Eigen::half> &biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
|
||||
PARAM(filter_descriptor), PARAM(filter_data),
|
||||
PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
|
||||
PARAM(output_descriptor), PARAM(output));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
CheckError(dnn->DoConvolve(
|
||||
this, input_descriptor, input_data, filter_descriptor, filter_data,
|
||||
convolution_descriptor, biases, activation_mode, output_descriptor,
|
||||
output, scratch_allocator, dnn::AlgorithmConfig(),
|
||||
/*output_profile_result=*/nullptr));
|
||||
CheckError(dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
@ -390,27 +428,32 @@ Stream &Stream::ThenConvolveWithScratch(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
Stream &Stream::ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
|
||||
PARAM(filter_descriptor), PARAM(filter_data),
|
||||
PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
|
||||
PARAM(output_descriptor), PARAM(output));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
CheckError(dnn->DoConvolve(
|
||||
this, input_descriptor, input_data, filter_descriptor, filter_data,
|
||||
convolution_descriptor, biases, activation_mode, output_descriptor,
|
||||
output, scratch_allocator, dnn::AlgorithmConfig(),
|
||||
/*output_profile_result=*/nullptr));
|
||||
CheckError(dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
@ -472,29 +515,34 @@ Stream &Stream::ThenConvolveWithScratch(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
Stream &Stream::ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result) {
|
||||
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
|
||||
PARAM(filter_descriptor), PARAM(filter_data),
|
||||
PARAM(convolution_descriptor), PARAM(biases),
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
|
||||
PARAM(algorithm_config));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoConvolve(
|
||||
this, input_descriptor, input_data, filter_descriptor, filter_data,
|
||||
convolution_descriptor, biases, activation_mode, output_descriptor,
|
||||
output, scratch_allocator, algorithm_config, output_profile_result);
|
||||
auto status = dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
algorithm_config, output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
@ -505,30 +553,73 @@ Stream &Stream::ThenConvolveWithAlgorithm(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
Stream &Stream::ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<Eigen::half> &biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result) {
|
||||
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
|
||||
PARAM(filter_descriptor), PARAM(filter_data),
|
||||
PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
|
||||
PARAM(algorithm_config));
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
|
||||
PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoConvolve(
|
||||
this, input_descriptor, input_data, filter_descriptor, filter_data,
|
||||
convolution_descriptor, biases, activation_mode, output_descriptor,
|
||||
output, scratch_allocator, algorithm_config, output_profile_result);
|
||||
auto status = dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
algorithm_config, output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result) {
|
||||
VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
|
||||
PARAM(conv_input_scale), PARAM(filter_descriptor),
|
||||
PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
|
||||
PARAM(side_input_data), PARAM(side_input_scale),
|
||||
PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
|
||||
PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
|
||||
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoFusedConvolve(
|
||||
this, conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor,
|
||||
side_input_data, side_input_scale, bias_descriptor, biases,
|
||||
activation_mode, output_descriptor, output, scratch_allocator,
|
||||
algorithm_config, output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
@ -601,19 +692,22 @@ Stream &Stream::ThenConvolveWithAlgorithm(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolve(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
Stream &Stream::ThenFusedConvolve(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output) {
|
||||
return ThenConvolveWithScratch(
|
||||
input_descriptor, input_data, filter_descriptor, filter_data,
|
||||
convolution_descriptor, biases, activation_mode, output_descriptor,
|
||||
output, /*scratch_allocator=*/nullptr);
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) {
|
||||
return ThenFusedConvolveWithScratch(
|
||||
conv_input_descriptor, conv_input_data, conv_input_scale,
|
||||
filter_descriptor, filter_data, convolution_descriptor, side_input_data,
|
||||
side_input_scale, bias_descriptor, biases, activation_mode,
|
||||
output_descriptor, output,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenConvolve(
|
||||
|
@ -240,15 +240,17 @@ class Stream {
|
||||
DeviceMemory<float> *offset_backprop);
|
||||
|
||||
// TODO(leary) add double-precision version of this interface.
|
||||
Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output);
|
||||
Stream &ThenFusedConvolve(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<int8> *output);
|
||||
|
||||
Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
@ -278,23 +280,39 @@ class Stream {
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output_data);
|
||||
|
||||
Stream &ThenConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
Stream &ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
Stream &ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<Eigen::half> &biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator);
|
||||
|
||||
Stream &ThenConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
Stream &ThenFusedConvolveWithScratch(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
|
||||
@ -323,7 +341,6 @@ class Stream {
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
@ -335,6 +352,47 @@ class Stream {
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
Stream &ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<double> &conv_input_data, double conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<double> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<double> &side_input_data, double side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
Stream &ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<float> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
Stream &ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<Eigen::half> &biases,
|
||||
dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
@ -342,25 +400,17 @@ class Stream {
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
Stream &ThenConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<float> &input_data,
|
||||
Stream &ThenFusedConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &conv_input_descriptor,
|
||||
const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<float> &filter_data,
|
||||
const DeviceMemory<int8> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
Stream &ThenConvolveWithAlgorithm(
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const DeviceMemory<Eigen::half> &filter_data,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
|
||||
const DeviceMemory<int8> &side_input_data, float side_input_scale,
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
|
||||
const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const dnn::AlgorithmConfig &algorithm_config,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user