Support NHWC Conv2D with cuDNN for fp16 (aka Eigen::half and DT_HALF)

PiperOrigin-RevId: 247502362
This commit is contained in:
Eugene Zhulenev 2019-05-09 14:59:16 -07:00 committed by TensorFlower Gardener
parent 7a8aadc06c
commit 3a5cb493f5
7 changed files with 325 additions and 128 deletions

View File

@ -1606,10 +1606,10 @@ tf_cuda_cc_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/stream_executor/cuda:cudnn_plugin",
],
)

View File

@ -179,42 +179,50 @@ struct MatMulConvFunctor {
// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
//
// Note: Currently OIHW is the only supported destination format. Support for
// OHWI format will be added in a follow-up change.
// Note: Currently supports OIHW and OHWI destination formats.
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
// NOTE: Source filter format is always HWIO.
Eigen::DSizes<IndexType, NDIMS - 2> spatial_dims;
for (int i = 0; i < spatial_dims.rank(); ++i) {
spatial_dims[i] = in.dimension(i);
}
// Merge the spatial dimensions together to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
merged_dims[0] *= in.dimension(i);
}
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
DCHECK(dst_filter_format == FORMAT_OIHW)
<< "Unsupported destination filter format: "
<< ToString(dst_filter_format);
// Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
// in the beginning.
Eigen::DSizes<IndexType, 3> shuffling_perm =
Eigen::DSizes<IndexType, 3>(2, 1, 0);
merged_dims[0] = spatial_dims.TotalSize(); // product of spatial dims [H*W]
merged_dims[1] = in.dimension(NDIMS - 2); // input filters [I]
merged_dims[2] = in.dimension(NDIMS - 1); // output filters [O]
// Shuffle tensor with merged spatial dimensions.
Eigen::DSizes<IndexType, 3> shuffling_perm;
// Expand shuffled tensor into final dimensions.
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
int out_index = 0;
for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
if (shuffling_perm[merged_dim] == 0) {
for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
expanded_dims[out_index++] = in.dimension(spatial_dim);
}
} else {
constexpr int kLastSpatialDim = NDIMS - 3;
expanded_dims[out_index++] =
in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
if (dst_filter_format == FORMAT_OIHW) {
shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 1, 0);
expanded_dims[0] = merged_dims[2]; // [O]
expanded_dims[1] = merged_dims[1]; // [I]
for (int i = 0; i < spatial_dims.rank(); ++i) {
expanded_dims[2 + i] = spatial_dims[i];
}
} else if (dst_filter_format == FORMAT_OHWI) {
shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 0, 1);
expanded_dims[0] = merged_dims[2]; // [O]
expanded_dims[NDIMS - 1] = merged_dims[1]; // [I]
for (int i = 0; i < spatial_dims.rank(); ++i) {
expanded_dims[1 + i] = spatial_dims[i];
}
} else {
DCHECK(false) << "Unsupported destination filter format: "
<< ToString(dst_filter_format);
}
out.device(d) =

View File

