From a6ec8dadc4a8fb5d3df6577cb903483f2582c0a8 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 19 Feb 2020 23:58:40 -0800 Subject: [PATCH] [XLA] Avoid hash collisions in CseHash. PiperOrigin-RevId: 296143190 Change-Id: I16cef346311b419f04911c241462fa55a5aa04ad --- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/hlo_cse.cc | 46 ++++++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 34fd40f11d8..bb6219eb584 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3434,6 +3434,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core/platform:hash", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index a58fcf4460a..373f4f12ba4 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/hash.h" namespace xla { @@ -96,17 +97,54 @@ StatusOr CombineConstants(HloComputation* computation, // share the exact same set of operands. int64 CseHash(const HloInstruction* instruction) { int64 hash = std::hash()(static_cast(instruction->opcode())); + auto c_hash = [](auto c) { + return tensorflow::Hash64(reinterpret_cast(c.data()), + c.size() * sizeof(c[0])); + }; + auto proto_hash = [](auto proto) { + return std::hash{}(proto.ByteSizeLong()); + }; hash = tensorflow::Hash64Combine( hash, instruction->opcode() == HloOpcode::kGetTupleElement ? instruction->tuple_index() - : -1); + : c_hash(instruction->shape().dimensions())); for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } - if (instruction->opcode() == HloOpcode::kConstant) { - hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + for (auto c : instruction->called_computations()) { + hash = tensorflow::Hash64Combine( + hash, std::hash()( + static_cast(c->root_instruction()->opcode()))); + } + switch (instruction->opcode()) { + case HloOpcode::kConstant: + return tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + case HloOpcode::kSlice: + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine(hash, c_hash(instruction->slice_starts())), + c_hash(instruction->slice_strides())); + case HloOpcode::kPad: + return tensorflow::Hash64Combine( + hash, proto_hash(instruction->padding_config())); + case HloOpcode::kDot: + return tensorflow::Hash64Combine( + hash, proto_hash(instruction->dot_dimension_numbers())); + case HloOpcode::kConvolution: + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine( + hash, proto_hash(instruction->convolution_dimension_numbers())), + proto_hash(instruction->window())); + case HloOpcode::kReduceWindow: + return tensorflow::Hash64Combine(hash, proto_hash(instruction->window())); + case HloOpcode::kConcatenate: + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + case HloOpcode::kIota: + case HloOpcode::kReduce: + return tensorflow::Hash64Combine(hash, c_hash(instruction->dimensions())); + default: + return hash; } - return hash; } } // namespace