Added support for NCHW_VECT_C and OIHW_VECT_I.
PiperOrigin-RevId: 306284325 Change-Id: I39f4f84beed20f192f79d626a862a915b0a128fa
This commit is contained in:
parent
19e0a5d35f
commit
c0f84b60d8
@ -693,46 +693,69 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
|
||||
VLOG(2) << "op features: " << op_info.DebugString();
|
||||
VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
|
||||
VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
|
||||
auto image_shape =
|
||||
MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
|
||||
auto filter_shape =
|
||||
MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes);
|
||||
VLOG(2) << "Image shape: " << image_shape.DebugString();
|
||||
VLOG(2) << "Filter shape: " << filter_shape.DebugString();
|
||||
|
||||
int x_index, y_index, channel_index;
|
||||
int x_index, y_index, major_channel_index, minor_channel_index = -1;
|
||||
const string& data_format = GetDataFormat(op_info);
|
||||
if (data_format == "NCHW") {
|
||||
channel_index = 1;
|
||||
major_channel_index = 1;
|
||||
y_index = 2;
|
||||
x_index = 3;
|
||||
} else if (data_format == "NCHW_VECT_C") {
|
||||
// Use NCHW_VECT_C
|
||||
minor_channel_index = 1;
|
||||
y_index = 2;
|
||||
x_index = 3;
|
||||
major_channel_index = 4;
|
||||
} else {
|
||||
// Use NHWC.
|
||||
y_index = 1;
|
||||
x_index = 2;
|
||||
channel_index = 3;
|
||||
major_channel_index = 3;
|
||||
}
|
||||
const string& filter_format = GetFilterFormat(op_info);
|
||||
int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
|
||||
int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
|
||||
in_minor_channel_index = -1;
|
||||
if (filter_format == "HWIO") {
|
||||
filter_y_index = 0;
|
||||
filter_x_index = 1;
|
||||
in_channel_index = 2;
|
||||
in_major_channel_index = 2;
|
||||
out_channel_index = 3;
|
||||
} else if (filter_format == "OIHW_VECT_I") {
|
||||
out_channel_index = 0;
|
||||
in_minor_channel_index = 1;
|
||||
filter_y_index = 2;
|
||||
filter_x_index = 3;
|
||||
in_major_channel_index = 4;
|
||||
} else {
|
||||
// Use OIHW
|
||||
out_channel_index = 0;
|
||||
in_channel_index = 1;
|
||||
in_major_channel_index = 1;
|
||||
filter_y_index = 2;
|
||||
filter_x_index = 3;
|
||||
}
|
||||
|
||||
auto image_shape = MaybeGetMinimumShape(original_image_shape,
|
||||
minor_channel_index >= 0 ? 5 : 4,
|
||||
found_unknown_shapes);
|
||||
auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
|
||||
in_minor_channel_index >= 0 ? 5 : 4,
|
||||
found_unknown_shapes);
|
||||
VLOG(2) << "Image shape: " << image_shape.DebugString();
|
||||
VLOG(2) << "Filter shape: " << filter_shape.DebugString();
|
||||
|
||||
int64 batch = image_shape.dim(0).size();
|
||||
int64 ix = image_shape.dim(x_index).size();
|
||||
int64 iy = image_shape.dim(y_index).size();
|
||||
int64 iz = image_shape.dim(channel_index).size();
|
||||
int64 iz = minor_channel_index >= 0
|
||||
? image_shape.dim(minor_channel_index).size() *
|
||||
image_shape.dim(major_channel_index).size()
|
||||
: image_shape.dim(major_channel_index).size();
|
||||
int64 kx = filter_shape.dim(filter_x_index).size();
|
||||
int64 ky = filter_shape.dim(filter_y_index).size();
|
||||
int64 kz = filter_shape.dim(in_channel_index).size();
|
||||
int64 kz = in_minor_channel_index >= 0
|
||||
? filter_shape.dim(in_major_channel_index).size() *
|
||||
filter_shape.dim(in_minor_channel_index).size()
|
||||
: filter_shape.dim(in_major_channel_index).size();
|
||||
std::vector<int64> strides = GetStrides(op_info);
|
||||
const auto padding = GetPadding(op_info);
|
||||
int64 sx = strides[x_index];
|
||||
@ -1303,17 +1326,18 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
|
||||
// For more information, see
|
||||
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
|
||||
|
||||
// TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
|
||||
// filter formats (OIHW_VECT_I).
|
||||
// TODO(yaozhang): Support NHWC_VECT_W.
|
||||
string data_format = GetDataFormat(op_context.op_info);
|
||||
if (data_format != "NCHW" && data_format != "NHWC") {
|
||||
if (data_format != "NCHW" && data_format != "NHWC" &&
|
||||
data_format != "NCHW_VECT_C") {
|
||||
LOG(WARNING) << "unsupported data format: " << data_format;
|
||||
Costs cost = Costs::ZeroCosts();
|
||||
cost.inaccurate = true;
|
||||
return cost;
|
||||
}
|
||||
string filter_format = GetFilterFormat(op_context.op_info);
|
||||
if (filter_format != "HWIO" && filter_format != "OIHW") {
|
||||
if (filter_format != "HWIO" && filter_format != "OIHW" &&
|
||||
filter_format != "OIHW_VECT_I") {
|
||||
LOG(WARNING) << "unsupported filter format: " << filter_format;
|
||||
Costs cost = Costs::ZeroCosts();
|
||||
cost.inaccurate = true;
|
||||
@ -1337,7 +1361,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
|
||||
// and format, as it may not be available yet.
|
||||
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
|
||||
OpInfo::TensorProperties output;
|
||||
if (data_format == "NCHW") {
|
||||
if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
|
||||
output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
|
||||
} else if (data_format == "NHWC") {
|
||||
output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
|
||||
|
@ -188,6 +188,19 @@ void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
|
||||
tensor->set_dtype(DT_FLOAT);
|
||||
}
|
||||
|
||||
// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
|
||||
// estimation purposes.
|
||||
void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
|
||||
OpInfo::TensorProperties* tensor) {
|
||||
auto shape = tensor->mutable_shape();
|
||||
shape->add_dim()->set_size(dim0);
|
||||
shape->add_dim()->set_size(dim1);
|
||||
shape->add_dim()->set_size(dim2);
|
||||
shape->add_dim()->set_size(dim3);
|
||||
shape->add_dim()->set_size(dim4);
|
||||
tensor->set_dtype(DT_FLOAT);
|
||||
}
|
||||
|
||||
// DescribeConvolution constructs an OpContext for a Conv2D applied to an input
|
||||
// tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
|
||||
// (kx, ky, iz2, oz).
|
||||
@ -225,7 +238,10 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
|
||||
// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
|
||||
// bias tensor with shape (oz), a side input tensor with shape
|
||||
// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
|
||||
// shape (1).
|
||||
// shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
|
||||
// default to 4 (the vector size most often used with this format on NVIDIA
|
||||
// platforms) for the major channel size, and divide the input channel size by
|
||||
// that amount.
|
||||
//
|
||||
// Note that this assumes the NHWC data format.
|
||||
OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
|
||||
@ -233,6 +249,7 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
|
||||
int oy, int oz, bool has_side_input,
|
||||
const string& data_format,
|
||||
const string& filter_format) {
|
||||
const int kVecWidth = 4;
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
op_context.op_info.set_op("FusedConv2DBiasActivation");
|
||||
@ -242,15 +259,25 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
|
||||
SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
|
||||
if (data_format == "NHWC") {
|
||||
DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
|
||||
} else {
|
||||
// Use the NCHW format.
|
||||
} else if (data_format == "NCHW") {
|
||||
DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
|
||||
} else {
|
||||
// Use the NCHW_VECT_C format.
|
||||
EXPECT_EQ(data_format, "NCHW_VECT_C");
|
||||
EXPECT_EQ(iz1 % kVecWidth, 0);
|
||||
DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
|
||||
op_context.op_info.add_inputs());
|
||||
}
|
||||
if (filter_format == "HWIO") {
|
||||
DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
|
||||
} else {
|
||||
// Use the OIHW format.
|
||||
} else if (filter_format == "OIHW") {
|
||||
DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
|
||||
} else {
|
||||
EXPECT_EQ(filter_format, "OIHW_VECT_I");
|
||||
EXPECT_EQ(iz2 % kVecWidth, 0);
|
||||
// Use the OIHW_VECT_I format.
|
||||
DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
|
||||
op_context.op_info.add_inputs());
|
||||
}
|
||||
DescribeTensor1D(oz, op_context.op_info.add_inputs());
|
||||
|
||||
@ -259,8 +286,13 @@ OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
|
||||
if (has_side_input) {
|
||||
if (data_format == "NHWC") {
|
||||
DescribeTensor4D(batch, ox, oy, oz, side_input);
|
||||
} else {
|
||||
} else if (data_format == "NCHW") {
|
||||
DescribeTensor4D(batch, oz, ox, oy, side_input);
|
||||
} else {
|
||||
// Use the NCHW_VECT_C format.
|
||||
EXPECT_EQ(data_format, "NCHW_VECT_C");
|
||||
EXPECT_EQ(oz % kVecWidth, 0);
|
||||
DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
|
||||
}
|
||||
}
|
||||
|
||||
@ -810,29 +842,40 @@ TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
// TODO(yaozhang): Update once NCHW_VECT_C is supported.
|
||||
TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
|
||||
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
|
||||
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
|
||||
"NCHW_VECT_C", "OIHW"));
|
||||
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.execution_time);
|
||||
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
// TODO(yaozhang): Update once OIHW_VECT_I is supported.
|
||||
TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
|
||||
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
|
||||
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
|
||||
"NCHW", "OIHW_VECT_I"));
|
||||
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.execution_time);
|
||||
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_TRUE(cost.inaccurate);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest,
|
||||
FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
|
||||
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
|
||||
16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
|
||||
"NCHW_VECT_C", "OIHW_VECT_I"));
|
||||
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user