TPU layout method cleanup by removing one parameter.
PiperOrigin-RevId: 360292316 Change-Id: Ib8332ba5b0d7576cfb02773f9974a48ead7ae47c
This commit is contained in:
parent
6441ec4b0c
commit
1038045628
@ -345,11 +345,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns whether we should avoid changing the precision of inst regardless of
|
||||
// the producers and users.
|
||||
bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) {
|
||||
bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
|
||||
const HloInstruction* inst) {
|
||||
if (inst->opcode() == HloOpcode::kFusion &&
|
||||
inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
|
||||
return ShouldKeepPrecisionUnchanged(
|
||||
@ -358,14 +355,12 @@ bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) {
|
||||
// Do not change precision for side-effecting instructions, control flow, and
|
||||
// bitcast-convert, because this pass might break the interfaces or
|
||||
// assumptions for them.
|
||||
return inst->opcode() == HloOpcode::kCustomCall || //
|
||||
inst->opcode() == HloOpcode::kCall || //
|
||||
inst->opcode() == HloOpcode::kBitcastConvert || //
|
||||
return inst->opcode() == HloOpcode::kCustomCall ||
|
||||
inst->opcode() == HloOpcode::kCall ||
|
||||
inst->opcode() == HloOpcode::kBitcastConvert ||
|
||||
inst->HasSideEffectNoRecurse();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
|
||||
bool skip_parameters) {
|
||||
// We handle any fusion computation, while body/condition or conditional
|
||||
|
@ -72,6 +72,10 @@ class BFloat16Propagation : public HloModulePass {
|
||||
// (precision reductions were added).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Returns whether we should avoid changing the precision of inst regardless
|
||||
// of the producers and users.
|
||||
virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst);
|
||||
|
||||
private:
|
||||
// ***************************
|
||||
// Function called and state produced by the forward analysis pass (from
|
||||
|
Loading…
x
Reference in New Issue
Block a user