Added support for NCHW_VECT_C and OIHW_VECT_I.

PiperOrigin-RevId: 306284325
Change-Id: I39f4f84beed20f192f79d626a862a915b0a128fa
This commit is contained in:
A. Unique TensorFlower 2020-04-13 12:20:34 -07:00 committed by TensorFlower Gardener
parent 19e0a5d35f
commit c0f84b60d8
2 changed files with 102 additions and 35 deletions

View File

@ -693,46 +693,69 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
VLOG(2) << "op features: " << op_info.DebugString();
VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
auto image_shape =
MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
auto filter_shape =
MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
VLOG(2) << "Image shape: " << image_shape.DebugString();
VLOG(2) << "Filter shape: " << filter_shape.DebugString();
int x_index, y_index, channel_index;
int x_index, y_index, major_channel_index, minor_channel_index = -1;
const string& data_format = GetDataFormat(op_info);
if (data_format == "NCHW") {
channel_index = 1;
major_channel_index = 1;
y_index = 2;
x_index = 3;
} else if (data_format == "NCHW_VECT_C") {
// Use NCHW_VECT_C
minor_channel_index = 1;
y_index = 2;
x_index = 3;
major_channel_index = 4;
} else {
// Use NHWC.
y_index = 1;
x_index = 2;
channel_index = 3;
major_channel_index = 3;
}
const string& filter_format = GetFilterFormat(op_info);
int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
in_minor_channel_index = -1;
if (filter_format == "HWIO") {
filter_y_index = 0;
filter_x_index = 1;
in_channel_index = 2;
in_major_channel_index = 2;
out_channel_index = 3;
} else if (filter_format == "OIHW_VECT_I") {
out_channel_index = 0;
in_minor_channel_index = 1;
filter_y_index = 2;
filter_x_index = 3;
in_major_channel_index = 4;
} else {
// Use OIHW
out_channel_index = 0;
in_channel_index = 1;
in_major_channel_index = 1;
filter_y_index = 2;
filter_x_index = 3;
}
auto image_shape = MaybeGetMinimumShape(original_image_shape,
minor_channel_index >= 0 ? 5 : 4,
found_unknown_shapes);
auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
in_minor_channel_index >= 0 ? 5 : 4,
found_unknown_shapes);
VLOG(2) << "Image shape: " << image_shape.DebugString();
VLOG(2) << "Filter shape: " << filter_shape.DebugString();
int64 batch = image_shape.dim(0).size();
int64 ix = image_shape.dim(x_index).size();
int64 iy = image_shape.dim(y_index).size();
int64 iz = image_shape.dim(channel_index).size();
int64 iz = minor_channel_index >= 0
? image_shape.dim(minor_channel_index).size() *
image_shape.dim(major_channel_index).size()
: image_shape.dim(major_channel_index).size();
int64 kx = filter_shape.dim(filter_x_index).size();
int64 ky = filter_shape.dim(filter_y_index).size();
int64 kz = filter_shape.dim(in_channel_index).size();
int64 kz = in_minor_channel_index >= 0
? filter_shape.dim(in_major_channel_index).size() *
filter_shape.dim(in_minor_channel_index).size()
: filter_shape.dim(in_major_channel_index).size();
std::vector<int64> strides = GetStrides(op_info);
const auto padding = GetPadding(op_info);
int64 sx = strides[x_index];
@ -1303,17 +1326,18 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// For more information, see
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
// TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
// filter formats (OIHW_VECT_I).
// TODO(yaozhang): Support NHWC_VECT_W.
string data_format = GetDataFormat(op_context.op_info);
if (data_format != "NCHW" && data_format != "NHWC") {
if (data_format != "NCHW" && data_format != "NHWC" &&
data_format != "NCHW_VECT_C") {
LOG(WARNING) << "unsupported data format: " << data_format;
Costs cost = Costs::ZeroCosts();
cost.inaccurate = true;
return cost;
}
string filter_format = GetFilterFormat(op_context.op_info);
if (filter_format != "HWIO" && filter_format != "OIHW") {
if (filter_format != "HWIO" && filter_format != "OIHW" &&
filter_format != "OIHW_VECT_I") {
LOG(WARNING) << "unsupported filter format: " << filter_format;
Costs cost = Costs::ZeroCosts();
cost.inaccurate = true;
@ -1337,7 +1361,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// and format, as it may not be available yet.
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
OpInfo::TensorProperties output;
if (data_format == "NCHW") {
if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
} else if (data_format == "NHWC") {
output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});

View File

@ -188,6 +188,19 @@ void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
tensor->set_dtype(DT_FLOAT);
}
// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
// estimation purposes.
void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
OpInfo::TensorProperties* tensor) {
auto shape = tensor->mutable_shape();
shape->add_dim()->set_size(dim0);
shape->add_dim()->set_size(dim1);
shape->add_dim()->set_size(dim2);
shape->add_dim()->set_size(dim3);
shape->add_dim()->set_size(dim4);
tensor->set_dtype(DT_FLOAT);
}
// DescribeConvolution constructs an OpContext for a Conv2D applied to an input
// tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
// (kx, ky, iz2, oz).
@ -225,7 +238,10 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
// bias tensor with shape (oz), a side input tensor with shape
// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
// shape (1).
// shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
// default to 4 (the vector size most often used with this format on NVIDIA
// platforms) for the major channel size, and divide the input channel size by
// that amount.
//
// Note that this assumes the NHWC data format.
OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
@ -233,6 +249,7 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
int oy, int oz, bool has_side_input,
const string& data_format,
const string& filter_format) {
const int kVecWidth = 4;
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("FusedConv2DBiasActivation");
@ -242,15 +259,25 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
if (data_format == "NHWC") {
DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
} else {
// Use the NCHW format.
} else if (data_format == "NCHW") {
DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
} else {
// Use the NCHW_VECT_C format.
EXPECT_EQ(data_format, "NCHW_VECT_C");
EXPECT_EQ(iz1 % kVecWidth, 0);
DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
op_context.op_info.add_inputs());
}
if (filter_format == "HWIO") {
DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
} else {
// Use the OIHW format.
} else if (filter_format == "OIHW") {
DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
} else {
EXPECT_EQ(filter_format, "OIHW_VECT_I");
EXPECT_EQ(iz2 % kVecWidth, 0);
// Use the OIHW_VECT_I format.
DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
op_context.op_info.add_inputs());
}
DescribeTensor1D(oz, op_context.op_info.add_inputs());
@ -259,8 +286,13 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
if (has_side_input) {
if (data_format == "NHWC") {
DescribeTensor4D(batch, ox, oy, oz, side_input);
} else {
} else if (data_format == "NCHW") {
DescribeTensor4D(batch, oz, ox, oy, side_input);
} else {
// Use the NCHW_VECT_C format.
EXPECT_EQ(data_format, "NCHW_VECT_C");
EXPECT_EQ(oz % kVecWidth, 0);
DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
}
}
@ -810,29 +842,40 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
// TODO(yaozhang): Update once NCHW_VECT_C is supported.
TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
"NCHW_VECT_C", "OIHW"));
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
EXPECT_EQ(Costs::Duration(0), cost.compute_time);
EXPECT_EQ(Costs::Duration(0), cost.execution_time);
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
// TODO(yaozhang): Update once OIHW_VECT_I is supported.
TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
"NCHW", "OIHW_VECT_I"));
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
EXPECT_EQ(Costs::Duration(0), cost.compute_time);
EXPECT_EQ(Costs::Duration(0), cost.execution_time);
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_TRUE(cost.inaccurate);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
TEST_F(OpLevelCostEstimatorTest,
FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
"NCHW_VECT_C", "OIHW_VECT_I"));
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}