@ -434,13 +434,22 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
CHECK(dst_filter_format == FORMAT_OIHW)
<< "Unsupported output layout: " << ToString(dst_filter_format);
if (dst_filter_format == FORMAT_OIHW) {
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else if (dst_filter_format == FORMAT_OHWI) {
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
config.block_count, config.thread_per_block,
0, d.stream(), config.virtual_thread_count,
in.data(), combined_dims, out.data()));
} else {
LOG(ERROR) << "Unsupported filter format: "
<< ToString(dst_filter_format);
}
}
};

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_ops.h"
#include <string.h>
#include <map>
#include <vector>
@ -561,6 +562,15 @@ template struct LaunchConv2DOp<CPUDevice, float>;
template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU.
bool IsVoltaOrLater(const se::StreamExecutor& stream_exec) {
int major, minor;
CHECK(stream_exec // Crash OK
.GetDeviceDescription()
.cuda_compute_capability(&major, &minor));
return major >= 7;
}
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
@ -676,6 +686,23 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
return;
}
// Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
// in NHWC data layout. In all other configurations it's more efficient to
// run computation in NCHW data format.
const bool compute_in_nhwc =
DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
// We only do one directional conversion: NHWC->NCHW. We never convert in the
// other direction. Grappler layout optimizer selects preferred layout and
// adds necessary annotations to the graph.
// TODO(ezhulenev): Convert in other direction for fp16?
const TensorFormat compute_data_format =
compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
VLOG(3) << "Compute Conv2D with cuDNN:"
<< " data_format=" << ToString(data_format)
<< " compute_data_format=" << ToString(compute_data_format);
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
@ -708,6 +735,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
// cuDNN only supports padding the same amount on the left and right sides,
// and on the top and bottom sides. So we manually create a new padded
// input tensor such that we can pass it to cuDNN.
VLOG(4) << "Pad input tensor:"
<< " padding_top=" << padding_top
<< " padding_bottom=" << padding_bottom
<< " padding_left=" << padding_left
<< " padding_right=" << padding_right;
// TODO(reedwm): In some cases, we can avoid an allocation even if the two
// padding sides are different. For example, if the input is 2x2, the filter
@ -750,8 +782,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
in_cols = new_in_cols;
}
if (data_format == FORMAT_NHWC) {
// Convert the input tensor from NHWC to NCHW.
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
TensorShape nchw_shape =
ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
if (in_depths > 1) {
@ -767,28 +800,48 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
// If depth <= 1, then just reshape.
CHECK(input.CopyFrom(input, nchw_shape));
}
} else {
CHECK(data_format == compute_data_format) // Crash OK
<< "Illegal data and compute format pair:"
<< " data_format=" << ToString(data_format)
<< " compute_data_format=" << ToString(compute_data_format);
}
CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
<< "Negative row or col paddings: (" << common_padding_rows << ", "
<< common_padding_cols << ")";
constexpr auto kComputeInNHWC =
std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
se::dnn::FilterLayout::kOutputYXInput);
constexpr auto kComputeInNCHW =
std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
se::dnn::FilterLayout::kOutputInputYX);
se::dnn::DataLayout compute_data_layout;
se::dnn::FilterLayout filter_layout;
std::tie(compute_data_layout, filter_layout) =
compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
se::dnn::BatchDescriptor input_desc;
input_desc.set_count(in_batch)
.set_feature_map_count(in_depths)
.set_height(in_rows)
.set_width(in_cols)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
.set_layout(compute_data_layout);
se::dnn::BatchDescriptor output_desc;
output_desc.set_count(out_batch)
.set_height(out_rows)
.set_width(out_cols)
.set_feature_map_count(out_depths)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
.set_layout(compute_data_layout);
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(patch_rows)
.set_input_filter_width(patch_cols)
.set_input_feature_map_count(patch_depths)
.set_output_feature_map_count(filter.dim_size(3));
.set_output_feature_map_count(filter.dim_size(3))
.set_layout(filter_layout);
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(row_dilation)
.set_horizontal_dilation_rate(col_dilation)
@ -799,22 +852,42 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_group_count(in_depths / patch_depths);
Tensor transformed_filter;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)}),
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
<< " to " << ToString(dst_format);
TensorShape dst_shape =
dst_format == FORMAT_OIHW
? TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)})
: TensorShape({filter.dim_size(3), filter.dim_size(0),
filter.dim_size(1), filter.dim_size(2)});
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
ctx->eigen_device<GPUDevice>(), dst_format,
To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
};
if (compute_data_format == FORMAT_NCHW) {
transform_filter(FORMAT_OIHW);
} else if (compute_data_format == FORMAT_NHWC) {
transform_filter(FORMAT_OHWI);
} else {
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
ToString(compute_data_format)));
return;
}
Tensor transformed_output;
if (data_format == FORMAT_NHWC) {
// Only allocate temporary memory when a layout transformation is needed.
if (data_format != compute_data_format) {
VLOG(4) << "Allocate temporary memory for output in compute data format";
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NCHW, out_batch,
ShapeFromFormat(compute_data_format, out_batch,
out_rows, out_cols, out_depths),
&transformed_output));
} else {
@ -842,7 +915,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
in_depths, // in_depths
{{in_rows, // in_rows
in_cols}}, // in_cols
FORMAT_NCHW, // compute_data_format
compute_data_format, // compute_data_format
out_depths, // out_depths
{{patch_rows, // filter_rows
patch_cols, // filter_cols
@ -901,6 +974,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
}
VLOG(4) << "Convolution Algorithm: "
<< algorithm_config.algorithm()->algo_id();
VLOG(4) << "tensor_ops_enabled: "
<< algorithm_config.algorithm()->tensor_ops_enabled();
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream
@ -916,8 +994,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
") filter shape(", filter.shape().DebugString(), ")"));
}
// Convert the output tensor back from NCHW to NHWC.
if (data_format == FORMAT_NHWC) {
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
functor::NCHWToNHWC<GPUDevice, T, 4>()(
ctx->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),

View File

@ -29,7 +29,7 @@ limitations under the License.
namespace tensorflow {
////////////////////////////////////////////////////////////////////////////////
// Performance benchmarks for the FusedConv2Op. //
// Performance benchmarks for the Conv2DOp and FusedConv2Op. //
////////////////////////////////////////////////////////////////////////////////
struct Conv2DGraph {
@ -63,19 +63,27 @@ struct Conv2DWithBatchNormAndActivationGraph {
Node* activation;
};
template <typename T>
static Tensor MakeRandomTensor(const TensorShape& shape) {
Tensor tensor(DT_FLOAT, TensorShape(shape));
tensor.flat<float>() = tensor.flat<float>().setRandom();
Tensor tensor(DataTypeToEnum<T>::value, TensorShape(shape));
tensor.flat<T>() = tensor.flat<T>().setRandom();
return tensor;
}
// Creates a simple Tensorflow graph with single Conv2D node.
template <typename T>
static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
int filter_w, int filter_h, int out_depth) {
int filter_w, int filter_h, int out_depth,
TensorFormat data_format = FORMAT_NHWC) {
Graph* graph = new Graph(OpRegistry::Global());
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
Tensor images_t = data_format == FORMAT_NHWC
? MakeRandomTensor<T>({batch, height, width, in_depth})
: MakeRandomTensor<T>({batch, in_depth, height, width});
// Filter is always in HWIO.
Tensor filter_t =
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
Node* images = test::graph::Constant(graph, images_t, "images");
Node* filter = test::graph::Constant(graph, filter_t, "filter");
@ -84,33 +92,35 @@ static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "Conv2D")
.Input(images)
.Input(filter)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Attr("strides", {1, 1, 1, 1})
.Attr("padding", "SAME")
.Attr("data_format", ToString(data_format))
.Finalize(graph, &conv2d));
return {graph, conv2d};
}
// Creates a Tensorflow graph with a Conv2D node followed by BiasAdd.
static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
int in_depth, int filter_w,
int filter_h, int out_depth) {
Conv2DGraph conv_graph =
Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
template <typename T>
static Conv2DWithBiasGraph Conv2DWithBias(
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
int out_depth, TensorFormat data_format = FORMAT_NHWC) {
Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
filter_h, out_depth, data_format);
Graph* graph = conv_graph.graph;
Node* conv2d = conv_graph.conv2d;
Tensor bias_t = MakeRandomTensor({out_depth});
Tensor bias_t = MakeRandomTensor<T>({out_depth});
Node* bias = test::graph::Constant(graph, bias_t, "bias");
Node* out;
TF_CHECK_OK(NodeBuilder(graph->NewName("bias"), "BiasAdd")
.Input(conv2d)
.Input(bias)
.Attr("T", DT_FLOAT)
.Attr("data_format", "NHWC")
.Attr("T", DataTypeToEnum<T>::value)
.Attr("data_format", ToString(data_format))
.Finalize(graph, &out));
return {graph, conv2d, out};
@ -118,11 +128,14 @@ static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
// Creates a Tensorflow graph with a Conv2D node followed by BiasAdd and
// activation (Relu, Relu6, etc...).
template <typename T>
static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
int out_depth, const string& activation_type) {
Conv2DWithBiasGraph conv_graph = Conv2DWithBias(
batch, height, width, in_depth, filter_w, filter_h, out_depth);
int out_depth, const string& activation_type,
TensorFormat data_format = FORMAT_NHWC) {
Conv2DWithBiasGraph conv_graph =
Conv2DWithBias<T>(batch, height, width, in_depth, filter_w, filter_h,
out_depth, data_format);
Graph* graph = conv_graph.graph;
Node* conv2d = conv_graph.conv2d;
@ -131,27 +144,27 @@ static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
Node* activation;
TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
.Input(bias)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Finalize(graph, &activation));
return {graph, conv2d, bias, activation};
}
// Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm.
static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
int width, int in_depth,
int filter_w, int filter_h,
int out_depth) {
Conv2DGraph conv_graph =
Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
template <typename T>
static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
int out_depth, TensorFormat data_format = FORMAT_NHWC) {
Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
filter_h, out_depth, data_format);
Graph* graph = conv_graph.graph;
Node* conv2d = conv_graph.conv2d;
Tensor scale_t = MakeRandomTensor({out_depth});
Tensor offset_t = MakeRandomTensor({out_depth});
Tensor mean_t = MakeRandomTensor({out_depth});
Tensor variance_t = MakeRandomTensor({out_depth});
Tensor scale_t = MakeRandomTensor<T>({out_depth});
Tensor offset_t = MakeRandomTensor<T>({out_depth});
Tensor mean_t = MakeRandomTensor<T>({out_depth});
Tensor variance_t = MakeRandomTensor<T>({out_depth});
Node* scale = test::graph::Constant(graph, scale_t, "scale");
Node* offset = test::graph::Constant(graph, offset_t, "offset");
@ -165,8 +178,9 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
.Input(offset)
.Input(mean)
.Input(variance)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Attr("is_training", false)
.Attr("data_format", ToString(data_format))
.Finalize(graph, &out));
return {graph, conv2d, out};
@ -174,11 +188,14 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
// Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm and
// activation (Relu, Relu6, etc...).
template <typename T>
static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
int out_depth, const string& activation_type) {
Conv2DWithBatchNormGraph conv_graph = Conv2DWithBatchNorm(
batch, height, width, in_depth, filter_w, filter_h, out_depth);
int out_depth, const string& activation_type,
TensorFormat data_format = FORMAT_NHWC) {
Conv2DWithBatchNormGraph conv_graph =
Conv2DWithBatchNorm<T>(batch, height, width, in_depth, filter_w, filter_h,
out_depth, data_format);
Graph* graph = conv_graph.graph;
Node* conv2d = conv_graph.conv2d;
@ -187,7 +204,7 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
Node* activation;
TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
.Input(batch_norm)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Finalize(graph, &activation));
return {graph, conv2d, batch_norm, activation};
@ -195,15 +212,22 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
// Creates a tensorflow graph with a single FusedConv2D (with BiasAdd) node and
// fuses into it additional computations (e.g. Relu).
template <typename T>
static Graph* FusedConv2DWithBias(int batch, int height, int width,
int in_depth, int filter_w, int filter_h,
int out_depth,
const std::vector<string>& fused_ops = {}) {
const std::vector<string>& fused_ops = {},
TensorFormat data_format = FORMAT_NHWC) {
Graph* graph = new Graph(OpRegistry::Global());
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
Tensor bias_t = MakeRandomTensor({out_depth});
Tensor images_t = data_format == FORMAT_NHWC
? MakeRandomTensor<T>({batch, height, width, in_depth})
: MakeRandomTensor<T>({batch, in_depth, height, width});
// Filter is always in HWIO.
Tensor filter_t =
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
Tensor bias_t = MakeRandomTensor<T>({out_depth});
Node* images = test::graph::Constant(graph, images_t, "images");
Node* filter = test::graph::Constant(graph, filter_t, "filter");
@ -217,7 +241,7 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
.Input(filter)
.Attr("num_args", 1)
.Input(args)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Attr("strides", {1, 1, 1, 1})
.Attr("padding", "SAME")
.Attr("fused_ops", fused_ops)
@ -228,17 +252,24 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
// Creates a tensorflow graph with a single FusedConv2D (with FusedBatchNorm)
// node and fuses into it additional computations (e.g. Relu).
template <typename T>
static Graph* FusedConv2DWithBatchNorm(
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
int out_depth, const std::vector<string>& fused_ops = {}) {
int out_depth, const std::vector<string>& fused_ops = {},
TensorFormat data_format = FORMAT_NHWC) {
Graph* graph = new Graph(OpRegistry::Global());
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
Tensor scale_t = MakeRandomTensor({out_depth});
Tensor offset_t = MakeRandomTensor({out_depth});
Tensor mean_t = MakeRandomTensor({out_depth});
Tensor variance_t = MakeRandomTensor({out_depth});
Tensor images_t = data_format == FORMAT_NHWC
? MakeRandomTensor<T>({batch, height, width, in_depth})
: MakeRandomTensor<T>({batch, in_depth, height, width});
// Filter is always in HWIO.
Tensor filter_t =
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
Tensor scale_t = MakeRandomTensor<T>({out_depth});
Tensor offset_t = MakeRandomTensor<T>({out_depth});
Tensor mean_t = MakeRandomTensor<T>({out_depth});
Tensor variance_t = MakeRandomTensor<T>({out_depth});
Node* images = test::graph::Constant(graph, images_t, "images");
Node* filter = test::graph::Constant(graph, filter_t, "filter");
@ -255,7 +286,7 @@ static Graph* FusedConv2DWithBatchNorm(
.Input(filter)
.Attr("num_args", 4)
.Input(args)
.Attr("T", DT_FLOAT)
.Attr("T", DataTypeToEnum<T>::value)
.Attr("strides", {1, 1, 1, 1})
.Attr("padding", "SAME")
.Attr("fused_ops", fused_ops)
@ -273,6 +304,10 @@ static Graph* FusedConv2DWithBatchNorm(
// FH: filter height
// FW: filter width
// -------------------------------------------------------------------------- //
// Following benchmarks are always using 'float' data type with NHWC layout.
// -------------------------------------------------------------------------- //
#define BM_SETUP(N, H, W, C, type, LABEL, NAME) \
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * (C)); \
testing::SetLabel(LABEL);
@ -280,39 +315,41 @@ static Graph* FusedConv2DWithBatchNorm(
#define BM_NAME(name, type, N, H, W, C, FW, FH, FC) \
name##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2D(N, H, W, C, FW, FH, FC).graph).Run(iters); \
} \
#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2D<float>(N, H, W, C, FW, FH, FC).graph) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
#define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2DWithBias(N, H, W, C, FW, FH, FC).graph) \
test::Benchmark(#type, \
Conv2DWithBias<float>(N, H, W, C, FW, FH, FC).graph) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark( \
#type, \
Conv2DWithBiasAndActivation(N, H, W, C, FW, FH, FC, "Relu").graph) \
.Run(iters); \
} \
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, \
FH, FC, "Relu") \
.graph) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, \
FusedConv2DWithBias(N, H, W, C, FW, FH, FC, {"BiasAdd"})) \
test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, \
{"BiasAdd"})) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
@ -321,8 +358,8 @@ static Graph* FusedConv2DWithBatchNorm(
static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, FusedConv2DWithBias(N, H, W, C, FW, FH, FC, \
{"BiasAdd", "Relu"})) \
test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, \
{"BiasAdd", "Relu"})) \
.Run(iters); \
} \
BENCHMARK( \
@ -332,7 +369,8 @@ static Graph* FusedConv2DWithBatchNorm(
static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC).graph) \
test::Benchmark(#type, \
Conv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC).graph) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
@ -341,8 +379,8 @@ static Graph* FusedConv2DWithBatchNorm(
static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, Conv2DWithBatchNormAndActivation(N, H, W, C, FW, \
FH, FC, "Relu") \
test::Benchmark(#type, Conv2DWithBatchNormAndActivation<float>( \
N, H, W, C, FW, FH, FC, "Relu") \
.graph) \
.Run(iters); \
} \
@ -353,8 +391,8 @@ static Graph* FusedConv2DWithBatchNorm(
static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, \
{"FusedBatchNorm"})) \
test::Benchmark(#type, FusedConv2DWithBatchNorm<float>( \
N, H, W, C, FW, FH, FC, {"FusedBatchNorm"})) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
@ -364,9 +402,9 @@ static Graph* FusedConv2DWithBatchNorm(
static void BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, \
FW, FH, FC)(int iters) { \
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
test::Benchmark(#type, \
FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, \
{"FusedBatchNorm", "Relu"})) \
test::Benchmark( \
#type, FusedConv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC, \
{"FusedBatchNorm", "Relu"})) \
.Run(iters); \
} \
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
@ -500,4 +538,63 @@ BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
#endif
// Macro arguments names: --------------------------------------------------- //
// T: data type
// FORMAT: data format (NHWC or NCHW)
// N: batch size
// H: height
// W: width
// C: channels
// FC: filter count
// FH: filter height
// FW: filter width
// -------------------------------------------------------------------------- //
// Following benchmarks are used to compare different data format performance
// for different data types. They make sense only when CUDA enabled, because on
// CPU we only support data in NHWC.
// -------------------------------------------------------------------------- //
#define BM_LONG_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \
name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
#define BM_Conv2DFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \
static void BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, \
FC)(int iters) { \
BM_SETUP(N, H, W, C, type, "", Conv2D); \
test::Benchmark(#type, \
Conv2D<T>(N, H, W, C, FW, FH, FC, FORMAT_##FORMAT).graph) \
.Run(iters); \
} \
BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC));
#if GOOGLE_CUDA
using fp32 = float;
using fp16 = Eigen::half;
// ResNet50-ish convolutions.
#define BENCHMARK_DTYPE(BATCH, T) \
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 64, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 256, 1, 1, 64, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 3, 3, 64, gpu); \
\
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 128, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 512, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 1, 1, 128, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 3, 3, 128, gpu); \
\
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 3, 3, 256, gpu);
BENCHMARK_DTYPE(32, fp32);
BENCHMARK_DTYPE(32, fp16);
BENCHMARK_DTYPE(64, fp32);
BENCHMARK_DTYPE(64, fp16);
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -63,6 +63,8 @@ string ToString(FilterTensorFormat format) {
return "HWIO";
case FORMAT_OIHW:
return "OIHW";
case FORMAT_OHWI:
return "OHWI";
case FORMAT_OIHW_VECT_I:
return "OIHW_VECT_I";
default:

View File

@ -80,6 +80,9 @@ enum FilterTensorFormat {
// FORMAT_OIHW often improves performance on GPUs.
FORMAT_OIHW = 1,
// FORMAT_OHWI used by cuDNN for NHWC convolutions.
FORMAT_OHWI = 2,
// OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
// int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
// data format. It is laid out in the same order as OIHW, except that the size
@ -88,7 +91,7 @@ enum FilterTensorFormat {
// int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
// dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
// A pre-condition of this format is that I must be a multiple of 4.
FORMAT_OIHW_VECT_I = 2,
FORMAT_OIHW_VECT_I = 3,
};
// Parse tensor format from the given string.