diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index f8ab8748285..618cc9a990f 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -40,6 +40,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index f6384c35360..8b8df527041 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -695,7 +695,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
 
 bool ModifiesInputsInPlace(const NodeDef& node) {
   // Some nodes do in-place updates on regular tensor inputs.
-  string op_name = node.op();
+  const string& op_name = node.op();
 
   // Ops that modify resource variables effectively modify one of their inputs.
   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
@@ -706,8 +706,10 @@ bool ModifiesInputsInPlace(const NodeDef& node) {
     return false;
   }
 
-  std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
-  if (absl::StrContains(op_name, "inplace")) {
+  string lower_op_name = op_name;
+  std::transform(lower_op_name.begin(), lower_op_name.end(),
+                 lower_op_name.begin(), ::tolower);
+  if (absl::StrContains(lower_op_name, "inplace")) {
     return true;
   }
   return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
@@ -862,20 +864,25 @@ bool NeverForwardsInputs(const NodeDef& node) {
       (new gtl::FlatSet<string>{"ArgMax",
                                 "ArgMin",
                                 "AudioSpectrogram",
+                                "AvgPool",
                                 "BatchMatMul",
                                 "BatchMatMulV2",
+                                "BatchNormWithGlobalNormalization",
                                 "BatchToSpace",
                                 "BatchToSpaceND",
                                 "Bincount",
                                 "BroadcastArgs",
                                 "BroadcastGradientArgs",
+                                "Bucketize",
                                 "CTCBeamSearchDecoder",
                                 "CTCGreedyDecoder",
                                 "CTCLoss",
+                                "CompareAndBitpack",
                                 "ComplexAbs",
                                 "Concat",
                                 "ConcatOffset",
                                 "ConcatV2",
+                                "Conv2D",
                                 "Copy",
                                 "CopyHost",
                                 "Cross",
@@ -890,8 +897,8 @@ bool NeverForwardsInputs(const NodeDef& node) {
                                 "CudnnRNNParamsToCanonicalV2",
                                 "CudnnRNNV2",
                                 "CudnnRNNV3",
-                                "CumSum",
                                 "CumProd",
+                                "CumSum",
                                 "DebugNanCount",
                                 "DebugNumericSummary",
                                 "DecodeProtoV2",
@@ -920,15 +927,25 @@ bool NeverForwardsInputs(const NodeDef& node) {
                                 "LowerBound",
                                 "MatMul",
                                 "MatrixDiag",
-                                "MatrixDiagV2",
                                 "MatrixDiagPart",
                                 "MatrixDiagPartV2",
+                                "MatrixDiagV2",
                                 "Mfcc",
+                                "Multinomial",
                                 "OneHot",
                                 "Pack",
+                                "ParameterizedTruncatedNormal",
                                 "PopulationCount",
+                                "RandomGamma",
+                                "RandomPoisson",
+                                "RandomPoissonV2",
+                                "RandomStandardNormal",
+                                "RandomUniform",
+                                "RandomUniformInt",
                                 "Range",
                                 "Rank",
+                                "RequantizationRange",
+                                "Requantize",
                                 "ReverseSequence",
                                 "Shape",
                                 "ShapeN",
@@ -939,6 +956,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
                                 "SparseMatMul",
                                 "Split",
                                 "SplitV",
+                                "TruncatedNormal",
                                 "Unique",
                                 "UniqueV2",
                                 "UniqueWithCounts",
@@ -946,23 +964,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
                                 "Unpack",
                                 "UnravelIndex",
                                 "UpperBound",
-                                "Where",
-                                "CompareAndBitpack",
-                                "Requantize",
-                                "RequantizationRange",
-                                "Bucketize",
-                                "AvgPool",
-                                "BatchNormWithGlobalNormalization",
-                                "Conv2D",
-                                "RandomUniform",
-                                "RandomUniformInt",
-                                "RandomStandardNormal",
-                                "ParameterizedTruncatedNormal",
-                                "TruncatedNormal",
-                                "Multinomial",
-                                "RandomGamma",
-                                "RandomPoisson",
-                                "RandomPoissonV2"}));
+                                "Where"}));
   const string& op_name = node.op();
   return kNonForwardingOps->count(op_name) > 0 ||
          absl::StrContains(op_name, "Segment") ||
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 9be609b3970..56b7754355c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -302,6 +302,7 @@ cc_library(
     visibility = ["//visibility:public"],
     deps = [
         ":arithmetic_optimizer",
+        ":common_subgraph_elimination",
         ":constant_folding",
         ":model_pruner",
         "//tensorflow/core:test",
@@ -334,6 +335,59 @@ tf_cuda_cc_test(
     ],
 )
 
+cc_library(
+    name = "common_subgraph_elimination",
+    srcs = ["common_subgraph_elimination.cc"],
+    hdrs = [
+        "common_subgraph_elimination.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_optimizer",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/grappler:graph_topology_view",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/utils:canonicalizer",
+        "//tensorflow/core/grappler/utils:topological_sort",
+        "//tensorflow/core/grappler/utils:traversal",
+        "//tensorflow/core/platform:hash",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
+    ],
+)
+
+tf_cuda_cc_test(
+    name = "common_subgraph_elimination_test",
+    size = "small",
+    srcs = ["common_subgraph_elimination_test.cc"],
+    deps = [
+        ":arithmetic_optimizer_test_utils",
+        ":common_subgraph_elimination",
+        ":model_pruner",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
+        "//tensorflow/core:all_kernels",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+        "//tensorflow/core/grappler/utils:grappler_test",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 cc_library(
     name = "dependency_optimizer",
     srcs = ["dependency_optimizer.cc"],
@@ -605,6 +659,7 @@ cc_library(
         ":arithmetic_optimizer",
         ":auto_mixed_precision",
         ":auto_parallel",
+        ":common_subgraph_elimination",
         ":constant_folding",
         ":custom_graph_optimizer_registry",
         ":debug_stripper",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3281f97457f..0b9701ca0c3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -3462,215 +3462,6 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
 
 }  // namespace
 
-class UniqueNodes {
- public:
-  NodeDef* FindOrAddRepresentative(NodeDef* node) {
-    uint64 sig = ComputeSignature(*node);
-    std::vector<NodeDef*>& candidates = rep_[sig];
-    for (auto& candidate : candidates) {
-      if ((candidate == node) || SameNode(*candidate, *node)) {
-        return candidate;
-      }
-    }
-    candidates.push_back(node);
-    return node;
-  }
-
-  void RemoveRepresentative(NodeDef* node) {
-    auto it = memoized_signatures_.find(node);
-    if (it == memoized_signatures_.end()) return;
-
-    std::vector<NodeDef*>& candidates = rep_[it->second];
-    for (int i = 0; i < candidates.size(); ++i) {
-      if (candidates[i] == node) {
-        std::swap(candidates[i], candidates[candidates.size() - 1]);
-        candidates.resize(candidates.size() - 1);
-        break;
-      }
-    }
-    memoized_signatures_.erase(node);
-  }
-
- private:
-  uint64 ComputeSignature(const NodeDef& node);
-  bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
-
-  absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
-  absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
-};
-
-uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
-  auto it = memoized_signatures_.find(&node);
-  if (it != memoized_signatures_.end()) return it->second;
-
-  uint64 h = Hash64(node.op());
-  h = Hash64Combine(Hash64(node.device()), h);
-
-  for (const auto& input : node.input()) {
-    const TensorId input_tensor = ParseTensorName(input);
-    uint64 input_hash = Hash64Combine(
-        Hash64(input_tensor.node().data(), input_tensor.node().size()),
-        std::hash<int>()(input_tensor.index()));
-    h = Hash64CombineUnordered(input_hash, h);
-  }
-  for (const auto& attr : node.attr()) {
-    uint64 attr_hash =
-        Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
-    h = Hash64CombineUnordered(attr_hash, h);
-  }
-  memoized_signatures_.emplace(&node, h);
-  return h;
-}
-
-// PRECONDITION:
-//  Node input orders are assumed to be canonicalized, i.e. control inputs for
-//  all nodes as well as regular inputs for commutative nodes must be sorted.
-bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
-  if (node1.op() != node2.op()) {
-    return false;
-  }
-  if (node1.device() != node2.device()) {
-    return false;
-  }
-  if (node1.input_size() != node2.input_size()) {
-    return false;
-  }
-  if (node1.attr_size() != node2.attr_size()) {
-    return false;
-  }
-
-  // Compare inputs.
-  auto it1 = node1.input().begin();
-  auto it2 = node2.input().begin();
-  for (; it1 != node1.input().end(); ++it1, ++it2) {
-    if (*it1 != *it2) return false;
-  }
-
-  // Compare attributes.
-  for (const auto& attr1 : node1.attr()) {
-    auto it = node2.attr().find(attr1.first);
-    if (it == node2.attr().end()) return false;
-    if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
-  }
-
-  return true;
-}
-
-bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
-  if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
-    return false;
-  }
-  if (IsEnter(node) || IsExit(node)) {
-    return false;
-  }
-  if (node.device().find("SPU") != string::npos) {
-    return false;
-  }
-  if (IsAssert(node) || IsPrint(node)) {
-    return true;
-  }
-  return IsFreeOfSideEffect(node);
-}
-
-void ArithmeticOptimizer::DedupComputations() {
-  CanonicalizeGraph(optimized_graph_);
-
-  GraphTopologyView graph_view;
-  if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
-    LOG(WARNING) << "Failed to initialize GraphTopologyView.";
-    return;
-  }
-
-  // Populate feed_inplace_op;
-  absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
-  for (const NodeDef& root : optimized_graph_->node()) {
-    if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
-
-    if (ModifiesInputsInPlace(root)) {
-      const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
-        return node->op() == root.op() || !NeverForwardsInputs(*node);
-      };
-
-      DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
-                   DfsPredicates::Advance(is_continue_traversal),
-                   DfsCallbacks::PreOrder([&](const NodeDef* node) {
-                     feeds_inplace_op.insert(node);
-                   }));
-    }
-  }
-
-  bool stop = true;
-  std::set<int> duplicates;
-  UniqueNodes nodes;
-  do {
-    stop = true;
-    for (int i = 0; i < optimized_graph_->node_size(); ++i) {
-      if (duplicates.find(i) != duplicates.end()) {
-        continue;
-      }
-      NodeDef* node = optimized_graph_->mutable_node(i);
-      if (!CanDedup(*node) ||
-          feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
-        continue;
-      }
-      NodeDef* rep = nodes.FindOrAddRepresentative(node);
-      if (rep == node) {
-        continue;
-      }
-      // If either node or rep feeds an inplace op, deduping them may cause data
-      // races. For example: If we dedup nodes initializing two independent
-      // inplace accumulations, they will write to the same buffer, clobbering
-      // each other's results.
-      if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
-        continue;
-      }
-      const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
-      std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
-      for (NodeDef* fanout : fanouts) {
-        // Update consumers of node.
-        bool updated_fanout = false;
-        for (int i = 0; i < fanout->input_size(); ++i) {
-          string* fanout_input = fanout->mutable_input(i);
-
-          const int position =
-              NodePositionIfSameNode(*fanout_input, node->name());
-          // Update name in-place.
-          if (position < -1) {
-            continue;
-          } else {
-            if (!updated_fanout) {
-              // The signature of the fanout node will change. Remove it from
-              // nodes.
-              nodes.RemoveRepresentative(fanout);
-            }
-            updated_fanout = true;
-            if (position > 0) {
-              *fanout_input = StrCat(rep->name(), ":", position);
-            } else if (position == 0) {
-              *fanout_input = rep->name();
-            } else {
-              *fanout_input = StrCat("^", rep->name());
-            }
-          }
-        }
-        if (updated_fanout) {
-          node_map_->UpdateInput(fanout->name(), node->name(), rep->name());
-          CanonicalizeNode(fanout);
-        }
-      }
-      duplicates.insert(i);
-      stop = false;
-    }
-  } while (!stop);
-
-  // Delete duplicates
-  if (fetch_nodes_known_ && !duplicates.empty()) {
-    EraseNodesFromGraph(duplicates, optimized_graph_);
-    // Rebuild the NodeMap which was invalidated by the node  swapping above.
-    node_map_.reset(new NodeMap(optimized_graph_));
-  }
-}
-
 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
   SetVector<NodeDef*> nodes_to_simplify;
   nodes_to_simplify.Reserve(optimized_graph_->node_size());
