Correctly handle batch_group_count in cases with depthwise_multipliers

PiperOrigin-RevId: 330008877
Change-Id: Icbda17222a58efc94aabb4c02e892a2504cb9b34
This commit is contained in:
A. Unique TensorFlower 2020-09-03 15:49:02 -07:00 committed by TensorFlower Gardener
parent 332f2338ce
commit ec878bb3e3
3 changed files with 59 additions and 18 deletions

View File

@ -1157,7 +1157,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 feature_group_index =
out_index[output_z_dim] / output_feature_group_size;
const int64 batch_group_index = out_index[output_z_dim];
const int64 depthwise_multiplier =
batch_group_count > 1 ? output_z_size / input_batch_size : 1;
const int64 batch_group_index =
out_index[output_z_dim] / depthwise_multiplier;
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@ -1218,7 +1221,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
feature_group_index * input_feature_group_size + rhs_iz;
int64 lhs_linear_index = lhs_linear_spatial_index;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
@ -1233,7 +1235,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
lhs_dim_multipliers[input_batch_dim];
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
int64 rhs_linear_index = rhs_linear_spatial_index;
rhs_linear_index += out_index[output_z_dim] *

View File

@ -385,7 +385,7 @@ xla_test(
"gpu",
"cpu",
],
shard_count = 6,
shard_count = 40,
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:execution_options_util",

View File

@ -45,13 +45,20 @@ class BatchGroupedConvolution2DTest
public ::testing::WithParamInterface<
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
class BatchGroupedConvolution2DDepthTest
: public HloTestBase,
public ::testing::WithParamInterface<
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases(
bool use_depth_multiplier) {
std::vector<BatchGroupedConvolution2DSpec> config_set;
std::vector<std::vector<int64>> config_options = {
{8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128},
{16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4},
{256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}};
{129, 10, 3, 2}, {4, 3, 3, 258}, {8, 4, 2, 128},
{8, 3, 2, 256}, {256, 7, 5, 4}, {128, 6, 6, 4},
{32, 5, 2, 129}, {16, 4, 3, 2}, {16, 3, 2, 64}};
int64 counter = 2;
for (auto option : config_options) {
int64 feature = option[3];
int64 activation_size = option[1];
@ -65,10 +72,16 @@ static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
config.activation_dims = {batch, activation_size, activation_size, feature};
config.kernel_dims = {batch, kernel_size, kernel_size, feature};
const int64 depthwise_multiplier = use_depth_multiplier ? counter++ : 1;
config.kernel_dims = {batch, kernel_size, kernel_size,
feature * depthwise_multiplier};
// Don't let the counter grow too much, else the compute demand will grow.
if (counter == 4) {
counter = 2;
}
int64 output_space_size = 3 + activation_size - kernel_size;
config.output_dims = {output_space_size, output_space_size, feature, 1};
config.output_dims = {output_space_size, output_space_size,
feature * depthwise_multiplier, 1};
config.activation_and_kernel_layout = {0, 3, 1, 2};
config.output_layout = {2, 3, 0, 1};
@ -123,11 +136,13 @@ string BatchGroupedConvolution2DTestDataToString(
}
string BuildHloTextBatchGroupedConvolution2D(
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) {
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16,
bool scheduled = false) {
const string data_type = GetFloatDataType(use_bfloat16);
const string scheduled_tag = scheduled ? ",is_scheduled=true" : "";
return absl::StrFormat(
R"(
HloModule TensorFlowDepthwiseConv, is_scheduled=true
HloModule TensorFlowDepthwiseConv %s
ENTRY main {
activation = %s[%s]{%s} parameter(0)
@ -137,7 +152,7 @@ string BuildHloTextBatchGroupedConvolution2D(
batch_group_count=%d
}
)",
data_type, absl::StrJoin(spec.activation_dims, ","),
scheduled_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
absl::StrJoin(spec.kernel_dims, ","),
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
@ -161,8 +176,8 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
}
#endif
const string hlo_text =
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
spec, use_bfloat16, /*scheduled=*/true);
EXPECT_TRUE(RunAndCompareNoHloPasses(
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
@ -176,8 +191,33 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
INSTANTIATE_TEST_CASE_P(
BatchGroupedConvolution2DTestWithRandomIndices,
BatchGroupedConvolution2DTest,
::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
::testing::Bool()),
::testing::Combine(
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/false)),
::testing::Bool()),
BatchGroupedConvolution2DTestDataToString);
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
bool use_bfloat16 = ::testing::get<1>(GetParam());
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
if (use_bfloat16) {
return;
}
#endif
const string hlo_text =
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
}
INSTANTIATE_TEST_CASE_P(
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
BatchGroupedConvolution2DDepthTest,
::testing::Combine(
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
::testing::Bool()),
BatchGroupedConvolution2DTestDataToString);
} // namespace