[XLA] Avoid hash collisions in CseHash.
PiperOrigin-RevId: 296143190 Change-Id: I16cef346311b419f04911c241462fa55a5aa04ad
This commit is contained in:
parent
d32328e24f
commit
a6ec8dadc4
@ -3434,6 +3434,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:hash",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
],
|
],
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/platform/hash.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -96,17 +97,54 @@ StatusOr<bool> CombineConstants(HloComputation* computation,
|
|||||||
// share the exact same set of operands.
|
// share the exact same set of operands.
|
||||||
int64 CseHash(const HloInstruction* instruction) {
|
int64 CseHash(const HloInstruction* instruction) {
|
||||||
int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
|
int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
|
||||||
|
auto c_hash = [](auto c) {
|
||||||
|
return tensorflow::Hash64(reinterpret_cast<const char*>(c.data()),
|
||||||
|
c.size() * sizeof(c[0]));
|
||||||
|
};
|
||||||
|
auto proto_hash = [](auto proto) {
|
||||||
|
return std::hash<int64>{}(proto.ByteSizeLong());
|
||||||
|
};
|
||||||
hash = tensorflow::Hash64Combine(
|
hash = tensorflow::Hash64Combine(
|
||||||
hash, instruction->opcode() == HloOpcode::kGetTupleElement
|
hash, instruction->opcode() == HloOpcode::kGetTupleElement
|
||||||
? instruction->tuple_index()
|
? instruction->tuple_index()
|
||||||
: -1);
|
: c_hash(instruction->shape().dimensions()));
|
||||||
for (auto operand : instruction->operands()) {
|
for (auto operand : instruction->operands()) {
|
||||||
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
|
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kConstant) {
|
for (auto c : instruction->called_computations()) {
|
||||||
hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
|
hash = tensorflow::Hash64Combine(
|
||||||
|
hash, std::hash<int64>()(
|
||||||
|
static_cast<int64>(c->root_instruction()->opcode())));
|
||||||
|
}
|
||||||
|
switch (instruction->opcode()) {
|
||||||
|
case HloOpcode::kConstant:
|
||||||
|
return tensorflow::Hash64Combine(hash, instruction->literal().Hash());
|
||||||
|
case HloOpcode::kSlice:
|
||||||
|
return tensorflow::Hash64Combine(
|
||||||
|
tensorflow::Hash64Combine(hash, c_hash(instruction->slice_starts())),
|
||||||
|
c_hash(instruction->slice_strides()));
|
||||||
|
case HloOpcode::kPad:
|
||||||
|
return tensorflow::Hash64Combine(
|
||||||
|
hash, proto_hash(instruction->padding_config()));
|
||||||
|
case HloOpcode::kDot:
|
||||||
|
return tensorflow::Hash64Combine(
|
||||||
|
hash, proto_hash(instruction->dot_dimension_numbers()));
|
||||||
|
case HloOpcode::kConvolution:
|
||||||
|
return tensorflow::Hash64Combine(
|
||||||
|
tensorflow::Hash64Combine(
|
||||||
|
hash, proto_hash(instruction->convolution_dimension_numbers())),
|
||||||
|
proto_hash(instruction->window()));
|
||||||
|
case HloOpcode::kReduceWindow:
|
||||||
|
return tensorflow::Hash64Combine(hash, proto_hash(instruction->window()));
|
||||||
|
case HloOpcode::kConcatenate:
|
||||||
|
case HloOpcode::kBroadcast:
|
||||||
|
case HloOpcode::kTranspose:
|
||||||
|
case HloOpcode::kIota:
|
||||||
|
case HloOpcode::kReduce:
|
||||||
|
return tensorflow::Hash64Combine(hash, c_hash(instruction->dimensions()));
|
||||||
|
default:
|
||||||
|
return hash;
|
||||||
}
|
}
|
||||||
return hash;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user