@@ -3818,11 +3609,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
 
-  if (options_.dedup_computations) {
-    DedupComputations();
-    GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
-  }
-
   graph_properties_.reset(new GraphProperties(optimized_item));
   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
   const Status status =
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index a421daa88a5..50896b11923 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -104,115 +104,6 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
   VerifyGraphsMatch(item.graph, output, __LINE__);
 }
 
-TEST_F(ArithmeticOptimizerTest, OpDedupping) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
-  Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
-  Output div = ops::Div(s.WithOpName("div"), c1, c2);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  item.fetch = {"div"};
-
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
-  ASSERT_EQ(tensors_expected.size(), 1);
-
-  ArithmeticOptimizer optimizer;
-  GraphDef output;
-  OptimizeTwice(&optimizer, &item, &output);
-  NodeMap node_map(&output);
-  EXPECT_EQ(output.node_size(), 2);
-  const NodeDef* new_c1 = node_map.GetNode("c1");
-  ASSERT_NE(new_c1, nullptr);
-
-  const NodeDef* new_div = node_map.GetNode("div");
-  ASSERT_NE(new_div, nullptr);
-  ASSERT_EQ(new_div->input_size(), 2);
-  EXPECT_EQ(new_div->input(0), "c1");
-  EXPECT_EQ(new_div->input(1), "c1");
-
-  auto tensors = EvaluateNodes(output, item.fetch);
-  ASSERT_EQ(tensors.size(), 1);
-  test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
-}
-
-TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
-  Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
-  auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
-  auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
-  auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
-  auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
-  Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
-                            {assert1.operation, assert2.operation}),
-                        check1, check2);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  item.fetch = {"div"};
-  Tensor bool_t(DT_BOOL, TensorShape({}));
-  bool_t.scalar<bool>().setConstant(true);
-  auto tensors_expected =
-      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
-  ASSERT_EQ(tensors_expected.size(), 1);
-
-  ArithmeticOptimizer optimizer;
-  GraphDef output;
-
-  OptimizeTwice(&optimizer, &item, &output);
-  NodeMap node_map(&output);
-
-  EXPECT_EQ(output.node_size(), 6);
-  const NodeDef* new_div = node_map.GetNode("div");
-  ASSERT_NE(new_div, nullptr);
-  ASSERT_EQ(new_div->input_size(), 3);
-  EXPECT_EQ(new_div->input(0), "check1");
-  EXPECT_EQ(new_div->input(1), "check2");
-  EXPECT_EQ(new_div->input(2), "^assert1");
-
-  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
-  EXPECT_EQ(tensors.size(), 1);
-  test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
-}
-
-TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
-  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
-  Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
-  Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
-  Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
-  Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
-  GrapplerItem item;
-  TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  item.fetch = {"div1"};
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
-  ASSERT_EQ(tensors_expected.size(), 1);
-
-  ArithmeticOptimizer optimizer;
-  GraphDef output;
-  OptimizeTwice(&optimizer, &item, &output);
-  NodeMap node_map(&output);
-
-  EXPECT_EQ(output.node_size(), 4);
-  const NodeDef* new_c1 = node_map.GetNode("c1");
-  ASSERT_NE(new_c1, nullptr);
-  const NodeDef* new_c2 = node_map.GetNode("c2");
-  ASSERT_NE(new_c2, nullptr);
-  const NodeDef* new_mul1 = node_map.GetNode("mul1");
-  ASSERT_NE(new_mul1, nullptr);
-  ASSERT_EQ(new_mul1->input_size(), 2);
-  EXPECT_EQ(new_mul1->input(0), "c1");
-  EXPECT_EQ(new_mul1->input(1), "c2");
-  const NodeDef* new_div1 = node_map.GetNode("div1");
-  ASSERT_NE(new_div1, nullptr);
-  ASSERT_EQ(new_div1->input_size(), 2);
-  EXPECT_EQ(new_div1->input(0), "mul1");
-  EXPECT_EQ(new_div1->input(1), "mul1");
-
-  auto tensors = EvaluateNodes(output, item.fetch);
-  ASSERT_EQ(tensors.size(), 1);
-  test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
-}
-
 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
