[XLA]DynamicDimensionInference: Handle multi batch dimension
This cl: - Rework on dynamic dimension inference's DOT handling. - Correctly support multi batch dimension on DOT (found in BERT). - Add tests. PiperOrigin-RevId: 248741549
This commit is contained in:
parent
6d07c2e23f
commit
ac0f632592
@ -222,54 +222,84 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||||
return ForEachOperandDynamicDimension(
|
return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
|
||||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
ShapeIndex operand_shape_index,
|
||||||
int64 operand_index, HloInstruction* dynamic_size) {
|
int64 operand_dimension,
|
||||||
HloInstruction* dot = hlo;
|
int64 operand_index,
|
||||||
const DotDimensionNumbers& dimension_numbers =
|
HloInstruction* dynamic_size) {
|
||||||
dot->dot_dimension_numbers();
|
// There are three types of dimensions in a dot:
|
||||||
// A map from the operand dimensions to result dimension.
|
// A. batch dims
|
||||||
absl::flat_hash_map<int64, int64> result_dim_mapping;
|
// B. contracting dims
|
||||||
int64 current_result_dims = 0;
|
// C. non-batch non-contracting dims.
|
||||||
std::unordered_set<int64> batch_dims(
|
// The output dimemsions of a dot has three parts with the following order:
|
||||||
dimension_numbers.rhs_batch_dimensions().begin(),
|
// [(type A), (lhs type C), (rhs type C)]
|
||||||
dimension_numbers.rhs_batch_dimensions().end());
|
//
|
||||||
|
// Note that both lhs and rhs have the same dimension sizes for batch,
|
||||||
|
// but the dimension index could be different.
|
||||||
|
//
|
||||||
|
// Given one dynamic input dimension, either lhs or rhs, we use a
|
||||||
|
// mapping to find the corresponding output dimension.
|
||||||
|
HloInstruction* dot = hlo;
|
||||||
|
const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
|
||||||
|
// A map from the operand dimensions to result dimension.
|
||||||
|
absl::flat_hash_map<int64, int64> result_dim_mapping;
|
||||||
|
int64 current_result_dims = 0;
|
||||||
|
|
||||||
for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
|
bool lhs = operand_index == 0;
|
||||||
result_dim_mapping[i] = current_result_dims++;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
|
// The first loop keep tracks of batch dimension. RHS and LHS could have
|
||||||
if (!absl::c_linear_search(
|
// diffrent batch dimension numbers.
|
||||||
dimension_numbers.lhs_contracting_dimensions(), i)) {
|
if (lhs) {
|
||||||
if (operand_index == 0) {
|
for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
|
||||||
result_dim_mapping[i] = current_result_dims;
|
result_dim_mapping[i] = current_result_dims++;
|
||||||
}
|
}
|
||||||
current_result_dims++;
|
} else {
|
||||||
}
|
for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
|
||||||
}
|
result_dim_mapping[i] = current_result_dims++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
|
// Handle dimensions in the lhs.
|
||||||
if (!absl::c_linear_search(
|
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
|
||||||
dimension_numbers.rhs_contracting_dimensions(), i) &&
|
// Look for non-contracting and non-batching dimension.
|
||||||
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
|
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
||||||
i)) {
|
i)) {
|
||||||
if (operand_index == 1) {
|
continue;
|
||||||
result_dim_mapping[i] = current_result_dims;
|
}
|
||||||
}
|
if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
|
||||||
current_result_dims++;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
if (lhs) {
|
||||||
|
result_dim_mapping[i] = current_result_dims;
|
||||||
|
}
|
||||||
|
current_result_dims++;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the operand dim is in the result shape. If so, add another
|
// Handle dimensions in the rhs.
|
||||||
// work item to trace that dimension.
|
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
|
||||||
auto iter = result_dim_mapping.find(dimension);
|
// Look for non-contracting and non-batching dimension.
|
||||||
if (iter != result_dim_mapping.end()) {
|
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
||||||
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
|
i)) {
|
||||||
}
|
continue;
|
||||||
|
}
|
||||||
|
if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!lhs) {
|
||||||
|
result_dim_mapping[i] = current_result_dims;
|
||||||
|
}
|
||||||
|
current_result_dims++;
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
// Check if the operand dim is in the result shape. If so, add another
|
||||||
});
|
// work item to trace that dimension.
|
||||||
|
auto iter = result_dim_mapping.find(operand_dimension);
|
||||||
|
if (iter != result_dim_mapping.end()) {
|
||||||
|
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
||||||
|
@ -344,6 +344,45 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) {
|
|||||||
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
|
||||||
|
auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
|
||||||
|
auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
|
||||||
|
|
||||||
|
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/0, lhs_shape, "A"));
|
||||||
|
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/1, rhs_shape, "B"));
|
||||||
|
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
/*parameter_number=*/2, scalar_shape_, "size_param"));
|
||||||
|
|
||||||
|
DotDimensionNumbers dot_dnums;
|
||||||
|
dot_dnums.add_lhs_contracting_dimensions(3);
|
||||||
|
dot_dnums.add_rhs_contracting_dimensions(3);
|
||||||
|
dot_dnums.add_lhs_batch_dimensions(0);
|
||||||
|
dot_dnums.add_lhs_batch_dimensions(2);
|
||||||
|
dot_dnums.add_rhs_batch_dimensions(0);
|
||||||
|
dot_dnums.add_rhs_batch_dimensions(2);
|
||||||
|
auto dot = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
|
||||||
|
HloTestBase::DefaultPrecisionConfig(2)));
|
||||||
|
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
// Set up dynamic parameter binding for batch dimension.
|
||||||
|
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
|
||||||
|
DynamicParameterBinding::DynamicParameter{2, {}},
|
||||||
|
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
|
||||||
|
|
||||||
|
SCOPED_TRACE(module_->ToString());
|
||||||
|
TF_ASSERT_OK(RunInference());
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
|
||||||
|
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
constexpr int xdim = 3;
|
constexpr int xdim = 3;
|
||||||
|
Loading…
Reference in New Issue
Block a user