diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 3da904beb36..67a8292d6cc 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2794,47 +2794,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "dynamic_dimension_simplifier",
-    srcs = ["dynamic_dimension_simplifier.cc"],
-    hdrs = ["dynamic_dimension_simplifier.h"],
-    deps = [
-        ":hlo",
-        ":hlo_pass",
-        "//tensorflow/compiler/xla:status_macros",
-    ],
-)
-
-tf_cc_test(
-    name = "dynamic_dimension_simplifier_test",
-    srcs = ["dynamic_dimension_simplifier_test.cc"],
-    deps = [
-        ":dynamic_dimension_simplifier",
-        ":hlo",
-        ":hlo_casting_utils",
-        ":hlo_creation_utils",
-        ":hlo_parser",
-        ":hlo_pass",
-        ":hlo_pass_pipeline",
-        ":pattern_matcher",
-        ":pattern_matcher_gmock",
-        ":shape_inference",
-        "//tensorflow/compiler/xla:literal",
-        "//tensorflow/compiler/xla:shape_util",
-        "//tensorflow/compiler/xla:test",
-        "//tensorflow/compiler/xla:types",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/compiler/xla:window_util",
-        "//tensorflow/compiler/xla:xla_data_proto_cc",
-        "//tensorflow/compiler/xla/tests:hlo_test_base",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
-        "//tensorflow/core:lib",
-        "//tensorflow/core:test",
-        "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
-    ],
-)
-
 cc_library(
     name = "dynamic_padder",
     srcs = ["dynamic_padder.cc"],
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 2389a33e52c..2328ad99113 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -636,7 +636,7 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
 }
 
 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
-    HloInstruction* gds) {
+    HloInstruction*) {
   // Dynamic dimension doesn't propagate through GetDimensionSize:
   //
   //   Input: F32[x, y, z]
@@ -646,24 +646,6 @@ Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
   // The returned value is a scalar, which doesn't have any dynamic dimension in
   // the shape (although the value contains the real size of the dynamic
   // dimension of the input).
-  int64 dim = gds->dimension();
-  HloInstruction* operand = gds->mutable_operand(0);
-  HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
-  HloComputation* computation = gds->parent();
-  if (dynamic_size != nullptr) {
-    TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
-    // The dependency between an instruction and its dynamic dimensions is not
-    // modeled in the IR. As instr is being replaced by dynamic_size, also tell
-    // dynamic dimension inference that the instruction is being replaced.
-    parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
-  } else {
-    TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
-    int32 size = gds->operand(0)->shape().dimensions(dim);
-    HloInstruction* new_instr = computation->AddInstruction(
-        HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
-    TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
-    parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
-  }
   return Status::OK();
 }
 
@@ -812,23 +794,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
 
 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
     HloInstruction* hlo) {
-  HloComputation* comp = hlo->parent();
-  return ForEachOperandDynamicDimension(
-      hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
-               int64 operand_index, HloInstruction* dynamic_size) {
-        HloInstruction* existing_size =
-            parent_->GetDynamicSize(hlo, index, dimension);
-        if (existing_size == nullptr || existing_size == dynamic_size) {
-          parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
-        } else {
-          HloInstruction* max =
-              comp->AddInstruction(HloInstruction::CreateBinary(
-                  ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
-                  dynamic_size, existing_size));
-          parent_->SetDynamicSize(hlo, index, dimension, max);
-        }
-        return Status::OK();
-      });
+  return PassThroughDynamicDimension(hlo);
 }
 
 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc
deleted file mode 100644
index d7253a3fbad..00000000000
--- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.cc
+++ /dev/null
@@ -1,214 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
-
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-
-namespace xla {
-namespace {
-
-// Concat(Concat(A, B), C) => Concat(A, B, C)
-StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
-  if (concat->opcode() != HloOpcode::kConcatenate) {
-    return false;
-  }
-  bool changed = false;
-
-  auto parent = concat->parent();
-  std::vector<HloInstruction*> new_operands;
-  for (HloInstruction* operand : concat->operands()) {
-    if (operand->opcode() != HloOpcode::kConcatenate ||
-        operand->concatenate_dimension() != concat->concatenate_dimension()) {
-      new_operands.push_back(operand);
-    } else {
-      changed = true;
-      for (HloInstruction* operand_operand : operand->operands()) {
-        new_operands.push_back(operand_operand);
-      }
-    }
-  }
-  if (changed) {
-    auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
-        concat->shape(), new_operands, concat->concatenate_dimension()));
-    TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
-  }
-  return changed;
-}
-
-// Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
-StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
-  if (slice->opcode() != HloOpcode::kSlice) {
-    return false;
-  }
-  auto concat = slice->mutable_operand(0);
-  if (concat->opcode() != HloOpcode::kConcatenate) {
-    return false;
-  }
-
-  if (slice->shape().rank() != 1) {
-    // Slice concat forwarding only work for size 1 tensor.
-    return false;
-  }
-
-  int64 concat_dim = concat->concatenate_dimension();
-
-  std::vector<HloInstruction*> new_operands;
-  int64 size_so_far = 0;
-  int64 slice_size = slice->shape().dimensions(concat_dim);
-  if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
-    return false;
-  }
-  if (slice->slice_strides(0) != 1) {
-    return false;
-  }
-  for (HloInstruction* operand : concat->operands()) {
-    if (size_so_far == slice->slice_starts(0) &&
-        operand->shape().dimensions(0) == slice_size) {
-      // Found an operand that can be forwarded.
-      TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
-      return true;
-    }
-    size_so_far += operand->shape().dimensions(concat_dim);
-  }
-
-  return false;
-}
-
-// Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
-StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
-  if (reshape->opcode() != HloOpcode::kReshape) {
-    return false;
-  }
-  auto broadcast = reshape->mutable_operand(0);
-  if (broadcast->opcode() != HloOpcode::kBroadcast) {
-    return false;
-  }
-
-  if (reshape->shape().rank() != 0) {
-    return false;
-  }
-
-  if (broadcast->shape().rank() != 1) {
-    return false;
-  }
-
-  if (broadcast->mutable_operand(0)->shape().rank() != 0) {
-    return false;
-  }
-
-  TF_RETURN_IF_ERROR(
-      reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
-
-  return true;
-}
-
-// Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
-StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
-  if (reshape->opcode() != HloOpcode::kReshape) {
-    return false;
-  }
-  auto reshape_2 = reshape->mutable_operand(0);
-  if (reshape_2->opcode() != HloOpcode::kReshape) {
-    return false;
-  }
-
-  if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
-    return false;
-  }
-  TF_RETURN_IF_ERROR(
-      reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
-
-  return true;
-}
-
-// Convert(A, T->T) ==> A
-StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
-  if (convert->opcode() != HloOpcode::kConvert) {
-    return false;
-  }
-  auto operand = convert->mutable_operand(0);
-  if (Shape::Equal()(convert->shape(), operand->shape())) {
-    TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
-    return true;
-  }
-  return false;
-}
-
-// Reshape(A, S->S) ==> A
-StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
-  if (reshape->opcode() != HloOpcode::kReshape) {
-    return false;
-  }
-  auto operand = reshape->mutable_operand(0);
-  if (Shape::Equal()(reshape->shape(), operand->shape())) {
-    TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
-    return true;
-  }
-  return false;
-}
-
-}  // namespace
-
-StatusOr<bool> DynamicDimensionSimplifier::Run(HloModule* module) {
-  XLA_VLOG_LINES(
-      2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
-  bool changed = false;
-
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
-      changed |= local_changed;
-    }
-  }
-
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
-      changed |= local_changed;
-    }
-  }
-
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
-      changed |= local_changed;
-    }
-  }
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
-      changed |= local_changed;
-    }
-  }
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
-      changed |= local_changed;
-    }
-  }
-  for (auto* comp : module->MakeNonfusionComputations()) {
-    for (auto* inst : comp->MakeInstructionPostOrder()) {
-      TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
-      changed |= local_changed;
-    }
-  }
-  XLA_VLOG_LINES(
-      2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
-  return changed;
-}
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h
deleted file mode 100644
index e9b99212172..00000000000
--- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
-
-#include <utility>
-
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-
-namespace xla {
-
-// This pass simplifies operations on dynamic dimension sizes so that it can be
-// easily analyzed by later passes.
-class DynamicDimensionSimplifier : public HloModulePass {
- public:
-  absl::string_view name() const override {
-    return "dynamic dimension simplifier";
-  }
-
-  StatusOr<bool> Run(HloModule* module) override;
-};
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc
deleted file mode 100644
index 1389d06953c..00000000000
--- a/tensorflow/compiler/xla/service/dynamic_dimension_simplifier_test.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
-
-#include <memory>
-#include <utility>
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_instructions.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher.h"
-#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
-#include "tensorflow/compiler/xla/service/shape_inference.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/window_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-
-namespace xla {
-namespace {
-
-namespace m = match;
-
-class DynamicDimensionSimplifierTest : public HloTestBase {};
-
-TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1] parameter(0)
-      p1 = s32[1] parameter(1)
-      p2 = s32[1] parameter(2)
-      concat1 = s32[2] concatenate(p0, p1), dimensions={0}
-      ROOT concat2 = s32[3] concatenate(concat1, p2), dimensions={0}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
-                                        m::Parameter(2))));
-}
-
-TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatMultipleDims) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1, 1] parameter(0)
-      p1 = s32[1, 1] parameter(1)
-      p2 = s32[2, 1] parameter(2)
-      concat1 = s32[2, 1] concatenate(p0, p1), dimensions={0}
-      ROOT concat2 = s32[2, 2] concatenate(concat1, p2), dimensions={1}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
-}
-
-TEST_F(DynamicDimensionSimplifierTest, ForwardConcatSlice) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1] parameter(0)
-      p1 = s32[1] parameter(1)
-      p2 = s32[1] parameter(2)
-      concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
-      ROOT slice = s32[1] slice(concat), slice={[1:2]}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Parameter(1)));
-}
-
-TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceSizeMismatch) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1] parameter(0)
-      p1 = s32[1] parameter(1)
-      p2 = s32[1] parameter(2)
-      concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
-      ROOT slice = s32[2] slice(concat), slice={[1:3]}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
-}
-
-TEST_F(DynamicDimensionSimplifierTest, DoNotForwardConcatSliceStrided) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1] parameter(0)
-      p1 = s32[1] parameter(1)
-      p2 = s32[1] parameter(2)
-      concat = s32[3] concatenate(p0, p1, p2), dimensions={0}
-      ROOT slice = s32[1] slice(concat), slice={[1:2:2]}
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
-}
-
-TEST_F(DynamicDimensionSimplifierTest, BroadcastReshapeForwarding) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[] parameter(0)
-      broadcast = s32[1] broadcast(p0), dimensions={}
-      ROOT reshape = s32[] reshape(broadcast)
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Parameter(0)));
-}
-
-TEST_F(DynamicDimensionSimplifierTest, ReshapeReshapeForwarding) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[] parameter(0)
-      reshape = s32[1] reshape(p0)
-      ROOT reshape2 = s32[] reshape(reshape)
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Parameter(0)));
-}
-
-TEST_F(DynamicDimensionSimplifierTest,
-       DoNotReshapeReshapeForwardingShapeMismatch) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1, 1] parameter(0)
-      reshape = s32[1] reshape(p0)
-      ROOT reshape2 = s32[] reshape(reshape)
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
-}
-
-TEST_F(DynamicDimensionSimplifierTest, IdConvertRemoving) {
-  const char* kModuleStr = R"(
-    HloModule m
-    test {
-      p0 = s32[1] parameter(0)
-      ROOT reshape2 = s32[1] convert(p0)
-    }
-  )";
-  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
-  DynamicDimensionSimplifier simplifier;
-  ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
-  EXPECT_THAT(m->entry_computation()->root_instruction(),
-              GmockMatch(m::Parameter(0)));
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index 7785908e15a..ab94695c1e2 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -1282,97 +1282,6 @@ StatusOr<bool> RewriteDynamicSort(
   return true;
 }
 