@@ -474,6 +365,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   Output id = ops::Identity(s.WithOpName("id"), add6);
 
   GrapplerItem item;
+  item.fetch = {"id"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
   const std::vector<string> devices{
@@ -488,16 +380,16 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   DisableAddToAddNCombining(&optimizer);
 
   GraphDef output;
-  OptimizeTwice(&optimizer, &item, &output);
+  DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   // We expect the following rewrite(s) to occur:
   //
   // Mul(p,
   //     Add_6(Add_4(Const(2), Const(2)),
-  //           Add_5(Const(2), Const(2))))
+  //           Add_5(Const(2), Const(2)))
   NodeMap node_map(&output);
 
-  EXPECT_EQ(output.node_size(), 17);
+  EXPECT_EQ(output.node_size(), 8);
 
   const NodeDef* id_node = node_map.GetNode("id");
   ASSERT_NE(id_node, nullptr);
@@ -507,8 +399,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
   ASSERT_NE(mul_node, nullptr);
   ASSERT_EQ(mul_node->input_size(), 2);
-  EXPECT_EQ(mul_node->input(0), HoistAddName("Add_6"));
-  EXPECT_EQ(mul_node->input(1), "Placeholder");
+  EXPECT_EQ(mul_node->input(0), "Placeholder");
+  EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
 
   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
   ASSERT_NE(add_6_node, nullptr);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
index 4d3ba976c4f..73bb5a0d97c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
 
 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
 #include "tensorflow/core/grappler/utils/grappler_test.h"
@@ -27,9 +28,9 @@ namespace grappler {
 
 class ArithmeticOptimizerTest : public GrapplerTest {
  protected:
-  // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
+  // Optimize a graph using optimizer and prune all the nodes that no
   // longer have any output consumers.
-  void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+  void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
                         GraphDef* output) {
     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
     item->graph.Swap(output);
@@ -37,8 +38,23 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
   }
 
-  // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
-  void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+  // Run optimizer twice to make sure the rewrite is idempotent.
+  void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer,
+                                     GrapplerItem* item, GraphDef* output) {
+    TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output));
+    item->graph.Swap(output);
+    output->Clear();
+    TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+    item->graph.Swap(output);
+    output->Clear();
+    TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+    item->graph.Swap(output);
+    output->Clear();
+    TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
+  }
+
+  // Run optimizer twice to make sure the rewrite is idempotent.
+  void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item,
                      GraphDef* output) {
     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
     item->graph.Swap(output);
@@ -46,9 +62,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
   }
 
