From d4c2030375551acc9ad36356e090dcc1af8baa51 Mon Sep 17 00:00:00 2001
From: David Majnemer <majnemer@google.com>
Date: Tue, 16 Jun 2020 15:01:21 -0700
Subject: [PATCH] [XLA] Make IdenticalSlowPath more strict

- Map did not check |dimensions|. This has not resulted in miscompliation
  because none of the backends/frontends support/generate map's with dimensions.
- DynamicSlice did not check |dynamic_slice_sizes|. This has not resulted in
  miscompliation because it is not possible for two DynamicSlice operations
  with different |dynamic_slice_sizes| to have the same shape.
- HloCollective did not check |constrain_layout|.

PiperOrigin-RevId: 316765004
Change-Id: Id2629fef71446842eeae18901142a502a634b010
---
 tensorflow/compiler/xla/service/hlo_instructions.cc | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index bcc00d806da..2a53841fd34 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -550,6 +550,7 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
   const auto& casted_other =
       static_cast<const HloCollectiveInstruction&>(other);
   return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
+         constrain_layout() == casted_other.constrain_layout() &&
          absl::c_equal(replica_groups(), casted_other.replica_groups(),
                        [](const ReplicaGroup& a, const ReplicaGroup& b) {
                          return absl::c_equal(a.replica_ids(), b.replica_ids());
@@ -1101,7 +1102,9 @@ bool HloMapInstruction::IdenticalSlowPath(
     const HloInstruction& other,
     const std::function<bool(const HloComputation*, const HloComputation*)>&
         eq_computations) const {
-  return eq_computations(to_apply(), other.to_apply());
+  const auto& casted_other = static_cast<const HloMapInstruction&>(other);
+  return eq_computations(to_apply(), casted_other.to_apply()) &&
+         dimensions() == casted_other.dimensions();
 }
 
 std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
@@ -2515,7 +2518,8 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
     const HloInstruction& other,
     const std::function<bool(const HloComputation*, const HloComputation*)>&
         eq_computations) const {
-  return true;
+  const auto& casted_other = static_cast<const HloMapInstruction&>(other);
+  return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
 }
 
 std::unique_ptr<HloInstruction>