-StatusOr<bool> RewriteDynamicBinaryOp(
-    HloInstruction* binary,
-    DynamicDimensionInference* dynamic_dimension_inference) {
-  HloInstruction* operand_0 = binary->mutable_operand(0);
-  HloInstruction* operand_1 = binary->mutable_operand(1);
-
-  HloComputation* comp = binary->parent();
-  TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
-  auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
-  auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
-  bool changed = false;
-  for (int64 i = 0; i < dims_0.size(); ++i) {
-    HloInstruction* dim_0 = dims_0[i];
-    HloInstruction* dim_1 = dims_1[i];
-
-    if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
-        dims_1[i] != nullptr) {
-      changed = true;
-      // It is possible that a dynamic dimension of one operand is size 1 while
-      // the other is greater than one. According to implicit broadcast
-      // semantics, we need to insert broadcast in this case to make the dynamic
-      // shape match.
-
-      // An implicit broadcast is inserted by slicing the small shape into a
-      // size 1 slice, reshape out the size 1 dimension then broadcast to the
-      // full shape:
-      //
-      // Input [2, <=5, 3]
-      //   |
-      // Slice [2, 1, 3]
-      //   |
-      // Reshape [2, 3]
-      //   |
-      // Broadcast [2, 5, 3]
-      auto rewrite_operand = [&](HloInstruction* pred,
-                                 HloInstruction* operand) -> HloInstruction* {
-        Shape static_shape = operand->shape();
-        static_shape.clear_dynamic_dimensions();
-        pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
-            ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
-        Shape slice_shape = static_shape;
-        slice_shape.set_dimensions(i, 1);
-        std::vector<int64> start_indices(slice_shape.rank(), 0);
-        std::vector<int64> strides(slice_shape.rank(), 1);
-        HloInstruction* slice = comp->AddInstruction(
-            HloInstruction::CreateSlice(slice_shape, operand, start_indices,
-                                        slice_shape.dimensions(), strides));
-        Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
-        HloInstruction* reshape = comp->AddInstruction(
-            HloInstruction::CreateReshape(reshape_shape, slice));
-        std::vector<int64> broadcast_dims;
-        broadcast_dims.reserve(static_shape.rank() - 1);
-        // Broadcast to all dims execpt for i.
-        for (int64 j = 0; j < static_shape.rank(); ++j) {
-          if (j != i) {
-            broadcast_dims.push_back(j);
-          }
-        }
-
-        HloInstruction* broadcast =
-            comp->AddInstruction(HloInstruction::CreateBroadcast(
-                                     static_shape, reshape, broadcast_dims),
-                                 "implicit_broadcast");
-
-        // Use a select instead of conditional as elementwise operations promote
-        // more fusion.
-        HloInstruction* select =
-            comp->AddInstruction(HloInstruction::CreateTernary(
-                static_shape, HloOpcode::kSelect, pred, broadcast, operand));
-        return select;
-      };
-      auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
-          HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
-                                        dim_1, ComparisonDirection::kLt),
-          "lhs_needs_implicit_broadcast");
-      operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
-
-      auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
-          HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
-                                        dim_0, ComparisonDirection::kLt),
-          "rhs_needs_implicit_broadcast");
-      operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
-    }
-  }
-  if (changed) {
-    TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
-    TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
-  }
-  return changed;
-}
-
 StatusOr<bool> RewriteDynamicReshape(
     HloInstruction* reshape,
     DynamicDimensionInference* dynamic_dimension_inference) {
@@ -1823,14 +1732,6 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
         continue;
       }
 
-      // Elementwise binary with dynamic shapes have implicit broadcast
-      // semantics.
-      if (inst->IsElementwiseBinary()) {
-        TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
-                                         inst, &dynamic_dimension_inference));
-        continue;
-      }
-
       if (inst->opcode() == HloOpcode::kDynamicReshape) {
         TF_ASSIGN_OR_RETURN(
             changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));