-  // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+  // Run optimizer twice to make sure the rewrite is idempotent.
   // Optionally run a constant folding pass before pruning.
-  void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+  void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
                              GraphDef* output, bool const_folding = false) {
     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
 
diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc
new file mode 100644
index 00000000000..8924e4c6bea
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc
@@ -0,0 +1,291 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
+
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/grappler/graph_topology_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/canonicalizer.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/grappler/utils/traversal.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/hash.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/strcat.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace grappler {
+class Cluster;
+}  // namespace grappler
+}  // namespace tensorflow
+
+using tensorflow::strings::StrCat;
+
+namespace tensorflow {
+namespace grappler {
+
+class UniqueNodes {
+ public:
+  NodeDef* FindOrAddRepresentative(NodeDef* node) {
+    uint64 sig = ComputeSignature(*node);
+    std::vector<NodeDef*>& candidates = rep_[sig];
+    for (auto& candidate : candidates) {
+      if ((candidate == node) || SameNode(*candidate, *node)) {
+        return candidate;
+      }
+    }
+    candidates.push_back(node);
+    return node;
+  }
+
+  void RemoveRepresentative(NodeDef* node) {
+    auto it = memoized_signatures_.find(node);
+    if (it == memoized_signatures_.end()) return;
+
+    std::vector<NodeDef*>& candidates = rep_[it->second];
+    for (int i = 0; i < candidates.size(); ++i) {
+      if (candidates[i] == node) {
+        std::swap(candidates[i], candidates[candidates.size() - 1]);
+        candidates.resize(candidates.size() - 1);
+        break;
+      }
+    }
+    memoized_signatures_.erase(node);
+  }
+
+ private:
+  uint64 ComputeSignature(const NodeDef& node);
+  bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
+
+  absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
+  absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
+};
+
+uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
+  auto it = memoized_signatures_.find(&node);
+  if (it != memoized_signatures_.end()) return it->second;
+
+  uint64 h = Hash64(node.op());
+  h = Hash64Combine(Hash64(node.device()), h);
+
+  for (const auto& input : node.input()) {
+    const TensorId input_tensor = ParseTensorName(input);
+    uint64 input_hash = Hash64Combine(
+        Hash64(input_tensor.node().data(), input_tensor.node().size()),
+        std::hash<int>()(input_tensor.index()));
+    h = Hash64CombineUnordered(input_hash, h);
+  }
+  for (const auto& attr : node.attr()) {
+    uint64 attr_hash =
+        Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
+    h = Hash64CombineUnordered(attr_hash, h);
+  }
+  memoized_signatures_.emplace(&node, h);
+  return h;
+}
+
+// PRECONDITION:
+//  Node input orders are assumed to be canonicalized, i.e. control inputs for
+//  all nodes as well as regular inputs for commutative nodes must be sorted.
+bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
+  if (node1.op() != node2.op()) {
+    return false;
+  }
+  if (node1.device() != node2.device()) {
+    return false;
+  }
+  if (node1.input_size() != node2.input_size()) {
+    return false;
+  }
+  if (node1.attr_size() != node2.attr_size()) {
+    return false;
+  }
+
+  // Compare inputs.
+  auto it1 = node1.input().begin();
+  auto it2 = node2.input().begin();
+  for (; it1 != node1.input().end(); ++it1, ++it2) {
+    if (*it1 != *it2) return false;
+  }
+
+  // Compare attributes.
+  for (const auto& attr1 : node1.attr()) {
+    auto it = node2.attr().find(attr1.first);
+    if (it == node2.attr().end()) return false;
+    if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
+  }
+
+  return true;
+}
+
+bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
+  if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
+    return false;
+  }
+  if (IsEnter(node) || IsExit(node)) {
+    return false;
+  }
+  if (node.device().find("SPU") != string::npos) {
+    return false;
+  }
+  // Workaround for Assert and Print mistakenly being labeled as stateful.
+  if (IsAssert(node) || IsPrint(node)) {
+    return true;
+  }
+  return IsFreeOfSideEffect(node);
+}
+
+Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
+  CanonicalizeGraph(optimized_graph);
+
+  GraphTopologyView graph_view;
+  if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
+    LOG(WARNING) << "Failed to initialize GraphTopologyView.";
+    return Status::OK();
+  }
+
+  // If either node or rep feeds an inplace op, deduping them may cause data
+  // races. For example: If we dedup nodes initializing two independent
+  // inplace accumulations, they will write to the same buffer, clobbering
+  // each other's results.
+  absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
+  for (int i = 0; i < optimized_graph->node_size(); ++i) {
+    const NodeDef& root = optimized_graph->node(i);
+    if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
+    if (ModifiesInputsInPlace(root)) {
+      const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
+        return node->op() == root.op() || !NeverForwardsInputs(*node);
+      };
+
+      DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
+                   DfsPredicates::Advance(is_continue_traversal),
+                   DfsCallbacks::PreOrder([&](const NodeDef* node) {
+                     feeds_inplace_op.insert(node);
+                   }));
+    }
+  }
+
+  std::vector<bool> can_dedup(optimized_graph->node_size());
+  for (int i = 0; i < optimized_graph->node_size(); ++i) {
+    const NodeDef& node = optimized_graph->node(i);
+    can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
+                   CanDedup(node);
+  }
+
+  bool stop = true;
+  std::set<int> duplicates;
+  UniqueNodes nodes;
+  NodeMap node_map(optimized_graph);
+  do {
+    stop = true;
+    for (int i = 0; i < optimized_graph->node_size(); ++i) {
+      if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
+        continue;
+      }
+      NodeDef* node = optimized_graph->mutable_node(i);
+      NodeDef* rep = nodes.FindOrAddRepresentative(node);
+      if (rep == node) {
+        continue;
+      }
+      const std::set<NodeDef*>& tmp = node_map.GetOutputs(node->name());
+      std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
+      for (NodeDef* fanout : fanouts) {
+        // Update consumers of node.
+        bool updated_fanout = false;
+        for (int i = 0; i < fanout->input_size(); ++i) {
+          string* fanout_input = fanout->mutable_input(i);
+
+          const int position =
+              NodePositionIfSameNode(*fanout_input, node->name());
+          // Update name in-place.
+          if (position < -1) {
+            continue;
+          } else {
+            if (!updated_fanout) {
+              // The signature of the fanout node will change. Remove it from
+              // nodes.
+              nodes.RemoveRepresentative(fanout);
+            }
+            updated_fanout = true;
+            if (position > 0) {
+              *fanout_input = StrCat(rep->name(), ":", position);
+            } else if (position == 0) {
+              *fanout_input = rep->name();
+            } else {
+              *fanout_input = StrCat("^", rep->name());
+            }
+          }
+        }
+        if (updated_fanout) {
+          node_map.UpdateInput(fanout->name(), node->name(), rep->name());
+          CanonicalizeNode(fanout);
+        }
+      }
+      duplicates.insert(i);
+      stop = false;
+    }
+  } while (!stop);
+
+  // Delete duplicates
+  if (fetch_nodes_known_ && !duplicates.empty()) {
+    EraseNodesFromGraph(duplicates, optimized_graph);
+  }
+
+  return Status::OK();
+}
+
+Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
+                                           const GrapplerItem& item,
+                                           GraphDef* optimized_graph) {
+  // Set up helper data structures.
+  nodes_to_preserve_ = item.NodesToPreserve();
+  fetch_nodes_known_ = !item.fetch.empty();
+  *optimized_graph = item.graph;
+
+  // Perform topological sort on the graph in order to help DedupComputations
+  // optimize larger subgraphs starting from the roots with more inputs.
+  TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+  GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
+
+  return DedupComputations(optimized_graph);
+}
+
+void CommonSubgraphElimination::Feedback(Cluster* /*cluster*/,
+                                         const GrapplerItem& /*item*/,
+                                         const GraphDef& /*optimized_graph*/,
+                                         double /*result*/) {
+  // Nothing to do for ArithmeticOptimizer.
+}
+
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h
new file mode 100644
index 00000000000..eec6ba79b3f
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h
@@ -0,0 +1,73 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/hash.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Optimize TF computations by deduping equivalent subgraphs.
+class Cluster;
+struct GrapplerItem;
+
+class CommonSubgraphElimination : public GraphOptimizer {
+ public:
+  CommonSubgraphElimination() {}
+
+  explicit CommonSubgraphElimination(RewriterConfig::Toggle opt_level)
+      : opt_level_(opt_level) {}
+
+  ~CommonSubgraphElimination() override {}
+
+  string name() const override { return "common_subgraph_elimination"; };
+
+  bool UsesFunctionLibrary() const override { return false; }
+
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* optimized_graph) override;
+
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimized_graph, double result) override;
+
+ private:
+  friend class CommonSubgraphEliminationTest;
+
+  // Returns true if it is safe to dedup node from the graph.
+  bool CanDedup(const NodeDef& node) const;
+
+  // Dedup redundant nodes in the graph.
+  Status DedupComputations(GraphDef* optimized_graph);
+
+  RewriterConfig::Toggle opt_level_;
+
+  bool fetch_nodes_known_ = false;
+  std::unordered_set<string> nodes_to_preserve_;
+};
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc
new file mode 100644
index 00000000000..3341a8abe56
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc
@@ -0,0 +1,178 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+
+void VerifyGraphsMatch(const GraphDef& original_graph,
+                       const GraphDef& optimized_graph, int line) {
+  EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
+  for (int i = 0; i < original_graph.node_size(); ++i) {
+    const NodeDef& original = original_graph.node(i);
+    const NodeDef& optimized = optimized_graph.node(i);
+    EXPECT_EQ(original.name(), optimized.name()) << line;
+    EXPECT_EQ(original.op(), optimized.op()) << line;
+    EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
+    for (int j = 0; j < original.input_size(); ++j) {
+      EXPECT_EQ(original.input(j), optimized.input(j)) << line;
+    }
+  }
+}
+}  // namespace
+
+class CommonSubgraphEliminationTest : public ArithmeticOptimizerTest {};
+
+TEST_F(CommonSubgraphEliminationTest, NoOp) {
+  // This trivial graph is so basic there's nothing to optimize.
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  CommonSubgraphElimination optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  VerifyGraphsMatch(item.graph, output, __LINE__);
+}
+
+TEST_F(CommonSubgraphEliminationTest, OpDedupping) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
+  Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
+  Output div = ops::Div(s.WithOpName("div"), c1, c2);
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"div"};
+
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  ASSERT_EQ(tensors_expected.size(), 1);
+
+  CommonSubgraphElimination optimizer;
+  GraphDef output;
+  OptimizeTwice(&optimizer, &item, &output);
+  NodeMap node_map(&output);
+  EXPECT_EQ(output.node_size(), 2);
+  const NodeDef* new_c1 = node_map.GetNode("c1");
+  ASSERT_NE(new_c1, nullptr);
+
+  const NodeDef* new_div = node_map.GetNode("div");
+  ASSERT_NE(new_div, nullptr);
+  ASSERT_EQ(new_div->input_size(), 2);
+  EXPECT_EQ(new_div->input(0), "c1");
+  EXPECT_EQ(new_div->input(1), "c1");
+
+  auto tensors = EvaluateNodes(output, item.fetch);
+  ASSERT_EQ(tensors.size(), 1);
+  test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
+}
+
+TEST_F(CommonSubgraphEliminationTest, OpDeduppingAssertAndCheckNumerics) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
+  Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
+  auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
+  auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
+  auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
+  auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
+  Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
+                            {assert1.operation, assert2.operation}),
+                        check1, check2);
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"div"};
+  Tensor bool_t(DT_BOOL, TensorShape({}));
+  bool_t.scalar<bool>().setConstant(true);
+  auto tensors_expected =
+      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
+  ASSERT_EQ(tensors_expected.size(), 1);
+
+  CommonSubgraphElimination optimizer;
+  GraphDef output;
+
+  OptimizeTwice(&optimizer, &item, &output);
+  NodeMap node_map(&output);
+
+  EXPECT_EQ(output.node_size(), 6);
+  const NodeDef* new_div = node_map.GetNode("div");
+  ASSERT_NE(new_div, nullptr);
+  ASSERT_EQ(new_div->input_size(), 3);
+  EXPECT_EQ(new_div->input(0), "check1");
+  EXPECT_EQ(new_div->input(1), "check2");
+  EXPECT_EQ(new_div->input(2), "^assert1");
+
+  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
+  EXPECT_EQ(tensors.size(), 1);
+  test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
+}
+
+TEST_F(CommonSubgraphEliminationTest, OpDedupCommutative) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
+  Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
+  Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
+  Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
+  Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"div1"};
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  ASSERT_EQ(tensors_expected.size(), 1);
+
+  CommonSubgraphElimination optimizer;
+  GraphDef output;
+  OptimizeTwice(&optimizer, &item, &output);
+  NodeMap node_map(&output);
+
+  EXPECT_EQ(output.node_size(), 4);
+  const NodeDef* new_c1 = node_map.GetNode("c1");
+  ASSERT_NE(new_c1, nullptr);
+  const NodeDef* new_c2 = node_map.GetNode("c2");
+  ASSERT_NE(new_c2, nullptr);
+  const NodeDef* new_mul1 = node_map.GetNode("mul1");
+  ASSERT_NE(new_mul1, nullptr);
+  ASSERT_EQ(new_mul1->input_size(), 2);
+  EXPECT_EQ(new_mul1->input(0), "c1");
+  EXPECT_EQ(new_mul1->input(1), "c2");
+  const NodeDef* new_div1 = node_map.GetNode("div1");
+  ASSERT_NE(new_div1, nullptr);
+  ASSERT_EQ(new_div1->input_size(), 2);
+  EXPECT_EQ(new_div1->input(0), "mul1");
+  EXPECT_EQ(new_div1->input(1), "mul1");
+
+  auto tensors = EvaluateNodes(output, item.fetch);
+  ASSERT_EQ(tensors.size(), 1);
+  test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
+}
+
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 3cf20ca7dab..77b2caf87b0 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -383,7 +383,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
         op != "TensorArraySizeV3") {
       continue;
     }
