[XLA]DynamicDimensionInference: Handle multi batch dimension

This cl:
- Rework on dynamic dimension inference's DOT handling.
- Correctly support multi batch dimension on DOT (found in BERT).
- Add tests.

PiperOrigin-RevId: 248741549
This commit is contained in:
Yunxing Dai 2019-05-17 10:13:35 -07:00 committed by TensorFlower Gardener
parent 6d07c2e23f
commit ac0f632592
2 changed files with 112 additions and 43 deletions

View File

@ -222,48 +222,78 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
}
Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size) {
return ForEachOperandDynamicDimension(hlo, [&](HloInstruction* operand,
ShapeIndex operand_shape_index,
int64 operand_dimension,
int64 operand_index,
HloInstruction* dynamic_size) {
// There are three types of dimensions in a dot:
// A. batch dims
// B. contracting dims
// C. non-batch non-contracting dims.
// The output dimemsions of a dot has three parts with the following order:
// [(type A), (lhs type C), (rhs type C)]
//
// Note that both lhs and rhs have the same dimension sizes for batch,
// but the dimension index could be different.
//
// Given one dynamic input dimension, either lhs or rhs, we use a
// mapping to find the corresponding output dimension.
HloInstruction* dot = hlo;
const DotDimensionNumbers& dimension_numbers =
dot->dot_dimension_numbers();
const DotDimensionNumbers& dimension_numbers = dot->dot_dimension_numbers();
// A map from the operand dimensions to result dimension.
absl::flat_hash_map<int64, int64> result_dim_mapping;
int64 current_result_dims = 0;
std::unordered_set<int64> batch_dims(
dimension_numbers.rhs_batch_dimensions().begin(),
dimension_numbers.rhs_batch_dimensions().end());
bool lhs = operand_index == 0;
// The first loop keep tracks of batch dimension. RHS and LHS could have
// diffrent batch dimension numbers.
if (lhs) {
for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
result_dim_mapping[i] = current_result_dims++;
}
} else {
for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
result_dim_mapping[i] = current_result_dims++;
}
}
// Handle dimensions in the lhs.
for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
if (!absl::c_linear_search(
dimension_numbers.lhs_contracting_dimensions(), i)) {
if (operand_index == 0) {
// Look for non-contracting and non-batching dimension.
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
i)) {
continue;
}
if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
continue;
}
if (lhs) {
result_dim_mapping[i] = current_result_dims;
}
current_result_dims++;
}
}
// Handle dimensions in the rhs.
for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
if (!absl::c_linear_search(
dimension_numbers.rhs_contracting_dimensions(), i) &&
!absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
// Look for non-contracting and non-batching dimension.
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
i)) {
if (operand_index == 1) {
continue;
}
if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
continue;
}
if (!lhs) {
result_dim_mapping[i] = current_result_dims;
}
current_result_dims++;
}
}
// Check if the operand dim is in the result shape. If so, add another
// work item to trace that dimension.
auto iter = result_dim_mapping.find(dimension);
auto iter = result_dim_mapping.find(operand_dimension);
if (iter != result_dim_mapping.end()) {
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
}

View File

@ -344,6 +344,45 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) {
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
auto builder = HloComputation::Builder(TestName());
auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, lhs_shape, "A"));
auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, rhs_shape, "B"));
auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, scalar_shape_, "size_param"));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(3);
dot_dnums.add_rhs_contracting_dimensions(3);
dot_dnums.add_lhs_batch_dimensions(0);
dot_dnums.add_lhs_batch_dimensions(2);
dot_dnums.add_rhs_batch_dimensions(0);
dot_dnums.add_rhs_batch_dimensions(2);
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
HloTestBase::DefaultPrecisionConfig(2)));
module_->AddEntryComputation(builder.Build());
// Set up dynamic parameter binding for batch dimension.
TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
DynamicParameterBinding::DynamicParameter{2, {}},
DynamicParameterBinding::DynamicDimension{0, {}, 0}));
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
}
TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
auto builder = HloComputation::Builder(TestName());
constexpr int xdim = 3;