Modify Hash() function of HloComputation and HloInstruction to prevent non-termination from infinite recursive calls.
PiperOrigin-RevId: 225412890
This commit is contained in:
parent
d501a62aae
commit
ec702337b8
tensorflow/compiler/xla/service
@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const {
|
||||
return eq(root_instruction(), other.root_instruction());
|
||||
}
|
||||
|
||||
uint64 HloComputation::Hash() const { return root_instruction()->Hash(); }
|
||||
|
||||
Status HloComputation::ReplaceWithNewInstruction(
|
||||
HloInstruction* old_instruction,
|
||||
std::unique_ptr<HloInstruction> new_instruction) {
|
||||
|
@ -264,12 +264,6 @@ class HloComputation {
|
||||
// Return whether `*this` and `other` are functionally equivalent.
|
||||
bool operator==(const HloComputation& other) const;
|
||||
|
||||
// Generates a hash value of an HLO computation. Hash considers
|
||||
// information on opcode, shape, operands, and typically a root instruction.
|
||||
// This function returns the same hash value for equivalent HLO computations,
|
||||
// with respect to HloInstruction::Identical() method.
|
||||
uint64 Hash() const;
|
||||
|
||||
// Replaces old instruction with newly created instruction. Removes old
|
||||
// instruction from computation. Updates uses and root instruction.
|
||||
Status ReplaceWithNewInstruction(
|
||||
|
@ -1761,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64 HloInstruction::Hash() const {
|
||||
static uint64 HashOperand(const HloInstruction* hlo) {
|
||||
return ShapeUtil::Hash(hlo->shape());
|
||||
}
|
||||
|
||||
uint64 HloInstruction::Hash(
|
||||
const std::function<uint64(const HloInstruction*)>& hash_operand) const {
|
||||
using tensorflow::Hash64Combine;
|
||||
|
||||
uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
|
||||
@ -1770,7 +1775,7 @@ uint64 HloInstruction::Hash() const {
|
||||
if (!IsCrossModuleAllReduce()) {
|
||||
if (!operands().empty()) {
|
||||
for (size_t i = 0; i < operands().size(); ++i) {
|
||||
hash_value = Hash64Combine(hash_value, operand(i)->Hash());
|
||||
hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1779,6 +1784,11 @@ uint64 HloInstruction::Hash() const {
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
uint64 HloInstruction::Hash() const {
|
||||
// Use HashOperand as an argument to prevent non-termination.
|
||||
return Hash(HashOperand);
|
||||
}
|
||||
|
||||
uint64 HloInstruction::InnerHash() const { return 13; }
|
||||
|
||||
void HloInstruction::RemoveUser(HloInstruction* user) {
|
||||
|
@ -909,6 +909,14 @@ class HloInstruction {
|
||||
// information on opcode, shape, operands, and typically a root instruction.
|
||||
// This function returns the same hash value for equivalent HLO instructions,
|
||||
// with respect to HloInstruction::Identical() method.
|
||||
//
|
||||
// Uses hash_operand function to compute hash values of its operands.
|
||||
// At the very top level, hash_operand should be non-recursive to prevent
|
||||
// non-termination.
|
||||
uint64 Hash(
|
||||
const std::function<uint64(const HloInstruction*)>& hash_operand) const;
|
||||
|
||||
// Calls the above method with non-recursive hash_operand function.
|
||||
uint64 Hash() const;
|
||||
|
||||
// Returns whether the instruction has a constant operand.
|
||||
|
@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath(
|
||||
other.fused_instructions_computation());
|
||||
}
|
||||
|
||||
static uint64 HashOperandRecursive(const HloInstruction* hlo) {
|
||||
return hlo->Hash(HashOperandRecursive);
|
||||
}
|
||||
|
||||
uint64 HloFusionInstruction::InnerHash() const {
|
||||
return fused_instructions_computation()->Hash();
|
||||
// Use HashOperandRecursive to recursively compute hash on inner operands.
|
||||
return fused_instructions_computation()->root_instruction()->Hash(
|
||||
HashOperandRecursive);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
|
||||
|
@ -136,7 +136,9 @@ class HloModule {
|
||||
// information on opcode, shape, operands, and typically a root instruction.
|
||||
// This function returns the same hash value for equivalent HLO modules,
|
||||
// with respect to HloInstruction::Identical() method.
|
||||
uint64 Hash() const { return entry_computation()->Hash(); }
|
||||
uint64 Hash() const {
|
||||
return entry_computation()->root_instruction()->Hash();
|
||||
}
|
||||
|
||||
// Gets the computations in this module.
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user