-
     const std::vector<OpInfo::TensorProperties>& output =
         properties.GetOutputProperties(node->name());
     const std::vector<OpInfo::TensorProperties>& input =
@@ -410,8 +409,16 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
         continue;
       }
 
+      // TODO(rmlarsen): Remove this workaround for b/150861569
+      // The bug involves an expression of the form Shape(ExpandDims(x)
+      // with an incorrectly inferred zero-size first dimension.
+      if (op == "Shape") {
+        if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
+      }
+
       // Repurpose the existing node to be the constant.
       // Device placement is preserved.
+      graph_modified_ = true;
       node->set_op("Const");
       node->clear_attr();
       (*node->mutable_attr())["dtype"].set_type(type);
@@ -424,9 +431,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
       // the original graph.
       string ctrl_dep =
           AddControlDependency(node->input(0), graph_, node_map_.get());
+      node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
       node->set_input(0, ctrl_dep);
-      node_map_->AddOutput(NodeName(ctrl_dep), node->name());
-
       // Done with the Shape/Size/Rank node, move to the next node.
       continue;
     }
@@ -458,6 +464,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
           continue;
         }
 
+        graph_modified_ = true;
         node->set_op("Const");
         *node->mutable_attr() = array_size->attr();
         node->set_input(0, AsControlDependency(NodeName(node->input(0))));
