[XLA] Avoid hash collisions in CseHash.
PiperOrigin-RevId: 296143190 Change-Id: I16cef346311b419f04911c241462fa55a5aa04ad
This commit is contained in:
parent
d32328e24f
commit
a6ec8dadc4
@ -3434,6 +3434,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:hash",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/hash.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -96,17 +97,54 @@ StatusOr<bool> CombineConstants(HloComputation* computation,
|
||||
// share the exact same set of operands.
|
||||
int64 CseHash(const HloInstruction* instruction) {
|
||||
int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
|
||||
auto c_hash = [](auto c) {
|
||||
return tensorflow::Hash64(reinterpret_cast<const char*>(c.data()),
|
||||
c.size() * sizeof(c[0]));
|
||||
};
|
||||
auto proto_hash = [](auto proto) {
|
||||
return std::hash<int64>{}(proto.ByteSizeLong());
|
||||
};
|
||||
hash = tensorflow::Hash64Combine(
|
||||
hash, instruction->opcode() == HloOpcode::kGetTupleElement
|
||||
? instruction->tuple_index()
|
||||
: -1);
|
||||
: c_hash(instruction->shape().dimensions()));
|
||||
for (auto operand : instruction->operands()) {
|
||||
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kConstant) {
|
||||
hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
|
||||
for (auto c : instruction->called_computations()) {
|
||||
hash = tensorflow::Hash64Combine(
|
||||
hash, std::hash<int64>()(
|
||||
static_cast<int64>(c->root_instruction()->opcode())));
|
||||
}
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kConstant:
|
||||
return tensorflow::Hash64Combine(hash, instruction->literal().Hash());
|
||||
case HloOpcode::kSlice:
|
||||
return tensorflow::Hash64Combine(
|
||||
tensorflow::Hash64Combine(hash, c_hash(instruction->slice_starts())),
|
||||
c_hash(instruction->slice_strides()));
|
||||
case HloOpcode::kPad:
|
||||
return tensorflow::Hash64Combine(
|
||||
hash, proto_hash(instruction->padding_config()));
|
||||
case HloOpcode::kDot:
|
||||
return tensorflow::Hash64Combine(
|
||||
hash, proto_hash(instruction->dot_dimension_numbers()));
|
||||
case HloOpcode::kConvolution:
|
||||
return tensorflow::Hash64Combine(
|
||||
tensorflow::Hash64Combine(
|
||||
hash, proto_hash(instruction->convolution_dimension_numbers())),
|
||||
proto_hash(instruction->window()));
|
||||
case HloOpcode::kReduceWindow:
|
||||
return tensorflow::Hash64Combine(hash, proto_hash(instruction->window()));
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kIota:
|
||||
case HloOpcode::kReduce:
|
||||
return tensorflow::Hash64Combine(hash, c_hash(instruction->dimensions()));
|
||||
default:
|
||||
return hash;
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user