Support NHWC Conv2D with cuDNN for fp16 (aka Eigen::half and DT_HALF)
PiperOrigin-RevId: 247502362
This commit is contained in:
parent
7a8aadc06c
commit
3a5cb493f5
@ -1606,10 +1606,10 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/stream_executor/cuda:cudnn_plugin",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -179,42 +179,50 @@ struct MatMulConvFunctor {
|
||||
|
||||
// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
|
||||
//
|
||||
// Note: Currently OIHW is the only supported destination format. Support for
|
||||
// OHWI format will be added in a follow-up change.
|
||||
// Note: Currently supports OIHW and OHWI destination formats.
|
||||
template <typename Device, typename T, typename IndexType, int NDIMS>
|
||||
struct TransformFilter {
|
||||
void operator()(const Device& d, FilterTensorFormat dst_filter_format,
|
||||
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
|
||||
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
|
||||
// NOTE: Source filter format is always HWIO.
|
||||
Eigen::DSizes<IndexType, NDIMS - 2> spatial_dims;
|
||||
for (int i = 0; i < spatial_dims.rank(); ++i) {
|
||||
spatial_dims[i] = in.dimension(i);
|
||||
}
|
||||
|
||||
// Merge the spatial dimensions together to speed up the shuffle operation.
|
||||
Eigen::DSizes<IndexType, 3> merged_dims;
|
||||
merged_dims[0] = in.dimension(0); // spatial dimensions
|
||||
for (int i = 1; i < NDIMS - 2; ++i) {
|
||||
merged_dims[0] *= in.dimension(i);
|
||||
}
|
||||
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
|
||||
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
|
||||
|
||||
DCHECK(dst_filter_format == FORMAT_OIHW)
|
||||
<< "Unsupported destination filter format: "
|
||||
<< ToString(dst_filter_format);
|
||||
// Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
|
||||
// in the beginning.
|
||||
Eigen::DSizes<IndexType, 3> shuffling_perm =
|
||||
Eigen::DSizes<IndexType, 3>(2, 1, 0);
|
||||
merged_dims[0] = spatial_dims.TotalSize(); // product of spatial dims [H*W]
|
||||
merged_dims[1] = in.dimension(NDIMS - 2); // input filters [I]
|
||||
merged_dims[2] = in.dimension(NDIMS - 1); // output filters [O]
|
||||
|
||||
// Shuffle tensor with merged spatial dimensions.
|
||||
Eigen::DSizes<IndexType, 3> shuffling_perm;
|
||||
// Expand shuffled tensor into final dimensions.
|
||||
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
|
||||
int out_index = 0;
|
||||
for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
|
||||
if (shuffling_perm[merged_dim] == 0) {
|
||||
for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
|
||||
expanded_dims[out_index++] = in.dimension(spatial_dim);
|
||||
}
|
||||
} else {
|
||||
constexpr int kLastSpatialDim = NDIMS - 3;
|
||||
expanded_dims[out_index++] =
|
||||
in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
|
||||
|
||||
if (dst_filter_format == FORMAT_OIHW) {
|
||||
shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 1, 0);
|
||||
|
||||
expanded_dims[0] = merged_dims[2]; // [O]
|
||||
expanded_dims[1] = merged_dims[1]; // [I]
|
||||
for (int i = 0; i < spatial_dims.rank(); ++i) {
|
||||
expanded_dims[2 + i] = spatial_dims[i];
|
||||
}
|
||||
|
||||
} else if (dst_filter_format == FORMAT_OHWI) {
|
||||
shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 0, 1);
|
||||
|
||||
expanded_dims[0] = merged_dims[2]; // [O]
|
||||
expanded_dims[NDIMS - 1] = merged_dims[1]; // [I]
|
||||
for (int i = 0; i < spatial_dims.rank(); ++i) {
|
||||
expanded_dims[1 + i] = spatial_dims[i];
|
||||
}
|
||||
|
||||
} else {
|
||||
DCHECK(false) << "Unsupported destination filter format: "
|
||||
<< ToString(dst_filter_format);
|
||||
}
|
||||
|
||||
out.device(d) =
|
||||
|
@ -434,13 +434,22 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
|
||||
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
||||
|
||||
CHECK(dst_filter_format == FORMAT_OIHW)
|
||||
<< "Unsupported output layout: " << ToString(dst_filter_format);
|
||||
if (dst_filter_format == FORMAT_OIHW) {
|
||||
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
||||
config.block_count, config.thread_per_block,
|
||||
0, d.stream(), config.virtual_thread_count,
|
||||
in.data(), combined_dims, out.data()));
|
||||
|
||||
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), config.virtual_thread_count,
|
||||
in.data(), combined_dims, out.data()));
|
||||
} else if (dst_filter_format == FORMAT_OHWI) {
|
||||
TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
|
||||
config.block_count, config.thread_per_block,
|
||||
0, d.stream(), config.virtual_thread_count,
|
||||
in.data(), combined_dims, out.data()));
|
||||
|
||||
} else {
|
||||
LOG(ERROR) << "Unsupported filter format: "
|
||||
<< ToString(dst_filter_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@ -561,6 +562,15 @@ template struct LaunchConv2DOp<CPUDevice, float>;
|
||||
template struct LaunchConv2DOp<CPUDevice, double>;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU.
|
||||
bool IsVoltaOrLater(const se::StreamExecutor& stream_exec) {
|
||||
int major, minor;
|
||||
CHECK(stream_exec // Crash OK
|
||||
.GetDeviceDescription()
|
||||
.cuda_compute_capability(&major, &minor));
|
||||
return major >= 7;
|
||||
}
|
||||
|
||||
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
@ -676,6 +686,23 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
return;
|
||||
}
|
||||
|
||||
// Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
|
||||
// in NHWC data layout. In all other configurations it's more efficient to
|
||||
// run computation in NCHW data format.
|
||||
const bool compute_in_nhwc =
|
||||
DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
|
||||
|
||||
// We only do one directional conversion: NHWC->NCHW. We never convert in the
|
||||
// other direction. Grappler layout optimizer selects preferred layout and
|
||||
// adds necessary annotations to the graph.
|
||||
// TODO(ezhulenev): Convert in other direction for fp16?
|
||||
const TensorFormat compute_data_format =
|
||||
compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
|
||||
|
||||
VLOG(3) << "Compute Conv2D with cuDNN:"
|
||||
<< " data_format=" << ToString(data_format)
|
||||
<< " compute_data_format=" << ToString(compute_data_format);
|
||||
|
||||
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
|
||||
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
|
||||
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
|
||||
@ -708,6 +735,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
// cuDNN only supports padding the same amount on the left and right sides,
|
||||
// and on the top and bottom sides. So we manually create a new padded
|
||||
// input tensor such that we can pass it to cuDNN.
|
||||
VLOG(4) << "Pad input tensor:"
|
||||
<< " padding_top=" << padding_top
|
||||
<< " padding_bottom=" << padding_bottom
|
||||
<< " padding_left=" << padding_left
|
||||
<< " padding_right=" << padding_right;
|
||||
|
||||
// TODO(reedwm): In some cases, we can avoid an allocation even if the two
|
||||
// padding sides are different. For example, if the input is 2x2, the filter
|
||||
@ -750,8 +782,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
in_cols = new_in_cols;
|
||||
}
|
||||
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
// Convert the input tensor from NHWC to NCHW.
|
||||
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||
VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
|
||||
|
||||
TensorShape nchw_shape =
|
||||
ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
|
||||
if (in_depths > 1) {
|
||||
@ -767,28 +800,48 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
// If depth <= 1, then just reshape.
|
||||
CHECK(input.CopyFrom(input, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
CHECK(data_format == compute_data_format) // Crash OK
|
||||
<< "Illegal data and compute format pair:"
|
||||
<< " data_format=" << ToString(data_format)
|
||||
<< " compute_data_format=" << ToString(compute_data_format);
|
||||
}
|
||||
|
||||
CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
|
||||
<< "Negative row or col paddings: (" << common_padding_rows << ", "
|
||||
<< common_padding_cols << ")";
|
||||
|
||||
constexpr auto kComputeInNHWC =
|
||||
std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
|
||||
se::dnn::FilterLayout::kOutputYXInput);
|
||||
constexpr auto kComputeInNCHW =
|
||||
std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
|
||||
se::dnn::FilterLayout::kOutputInputYX);
|
||||
|
||||
se::dnn::DataLayout compute_data_layout;
|
||||
se::dnn::FilterLayout filter_layout;
|
||||
|
||||
std::tie(compute_data_layout, filter_layout) =
|
||||
compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
|
||||
|
||||
se::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(in_batch)
|
||||
.set_feature_map_count(in_depths)
|
||||
.set_height(in_rows)
|
||||
.set_width(in_cols)
|
||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
||||
.set_layout(compute_data_layout);
|
||||
se::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(out_batch)
|
||||
.set_height(out_rows)
|
||||
.set_width(out_cols)
|
||||
.set_feature_map_count(out_depths)
|
||||
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
|
||||
.set_layout(compute_data_layout);
|
||||
se::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(patch_rows)
|
||||
.set_input_filter_width(patch_cols)
|
||||
.set_input_feature_map_count(patch_depths)
|
||||
.set_output_feature_map_count(filter.dim_size(3));
|
||||
.set_output_feature_map_count(filter.dim_size(3))
|
||||
.set_layout(filter_layout);
|
||||
se::dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_dilation_rate(row_dilation)
|
||||
.set_horizontal_dilation_rate(col_dilation)
|
||||
@ -799,22 +852,42 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
.set_group_count(in_depths / patch_depths);
|
||||
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({filter.dim_size(3), filter.dim_size(2),
|
||||
filter.dim_size(0), filter.dim_size(1)}),
|
||||
&transformed_filter));
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
|
||||
To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
|
||||
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
|
||||
<< " to " << ToString(dst_format);
|
||||
|
||||
TensorShape dst_shape =
|
||||
dst_format == FORMAT_OIHW
|
||||
? TensorShape({filter.dim_size(3), filter.dim_size(2),
|
||||
filter.dim_size(0), filter.dim_size(1)})
|
||||
: TensorShape({filter.dim_size(3), filter.dim_size(0),
|
||||
filter.dim_size(1), filter.dim_size(2)});
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||
&transformed_filter));
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), dst_format,
|
||||
To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
};
|
||||
|
||||
if (compute_data_format == FORMAT_NCHW) {
|
||||
transform_filter(FORMAT_OIHW);
|
||||
} else if (compute_data_format == FORMAT_NHWC) {
|
||||
transform_filter(FORMAT_OHWI);
|
||||
} else {
|
||||
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
|
||||
ToString(compute_data_format)));
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor transformed_output;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
// Only allocate temporary memory when a layout transformation is needed.
|
||||
if (data_format != compute_data_format) {
|
||||
VLOG(4) << "Allocate temporary memory for output in compute data format";
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW, out_batch,
|
||||
ShapeFromFormat(compute_data_format, out_batch,
|
||||
out_rows, out_cols, out_depths),
|
||||
&transformed_output));
|
||||
} else {
|
||||
@ -842,7 +915,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
in_depths, // in_depths
|
||||
{{in_rows, // in_rows
|
||||
in_cols}}, // in_cols
|
||||
FORMAT_NCHW, // compute_data_format
|
||||
compute_data_format, // compute_data_format
|
||||
out_depths, // out_depths
|
||||
{{patch_rows, // filter_rows
|
||||
patch_cols, // filter_cols
|
||||
@ -901,6 +974,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
VLOG(4) << "Convolution Algorithm: "
|
||||
<< algorithm_config.algorithm()->algo_id();
|
||||
VLOG(4) << "tensor_ops_enabled: "
|
||||
<< algorithm_config.algorithm()->tensor_ops_enabled();
|
||||
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
@ -916,8 +994,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
") filter shape(", filter.shape().DebugString(), ")"));
|
||||
}
|
||||
|
||||
// Convert the output tensor back from NCHW to NHWC.
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
|
||||
VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
|
||||
functor::NCHWToNHWC<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Performance benchmarks for the FusedConv2Op. //
|
||||
// Performance benchmarks for the Conv2DOp and FusedConv2Op. //
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Conv2DGraph {
|
||||
@ -63,19 +63,27 @@ struct Conv2DWithBatchNormAndActivationGraph {
|
||||
Node* activation;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static Tensor MakeRandomTensor(const TensorShape& shape) {
|
||||
Tensor tensor(DT_FLOAT, TensorShape(shape));
|
||||
tensor.flat<float>() = tensor.flat<float>().setRandom();
|
||||
Tensor tensor(DataTypeToEnum<T>::value, TensorShape(shape));
|
||||
tensor.flat<T>() = tensor.flat<T>().setRandom();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Creates a simple Tensorflow graph with single Conv2D node.
|
||||
template <typename T>
|
||||
static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
|
||||
int filter_w, int filter_h, int out_depth) {
|
||||
int filter_w, int filter_h, int out_depth,
|
||||
TensorFormat data_format = FORMAT_NHWC) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
|
||||
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
|
||||
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
|
||||
Tensor images_t = data_format == FORMAT_NHWC
|
||||
? MakeRandomTensor<T>({batch, height, width, in_depth})
|
||||
: MakeRandomTensor<T>({batch, in_depth, height, width});
|
||||
|
||||
// Filter is always in HWIO.
|
||||
Tensor filter_t =
|
||||
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
|
||||
|
||||
Node* images = test::graph::Constant(graph, images_t, "images");
|
||||
Node* filter = test::graph::Constant(graph, filter_t, "filter");
|
||||
@ -84,33 +92,35 @@ static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "Conv2D")
|
||||
.Input(images)
|
||||
.Input(filter)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("strides", {1, 1, 1, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("data_format", ToString(data_format))
|
||||
.Finalize(graph, &conv2d));
|
||||
|
||||
return {graph, conv2d};
|
||||
}
|
||||
|
||||
// Creates a Tensorflow graph with a Conv2D node followed by BiasAdd.
|
||||
static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
|
||||
int in_depth, int filter_w,
|
||||
int filter_h, int out_depth) {
|
||||
Conv2DGraph conv_graph =
|
||||
Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
|
||||
template <typename T>
|
||||
static Conv2DWithBiasGraph Conv2DWithBias(
|
||||
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
|
||||
int out_depth, TensorFormat data_format = FORMAT_NHWC) {
|
||||
Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
|
||||
filter_h, out_depth, data_format);
|
||||
|
||||
Graph* graph = conv_graph.graph;
|
||||
Node* conv2d = conv_graph.conv2d;
|
||||
|
||||
Tensor bias_t = MakeRandomTensor({out_depth});
|
||||
Tensor bias_t = MakeRandomTensor<T>({out_depth});
|
||||
Node* bias = test::graph::Constant(graph, bias_t, "bias");
|
||||
|
||||
Node* out;
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName("bias"), "BiasAdd")
|
||||
.Input(conv2d)
|
||||
.Input(bias)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("data_format", "NHWC")
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("data_format", ToString(data_format))
|
||||
.Finalize(graph, &out));
|
||||
|
||||
return {graph, conv2d, out};
|
||||
@ -118,11 +128,14 @@ static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
|
||||
|
||||
// Creates a Tensorflow graph with a Conv2D node followed by BiasAdd and
|
||||
// activation (Relu, Relu6, etc...).
|
||||
template <typename T>
|
||||
static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
|
||||
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
|
||||
int out_depth, const string& activation_type) {
|
||||
Conv2DWithBiasGraph conv_graph = Conv2DWithBias(
|
||||
batch, height, width, in_depth, filter_w, filter_h, out_depth);
|
||||
int out_depth, const string& activation_type,
|
||||
TensorFormat data_format = FORMAT_NHWC) {
|
||||
Conv2DWithBiasGraph conv_graph =
|
||||
Conv2DWithBias<T>(batch, height, width, in_depth, filter_w, filter_h,
|
||||
out_depth, data_format);
|
||||
|
||||
Graph* graph = conv_graph.graph;
|
||||
Node* conv2d = conv_graph.conv2d;
|
||||
@ -131,27 +144,27 @@ static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
|
||||
Node* activation;
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
|
||||
.Input(bias)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Finalize(graph, &activation));
|
||||
|
||||
return {graph, conv2d, bias, activation};
|
||||
}
|
||||
|
||||
// Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm.
|
||||
static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
|
||||
int width, int in_depth,
|
||||
int filter_w, int filter_h,
|
||||
int out_depth) {
|
||||
Conv2DGraph conv_graph =
|
||||
Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
|
||||
template <typename T>
|
||||
static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(
|
||||
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
|
||||
int out_depth, TensorFormat data_format = FORMAT_NHWC) {
|
||||
Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
|
||||
filter_h, out_depth, data_format);
|
||||
|
||||
Graph* graph = conv_graph.graph;
|
||||
Node* conv2d = conv_graph.conv2d;
|
||||
|
||||
Tensor scale_t = MakeRandomTensor({out_depth});
|
||||
Tensor offset_t = MakeRandomTensor({out_depth});
|
||||
Tensor mean_t = MakeRandomTensor({out_depth});
|
||||
Tensor variance_t = MakeRandomTensor({out_depth});
|
||||
Tensor scale_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor offset_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor mean_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor variance_t = MakeRandomTensor<T>({out_depth});
|
||||
|
||||
Node* scale = test::graph::Constant(graph, scale_t, "scale");
|
||||
Node* offset = test::graph::Constant(graph, offset_t, "offset");
|
||||
@ -165,8 +178,9 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
|
||||
.Input(offset)
|
||||
.Input(mean)
|
||||
.Input(variance)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("is_training", false)
|
||||
.Attr("data_format", ToString(data_format))
|
||||
.Finalize(graph, &out));
|
||||
|
||||
return {graph, conv2d, out};
|
||||
@ -174,11 +188,14 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
|
||||
|
||||
// Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm and
|
||||
// activation (Relu, Relu6, etc...).
|
||||
template <typename T>
|
||||
static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
|
||||
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
|
||||
int out_depth, const string& activation_type) {
|
||||
Conv2DWithBatchNormGraph conv_graph = Conv2DWithBatchNorm(
|
||||
batch, height, width, in_depth, filter_w, filter_h, out_depth);
|
||||
int out_depth, const string& activation_type,
|
||||
TensorFormat data_format = FORMAT_NHWC) {
|
||||
Conv2DWithBatchNormGraph conv_graph =
|
||||
Conv2DWithBatchNorm<T>(batch, height, width, in_depth, filter_w, filter_h,
|
||||
out_depth, data_format);
|
||||
|
||||
Graph* graph = conv_graph.graph;
|
||||
Node* conv2d = conv_graph.conv2d;
|
||||
@ -187,7 +204,7 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
|
||||
Node* activation;
|
||||
TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
|
||||
.Input(batch_norm)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Finalize(graph, &activation));
|
||||
|
||||
return {graph, conv2d, batch_norm, activation};
|
||||
@ -195,15 +212,22 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
|
||||
|
||||
// Creates a tensorflow graph with a single FusedConv2D (with BiasAdd) node and
|
||||
// fuses into it additional computations (e.g. Relu).
|
||||
template <typename T>
|
||||
static Graph* FusedConv2DWithBias(int batch, int height, int width,
|
||||
int in_depth, int filter_w, int filter_h,
|
||||
int out_depth,
|
||||
const std::vector<string>& fused_ops = {}) {
|
||||
const std::vector<string>& fused_ops = {},
|
||||
TensorFormat data_format = FORMAT_NHWC) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
|
||||
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
|
||||
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
|
||||
Tensor bias_t = MakeRandomTensor({out_depth});
|
||||
Tensor images_t = data_format == FORMAT_NHWC
|
||||
? MakeRandomTensor<T>({batch, height, width, in_depth})
|
||||
: MakeRandomTensor<T>({batch, in_depth, height, width});
|
||||
|
||||
// Filter is always in HWIO.
|
||||
Tensor filter_t =
|
||||
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
|
||||
Tensor bias_t = MakeRandomTensor<T>({out_depth});
|
||||
|
||||
Node* images = test::graph::Constant(graph, images_t, "images");
|
||||
Node* filter = test::graph::Constant(graph, filter_t, "filter");
|
||||
@ -217,7 +241,7 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
|
||||
.Input(filter)
|
||||
.Attr("num_args", 1)
|
||||
.Input(args)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("strides", {1, 1, 1, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("fused_ops", fused_ops)
|
||||
@ -228,17 +252,24 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
|
||||
|
||||
// Creates a tensorflow graph with a single FusedConv2D (with FusedBatchNorm)
|
||||
// node and fuses into it additional computations (e.g. Relu).
|
||||
template <typename T>
|
||||
static Graph* FusedConv2DWithBatchNorm(
|
||||
int batch, int height, int width, int in_depth, int filter_w, int filter_h,
|
||||
int out_depth, const std::vector<string>& fused_ops = {}) {
|
||||
int out_depth, const std::vector<string>& fused_ops = {},
|
||||
TensorFormat data_format = FORMAT_NHWC) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
|
||||
Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
|
||||
Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
|
||||
Tensor scale_t = MakeRandomTensor({out_depth});
|
||||
Tensor offset_t = MakeRandomTensor({out_depth});
|
||||
Tensor mean_t = MakeRandomTensor({out_depth});
|
||||
Tensor variance_t = MakeRandomTensor({out_depth});
|
||||
Tensor images_t = data_format == FORMAT_NHWC
|
||||
? MakeRandomTensor<T>({batch, height, width, in_depth})
|
||||
: MakeRandomTensor<T>({batch, in_depth, height, width});
|
||||
|
||||
// Filter is always in HWIO.
|
||||
Tensor filter_t =
|
||||
MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
|
||||
Tensor scale_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor offset_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor mean_t = MakeRandomTensor<T>({out_depth});
|
||||
Tensor variance_t = MakeRandomTensor<T>({out_depth});
|
||||
|
||||
Node* images = test::graph::Constant(graph, images_t, "images");
|
||||
Node* filter = test::graph::Constant(graph, filter_t, "filter");
|
||||
@ -255,7 +286,7 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
.Input(filter)
|
||||
.Attr("num_args", 4)
|
||||
.Input(args)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("strides", {1, 1, 1, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("fused_ops", fused_ops)
|
||||
@ -273,6 +304,10 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
// FH: filter height
|
||||
// FW: filter width
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Following benchmarks are always using 'float' data type with NHWC layout.
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
#define BM_SETUP(N, H, W, C, type, LABEL, NAME) \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * (C)); \
|
||||
testing::SetLabel(LABEL);
|
||||
@ -280,39 +315,41 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
#define BM_NAME(name, type, N, H, W, C, FW, FH, FC) \
|
||||
name##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
|
||||
|
||||
#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2D(N, H, W, C, FW, FH, FC).graph).Run(iters); \
|
||||
} \
|
||||
#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2D<float>(N, H, W, C, FW, FH, FC).graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
|
||||
|
||||
#define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2DWithBias(N, H, W, C, FW, FH, FC).graph) \
|
||||
test::Benchmark(#type, \
|
||||
Conv2DWithBias<float>(N, H, W, C, FW, FH, FC).graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
|
||||
|
||||
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark( \
|
||||
#type, \
|
||||
Conv2DWithBiasAndActivation(N, H, W, C, FW, FH, FC, "Relu").graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, \
|
||||
FH, FC, "Relu") \
|
||||
.graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
|
||||
|
||||
#define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL) \
|
||||
static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, \
|
||||
FusedConv2DWithBias(N, H, W, C, FW, FH, FC, {"BiasAdd"})) \
|
||||
test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, \
|
||||
{"BiasAdd"})) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
|
||||
@ -321,8 +358,8 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, FusedConv2DWithBias(N, H, W, C, FW, FH, FC, \
|
||||
{"BiasAdd", "Relu"})) \
|
||||
test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, \
|
||||
{"BiasAdd", "Relu"})) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
@ -332,7 +369,8 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC).graph) \
|
||||
test::Benchmark(#type, \
|
||||
Conv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC).graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
|
||||
@ -341,8 +379,8 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, Conv2DWithBatchNormAndActivation(N, H, W, C, FW, \
|
||||
FH, FC, "Relu") \
|
||||
test::Benchmark(#type, Conv2DWithBatchNormAndActivation<float>( \
|
||||
N, H, W, C, FW, FH, FC, "Relu") \
|
||||
.graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
@ -353,8 +391,8 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, \
|
||||
{"FusedBatchNorm"})) \
|
||||
test::Benchmark(#type, FusedConv2DWithBatchNorm<float>( \
|
||||
N, H, W, C, FW, FH, FC, {"FusedBatchNorm"})) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
|
||||
@ -364,9 +402,9 @@ static Graph* FusedConv2DWithBatchNorm(
|
||||
static void BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, \
|
||||
FW, FH, FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, LABEL, Conv2D); \
|
||||
test::Benchmark(#type, \
|
||||
FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC, \
|
||||
{"FusedBatchNorm", "Relu"})) \
|
||||
test::Benchmark( \
|
||||
#type, FusedConv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC, \
|
||||
{"FusedBatchNorm", "Relu"})) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
|
||||
@ -500,4 +538,63 @@ BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
|
||||
BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
|
||||
#endif
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// T: data type
|
||||
// FORMAT: data format (NHWC or NCHW)
|
||||
// N: batch size
|
||||
// H: height
|
||||
// W: width
|
||||
// C: channels
|
||||
// FC: filter count
|
||||
// FH: filter height
|
||||
// FW: filter width
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Following benchmarks are used to compare different data format performance
|
||||
// for different data types. They make sense only when CUDA enabled, because on
|
||||
// CPU we only support data in NHWC.
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
#define BM_LONG_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \
|
||||
name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
|
||||
|
||||
#define BM_Conv2DFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type) \
|
||||
static void BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, \
|
||||
FC)(int iters) { \
|
||||
BM_SETUP(N, H, W, C, type, "", Conv2D); \
|
||||
test::Benchmark(#type, \
|
||||
Conv2D<T>(N, H, W, C, FW, FH, FC, FORMAT_##FORMAT).graph) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC));
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
using fp32 = float;
|
||||
using fp16 = Eigen::half;
|
||||
|
||||
// ResNet50-ish convolutions.
|
||||
#define BENCHMARK_DTYPE(BATCH, T) \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 64, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 256, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 256, 1, 1, 64, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 3, 3, 64, gpu); \
|
||||
\
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 128, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 512, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 1, 1, 128, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 3, 3, 128, gpu); \
|
||||
\
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 256, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \
|
||||
BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 3, 3, 256, gpu);
|
||||
|
||||
BENCHMARK_DTYPE(32, fp32);
|
||||
BENCHMARK_DTYPE(32, fp16);
|
||||
|
||||
BENCHMARK_DTYPE(64, fp32);
|
||||
BENCHMARK_DTYPE(64, fp16);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -63,6 +63,8 @@ string ToString(FilterTensorFormat format) {
|
||||
return "HWIO";
|
||||
case FORMAT_OIHW:
|
||||
return "OIHW";
|
||||
case FORMAT_OHWI:
|
||||
return "OHWI";
|
||||
case FORMAT_OIHW_VECT_I:
|
||||
return "OIHW_VECT_I";
|
||||
default:
|
||||
|
@ -80,6 +80,9 @@ enum FilterTensorFormat {
|
||||
// FORMAT_OIHW often improves performance on GPUs.
|
||||
FORMAT_OIHW = 1,
|
||||
|
||||
// FORMAT_OHWI used by cuDNN for NHWC convolutions.
|
||||
FORMAT_OHWI = 2,
|
||||
|
||||
// OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
|
||||
// int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
|
||||
// data format. It is laid out in the same order as OIHW, except that the size
|
||||
@ -88,7 +91,7 @@ enum FilterTensorFormat {
|
||||
// int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
|
||||
// dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
|
||||
// A pre-condition of this format is that I must be a multiple of 4.
|
||||
FORMAT_OIHW_VECT_I = 2,
|
||||
FORMAT_OIHW_VECT_I = 3,
|
||||
};
|
||||
|
||||
// Parse tensor format from the given string.
|
||||
|
Loading…
Reference in New Issue
Block a user