@@ -519,6 +526,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
             }
             *output->mutable_input(k) = const_name;
             node_map_->AddOutput(const_name, output->name());
+            graph_modified_ = true;
           }
           if (node_name == shape_n_node->name() && port != port_idx) {
             direct_edges_exist = true;
@@ -3705,8 +3713,9 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
 }
 
 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
-                                            const GrapplerItem& item,
+                                            GrapplerItem* item,
                                             GraphDef* optimized_graph) {
+  graph_ = &item->graph;
   node_map_.reset(new NodeMap(graph_));
   nodes_whitelist_.clear();
   // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
@@ -3716,14 +3725,14 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   // replace the node with multiple constants (each for one fanout) with
   // new names, and as a result users would not be able to fetch the node any
   // more with the original node name.
-  for (const auto& fetch : item.fetch) {
+  for (const auto& fetch : item->fetch) {
     const NodeDef* fetch_node = node_map_->GetNode(fetch);
     if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
       nodes_whitelist_.insert(fetch_node->name());
     }
   }
 
-  GraphProperties properties(item);
+  GraphProperties properties(*item);
   // It's possible to feed a placeholder with a tensor of any shape: make sure
   // that the shape inference deals with this conservatively unless we're in
   // aggressive mode.
@@ -3732,15 +3741,18 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
                                         /*aggressive_shape_inference=*/false,
                                         /*include_input_tensor_values=*/false,
                                         /*include_output_tensor_values=*/true);
