Correctly handle batch_group_count in cases with depthwise_multipliers
PiperOrigin-RevId: 330008877 Change-Id: Icbda17222a58efc94aabb4c02e892a2504cb9b34
This commit is contained in:
parent
332f2338ce
commit
ec878bb3e3
@ -1157,7 +1157,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const int64 feature_group_index =
|
||||
out_index[output_z_dim] / output_feature_group_size;
|
||||
|
||||
const int64 batch_group_index = out_index[output_z_dim];
|
||||
const int64 depthwise_multiplier =
|
||||
batch_group_count > 1 ? output_z_size / input_batch_size : 1;
|
||||
const int64 batch_group_index =
|
||||
out_index[output_z_dim] / depthwise_multiplier;
|
||||
|
||||
ElementwiseT result_val = static_cast<ElementwiseT>(0);
|
||||
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
|
||||
@ -1218,7 +1221,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
feature_group_index * input_feature_group_size + rhs_iz;
|
||||
|
||||
int64 lhs_linear_index = lhs_linear_spatial_index;
|
||||
|
||||
lhs_linear_index += out_index[output_batch_dim] *
|
||||
lhs_dim_multipliers[input_batch_dim];
|
||||
|
||||
@ -1233,7 +1235,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
lhs_dim_multipliers[input_batch_dim];
|
||||
|
||||
lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
|
||||
|
||||
int64 rhs_linear_index = rhs_linear_spatial_index;
|
||||
|
||||
rhs_linear_index += out_index[output_z_dim] *
|
||||
|
@ -385,7 +385,7 @@ xla_test(
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
shard_count = 6,
|
||||
shard_count = 40,
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
|
@ -45,13 +45,20 @@ class BatchGroupedConvolution2DTest
|
||||
public ::testing::WithParamInterface<
|
||||
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
|
||||
|
||||
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
|
||||
class BatchGroupedConvolution2DDepthTest
|
||||
: public HloTestBase,
|
||||
public ::testing::WithParamInterface<
|
||||
::testing::tuple<BatchGroupedConvolution2DSpec, bool>> {};
|
||||
|
||||
static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases(
|
||||
bool use_depth_multiplier) {
|
||||
std::vector<BatchGroupedConvolution2DSpec> config_set;
|
||||
std::vector<std::vector<int64>> config_options = {
|
||||
{8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128},
|
||||
{16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4},
|
||||
{256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}};
|
||||
{129, 10, 3, 2}, {4, 3, 3, 258}, {8, 4, 2, 128},
|
||||
{8, 3, 2, 256}, {256, 7, 5, 4}, {128, 6, 6, 4},
|
||||
{32, 5, 2, 129}, {16, 4, 3, 2}, {16, 3, 2, 64}};
|
||||
|
||||
int64 counter = 2;
|
||||
for (auto option : config_options) {
|
||||
int64 feature = option[3];
|
||||
int64 activation_size = option[1];
|
||||
@ -65,10 +72,16 @@ static std::vector<BatchGroupedConvolution2DSpec> GetConv2DTestCases() {
|
||||
|
||||
config.activation_dims = {batch, activation_size, activation_size, feature};
|
||||
|
||||
config.kernel_dims = {batch, kernel_size, kernel_size, feature};
|
||||
|
||||
const int64 depthwise_multiplier = use_depth_multiplier ? counter++ : 1;
|
||||
config.kernel_dims = {batch, kernel_size, kernel_size,
|
||||
feature * depthwise_multiplier};
|
||||
// Don't let the counter grow too much, else the compute demand will grow.
|
||||
if (counter == 4) {
|
||||
counter = 2;
|
||||
}
|
||||
int64 output_space_size = 3 + activation_size - kernel_size;
|
||||
config.output_dims = {output_space_size, output_space_size, feature, 1};
|
||||
config.output_dims = {output_space_size, output_space_size,
|
||||
feature * depthwise_multiplier, 1};
|
||||
|
||||
config.activation_and_kernel_layout = {0, 3, 1, 2};
|
||||
config.output_layout = {2, 3, 0, 1};
|
||||
@ -123,11 +136,13 @@ string BatchGroupedConvolution2DTestDataToString(
|
||||
}
|
||||
|
||||
string BuildHloTextBatchGroupedConvolution2D(
|
||||
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) {
|
||||
const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16,
|
||||
bool scheduled = false) {
|
||||
const string data_type = GetFloatDataType(use_bfloat16);
|
||||
const string scheduled_tag = scheduled ? ",is_scheduled=true" : "";
|
||||
return absl::StrFormat(
|
||||
R"(
|
||||
HloModule TensorFlowDepthwiseConv, is_scheduled=true
|
||||
HloModule TensorFlowDepthwiseConv %s
|
||||
|
||||
ENTRY main {
|
||||
activation = %s[%s]{%s} parameter(0)
|
||||
@ -137,7 +152,7 @@ string BuildHloTextBatchGroupedConvolution2D(
|
||||
batch_group_count=%d
|
||||
}
|
||||
)",
|
||||
data_type, absl::StrJoin(spec.activation_dims, ","),
|
||||
scheduled_tag, data_type, absl::StrJoin(spec.activation_dims, ","),
|
||||
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
||||
absl::StrJoin(spec.kernel_dims, ","),
|
||||
absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type,
|
||||
@ -161,8 +176,8 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
||||
}
|
||||
#endif
|
||||
|
||||
const string hlo_text =
|
||||
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
||||
const string hlo_text = BuildHloTextBatchGroupedConvolution2D(
|
||||
spec, use_bfloat16, /*scheduled=*/true);
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(
|
||||
hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status {
|
||||
@ -176,8 +191,33 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) {
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
BatchGroupedConvolution2DTestWithRandomIndices,
|
||||
BatchGroupedConvolution2DTest,
|
||||
::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()),
|
||||
::testing::Bool()),
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/false)),
|
||||
::testing::Bool()),
|
||||
BatchGroupedConvolution2DTestDataToString);
|
||||
|
||||
XLA_TEST_P(BatchGroupedConvolution2DDepthTest, DoItDepth) {
|
||||
const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam());
|
||||
bool use_bfloat16 = ::testing::get<1>(GetParam());
|
||||
|
||||
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_BFLOAT16
|
||||
if (use_bfloat16) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const string hlo_text =
|
||||
BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
BatchGroupedConvolution2DDepthMultiplierTestWithRandomIndices,
|
||||
BatchGroupedConvolution2DDepthTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(GetConv2DTestCases(/*use_depth_multiplier=*/true)),
|
||||
::testing::Bool()),
|
||||
BatchGroupedConvolution2DTestDataToString);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user