[XLA] Bug fix for bfloat16 propagation where fixed-point repetition of resolving aliasing operands are needed.
PiperOrigin-RevId: 340723186 Change-Id: I9e8769bd4e35b7e41d2feea9a7036cc7e2c2d303
This commit is contained in:
		
							parent
							
								
									9c510da34b
								
							
						
					
					
						commit
						4e9997d049
					
				@ -579,105 +579,119 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
 | 
			
		||||
  auto insts = computation->MakeInstructionPostOrder();
 | 
			
		||||
  // Do the adjustment on each instruction in the computation in reverse
 | 
			
		||||
  // topological order.
 | 
			
		||||
  for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
 | 
			
		||||
    auto hlo = *inst_it;
 | 
			
		||||
    auto adjust_hlo_output = [this, hlo, ¶meter_changed](
 | 
			
		||||
                                 const Shape& /* subshape */,
 | 
			
		||||
                                 const ShapeIndex& index) {
 | 
			
		||||
      auto output_type = OutputTypeAfterChange(hlo, index);
 | 
			
		||||
      if (output_type != F32 && output_type != BF16) {
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      PrimitiveType type = BF16;
 | 
			
		||||
      for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
 | 
			
		||||
        auto value_type = ValueTypeAfterChange(value);
 | 
			
		||||
        if (value_type == BF16) {
 | 
			
		||||
          continue;
 | 
			
		||||
  while (true) {
 | 
			
		||||
    bool any_change = false;
 | 
			
		||||
    for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
 | 
			
		||||
      auto hlo = *inst_it;
 | 
			
		||||
      auto adjust_hlo_output = [&](const Shape& /* subshape */,
 | 
			
		||||
                                   const ShapeIndex& index) {
 | 
			
		||||
        auto output_type = OutputTypeAfterChange(hlo, index);
 | 
			
		||||
        VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
 | 
			
		||||
                << " for :" << hlo->ToString() << "\n";
 | 
			
		||||
        if (output_type != F32 && output_type != BF16) {
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
        CHECK_EQ(value_type, F32);
 | 
			
		||||
        type = F32;
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      // In order to find aliases due to in-place operations, use
 | 
			
		||||
      // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
 | 
			
		||||
      // but this code works with HloModules that aren't ready yet to use
 | 
			
		||||
      // HloAliasAnalysis (e.g., their computation graphs may not have been
 | 
			
		||||
      // flattened yet).
 | 
			
		||||
      for (const auto& operand_and_output_index :
 | 
			
		||||
           HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
 | 
			
		||||
        if (operand_and_output_index.second == index) {
 | 
			
		||||
          const HloUse& operand = operand_and_output_index.first;
 | 
			
		||||
          for (const auto* value :
 | 
			
		||||
               dataflow_
 | 
			
		||||
                   ->GetValueSet(hlo->operand(operand.operand_number),
 | 
			
		||||
                                 operand.operand_index)
 | 
			
		||||
                   .values()) {
 | 
			
		||||
            auto value_type = ValueTypeAfterChange(value);
 | 
			
		||||
            if (value_type == BF16) {
 | 
			
		||||
              continue;
 | 
			
		||||
        PrimitiveType type = BF16;
 | 
			
		||||
        for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
 | 
			
		||||
          auto value_type = ValueTypeAfterChange(value);
 | 
			
		||||
          if (value_type == BF16) {
 | 
			
		||||
            continue;
 | 
			
		||||
          }
 | 
			
		||||
          VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
 | 
			
		||||
                  << value->ToString() << "\n";
 | 
			
		||||
          CHECK_EQ(value_type, F32);
 | 
			
		||||
          type = F32;
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        // In order to find aliases due to in-place operations, use
 | 
			
		||||
        // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
 | 
			
		||||
        // but this code works with HloModules that aren't ready yet to use
 | 
			
		||||
        // HloAliasAnalysis (e.g., their computation graphs may not have been
 | 
			
		||||
        // flattened yet).
 | 
			
		||||
        for (const auto& operand_and_output_index :
 | 
			
		||||
             HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
 | 
			
		||||
          if (operand_and_output_index.second == index) {
 | 
			
		||||
            const HloUse& operand = operand_and_output_index.first;
 | 
			
		||||
            for (const auto* value :
 | 
			
		||||
                 dataflow_
 | 
			
		||||
                     ->GetValueSet(hlo->operand(operand.operand_number),
 | 
			
		||||
                                   operand.operand_index)
 | 
			
		||||
                     .values()) {
 | 
			
		||||
              auto value_type = ValueTypeAfterChange(value);
 | 
			
		||||
              if (value_type == BF16) {
 | 
			
		||||
                continue;
 | 
			
		||||
              }
 | 
			
		||||
              VLOG(2) << "Adjust to F32 due to InputOutPair: "
 | 
			
		||||
                      << value->ToString() << "\n";
 | 
			
		||||
              CHECK_EQ(value_type, F32);
 | 
			
		||||
              type = F32;
 | 
			
		||||
              break;
 | 
			
		||||
            }
 | 
			
		||||
            CHECK_EQ(value_type, F32);
 | 
			
		||||
            type = F32;
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // It's possible that a user has been changed from BF16 to F32
 | 
			
		||||
      // during this final adjustment pass, so we need to check
 | 
			
		||||
      // AllUsersConsumeBF16() again.
 | 
			
		||||
      if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
 | 
			
		||||
        type = F32;
 | 
			
		||||
      }
 | 
			
		||||
      if (type == F32) {
 | 
			
		||||
        for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
 | 
			
		||||
          // We rely on the fact that this adjustment works in reverse
 | 
			
		||||
          // topological order. Adding the value to
 | 
			
		||||
          // values_that_must_be_kept_as_f32_ will ensure the correctness
 | 
			
		||||
          // of the adjustment for HLOs that will be processed later.
 | 
			
		||||
          values_that_must_be_kept_as_f32_.insert(value);
 | 
			
		||||
        // It's possible that a user has been changed from BF16 to F32
 | 
			
		||||
        // during this final adjustment pass, so we need to check
 | 
			
		||||
        // AllUsersConsumeBF16() again.
 | 
			
		||||
        if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
 | 
			
		||||
          VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
 | 
			
		||||
          type = F32;
 | 
			
		||||
        }
 | 
			
		||||
        if (type == F32) {
 | 
			
		||||
          for (const auto* value :
 | 
			
		||||
               dataflow_->GetValueSet(hlo, index).values()) {
 | 
			
		||||
            // We rely on the fact that this adjustment works in reverse
 | 
			
		||||
            // topological order. Adding the value to
 | 
			
		||||
            // values_that_must_be_kept_as_f32_ will ensure the correctness
 | 
			
		||||
            // of the adjustment for HLOs that will be processed later.
 | 
			
		||||
            values_that_must_be_kept_as_f32_.insert(value);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        if (type != output_type) {
 | 
			
		||||
          any_change = true;
 | 
			
		||||
          AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
 | 
			
		||||
          VLOG(2) << "HloInstruction output at shape index " << index
 | 
			
		||||
                  << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
 | 
			
		||||
                  << hlo->ToString();
 | 
			
		||||
          if (hlo->opcode() == HloOpcode::kParameter) {
 | 
			
		||||
            parameter_changed = true;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      };
 | 
			
		||||
      ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
 | 
			
		||||
      AdjustCalledComputationRoot(hlo);
 | 
			
		||||
      if (hlo->opcode() == HloOpcode::kWhile) {
 | 
			
		||||
        // We need to run on the while body and condition repeatedly until a
 | 
			
		||||
        // fixed point is reached, i.e., the parameters do not change any more.
 | 
			
		||||
        // We may need more than one iteration because the while input and
 | 
			
		||||
        // output alias each other, so changing one input parameter requires
 | 
			
		||||
        // changing the corresponding output element and thus may transitively
 | 
			
		||||
        // require changing another input parameter. A fixed point will be
 | 
			
		||||
        // reached because the parameters can only be changed from BF16 to F32,
 | 
			
		||||
        // not the other way around.
 | 
			
		||||
        absl::flat_hash_set<const HloComputation*> visited_in_while;
 | 
			
		||||
        while (ResolveInconsistencyOfAliasingBuffersHelper(
 | 
			
		||||
                   hlo->while_condition(), &visited_in_while) ||
 | 
			
		||||
               ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
 | 
			
		||||
                                                           &visited_in_while)) {
 | 
			
		||||
          visited_in_while.clear();
 | 
			
		||||
          ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
 | 
			
		||||
          AdjustCalledComputationRoot(hlo);
 | 
			
		||||
        }
 | 
			
		||||
        visited_computations->insert(visited_in_while.begin(),
 | 
			
		||||
                                     visited_in_while.end());
 | 
			
		||||
      } else if (hlo->opcode() == HloOpcode::kFusion) {
 | 
			
		||||
        ResolveInconsistencyOfAliasingBuffersHelper(
 | 
			
		||||
            hlo->fused_instructions_computation(), visited_computations);
 | 
			
		||||
      } else if (hlo->opcode() == HloOpcode::kConditional) {
 | 
			
		||||
        for (auto* branch : hlo->branch_computations()) {
 | 
			
		||||
          ResolveInconsistencyOfAliasingBuffersHelper(branch,
 | 
			
		||||
                                                      visited_computations);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (type != output_type) {
 | 
			
		||||
        AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
 | 
			
		||||
        VLOG(2) << "HloInstruction output at shape index " << index
 | 
			
		||||
                << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
 | 
			
		||||
                << hlo->ToString();
 | 
			
		||||
        if (hlo->opcode() == HloOpcode::kParameter) {
 | 
			
		||||
          parameter_changed = true;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
 | 
			
		||||
    AdjustCalledComputationRoot(hlo);
 | 
			
		||||
    if (hlo->opcode() == HloOpcode::kWhile) {
 | 
			
		||||
      // We need to run on the while body and condition repeatedly until a fixed
 | 
			
		||||
      // point is reached, i.e., the parameters do not change any more. We may
 | 
			
		||||
      // need more than one iteration because the while input and output alias
 | 
			
		||||
      // each other, so changing one input parameter requires changing the
 | 
			
		||||
      // corresponding output element and thus may transitively require changing
 | 
			
		||||
      // another input parameter. A fixed point will be reached because the
 | 
			
		||||
      // parameters can only be changed from BF16 to F32, not the other way
 | 
			
		||||
      // around.
 | 
			
		||||
      absl::flat_hash_set<const HloComputation*> visited_in_while;
 | 
			
		||||
      while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
 | 
			
		||||
                                                         &visited_in_while) ||
 | 
			
		||||
             ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
 | 
			
		||||
                                                         &visited_in_while)) {
 | 
			
		||||
        visited_in_while.clear();
 | 
			
		||||
        ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
 | 
			
		||||
        AdjustCalledComputationRoot(hlo);
 | 
			
		||||
      }
 | 
			
		||||
      visited_computations->insert(visited_in_while.begin(),
 | 
			
		||||
                                   visited_in_while.end());
 | 
			
		||||
    } else if (hlo->opcode() == HloOpcode::kFusion) {
 | 
			
		||||
      ResolveInconsistencyOfAliasingBuffersHelper(
 | 
			
		||||
          hlo->fused_instructions_computation(), visited_computations);
 | 
			
		||||
    } else if (hlo->opcode() == HloOpcode::kConditional) {
 | 
			
		||||
      for (auto* branch : hlo->branch_computations()) {
 | 
			
		||||
        ResolveInconsistencyOfAliasingBuffersHelper(branch,
 | 
			
		||||
                                                    visited_computations);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (!any_change) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Now adjust parameters of called computations.
 | 
			
		||||
 | 
			
		||||
@ -1182,4 +1182,87 @@ ENTRY main {
 | 
			
		||||
  EXPECT_FALSE(OutputsBF16(dus));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This test demonstrates the need for invoking the ResolveAliasingBuffer
 | 
			
		||||
// multiple times via a fixed-point algorithm. The key was the aliasing of the
 | 
			
		||||
// two output buffers of the conditional, at subshape 0 (first element). This
 | 
			
		||||
// aliasing is not resolved until after the gte0 variale is already processed,
 | 
			
		||||
// triggering incorrect type for gte0 if not repeating the aliasing analysis.
 | 
			
		||||
TEST_F(BFloat16PropagationTest, ConditionalGTEWithFusion) {
 | 
			
		||||
  const string module_str = R"(
 | 
			
		||||
HloModule module
 | 
			
		||||
 | 
			
		||||
%add.0 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
 | 
			
		||||
  x.1 = f32[4096,4096] parameter(0)
 | 
			
		||||
  y.1 = f32[4096,4096] parameter(1)
 | 
			
		||||
  ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
 | 
			
		||||
    lhs_contracting_dims={1}, rhs_contracting_dims={0}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%add.1 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
 | 
			
		||||
  x.1 = f32[4096,4096] parameter(0)
 | 
			
		||||
  y.1 = f32[4096,4096] parameter(1)
 | 
			
		||||
  ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
 | 
			
		||||
    lhs_contracting_dims={1}, rhs_contracting_dims={0}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%add.2 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
 | 
			
		||||
  x.1 = f32[4096,4096] parameter(0)
 | 
			
		||||
  y.1 = f32[4096,4096] parameter(1)
 | 
			
		||||
  ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
 | 
			
		||||
    lhs_contracting_dims={1}, rhs_contracting_dims={0}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
%add.3 (x: f32[4096,4096], y: f32[4096,4096]) -> f32[4096,4096] {
 | 
			
		||||
  x.1 = f32[4096,4096] parameter(0)
 | 
			
		||||
  y.1 = f32[4096,4096] parameter(1)
 | 
			
		||||
  ROOT dot1 = f32[4096,4096] dot(x.1, y.1),
 | 
			
		||||
    lhs_contracting_dims={1}, rhs_contracting_dims={0}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
true_branch {
 | 
			
		||||
  true_param = f32[4096,4096] parameter(0)
 | 
			
		||||
  constant.1 = f32[4096,4096] constant(0)
 | 
			
		||||
  add0 = f32[4096,4096] fusion(true_param,true_param), kind=kLoop, calls=add.0
 | 
			
		||||
  constant.2 = f32[4096,4096] constant(0)
 | 
			
		||||
  ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(true_param,add0,constant.2)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
false_branch {
 | 
			
		||||
  false_param = f32[4096,4096] parameter(0)
 | 
			
		||||
  add3 = f32[4096,4096] fusion(false_param,false_param), kind=kLoop, calls=add.1
 | 
			
		||||
  constant.1 = f32[4096,4096] constant(0)
 | 
			
		||||
  ROOT tuple.2 = (f32[4096,4096], f32[4096,4096], f32[]) tuple(add3, add3,constant.1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ENTRY entry {
 | 
			
		||||
  param0 = f32[4096,4096] parameter(0)
 | 
			
		||||
  copy0 = f32[4096,4096] copy(param0)
 | 
			
		||||
  param1 = pred[] parameter(1)
 | 
			
		||||
  conditional = (f32[4096,4096], f32[4096,4096], f32[4096,4096]) conditional(param1, param0, copy0),
 | 
			
		||||
    true_computation=true_branch, false_computation=false_branch
 | 
			
		||||
  gte = f32[4096,4096] get-tuple-element(conditional), index=0
 | 
			
		||||
  gte1 = f32[4096,4096] get-tuple-element(conditional), index=1
 | 
			
		||||
  gte2 = f32[4096,4096] get-tuple-element(conditional), index=2
 | 
			
		||||
  add2 = f32[4096,4096] fusion(gte, gte1), kind=kLoop, calls=add.2
 | 
			
		||||
  ROOT add3 = f32[4096,4096] fusion(add2, gte2), kind=kLoop, calls=add.3
 | 
			
		||||
  }
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(module_str));
 | 
			
		||||
  EXPECT_TRUE(PropagatePrecision(module.get()));
 | 
			
		||||
  VLOG(2) << module->ToString() << "\n";
 | 
			
		||||
  EXPECT_TRUE(HloVerifier(/*layout_sensitive=*/false,
 | 
			
		||||
                          /*allow_mixed_precision=*/true)
 | 
			
		||||
                  .Run(module.get())
 | 
			
		||||
                  .status()
 | 
			
		||||
                  .ok());
 | 
			
		||||
  auto gte = FindInstruction(module.get(), "gte");
 | 
			
		||||
  auto gte1 = FindInstruction(module.get(), "gte1");
 | 
			
		||||
  auto gte2 = FindInstruction(module.get(), "gte2");
 | 
			
		||||
  EXPECT_FALSE(OutputsBF16(gte));
 | 
			
		||||
  EXPECT_FALSE(OutputsBF16(gte1));
 | 
			
		||||
  EXPECT_TRUE(OutputsBF16(gte2));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user