Correctly handle batch_group_count in cases with depthwise_multipliers
PiperOrigin-RevId: 330008877 Change-Id: Icbda17222a58efc94aabb4c02e892a2504cb9b34
This commit is contained in:
parent
332f2338ce
commit
ec878bb3e3
@ -1157,7 +1157,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
const int64 feature_group_index =
|
const int64 feature_group_index =
|
||||||
out_index[output_z_dim] / output_feature_group_size;
|
out_index[output_z_dim] / output_feature_group_size;
|
||||||
|
|
||||||
const int64 batch_group_index = out_index[output_z_dim];
|
const int64 depthwise_multiplier =
|
||||||
|
batch_group_count > 1 ? output_z_size / input_batch_size : 1;
|
||||||
|
const int64 batch_group_index =
|
||||||
|
out_index[output_z_dim] / depthwise_multiplier;
|
||||||
|
|
||||||
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
||||||
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
|
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
|
||||||
@ -1218,7 +1221,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
feature_group_index * input_feature_group_size + rhs_iz;
|
feature_group_index * input_feature_group_size + rhs_iz;
|
||||||
|
|
||||||
int64 lhs_linear_index = lhs_linear_spatial_index;
|
int64 lhs_linear_index = lhs_linear_spatial_index;
|
||||||
|
|
||||||
lhs_linear_index += out_index[output_batch_dim] *
|
lhs_linear_index += out_index[output_batch_dim] *
|
||||||
lhs_dim_multipliers[input_batch_dim];
|
lhs_dim_multipliers[input_batch_dim];
|
||||||
|
|
||||||
@ -1233,7 +1235,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
lhs_dim_multipliers[input_batch_dim];
|
lhs_dim_multipliers[input_batch_dim];
|
||||||
|
|
||||||
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
|
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
|
||||||
|
|
||||||
int64 rhs_linear_index = rhs_linear_spatial_index;
|
int64 rhs_linear_index = rhs_linear_spatial_index;
|
||||||
|
|
||||||
rhs_linear_index += out_index[output_z_dim] *
|
rhs_linear_index += out_index[output_z_dim] *
|
||||||
|
@ -385,7 +385,7 @@ xla_test(
|
|||||||
"gpu",
|
"gpu",
|
||||||
"cpu",
|
"cpu",
|
||||||
],
|
],
|
||||||
shard_count = 6,
|
shard_count = 40,
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
"//tensorflow/compiler/xla:execution_options_util",
|
"//tensorflow/compiler/xla:execution_options_util",
|
||||||
|
@ -45,13 +45,20 @@ class BatchGroupedConvolution2DTest
|
|||||||
public ::testing::WithParamInterface<
|
public ::testing::WithParamInterface<
|
||||||
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
|
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
|
||||||
|
|
||||||
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
|
class BatchGroupedConvolution2DDepthTest
|
||||||
|
: public HloTestBase,
|
||||||
|
public ::testing::WithParamInterface<
|
||||||
|
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
|
||||||
|
|
||||||
|
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases(
|
||||||
|
bool use_depth_multiplier) {
|
||||||
std::vector<BatchGroupedConvolution2DSpec> config_set;
|
std::vector<BatchGroupedConvolution2DSpec> config_set;
|
||||||
std::vector<std::vector<int64>> config_options = {
|
std::vector<std::vector<int64>> config_options = {
|
||||||
{8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128},
|
{129, 10, 3, 2}, {4, 3, 3, 258}, {8, 4, 2, 128},
|
||||||
{16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4},
|
{8, 3, 2, 256}, {256, 7, 5, 4}, {128, 6, 6, 4},
|
||||||
{256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}};
|
{32, 5, 2, 129}, {16, 4, 3, 2}, {16, 3, 2, 64}};
|
||||||
|
|
||||||
|
int64 counter = 2;
|
||||||
for (auto option : config_options) {
|
for (auto option : config_options) {
|
||||||
int64 feature = option[3];
|
int64 feature = option[3];
|
||||||
int64 activation_size = option[1];
|
int64 activation_size = option[1];
|
||||||
@ -65,10 +72,16 @@ static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
|
|||||||
|
|
||||||
config.activation_dims = {batch, activation_size, activation_size, feature};
|
config.activation_dims = {batch, activation_size, activation_size, feature};
|
||||||
|
|
||||||
config.kernel_dims = {batch, kernel_size, kernel_size, feature};
|
const int64 depthwise_multiplier = use_depth_multiplier ? counter++ : 1;
|
||||||
|
config.kernel_dims = {batch, kernel_size, kernel_size,
|
||||||
|
feature * depthwise_multiplier};
|
||||||
|
// Don't let the counter grow too much, else the compute demand will grow.
|
||||||
|
if (counter == 4) {
|
||||||
|
counter = 2;
|
||||||
|
}
|
||||||
int64 output_space_size = 3 + activation_size - kernel_size;
|
int64 output_space_size = 3 + activation_size - kernel_size;
|
||||||
config.output_dims = {output_space_size, output_space_size, feature, 1};
|
config.output_dims = {output_space_size, output_space_size,
|
||||||
|
feature * depthwise_multiplier, 1};
|
||||||
|
|
||||||
config.activation_and_kernel_layout = {0, 3, 1, 2};
|
config.activation_and_kernel_layout = {0, 3, 1, 2};
|
||||||
config.output_layout = {2, 3, 0, 1};
|
config.output_layout = {2, 3, 0, 1};
|
||||||
@ -123,11 +136,13 @@ string BatchGroupedConvolution2DTestDataToString(
|
|||||||
}
|
}
|
||||||
|
|
||||||
string BuildHloTextBatchGroupedConvolution2D(
|
string BuildHloTextBatchGroupedConvolution2D(
|
||||||
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) {
|
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16,
|
||||||
|
bool scheduled = false) {
|
||||||
const string data_type = GetFloatDataType(use_bfloat16);
|
const string data_type = GetFloatDataType(use_bfloat16);
|
||||||
|
const string scheduled_tag = scheduled ? ",is_scheduled=true" : "";
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
R"(
|
R"(
|
||||||
HloModule TensorFlowDepthwiseConv, is_scheduled=true
|
HloModule TensorFlowDepthwiseConv %s
|
||||||
|
|
||||||
ENTRY main {
|
ENTRY main {
|
||||||
activation = %s[%s]{%s} parameter(0)
|
activation = %s[%s]{%s} parameter(0)
|
||||||
@ -137,7 +152,7 @@ string BuildHloTextBatchGroupedConvolution2D(
|
|||||||
batch_group_count=%d
|
batch_group_count=%d
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
data_type, absl::StrJoin(spec.activation_dims, ","),
|
scheduled_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
|
||||||
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
||||||
absl::StrJoin(spec.kernel_dims, ","),
|
absl::StrJoin(spec.kernel_dims, ","),
|
||||||
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
||||||
@ -161,8 +176,8 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const string hlo_text =
|
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
|
||||||
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
spec, use_bfloat16, /*scheduled=*/true);
|
||||||
|
|
||||||
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||||
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
|
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
|
||||||
@ -176,8 +191,33 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
|||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
BatchGroupedConvolution2DTestWithRandomIndices,
|
BatchGroupedConvolution2DTestWithRandomIndices,
|
||||||
BatchGroupedConvolution2DTest,
|
BatchGroupedConvolution2DTest,
|
||||||
::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
|
::testing::Combine(
|
||||||
::testing::Bool()),
|
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/false)),
|
||||||
|
::testing::Bool()),
|
||||||
|
BatchGroupedConvolution2DTestDataToString);
|
||||||
|
|
||||||
|
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
|
||||||
|
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
|
||||||
|
bool use_bfloat16 = ::testing::get<1>(GetParam());
|
||||||
|
|
||||||
|
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
|
||||||
|
if (use_bfloat16) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const string hlo_text =
|
||||||
|
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
|
||||||
|
BatchGroupedConvolution2DDepthTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
|
||||||
|
::testing::Bool()),
|
||||||
BatchGroupedConvolution2DTestDataToString);
|
BatchGroupedConvolution2DTestDataToString);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user