diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index f9e19493a86..c3d8df85b6c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -579,105 +579,119 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse // topological order. - for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - auto hlo = *inst_it; - auto adjust_hlo_output = [this, hlo, ¶meter_changed]( - const Shape& /* subshape */, - const ShapeIndex& index) { - auto output_type = OutputTypeAfterChange(hlo, index); - if (output_type != F32 && output_type != BF16) { - return; - } - PrimitiveType type = BF16; - for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - auto value_type = ValueTypeAfterChange(value); - if (value_type == BF16) { - continue; + while (true) { + bool any_change = false; + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + auto adjust_hlo_output = [&](const Shape& /* subshape */, + const ShapeIndex& index) { + auto output_type = OutputTypeAfterChange(hlo, index); + VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32") + << " for :" << hlo->ToString() << "\n"; + if (output_type != F32 && output_type != BF16) { + return; } - CHECK_EQ(value_type, F32); - type = F32; - break; - } - // In order to find aliases due to in-place operations, use - // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here, - // but this code works with HloModules that aren't ready yet to use - // HloAliasAnalysis (e.g., their computation graphs may not have been - // flattened yet). - for (const auto& operand_and_output_index : - HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) { - if (operand_and_output_index.second == index) { - const HloUse& operand = operand_and_output_index.first; - for (const auto* value : - dataflow_ - ->GetValueSet(hlo->operand(operand.operand_number), - operand.operand_index) - .values()) { - auto value_type = ValueTypeAfterChange(value); - if (value_type == BF16) { - continue; + PrimitiveType type = BF16; + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + VLOG(2) << "Adjust to F32 due to aliased dataflow value: " + << value->ToString() << "\n"; + CHECK_EQ(value_type, F32); + type = F32; + break; + } + // In order to find aliases due to in-place operations, use + // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here, + // but this code works with HloModules that aren't ready yet to use + // HloAliasAnalysis (e.g., their computation graphs may not have been + // flattened yet). + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) { + if (operand_and_output_index.second == index) { + const HloUse& operand = operand_and_output_index.first; + for (const auto* value : + dataflow_ + ->GetValueSet(hlo->operand(operand.operand_number), + operand.operand_index) + .values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + VLOG(2) << "Adjust to F32 due to InputOutPair: " + << value->ToString() << "\n"; + CHECK_EQ(value_type, F32); + type = F32; + break; } - CHECK_EQ(value_type, F32); - type = F32; - break; } } - } - // It's possible that a user has been changed from BF16 to F32 - // during this final adjustment pass, so we need to check - // AllUsersConsumeBF16() again. - if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { - type = F32; - } - if (type == F32) { - for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - // We rely on the fact that this adjustment works in reverse - // topological order. Adding the value to - // values_that_must_be_kept_as_f32_ will ensure the correctness - // of the adjustment for HLOs that will be processed later. - values_that_must_be_kept_as_f32_.insert(value); + // It's possible that a user has been changed from BF16 to F32 + // during this final adjustment pass, so we need to check + // AllUsersConsumeBF16() again. + if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { + VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n"; + type = F32; + } + if (type == F32) { + for (const auto* value : + dataflow_->GetValueSet(hlo, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the correctness + // of the adjustment for HLOs that will be processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + if (type != output_type) { + any_change = true; + AddToOrRemoveFromBF16ChangeSet(hlo, index, type); + VLOG(2) << "HloInstruction output at shape index " << index + << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " + << hlo->ToString(); + if (hlo->opcode() == HloOpcode::kParameter) { + parameter_changed = true; + } + } + }; + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + if (hlo->opcode() == HloOpcode::kWhile) { + // We need to run on the while body and condition repeatedly until a + // fixed point is reached, i.e., the parameters do not change any more. + // We may need more than one iteration because the while input and + // output alias each other, so changing one input parameter requires + // changing the corresponding output element and thus may transitively + // require changing another input parameter. A fixed point will be + // reached because the parameters can only be changed from BF16 to F32, + // not the other way around. + absl::flat_hash_set visited_in_while; + while (ResolveInconsistencyOfAliasingBuffersHelper( + hlo->while_condition(), &visited_in_while) || + ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), + &visited_in_while)) { + visited_in_while.clear(); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + } + visited_computations->insert(visited_in_while.begin(), + visited_in_while.end()); + } else if (hlo->opcode() == HloOpcode::kFusion) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->fused_instructions_computation(), visited_computations); + } else if (hlo->opcode() == HloOpcode::kConditional) { + for (auto* branch : hlo->branch_computations()) { + ResolveInconsistencyOfAliasingBuffersHelper(branch, + visited_computations); } } - if (type != output_type) { - AddToOrRemoveFromBF16ChangeSet(hlo, index, type); - VLOG(2) << "HloInstruction output at shape index " << index - << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " - << hlo->ToString(); - if (hlo->opcode() == HloOpcode::kParameter) { - parameter_changed = true; - } - } - }; - ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); - AdjustCalledComputationRoot(hlo); - if (hlo->opcode() == HloOpcode::kWhile) { - // We need to run on the while body and condition repeatedly until a fixed - // point is reached, i.e., the parameters do not change any more. We may - // need more than one iteration because the while input and output alias - // each other, so changing one input parameter requires changing the - // corresponding output element and thus may transitively require changing - // another input parameter. A fixed point will be reached because the - // parameters can only be changed from BF16 to F32, not the other way - // around. - absl::flat_hash_set visited_in_while; - while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), - &visited_in_while) || - ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), - &visited_in_while)) { - visited_in_while.clear(); - ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); - AdjustCalledComputationRoot(hlo); - } - visited_computations->insert(visited_in_while.begin(), - visited_in_while.end()); - } else if (hlo->opcode() == HloOpcode::kFusion) { - ResolveInconsistencyOfAliasingBuffersHelper( - hlo->fused_instructions_computation(), visited_computations); - } else if (hlo->opcode() == HloOpcode::kConditional) { - for (auto* branch : hlo->branch_computations()) { - ResolveInconsistencyOfAliasingBuffersHelper(branch, - visited_computations); - } + } + if (!any_change) { + break; } } // Now adjust parameters of called computations. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 9a898833373..bb99b6454bd 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -1182,4 +1182,87 @@ ENTRY main { EXPECT_FALSE(OutputsBF16(dus)); } +// This test demonstrates the need for invoking the ResolveAliasingBuffer +// multiple times via a fixed-point algorithm. The key was the aliasing of the +// two output buffers of the conditional, at subshape 0 (first element). This +// aliasing is not resolved until after the gte0 variale is already processed, +// triggering incorrect type for gte0 if not repeating the aliasing analysis. +TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) { + const string module_str = R"( +HloModule module + +%add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] { + x.1 = f32[4096,4096] parameter(0) + y.1 = f32[4096,4096] parameter(1) + ROOT dot1 = f32[4096,4096] dot(x.1, y.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +%add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] { + x.1 = f32[4096,4096] parameter(0) + y.1 = f32[4096,4096] parameter(1) + ROOT dot1 = f32[4096,4096] dot(x.1, y.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +%add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] { + x.1 = f32[4096,4096] parameter(0) + y.1 = f32[4096,4096] parameter(1) + ROOT dot1 = f32[4096,4096] dot(x.1, y.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +%add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] { + x.1 = f32[4096,4096] parameter(0) + y.1 = f32[4096,4096] parameter(1) + ROOT dot1 = f32[4096,4096] dot(x.1, y.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +true_branch { + true_param = f32[4096,4096] parameter(0) + constant.1 = f32[4096,4096] constant(0) + add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0 + constant.2 = f32[4096,4096] constant(0) + ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(true_param,add0,constant.2) +} + +false_branch { + false_param = f32[4096,4096] parameter(0) + add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1 + constant.1 = f32[4096,4096] constant(0) + ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(add3, add3,constant.1) +} + +ENTRY entry { + param0 = f32[4096,4096] parameter(0) + copy0 = f32[4096,4096] copy(param0) + param1 = pred[] parameter(1) + conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0), + true_computation=true_branch, false_computation=false_branch + gte = f32[4096,4096] get-tuple-element(conditional), index=0 + gte1 = f32[4096,4096] get-tuple-element(conditional), index=1 + gte2 = f32[4096,4096] get-tuple-element(conditional), index=2 + add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2 + ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3 + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_TRUE(PropagatePrecision(module.get())); + VLOG(2) << module->ToString() << "\n"; + EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status() + .ok()); + auto gte = FindInstruction(module.get(), "gte"); + auto gte1 = FindInstruction(module.get(), "gte1"); + auto gte2 = FindInstruction(module.get(), "gte2"); + EXPECT_FALSE(OutputsBF16(gte)); + EXPECT_FALSE(OutputsBF16(gte1)); + EXPECT_TRUE(OutputsBF16(gte2)); +} + } // namespace xla