[XLA]DynamicDimensionInference: Handle multi batch dimension
This cl: - Rework on dynamic dimension inference's DOT handling. - Correctly support multi batch dimension on DOT (found in BERT). - Add tests. PiperOrigin-RevId: 248741549
This commit is contained in:
parent
6d07c2e23f
commit
ac0f632592
@ -222,48 +222,78 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
|
||||
ShapeIndex operand_shape_index,
|
||||
int64 operand_dimension,
|
||||
int64 operand_index,
|
||||
HloInstruction* dynamic_size) {
|
||||
// There are three types of dimensions in a dot:
|
||||
// A. batch dims
|
||||
// B. contracting dims
|
||||
// C. non-batch non-contracting dims.
|
||||
// The output dimemsions of a dot has three parts with the following order:
|
||||
// [(type A), (lhs type C), (rhs type C)]
|
||||
//
|
||||
// Note that both lhs and rhs have the same dimension sizes for batch,
|
||||
// but the dimension index could be different.
|
||||
//
|
||||
// Given one dynamic input dimension, either lhs or rhs, we use a
|
||||
// mapping to find the corresponding output dimension.
|
||||
HloInstruction* dot = hlo;
|
||||
const DotDimensionNumbers& dimension_numbers =
|
||||
dot->dot_dimension_numbers();
|
||||
const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
|
||||
// A map from the operand dimensions to result dimension.
|
||||
absl::flat_hash_map<int64, int64> result_dim_mapping;
|
||||
int64 current_result_dims = 0;
|
||||
std::unordered_set<int64> batch_dims(
|
||||
dimension_numbers.rhs_batch_dimensions().begin(),
|
||||
dimension_numbers.rhs_batch_dimensions().end());
|
||||
|
||||
bool lhs = operand_index == 0;
|
||||
|
||||
// The first loop keep tracks of batch dimension. RHS and LHS could have
|
||||
// diffrent batch dimension numbers.
|
||||
if (lhs) {
|
||||
for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
|
||||
result_dim_mapping[i] = current_result_dims++;
|
||||
}
|
||||
} else {
|
||||
for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
|
||||
result_dim_mapping[i] = current_result_dims++;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle dimensions in the lhs.
|
||||
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
|
||||
if (!absl::c_linear_search(
|
||||
dimension_numbers.lhs_contracting_dimensions(), i)) {
|
||||
if (operand_index == 0) {
|
||||
// Look for non-contracting and non-batching dimension.
|
||||
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
||||
i)) {
|
||||
continue;
|
||||
}
|
||||
if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
|
||||
continue;
|
||||
}
|
||||
if (lhs) {
|
||||
result_dim_mapping[i] = current_result_dims;
|
||||
}
|
||||
current_result_dims++;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle dimensions in the rhs.
|
||||
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
|
||||
if (!absl::c_linear_search(
|
||||
dimension_numbers.rhs_contracting_dimensions(), i) &&
|
||||
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
|
||||
// Look for non-contracting and non-batching dimension.
|
||||
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
||||
i)) {
|
||||
if (operand_index == 1) {
|
||||
continue;
|
||||
}
|
||||
if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
|
||||
continue;
|
||||
}
|
||||
if (!lhs) {
|
||||
result_dim_mapping[i] = current_result_dims;
|
||||
}
|
||||
current_result_dims++;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the operand dim is in the result shape. If so, add another
|
||||
// work item to trace that dimension.
|
||||
auto iter = result_dim_mapping.find(dimension);
|
||||
auto iter = result_dim_mapping.find(operand_dimension);
|
||||
if (iter != result_dim_mapping.end()) {
|
||||
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
|
||||
}
|
||||
|
@ -344,6 +344,45 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) {
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
|
||||
auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
|
||||
|
||||
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, lhs_shape, "A"));
|
||||
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, rhs_shape, "B"));
|
||||
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(3);
|
||||
dot_dnums.add_rhs_contracting_dimensions(3);
|
||||
dot_dnums.add_lhs_batch_dimensions(0);
|
||||
dot_dnums.add_lhs_batch_dimensions(2);
|
||||
dot_dnums.add_rhs_batch_dimensions(0);
|
||||
dot_dnums.add_rhs_batch_dimensions(2);
|
||||
auto dot = builder.AddInstruction(
|
||||
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
|
||||
HloTestBase::DefaultPrecisionConfig(2)));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
// Set up dynamic parameter binding for batch dimension.
|
||||
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
TF_ASSERT_OK(RunInference());
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
constexpr int xdim = 3;
|
||||
|
Loading…
Reference in New Issue
Block a user