Modify Hash() function of HloComputation and HloInstruction to prevent non-termination from infinite recursive calls.
PiperOrigin-RevId: 225412890
This commit is contained in:
parent
d501a62aae
commit
ec702337b8
@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const {
|
|||||||
return eq(root_instruction(), other.root_instruction());
|
return eq(root_instruction(), other.root_instruction());
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 HloComputation::Hash() const { return root_instruction()->Hash(); }
|
|
||||||
|
|
||||||
Status HloComputation::ReplaceWithNewInstruction(
|
Status HloComputation::ReplaceWithNewInstruction(
|
||||||
HloInstruction* old_instruction,
|
HloInstruction* old_instruction,
|
||||||
std::unique_ptr<HloInstruction> new_instruction) {
|
std::unique_ptr<HloInstruction> new_instruction) {
|
||||||
|
@ -264,12 +264,6 @@ class HloComputation {
|
|||||||
// Return whether `*this` and `other` are functionally equivalent.
|
// Return whether `*this` and `other` are functionally equivalent.
|
||||||
bool operator==(const HloComputation& other) const;
|
bool operator==(const HloComputation& other) const;
|
||||||
|
|
||||||
// Generates a hash value of an HLO computation. Hash considers
|
|
||||||
// information on opcode, shape, operands, and typically a root instruction.
|
|
||||||
// This function returns the same hash value for equivalent HLO computations,
|
|
||||||
// with respect to HloInstruction::Identical() method.
|
|
||||||
uint64 Hash() const;
|
|
||||||
|
|
||||||
// Replaces old instruction with newly created instruction. Removes old
|
// Replaces old instruction with newly created instruction. Removes old
|
||||||
// instruction from computation. Updates uses and root instruction.
|
// instruction from computation. Updates uses and root instruction.
|
||||||
Status ReplaceWithNewInstruction(
|
Status ReplaceWithNewInstruction(
|
||||||
|
@ -1761,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 HloInstruction::Hash() const {
|
static uint64 HashOperand(const HloInstruction* hlo) {
|
||||||
|
return ShapeUtil::Hash(hlo->shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64 HloInstruction::Hash(
|
||||||
|
const std::function<uint64(const HloInstruction*)>& hash_operand) const {
|
||||||
using tensorflow::Hash64Combine;
|
using tensorflow::Hash64Combine;
|
||||||
|
|
||||||
uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
|
uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
|
||||||
@ -1770,7 +1775,7 @@ uint64 HloInstruction::Hash() const {
|
|||||||
if (!IsCrossModuleAllReduce()) {
|
if (!IsCrossModuleAllReduce()) {
|
||||||
if (!operands().empty()) {
|
if (!operands().empty()) {
|
||||||
for (size_t i = 0; i < operands().size(); ++i) {
|
for (size_t i = 0; i < operands().size(); ++i) {
|
||||||
hash_value = Hash64Combine(hash_value, operand(i)->Hash());
|
hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1779,6 +1784,11 @@ uint64 HloInstruction::Hash() const {
|
|||||||
return hash_value;
|
return hash_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64 HloInstruction::Hash() const {
|
||||||
|
// Use HashOperand as an argument to prevent non-termination.
|
||||||
|
return Hash(HashOperand);
|
||||||
|
}
|
||||||
|
|
||||||
uint64 HloInstruction::InnerHash() const { return 13; }
|
uint64 HloInstruction::InnerHash() const { return 13; }
|
||||||
|
|
||||||
void HloInstruction::RemoveUser(HloInstruction* user) {
|
void HloInstruction::RemoveUser(HloInstruction* user) {
|
||||||
|
@ -909,6 +909,14 @@ class HloInstruction {
|
|||||||
// information on opcode, shape, operands, and typically a root instruction.
|
// information on opcode, shape, operands, and typically a root instruction.
|
||||||
// This function returns the same hash value for equivalent HLO instructions,
|
// This function returns the same hash value for equivalent HLO instructions,
|
||||||
// with respect to HloInstruction::Identical() method.
|
// with respect to HloInstruction::Identical() method.
|
||||||
|
//
|
||||||
|
// Uses hash_operand function to compute hash values of its operands.
|
||||||
|
// At the very top level, hash_operand should be non-recursive to prevent
|
||||||
|
// non-termination.
|
||||||
|
uint64 Hash(
|
||||||
|
const std::function<uint64(const HloInstruction*)>& hash_operand) const;
|
||||||
|
|
||||||
|
// Calls the above method with non-recursive hash_operand function.
|
||||||
uint64 Hash() const;
|
uint64 Hash() const;
|
||||||
|
|
||||||
// Returns whether the instruction has a constant operand.
|
// Returns whether the instruction has a constant operand.
|
||||||
|
@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath(
|
|||||||
other.fused_instructions_computation());
|
other.fused_instructions_computation());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static uint64 HashOperandRecursive(const HloInstruction* hlo) {
|
||||||
|
return hlo->Hash(HashOperandRecursive);
|
||||||
|
}
|
||||||
|
|
||||||
uint64 HloFusionInstruction::InnerHash() const {
|
uint64 HloFusionInstruction::InnerHash() const {
|
||||||
return fused_instructions_computation()->Hash();
|
// Use HashOperandRecursive to recursively compute hash on inner operands.
|
||||||
|
return fused_instructions_computation()->root_instruction()->Hash(
|
||||||
|
HashOperandRecursive);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
|
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
|
||||||
|
@ -136,7 +136,9 @@ class HloModule {
|
|||||||
// information on opcode, shape, operands, and typically a root instruction.
|
// information on opcode, shape, operands, and typically a root instruction.
|
||||||
// This function returns the same hash value for equivalent HLO modules,
|
// This function returns the same hash value for equivalent HLO modules,
|
||||||
// with respect to HloInstruction::Identical() method.
|
// with respect to HloInstruction::Identical() method.
|
||||||
uint64 Hash() const { return entry_computation()->Hash(); }
|
uint64 Hash() const {
|
||||||
|
return entry_computation()->root_instruction()->Hash();
|
||||||
|
}
|
||||||
|
|
||||||
// Gets the computations in this module.
|
// Gets the computations in this module.
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user