+
   const bool can_use_shape_info = s.ok();
 
+  absl::flat_hash_set<string> nodes_to_not_simplify;
   if (can_use_shape_info) {
     TF_RETURN_IF_ERROR(MaterializeShapes(properties));
     TF_RETURN_IF_ERROR(MaterializeConstants(properties));
+    TF_RETURN_IF_ERROR(
+        FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
+  } else {
+    *optimized_graph = *graph_;
   }
-  absl::flat_hash_set<string> nodes_to_not_simplify;
-  TF_RETURN_IF_ERROR(
-      FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
   node_map_.reset(new NodeMap(optimized_graph));
   TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
                                    &properties, &nodes_to_not_simplify));
@@ -3795,11 +3807,10 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
     graph_modified_ = false;
     item_to_optimize.graph.Swap(optimized_graph);
-    graph_ = &item_to_optimize.graph;
-    *optimized_graph = GraphDef();
-    node_count = graph_->node_size();
+    optimized_graph->Clear();
+    node_count = item_to_optimize.graph.node_size();
     TF_RETURN_IF_ERROR(
-        RunOptimizationPass(cluster, item_to_optimize, optimized_graph));
+        RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
   } while (graph_modified_ || optimized_graph->node_size() != node_count);
   *optimized_graph->mutable_library() = item.graph.library();
   *optimized_graph->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 88c4094bb1a..074f0c5f057 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -131,8 +131,8 @@ class ConstantFolding : public GraphOptimizer {
   Status SimplifyNode(bool use_shape_info, NodeDef* node,
                       GraphDef* optimized_graph, GraphProperties* properties);
 
-  Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
-                             GraphDef* output);
+  Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item,
+                             GraphDef* optimized_graph);
 
   // Applies partial constant folding for Concat which is not commutative.
   // Returns true if the transformation applied successfully.
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
index 238606ee673..de678d0a390 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -17,9 +17,12 @@ limitations under the License.
 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
 
 #include <string>
