From c0f84b60d803668a18c9dc694b9c387a26166288 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Apr 2020 12:20:34 -0700 Subject: [PATCH] Added support for NCHW_VECT_C and OIHW_VECT_I. PiperOrigin-RevId: 306284325 Change-Id: I39f4f84beed20f192f79d626a862a915b0a128fa --- .../grappler/costs/op_level_cost_estimator.cc | 62 ++++++++++----- .../costs/op_level_cost_estimator_test.cc | 75 +++++++++++++++---- 2 files changed, 102 insertions(+), 35 deletions(-) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index b427da69fce..b8b62cbd6e5 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -693,46 +693,69 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( VLOG(2) << "op features: " << op_info.DebugString(); VLOG(2) << "Original image shape: " << original_image_shape.DebugString(); VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString(); - auto image_shape = - MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes); - auto filter_shape = - MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes); - VLOG(2) << "Image shape: " << image_shape.DebugString(); - VLOG(2) << "Filter shape: " << filter_shape.DebugString(); - int x_index, y_index, channel_index; + int x_index, y_index, major_channel_index, minor_channel_index = -1; const string& data_format = GetDataFormat(op_info); if (data_format == "NCHW") { - channel_index = 1; + major_channel_index = 1; y_index = 2; x_index = 3; + } else if (data_format == "NCHW_VECT_C") { + // Use NCHW_VECT_C + minor_channel_index = 1; + y_index = 2; + x_index = 3; + major_channel_index = 4; } else { // Use NHWC. y_index = 1; x_index = 2; - channel_index = 3; + major_channel_index = 3; } const string& filter_format = GetFilterFormat(op_info); - int filter_x_index, filter_y_index, in_channel_index, out_channel_index; + int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index, + in_minor_channel_index = -1; if (filter_format == "HWIO") { filter_y_index = 0; filter_x_index = 1; - in_channel_index = 2; + in_major_channel_index = 2; out_channel_index = 3; + } else if (filter_format == "OIHW_VECT_I") { + out_channel_index = 0; + in_minor_channel_index = 1; + filter_y_index = 2; + filter_x_index = 3; + in_major_channel_index = 4; } else { // Use OIHW out_channel_index = 0; - in_channel_index = 1; + in_major_channel_index = 1; filter_y_index = 2; filter_x_index = 3; } + + auto image_shape = MaybeGetMinimumShape(original_image_shape, + minor_channel_index >= 0 ? 5 : 4, + found_unknown_shapes); + auto filter_shape = MaybeGetMinimumShape(original_filter_shape, + in_minor_channel_index >= 0 ? 5 : 4, + found_unknown_shapes); + VLOG(2) << "Image shape: " << image_shape.DebugString(); + VLOG(2) << "Filter shape: " << filter_shape.DebugString(); + int64 batch = image_shape.dim(0).size(); int64 ix = image_shape.dim(x_index).size(); int64 iy = image_shape.dim(y_index).size(); - int64 iz = image_shape.dim(channel_index).size(); + int64 iz = minor_channel_index >= 0 + ? image_shape.dim(minor_channel_index).size() * + image_shape.dim(major_channel_index).size() + : image_shape.dim(major_channel_index).size(); int64 kx = filter_shape.dim(filter_x_index).size(); int64 ky = filter_shape.dim(filter_y_index).size(); - int64 kz = filter_shape.dim(in_channel_index).size(); + int64 kz = in_minor_channel_index >= 0 + ? filter_shape.dim(in_major_channel_index).size() * + filter_shape.dim(in_minor_channel_index).size() + : filter_shape.dim(in_major_channel_index).size(); std::vector strides = GetStrides(op_info); const auto padding = GetPadding(op_info); int64 sx = strides[x_index]; @@ -1303,17 +1326,18 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // For more information, see // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc - // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and - // filter formats (OIHW_VECT_I). + // TODO(yaozhang): Support NHWC_VECT_W. string data_format = GetDataFormat(op_context.op_info); - if (data_format != "NCHW" && data_format != "NHWC") { + if (data_format != "NCHW" && data_format != "NHWC" && + data_format != "NCHW_VECT_C") { LOG(WARNING) << "unsupported data format: " << data_format; Costs cost = Costs::ZeroCosts(); cost.inaccurate = true; return cost; } string filter_format = GetFilterFormat(op_context.op_info); - if (filter_format != "HWIO" && filter_format != "OIHW") { + if (filter_format != "HWIO" && filter_format != "OIHW" && + filter_format != "OIHW_VECT_I") { LOG(WARNING) << "unsupported filter format: " << filter_format; Costs cost = Costs::ZeroCosts(); cost.inaccurate = true; @@ -1337,7 +1361,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation( // and format, as it may not be available yet. // TODO(varomodt): should we centralize the Conv2D input/output shapes? OpInfo::TensorProperties output; - if (data_format == "NCHW") { + if (data_format == "NCHW" || data_format == "NCHW_VECT_C") { output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox}); } else if (data_format == "NHWC") { output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz}); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 6048b3ec16d..a086fe02191 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -188,6 +188,19 @@ void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, tensor->set_dtype(DT_FLOAT); } +// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost +// estimation purposes. +void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4, + OpInfo::TensorProperties* tensor) { + auto shape = tensor->mutable_shape(); + shape->add_dim()->set_size(dim0); + shape->add_dim()->set_size(dim1); + shape->add_dim()->set_size(dim2); + shape->add_dim()->set_size(dim3); + shape->add_dim()->set_size(dim4); + tensor->set_dtype(DT_FLOAT); +} + // DescribeConvolution constructs an OpContext for a Conv2D applied to an input // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape // (kx, ky, iz2, oz). @@ -225,7 +238,10 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1, // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a // bias tensor with shape (oz), a side input tensor with shape // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with -// shape (1). +// shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll +// default to 4 (the vector size most often used with this format on NVIDIA +// platforms) for the major channel size, and divide the input channel size by +// that amount. // // Note that this assumes the NHWC data format. OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1, @@ -233,6 +249,7 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1, int oy, int oz, bool has_side_input, const string& data_format, const string& filter_format) { + const int kVecWidth = 4; OpContext op_context; SetCpuDevice(&op_context.op_info); op_context.op_info.set_op("FusedConv2DBiasActivation"); @@ -242,15 +259,25 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1, SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]); if (data_format == "NHWC") { DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs()); - } else { - // Use the NCHW format. + } else if (data_format == "NCHW") { DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs()); + } else { + // Use the NCHW_VECT_C format. + EXPECT_EQ(data_format, "NCHW_VECT_C"); + EXPECT_EQ(iz1 % kVecWidth, 0); + DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth, + op_context.op_info.add_inputs()); } if (filter_format == "HWIO") { DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs()); - } else { - // Use the OIHW format. + } else if (filter_format == "OIHW") { DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs()); + } else { + EXPECT_EQ(filter_format, "OIHW_VECT_I"); + EXPECT_EQ(iz2 % kVecWidth, 0); + // Use the OIHW_VECT_I format. + DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth, + op_context.op_info.add_inputs()); } DescribeTensor1D(oz, op_context.op_info.add_inputs()); @@ -259,8 +286,13 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1, if (has_side_input) { if (data_format == "NHWC") { DescribeTensor4D(batch, ox, oy, oz, side_input); - } else { + } else if (data_format == "NCHW") { DescribeTensor4D(batch, oz, ox, oy, side_input); + } else { + // Use the NCHW_VECT_C format. + EXPECT_EQ(data_format, "NCHW_VECT_C"); + EXPECT_EQ(oz % kVecWidth, 0); + DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input); } } @@ -810,29 +842,40 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) { EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); } -// TODO(yaozhang): Update once NCHW_VECT_C is supported. TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) { auto cost = PredictCosts(DescribeFusedConv2DBiasActivation( 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW_VECT_C", "OIHW")); - EXPECT_EQ(Costs::Duration(0), cost.memory_time); - EXPECT_EQ(Costs::Duration(0), cost.compute_time); - EXPECT_EQ(Costs::Duration(0), cost.execution_time); + EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); + EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); EXPECT_EQ(1, cost.num_ops_total); - EXPECT_TRUE(cost.inaccurate); + EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); } -// TODO(yaozhang): Update once OIHW_VECT_I is supported. TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) { auto cost = PredictCosts(DescribeFusedConv2DBiasActivation( 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, "NCHW", "OIHW_VECT_I")); - EXPECT_EQ(Costs::Duration(0), cost.memory_time); - EXPECT_EQ(Costs::Duration(0), cost.compute_time); - EXPECT_EQ(Costs::Duration(0), cost.execution_time); + EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); + EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); EXPECT_EQ(1, cost.num_ops_total); - EXPECT_TRUE(cost.inaccurate); + EXPECT_FALSE(cost.inaccurate); + EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); +} + +TEST_F(OpLevelCostEstimatorTest, + FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) { + auto cost = PredictCosts(DescribeFusedConv2DBiasActivation( + 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true, + "NCHW_VECT_C", "OIHW_VECT_I")); + EXPECT_EQ(Costs::Duration(1416808), cost.memory_time); + EXPECT_EQ(Costs::Duration(355616770), cost.compute_time); + EXPECT_EQ(Costs::Duration(357033578), cost.execution_time); + EXPECT_EQ(1, cost.num_ops_total); + EXPECT_FALSE(cost.inaccurate); EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); }