[XLA] Avoid hash collisions in CseHash.

PiperOrigin-RevId: 296143190
Change-Id: I16cef346311b419f04911c241462fa55a5aa04ad
This commit is contained in:
Blake Hechtman 2020-02-19 23:58:40 -08:00 committed by TensorFlower Gardener
parent d32328e24f
commit a6ec8dadc4
2 changed files with 43 additions and 4 deletions

View File

@ -3434,6 +3434,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core/platform:hash",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/hash.h"
namespace xla {
@ -96,17 +97,54 @@ StatusOr<bool> CombineConstants(HloComputation* computation,
// share the exact same set of operands.
int64 CseHash(const HloInstruction* instruction) {
int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
auto c_hash = [](auto c) {
return tensorflow::Hash64(reinterpret_cast<const char*>(c.data()),
c.size() * sizeof(c[0]));
};
auto proto_hash = [](auto proto) {
return std::hash<int64>{}(proto.ByteSizeLong());
};
hash = tensorflow::Hash64Combine(
hash, instruction->opcode() == HloOpcode::kGetTupleElement
? instruction->tuple_index()
: -1);
: c_hash(instruction->shape().dimensions()));
for (auto operand : instruction->operands()) {
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
}
if (instruction->opcode() == HloOpcode::kConstant) {
hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
for (auto c : instruction->called_computations()) {
hash = tensorflow::Hash64Combine(
hash, std::hash<int64>()(
static_cast<int64>(c->root_instruction()->opcode())));
}
switch (instruction->opcode()) {
case HloOpcode::kConstant:
return tensorflow::Hash64Combine(hash, instruction->literal().Hash());
case HloOpcode::kSlice:
return tensorflow::Hash64Combine(
tensorflow::Hash64Combine(hash, c_hash(instruction->slice_starts())),
c_hash(instruction->slice_strides()));
case HloOpcode::kPad:
return tensorflow::Hash64Combine(
hash, proto_hash(instruction->padding_config()));
case HloOpcode::kDot:
return tensorflow::Hash64Combine(
hash, proto_hash(instruction->dot_dimension_numbers()));
case HloOpcode::kConvolution:
return tensorflow::Hash64Combine(
tensorflow::Hash64Combine(
hash, proto_hash(instruction->convolution_dimension_numbers())),
proto_hash(instruction->window()));
case HloOpcode::kReduceWindow:
return tensorflow::Hash64Combine(hash, proto_hash(instruction->window()));
case HloOpcode::kConcatenate:
case HloOpcode::kBroadcast:
case HloOpcode::kTranspose:
case HloOpcode::kIota:
case HloOpcode::kReduce:
return tensorflow::Hash64Combine(hash, c_hash(instruction->dimensions()));
default:
return hash;
}
return hash;
}
} // namespace