+
 #include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace grappler {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index da83e413ff6..82758d1f970 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
@@ -184,6 +185,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
   MK_OPT("auto_mixed_precision",
          new AutoMixedPrecision(cfg_.auto_mixed_precision()));
   MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
+  MK_OPT("common_subgraph_elimination",
+         new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
   MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
   MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
   MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
@@ -224,6 +227,11 @@ Status MetaOptimizer::InitializeOptimizers(
         cfg_.function_optimization(),
         /*lower_control_flow=*/!IsSingleThreadedExecutor()));
   }
+  if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF &&
+      cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
+    optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
+        cfg_.common_subgraph_elimination()));
+  }
   if (cfg_.debug_stripper() == RewriterConfig::ON) {
     optimizers->push_back(MakeUnique<DebugStripper>());
   }
@@ -812,6 +820,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
          rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
          rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
          rewrite_cfg.remapping() != RewriterConfig::OFF ||
+         rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
          rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
          rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
          rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 87835245762..a50c6f71fee 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -23,6 +23,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
+#include "absl/container/node_hash_map.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "tensorflow/core/framework/graph.pb.h"
@@ -65,8 +66,8 @@ class NodeMap {
 
  private:
   const std::set<NodeDef*> empty_set_;
-  gtl::FlatMap<string, NodeDef*> nodes_;
-  gtl::FlatMap<string, std::set<NodeDef*>> outputs_;
+  absl::node_hash_map<string, NodeDef*> nodes_;
+  absl::node_hash_map<string, std::set<NodeDef*>> outputs_;
 };
 
 // A vector with a set. The set stores the same elements as the vector, and
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index e657a184b65..38c3ad7ae57 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -60,6 +60,9 @@ message RewriterConfig {
   // Remapping (default is ON)
   // Remap subgraphs onto more efficient implementations.
   Toggle remapping = 14;
+  // Common subgraph elimination (default is ON)
+  // e.g. Simplify arithmetic ops; merge ops with same value (like constants).
+  Toggle common_subgraph_elimination = 24;
   // Arithmetic optimizations (default is ON)
   // e.g. Simplify arithmetic ops; merge ops with same value (like constants).
   Toggle arithmetic_optimization = 7;