[XLA]DynamicDimensionInference: Handle multi batch dimension

This cl:
- Rework on dynamic dimension inference's DOT handling.
- Correctly support multi batch dimension on DOT (found in BERT).
- Add tests.

PiperOrigin-RevId: 248741549
This commit is contained in:
Yunxing Dai 2019-05-17 10:13:35 -07:00 committed by TensorFlower Gardener
parent 6d07c2e23f
commit ac0f632592
2 changed files with 112 additions and 43 deletions

View File

@ -222,54 +222,84 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
} }
Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
return ForEachOperandDynamicDimension( return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, ShapeIndex operand_shape_index,
int64 operand_index, HloInstruction* dynamic_size) { int64 operand_dimension,
HloInstruction* dot = hlo; int64 operand_index,
const DotDimensionNumbers& dimension_numbers = HloInstruction* dynamic_size) {
dot->dot_dimension_numbers(); // There are three types of dimensions in a dot:
// A map from the operand dimensions to result dimension. // A. batch dims
absl::flat_hash_map<int64, int64> result_dim_mapping; // B. contracting dims
int64 current_result_dims = 0; // C. non-batch non-contracting dims.
std::unordered_set<int64> batch_dims( // The output dimemsions of a dot has three parts with the following order:
dimension_numbers.rhs_batch_dimensions().begin(), // [(type A), (lhs type C), (rhs type C)]
dimension_numbers.rhs_batch_dimensions().end()); //
// Note that both lhs and rhs have the same dimension sizes for batch,
// but the dimension index could be different.
//
// Given one dynamic input dimension, either lhs or rhs, we use a
// mapping to find the corresponding output dimension.
HloInstruction* dot = hlo;
const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
// A map from the operand dimensions to result dimension.
absl::flat_hash_map<int64, int64> result_dim_mapping;
int64 current_result_dims = 0;
for (int64 i : dimension_numbers.rhs_batch_dimensions()) { bool lhs = operand_index == 0;
result_dim_mapping[i] = current_result_dims++;
}
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { // The first loop keep tracks of batch dimension. RHS and LHS could have
if (!absl::c_linear_search( // diffrent batch dimension numbers.
dimension_numbers.lhs_contracting_dimensions(), i)) { if (lhs) {
if (operand_index == 0) { for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
result_dim_mapping[i] = current_result_dims; result_dim_mapping[i] = current_result_dims++;
} }
current_result_dims++; } else {
} for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
} result_dim_mapping[i] = current_result_dims++;
}
}
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { // Handle dimensions in the lhs.
if (!absl::c_linear_search( for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
dimension_numbers.rhs_contracting_dimensions(), i) && // Look for non-contracting and non-batching dimension.
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
i)) { i)) {
if (operand_index == 1) { continue;
result_dim_mapping[i] = current_result_dims; }
} if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
current_result_dims++; continue;
} }
} if (lhs) {
result_dim_mapping[i] = current_result_dims;
}
current_result_dims++;
}
// Check if the operand dim is in the result shape. If so, add another // Handle dimensions in the rhs.
// work item to trace that dimension. for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
auto iter = result_dim_mapping.find(dimension); // Look for non-contracting and non-batching dimension.
if (iter != result_dim_mapping.end()) { if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); i)) {
} continue;
}
if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
continue;
}
if (!lhs) {
result_dim_mapping[i] = current_result_dims;
}
current_result_dims++;
}
return Status::OK(); // Check if the operand dim is in the result shape. If so, add another
}); // work item to trace that dimension.
auto iter = result_dim_mapping.find(operand_dimension);
if (iter != result_dim_mapping.end()) {
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
}
return Status::OK();
});
} }
Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {

View File

@ -344,6 +344,45 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) {
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
} }
TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
auto builder = HloComputation::Builder(TestName());
auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, lhs_shape, "A"));
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, rhs_shape, "B"));
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, scalar_shape_, "size_param"));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(3);
dot_dnums.add_rhs_contracting_dimensions(3);
dot_dnums.add_lhs_batch_dimensions(0);
dot_dnums.add_lhs_batch_dimensions(2);
dot_dnums.add_rhs_batch_dimensions(0);
dot_dnums.add_rhs_batch_dimensions(2);
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
HloTestBase::DefaultPrecisionConfig(2)));
module_->AddEntryComputation(builder.Build());
// Set up dynamic parameter binding for batch dimension.
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
auto builder = HloComputation::Builder(TestName()); auto builder = HloComputation::Builder(TestName());
constexpr int xdim = 3; constexpr int xdim = 3;