diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD
index 69238456d57..bffb44c1b97 100644
--- a/tensorflow/core/tpu/graph_rewrite/BUILD
+++ b/tensorflow/core/tpu/graph_rewrite/BUILD
@@ -13,6 +13,7 @@ cc_library(
     srcs = ["tpu_rewrite_pass_registration.cc"],
     deps = [
         ":distributed_tpu_configuration_rewrite_pass",
+        ":distributed_tpu_rewrite_pass",
         ":encapsulate_tpu_computations_pass",
         ":variable_merger_pass",
         "//tensorflow/core:core_cpu",
@@ -101,3 +102,120 @@ cc_library(
         "@com_google_absl//absl/strings",
     ],
 )
+
+cc_library(
+    name = "distributed_tpu_rewrite_pass_internal",
+    srcs = ["distributed_tpu_rewrite_pass_internal.cc"],
+    hdrs = ["distributed_tpu_rewrite_pass_internal.h"],
+    deps = [
+        "//tensorflow/core:framework",
+        "@com_google_absl//absl/random",
+    ],
+)
+
+cc_library(
+    name = "distributed_tpu_rewrite_pass",
+    srcs = [
+        "distributed_tpu_rewrite_pass.cc",
+    ],
+    hdrs = [
+        "distributed_tpu_rewrite_pass.h",
+    ],
+    deps = [
+        ":cond_builder",
+        ":distributed_tpu_rewrite_helpers",
+        ":distributed_tpu_rewrite_pass_internal",
+        ":host_training_loop_optimization_util",
+        ":incomplete_nodedef_builder",
+        "//tensorflow/compiler/jit:encapsulate_util",
+        "//tensorflow/compiler/jit:shape_inference",
+        "//tensorflow/compiler/tf2xla:resource_operation_table",
+        "//tensorflow/compiler/tf2xla:sharding_util",
+        "//tensorflow/compiler/tf2xla:side_effect_util",
+        "//tensorflow/compiler/tf2xla:tf2xla_util",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:array3d",
+        "//tensorflow/compiler/xla:array4d",
+        "//tensorflow/compiler/xla:xla_proto_cc",
+        "//tensorflow/compiler/xla/client:sharding_builder",
+        "//tensorflow/compiler/xla/service:computation_placer",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:session_options",
+        "//tensorflow/core/common_runtime:function",
+        "//tensorflow/core/common_runtime:graph_constructor",
+        "//tensorflow/core/common_runtime:lower_function_call_op",
+        "//tensorflow/core/common_runtime:lower_functional_ops",
+        "//tensorflow/core/common_runtime:lower_if_op",
+        "//tensorflow/core/common_runtime:lower_while_op",
+        "//tensorflow/core/common_runtime:optimization_registry",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
+        "//tensorflow/core/protobuf/tpu:topology_proto_cc",
+        "//tensorflow/core/tpu:tpu_compile_interface",
+        "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
+        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+        "//tensorflow/stream_executor/tpu:tpu_topology_external",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
+    name = "incomplete_nodedef_builder",
+    srcs = ["incomplete_nodedef_builder.cc"],
+    hdrs = ["incomplete_nodedef_builder.h"],
+    deps = [
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "cond_builder",
+    srcs = ["cond_builder.cc"],
+    hdrs = ["cond_builder.h"],
+    deps = [
+        ":incomplete_nodedef_builder",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "host_training_loop_optimization_util",
+    srcs = [
+        "host_training_loop_optimization_util.cc",
+    ],
+    hdrs = [
+        "host_training_loop_optimization_util.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":distributed_tpu_rewrite_pass_internal",
+        "//tensorflow/compiler/tf2xla:functionalize_control_flow_util",
+        "//tensorflow/compiler/tf2xla:tf2xla_util",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/container:node_hash_set",
+        "@com_google_absl//absl/types:optional",
+    ],
+)
diff --git a/tensorflow/core/tpu/graph_rewrite/cond_builder.cc b/tensorflow/core/tpu/graph_rewrite/cond_builder.cc
new file mode 100644
index 00000000000..e16ae08aec3
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/cond_builder.cc
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
+
+namespace tensorflow {
+
+CondBuilder::CondBuilder(string name, string device, const NodeDebugInfo& debug,
+                         Graph* graph)
+    : graph_(graph), name_(std::move(name)), device_(std::move(device)) {
+  auto new_name = [graph, this](string suffix) {
+    return graph->NewName(strings::StrCat(name_, "/", suffix));
+  };
+  TF_CHECK_OK(
+      IncompleteNodeDefBuilder::Identity(new_name("pred"), DT_BOOL, debug)
+          .Device(device_)
+          .Build(graph_, &pred_));
+  Node* switch_pred;
+  TF_CHECK_OK(
+      IncompleteNodeDefBuilder::Switch(new_name("switch_pred"), DT_BOOL, debug)
+          .Device(device_)
+          .Build(graph_, &switch_pred));
+  graph_->AddEdge(pred(), 0, switch_pred, 0);
+  graph_->AddEdge(pred(), 0, switch_pred, 1);
+  TF_CHECK_OK(
+      IncompleteNodeDefBuilder::Identity(new_name("switch_f"), DT_BOOL, debug)
+          .Device(device_)
+          .Build(graph_, &switch_f_));
+  TF_CHECK_OK(
+      IncompleteNodeDefBuilder::Identity(new_name("switch_t"), DT_BOOL, debug)
+          .Device(device_)
+          .Build(graph_, &switch_t_));
+  graph_->AddEdge(switch_pred, kElseBranch, switch_f_, 0);
+  graph_->AddEdge(switch_pred, kThenBranch, switch_t_, 0);
+  Node* merge_pred;
+  TF_CHECK_OK(IncompleteNodeDefBuilder::Merge(new_name("merge_pred"), DT_BOOL,
+                                              debug, /*n=*/2)
+                  .Device(device_)
+                  .Build(graph_, &merge_pred));
+  graph_->AddEdge(switch_f_, 0, merge_pred, kElseBranch);
+  graph_->AddEdge(switch_t_, 0, merge_pred, kThenBranch);
+  // Note: when additional return values are added then there should be a
+  // control dependency between those merge nodes and control_successor_ to
+  // ensure that it is control successor of conditional.
+  control_successor_ = merge_pred;
+}
+
+Node* CondBuilder::pred() { return pred_; }
+
+Node* CondBuilder::switch_f() { return switch_f_; }
+
+Node* CondBuilder::switch_t() { return switch_t_; }
+
+Node* CondBuilder::control_successor() { return control_successor_; }
+
+Status CondBuilder::AddInput(const string& input_name, const DataType& type,
+                             const string& device, const NodeDebugInfo& debug,
+                             Node** input) {
+  auto b = IncompleteNodeDefBuilder::Switch(
+      graph_->NewName(strings::StrCat(name_, "/", input_name)), type, debug);
+  TF_RETURN_IF_ERROR(b.Device(device).Build(graph_, input));
+  graph_->AddEdge(pred(), 0, *input, 1);
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/cond_builder.h b/tensorflow/core/tpu/graph_rewrite/cond_builder.h
new file mode 100644
index 00000000000..29e264dfc0a
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/cond_builder.h
@@ -0,0 +1,74 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
+
+#include <string>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Conditional builder.
+// Convenience builder to make it easy to construct a conditional. E.g.,
+//   Node* pred = ...;
+//   CondBuilder cb("cond", g);
+//   auto switch_var = cb.AddInput("var", DT_RESOURCE);
+//   g->AddEdge(pred, 0, cb.pred(), 0);
+// Will create the nodes of a conditional that takes as input a resource
+// variable ("var") as input and that switches on pred.
+//
+// This currently only handles the case needed by distributed_tpu_rewrite_pass
+// and is not completely general.
+class CondBuilder {
+ public:
+  enum Branch { kElseBranch = 0, kThenBranch = 1 };
+
+  CondBuilder(string name, string device, const NodeDebugInfo& debug,
+              Graph* graph);
+
+  // Returns node corresponding to the predicate input.
+  Node* pred();
+
+  // Returns node corresponding to switch_f branch of predicate switch.
+  Node* switch_f();
+
+  // Returns node corresponding to switch_t branch of predicate switch.
+  Node* switch_t();
+
+  // Returns node corresponding to control successor.
+  Node* control_successor();
+
+  // Returns the Switch node to feed a value of the given type into the
+  // conditional.
+  Status AddInput(const string& input_name, const DataType& type,
+                  const string& device, const NodeDebugInfo& debug,
+                  Node** input);
+
+ private:
+  Node* control_successor_;
+  Node* switch_f_;
+  Node* switch_t_;
+  Node* pred_;
+  Graph* const graph_;
+  const string name_;
+  const string device_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
new file mode 100644
index 00000000000..f0032f5dfd9
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -0,0 +1,4105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
+
+#include <queue>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/compiler/jit/encapsulate_util.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/sharding_builder.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_constructor.h"
+#include "tensorflow/core/common_runtime/lower_function_call_op.h"
+#include "tensorflow/core/common_runtime/lower_functional_ops.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
+#include "tensorflow/core/protobuf/tpu/topology.pb.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
+#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
+#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
+#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
+#include "tensorflow/core/tpu/tpu_compile_interface.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/dump_graph.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
+// topology.
+constexpr int kTPUTopologyRank = 4;
+
+// An upper bound on how many cores may be present in the topology.
+static constexpr int kTPUMaxTopologySize = 4096;
+
+// Attribute containing the serialized xla::OpSharding to be passed to the
+// corresponding XLA HLO operation, which represents how a shape is distributed
+// across logical cores, e.g., replication, single-device, or partitioning.
+const char kShardingAttribute[] = "_XlaSharding";
+
+const char kTPUPartitionedInput[] = "TPUPartitionedInput";
+const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
+
+static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
+static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
+
+class IntrusiveHeapLink {
+ public:
+  using size_type = size_t;
+  static constexpr size_type kNotMember = -1;
+
+  IntrusiveHeapLink() = default;
+
+  // Only IntrusiveHeap and LinkAccess objects should make these objects.
+  explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
+
+  // Only IntrusiveHeap and LinkAccess should get the value.
+  size_type get() const { return pos_; }
+
+ private:
+  size_type pos_{kNotMember};
+};
+
+template <typename T, IntrusiveHeapLink T::*M>
+struct IntrusiveHeapDataMemberLinkAccess {
+  IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
+  void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
+};
+
+template <typename T>
+struct DefaultIntrusiveHeapLinkAccess {
+  IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
+  void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
+};
+
+template <typename T, typename PtrCompare,
+          typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
+          typename Alloc = std::allocator<T*>>
+class IntrusiveHeap {
+ public:
+  typedef typename IntrusiveHeapLink::size_type size_type;
+  typedef T value_type;
+  typedef T* pointer;
+  typedef const T* const_pointer;
+  typedef PtrCompare pointer_compare_type;
+  typedef LinkAccess link_access_type;
+  typedef Alloc allocator_type;
+
+  explicit IntrusiveHeap(
+      const pointer_compare_type& comp = pointer_compare_type(),
+      const link_access_type& link_access = link_access_type(),
+      const allocator_type& alloc = allocator_type())
+      : rep_(comp, link_access, alloc) {}
+
+  size_type size() const { return heap().size(); }
+
+  bool empty() const { return heap().empty(); }
+
+  // Return the top element, but don't remove it.
+  pointer top() const {
+    DCHECK(!empty());
+    return heap()[0];
+  }
+
+  // Remove the top() pointer from the heap and return it.
+  pointer Pop() {
+    pointer t = top();
+    Remove(t);
+    return t;
+  }
+
+  // Insert 't' into the heap.
+  void Push(pointer t) {
+    SetPositionOf(t, heap().size());
+    heap().push_back(t);
+    FixHeapUp(t);
+  }
+
+  // Adjust the heap to accommodate changes in '*t'.
+  void Adjust(pointer t) {
+    DCHECK(Contains(t));
+    size_type h = GetPositionOf(t);
+    if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
+      FixHeapUp(t);
+    } else {
+      FixHeapDown(t);
+    }
+  }
+
+  // Remove the specified pointer from the heap.
+  void Remove(pointer t) {
+    DCHECK(Contains(t));
+    size_type h = GetPositionOf(t);
+    SetPositionOf(t, IntrusiveHeapLink::kNotMember);
+    if (h == heap().size() - 1) {
+      // Fast path for removing from back of heap.
+      heap().pop_back();
+      return;
+    }
+    // Move the element from the back of the heap to overwrite 't'.
+    pointer& elem = heap()[h];
+    elem = heap().back();
+    SetPositionOf(elem, h);  // Element has moved, so update its link.
+    heap().pop_back();
+    Adjust(elem);  // Restore the heap invariant.
+  }
+
+  void Clear() { heap().clear(); }
+
+  bool Contains(const_pointer t) const {
+    size_type h = GetPositionOf(t);
+    return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
+           heap()[h] == t;
+  }
+
+  void reserve(size_type n) { heap().reserve(n); }
+
+  size_type capacity() const { return heap().capacity(); }
+
+  allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
+
+ private:
+  typedef std::vector<pointer, allocator_type> heap_type;
+
+  // Empty base class optimization for pointer_compare and link_access.
+  // The heap_ data member retains a copy of the allocator, so it is not
+  // stored explicitly.
+  struct Rep : pointer_compare_type, link_access_type {
+    explicit Rep(const pointer_compare_type& cmp,
+                 const link_access_type& link_access,
+                 const allocator_type& alloc)
+        : pointer_compare_type(cmp),
+          link_access_type(link_access),
+          heap_(alloc) {}
+    heap_type heap_;  // NOLINT
+  };
+
+  const pointer_compare_type& compare() const { return rep_; }
+
+  const link_access_type& link_access() const { return rep_; }
+
+  const heap_type& heap() const { return rep_.heap_; }
+  heap_type& heap() { return rep_.heap_; }
+
+  size_type GetPositionOf(const_pointer t) const {
+    return link_access().Get(t).get();
+  }
+
+  void SetPositionOf(pointer t, size_type pos) const {
+    return link_access().Set(t, IntrusiveHeapLink(pos));
+  }
+
+  void FixHeapUp(pointer t) {
+    size_type h = GetPositionOf(t);
+    while (h != 0) {
+      size_type parent = (h - 1) >> 1;
+      if (compare()(heap()[parent], t)) {
+        break;
+      }
+      heap()[h] = heap()[parent];
+      SetPositionOf(heap()[h], h);
+      h = parent;
+    }
+    heap()[h] = t;
+    SetPositionOf(t, h);
+  }
+
+  void FixHeapDown(pointer t) {
+    size_type h = GetPositionOf(t);
+    for (;;) {
+      size_type kid = (h << 1) + 1;
+      if (kid >= heap().size()) {
+        break;
+      }
+      if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
+        ++kid;
+      }
+      if (compare()(t, heap()[kid])) {
+        break;
+      }
+      heap()[h] = heap()[kid];
+      SetPositionOf(heap()[h], h);
+      h = kid;
+    }
+
+    heap()[h] = t;
+    SetPositionOf(t, h);
+  }
+
+  Rep rep_;
+};
+
+string CoreDeviceLabel(int core) {
+  return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
+}
+
+// Creates a unique node name with a particular prefix.
+string UniqueNodeName(const StringPiece prefix, Graph* graph) {
+  return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
+}
+
+Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
+                                        const string& target_device_type,
+                                        Node* node) {
+  TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
+  TF_RET_CHECK(device.has_id);
+  TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
+
+  // Store the device instance as an attr on the Node.
+  TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
+
+  // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
+  // compiled protos in the resource manager.
+  device.type = target_device_type;
+  device.id = 0;
+
+  node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
+  return Status::OK();
+}
+
+// Iterate over the nodes in the original graph and find all the TPUReplicate
+// nodes, and all the nodes that are part of outside_compilation clusters.
+Status FindTaggedNodes(
+    Graph* graph, std::vector<Node*>* replicate_nodes,
+    std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
+        outside_compilation_nodes,
+    std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
+  for (Node* node : graph->op_nodes()) {
+    if (node->type_string() == "_TPUReplicate") {
+      replicate_nodes->push_back(node);
+      const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
+      if (cluster_attr == nullptr) {
+        return errors::Internal("TPUReplicate node ", node->name(), " has no ",
+                                kTPUReplicateAttr, " attr.");
+      } else {
+        const string& cluster = cluster_attr->s();
+        if (cluster.empty()) {
+          return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
+                                  node->name(), " has no string value.");
+        }
+        if (outside_compilation_nodes->find(cluster) !=
+            outside_compilation_nodes->end()) {
+          return errors::Internal(
+              "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
+              " attr value '", cluster,
+              "' which is a duplicate of another TPUReplicate node in the "
+              "graph.");
+        }
+        (*outside_compilation_nodes)[cluster] =
+            DistributedTPURewritePass::OutsideCompilationNodeMap();
+        (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
+      }
+    }
+  }
+  for (Node* node : graph->op_nodes()) {
+    if (node->type_string() != "_TPUReplicate") {
+      const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
+      const AttrValue* outside_compilation_attr =
+          node->attrs().Find(kOutsideCompilationAttr);
+      if (cluster_attr == nullptr) {
+        if (outside_compilation_attr != nullptr) {
+          return errors::Internal("Node ", node->name(), " has ",
+                                  kOutsideCompilationAttr, " attr but no ",
+                                  kTPUReplicateAttr, " attr.");
+        }
+      } else {
+        const string& cluster = cluster_attr->s();
+        if (cluster.empty()) {
+          return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
+                                  node->name(), " has no string value.");
+        }
+        const auto iter = outside_compilation_nodes->find(cluster);
+        if (iter == outside_compilation_nodes->end()) {
+          return errors::Internal(
+              "Attr ", kTPUReplicateAttr, " on node ", node->name(),
+              " does not correspond to a TPUReplicate node.");
+        }
+        if (outside_compilation_attr == nullptr) {
+          return errors::Internal("Node ", node->name(), " has ",
+                                  kTPUReplicateAttr, " attr but no ",
+                                  kOutsideCompilationAttr, " attr.");
+        }
+        const string& oc_cluster = outside_compilation_attr->s();
+        if (oc_cluster.empty()) {
+          return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
+                                  node->name(), " has no string value.");
+        }
+
+        // Outside compilation cluster at head and tail of TPU computation has
+        // already been moved to host and is already replicated. As so, do not
+        // replicate outside compilation nodes with replica id attribute.
+        int replica_id;
+        if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
+          const AttrValue* head_attr =
+              node->attrs().Find("_xla_only_arg_or_oc_input");
+          const AttrValue* tail_attr =
+              node->attrs().Find("_xla_only_ret_or_oc_output");
+          if (((head_attr != nullptr) && (head_attr->b())) ||
+              ((tail_attr != nullptr) && (tail_attr->b()))) {
+            // This is safe as this has the same keys as
+            // outside_compilation_nodes which we already know has this key.
+            (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
+          }
+          continue;
+        }
+        iter->second[oc_cluster].push_back(node);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+// Helper class to spread TPU computation arguments and return values
+// across cores.
+// If all shapes are fully defined, balance by their size.
+// If some of them are not fully defined, the undefined shapes size will
+// be estimated with the average size of the fully defined ones.
+// If none are defined, fall back to round-robin.
+class TensorDevicePlacer {
+ public:
+  // Creates a TensorDevicePlacer object to distribute arguments or
+  // return values to a set of num_devices devices, where the types and
+  // the inferred shapes of the inputs (arguments or return values) are
+  // passed in types and shapes.
+  TensorDevicePlacer(int64 num_devices, const DataTypeVector& types,
+                     const std::vector<InferredShape>& shapes)
+      : index_nodes_(num_devices), sizes_(types.size()) {
+    int64 total_size = 0;
+    int64 num_defined = 0;
+    for (int64 i = 0; i < types.size(); ++i) {
+      sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
+      if (sizes_[i] >= 0) {
+        total_size += sizes_[i];
+        ++num_defined;
+      }
+    }
+    // If a shape is undefined, select a size for it which is the average
+    // of the defined shapes. If no shapes are defined, assign 1 so that we
+    // get round-robin behavior.
+    int64 undefined_shape_size =
+        (num_defined > 0) ? total_size / num_defined : 1;
+    for (int64 i = 0; i < sizes_.size(); ++i) {
+      if (sizes_[i] < 0) {
+        sizes_[i] = undefined_shape_size;
+      }
+    }
+
+    for (int64 i = 0; i < num_devices; ++i) {
+      heap_.Push(&index_nodes_[i]);
+    }
+  }
+
+  // Reports that the argument/return-value at index has been assigned
+  // by the user to a given device.
+  void ReportDeviceAssigned(int64 device, int64 index) {
+    DeviceNode* node = &index_nodes_.at(device);
+    node->size += sizes_.at(index);
+    heap_.Adjust(node);
+  }
+
+  // Retrieves the device at which the argument/return-value at index
+  // should be assigned to.
+  int64 RetrieveAssignment(int64 index) {
+    DeviceNode* node = heap_.top();
+    int64 device = node - index_nodes_.data();
+    node->size += sizes_.at(index);
+    heap_.Adjust(node);
+    return device;
+  }
+
+ private:
+  struct DeviceNode {
+    struct Compare {
+      // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
+      // infrastructure.
+      bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
+        return lhs->size < rhs->size;
+      }
+    };
+
+    IntrusiveHeapLink heap;
+    int64 size = 0;
+  };
+
+  static int64 GetInferredShapeSize(const InferredShape& ishape,
+                                    DataType dtype) {
+    return ishape.shape.IsFullyDefined()
+               ? ishape.shape.num_elements() * DataTypeSize(dtype)
+               : -1;
+  }
+
+  std::vector<DeviceNode> index_nodes_;
+  IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
+  std::vector<int64> sizes_;
+};
+
+Status ValidateCoreNumber(int64 core, int64 num_cores_per_replica) {
+  if (core < 0 || core >= num_cores_per_replica) {
+    return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
+                                               ". The valid core IDs are [0..",
+                                               num_cores_per_replica, ")");
+  }
+  return Status::OK();
+}
+
+Status FindHostComputeKeyPlaceholderNodes(
+    const Graph* graph, const std::vector<Node*>& replicate_nodes,
+    std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
+  host_compute_key_placeholder_map->clear();
+  for (const auto node : replicate_nodes) {
+    (*host_compute_key_placeholder_map)[node->name()] = nullptr;
+  }
+
+  for (Node* node : graph->op_nodes()) {
+    if (node->type_string() == "Placeholder" &&
+        str_util::EndsWith(node->name(), "_key_placeholder")) {
+      const AttrValue* call_node_attr =
+          node->attrs().Find("_host_compute_call_node");
+      if (call_node_attr != nullptr) {
+        auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
+        if (iter == host_compute_key_placeholder_map->end()) {
+          return errors::InvalidArgument(
+              "Node ", node->name(), " has _host_compute_call_node attribute '",
+              call_node_attr->s(), "' that doesn't correspond to a call node");
+        }
+        if (iter->second != nullptr) {
+          return errors::InvalidArgument(
+              "Key placeholder node ", iter->second->name(), " for call node ",
+              call_node_attr->s(), " previously found as ",
+              iter->second->name());
+        }
+        iter->second = node;
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
+Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
+  Node* old_node = *node;
+  // We want to replace the node with an identity node with the same name.
+  const string& node_name = old_node->name();
+
+  // Create identity node.
+  TF_ASSIGN_OR_RETURN(
+      Node * id_node,
+      BuildIdentityNode(graph, node_name, DT_STRING,
+                        /*input=*/nullptr, /*requested_device=*/""));
+
+  // No incoming edges are copied as a new one will be added from compile node
+  // to id_node.
+
+  // Copy outgoing edges to the id node.
+  std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+                                     old_node->out_edges().end());
+  for (const Edge* edge : out_edges) {
+    Node* dst = edge->dst();
+    int src_output = edge->src_output();
+    int dst_input = edge->dst_input();
+
+    if (src_output == Graph::kControlSlot) {
+      graph->AddControlEdge(id_node, dst);
+    } else {
+      graph->AddEdge(id_node, src_output, dst, dst_input);
+    }
+    graph->RemoveEdge(edge);
+  }
+  graph->RemoveNode(old_node);
+
+  *node = id_node;
+  return Status::OK();
+}
+
+Status FillPaddingMap(
+    const Node& replicate_node,
+    protobuf::RepeatedPtrField<tpu::PaddingMap>* padding_maps) {
+  std::vector<string> padding_map_strs;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "padding_map", &padding_map_strs));
+  padding_maps->Reserve(padding_map_strs.size());
+  for (const string& padding_map_str : padding_map_strs) {
+    tpu::PaddingMap* padding_map = padding_maps->Add();
+    if (!padding_map->ParseFromString(padding_map_str)) {
+      return errors::InvalidArgument(
+          "Malformed padding_map serialized string: ", padding_map_str);
+    }
+  }
+  return Status::OK();
+}
+
+Status GetStepMarkerLocation(const Node& replicate_node,
+                             xla::DebugOptions::StepMarkerLocation* location) {
+  string step_marker_location_attr;
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
+                                 &step_marker_location_attr));
+  if (step_marker_location_attr.empty()) {
+    *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
+  } else {
+    if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
+                                                     location)) {
+      return errors::InvalidArgument("Malformed step_marker_location: ",
+                                     step_marker_location_attr);
+    }
+  }
+  return Status::OK();
+}
+
+// Extracts a map of dimension and number of splits for tiled input from xla
+// sharding attribute.
+Status GetDimensionIndicesAndNumSplitsFromSharding(
+    const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
+  for (int dim_index = 0;
+       dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) {
+    if (sharding.tile_assignment_dimensions(dim_index) > 1) {
+      split_dimension_map->emplace(
+          dim_index, sharding.tile_assignment_dimensions(dim_index));
+    }
+  }
+
+  if (split_dimension_map->empty()) {
+    return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
+                                   sharding.DebugString());
+  }
+  return Status::OK();
+}
+
+// Updates contents of the function with `function_name` in function library
+// definition `flib_def` to `new_graph`. This is required when graph
+// transformation happens inside a function call body.
+Status UpdateFunctionLibDefinition(const Graph& new_graph,
+                                   const std::string& function_name,
+                                   FunctionLibraryDefinition* flib_def) {
+  FunctionDef graph_fdef;
+  TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
+  TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
+  return Status::OK();
+}
+
+struct NodeOut {
+  Node* node;
+  int index;
+};
+
+struct ShardedInputIndex {
+  int replica_id;
+  int argument_index;
+
+  bool operator<(const ShardedInputIndex& rhs) const {
+    return std::tie(replica_id, argument_index) <
+           std::tie(rhs.replica_id, rhs.argument_index);
+  }
+};
+
+struct ShardedInputInfo {
+  // Split node that would be connected to tiled input Node.
+  Node* split_node;
+  // List of splits nodes and output index of the split node from which sharded
+  // input will be connected to the TPUExecute node. The inputs are ordered by
+  // logical core ids.
+  std::vector<NodeOut> sharded_inputs;
+};
+
+// Adds split node and split dimension node to graph for sharding tiled inputs.
+// |graph| owns the returned Node* instance.
+xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim,
+                                     int orig_src_output, DataType dtype,
+                                     absl::string_view name_prefix,
+                                     Node* control_predecessor, Node* orig_src,
+                                     Graph* graph) {
+  const std::string input_assigned_device = orig_src->assigned_device_name();
+
+  // Add a split dimension node.
+  NodeDef split_dim_def;
+  split_dim_def.set_name(
+      graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
+  split_dim_def.set_op("Const");
+  split_dim_def.set_device(input_assigned_device);
+  AddNodeAttr("dtype", DT_INT32, &split_dim_def);
+  TensorProto tensor_proto;
+  tensor_proto.set_dtype(DT_INT32);
+  tensor_proto.add_int_val(dim);
+  TensorShape shape({});
+  shape.AsProto(tensor_proto.mutable_tensor_shape());
+  AddNodeAttr("value", tensor_proto, &split_dim_def);
+  Status s;
+  Node* split_dim_node = graph->AddNode(split_dim_def, &s);
+  TF_RETURN_IF_ERROR(s);
+  // Add a split node.
+  NodeDef split_def;
+  split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
+  split_def.set_op("Split");
+  split_def.set_device(input_assigned_device);
+  AddNodeAttr("num_split", num_splits, &split_def);
+  AddNodeAttr("T", dtype, &split_def);
+  split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
+  split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output));
+  Node* split_node = graph->AddNode(split_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  graph->AddEdge(split_dim_node, 0, split_node, 0);
+  graph->AddEdge(orig_src, orig_src_output, split_node, 1);
+
+  // Add a control dependency from `control_predecessor` to newly created
+  // constant node. This ensures that newly added split/split dim
+  // nodes are placed inside correct while loop frames when TPUExecute
+  // node is inside a host training loop.
+  graph->AddControlEdge(control_predecessor, split_dim_node);
+
+  return split_node;
+}
+
+// Creates a set of splits nodes that shards tiled input node in graph.
+xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
+    const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
+    int replica_id, int orig_src_output, Node* orig_src,
+    Node* control_predecessor, Graph* graph,
+    std::map<ShardedInputIndex, ShardedInputInfo>*
+        arg_index_to_sharded_input_map) {
+  ShardedInputIndex input_index{replica_id, orig_arg_num};
+  auto iter = arg_index_to_sharded_input_map->find(input_index);
+  if (iter != arg_index_to_sharded_input_map->end()) {
+    return iter->second;
+  }
+  // Maps input dimension and number of splits with which the
+  // dimension sharded.
+  std::map<int, int> split_dimension_map;
+  TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
+      sharding, &split_dimension_map));
+  TF_RET_CHECK(!split_dimension_map.empty())
+      << "Unnecessary sharding attribute found.";
+
+  // For v1 while loop, nodes inside the loop body must either
+  //  1) Have data edges from while loop input node.
+  //  or
+  //  2) Have direct control dependency from while loop input control
+  //     node.
+  //
+  // As so, if we are adding Split node inside, while loop body,
+  // we must manually add a control dependency to a node inside
+  // a while loop (i.e. `control_predecessor`) to constant nodes
+  // without data in-edges to make sure that added split nodes
+  // have correct frame name. Else, placer will complain when
+  // `BuildControlFlow()` is invoked.
+
+  auto sharding_it = split_dimension_map.begin();
+  std::queue<Node*> split_nodes_for_dimension;
+  int split_dimension = sharding_it->first;
+  int num_split = sharding_it->second;
+
+  // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
+  // are created such that input data is sharded in row major order.
+  // Split nodes at ith depth from the original input node represent nodes
+  // that split the input data at ith dimension.
+  TF_ASSIGN_OR_RETURN(
+      Node * root_split_node,
+      CreateSplitNode(num_split, split_dimension, orig_src_output, dtype,
+                      absl::StrCat("sharded_input/replica_", replica_id,
+                                   "_dim_", split_dimension),
+                      control_predecessor, orig_src, graph));
+  sharding_it++;
+
+  split_nodes_for_dimension.emplace(root_split_node);
+
+  while (sharding_it != split_dimension_map.end()) {
+    split_dimension = sharding_it->first;
+    num_split = sharding_it->second;
+    int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
+    for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
+      Node* input_split_node = split_nodes_for_dimension.front();
+      split_nodes_for_dimension.pop();
+      for (int src_output_index = 0;
+           src_output_index < input_split_node->num_outputs();
+           ++src_output_index) {
+        TF_ASSIGN_OR_RETURN(
+            Node * split_node,
+            CreateSplitNode(num_split, split_dimension, src_output_index, dtype,
+                            absl::StrCat("sharded_input/replica_", replica_id,
+                                         "_dim_", split_dimension),
+                            control_predecessor, input_split_node, graph));
+        split_nodes_for_dimension.emplace(split_node);
+      }
+    }
+    sharding_it++;
+  }
+
+  // `split_nodes_for_dimension` now includes final split nodes
+  // from which sharded data will be fed into TPUExcute nodes -- sorted by
+  // row major order.
+  std::vector<NodeOut> sharded_inputs_list;
+  sharded_inputs_list.reserve(split_nodes_for_dimension.size());
+  while (!split_nodes_for_dimension.empty()) {
+    Node* split_node = split_nodes_for_dimension.front();
+    split_nodes_for_dimension.pop();
+    int num_splits;
+    TF_RETURN_IF_ERROR(
+        GetNodeAttr(split_node->def(), "num_split", &num_splits));
+    for (int out_index = 0; out_index < num_splits; ++out_index) {
+      sharded_inputs_list.emplace_back(NodeOut{split_node, out_index});
+    }
+  }
+
+  ShardedInputInfo sharded_input_info{root_split_node,
+                                      std::move(sharded_inputs_list)};
+  (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
+  return sharded_input_info;
+}
+
+// Creates a concat node to be used for aggregating sharded retvals across
+// logical cores.
+xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
+                                      absl::string_view name_prefix,
+                                      const std::vector<NodeOut>& inputs,
+                                      Graph* graph, absl::string_view device) {
+  // Add a Concat dim node.
+  NodeDef concat_dim_def;
+  concat_dim_def.set_name(
+      graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
+  concat_dim_def.set_op("Const");
+  AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
+  concat_dim_def.set_device(std::string(device));
+  TensorProto tensor_proto;
+  tensor_proto.set_dtype(DT_INT32);
+  tensor_proto.add_int_val(dim);
+  TensorShape shape({});
+  shape.AsProto(tensor_proto.mutable_tensor_shape());
+  AddNodeAttr("value", tensor_proto, &concat_dim_def);
+  Status s;
+  Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  // Add a Concat node.
+  NodeDef concat_def;
+  concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
+  concat_def.set_op("Concat");
+  AddNodeAttr("N", num_splits, &concat_def);
+  AddNodeAttr("T", dtype, &concat_def);
+  concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
+  concat_def.set_device(std::string(device));
+  for (const auto& i : inputs) {
+    concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
+  }
+  Node* concat_node = graph->AddNode(concat_def, &s);
+  TF_RETURN_IF_ERROR(s);
+
+  graph->AddEdge(concat_dim_node, 0, concat_node, 0);
+
+  // 0th input to concat node is a concat dim node. So we start from 1st input
+  // and add all input edges.
+  int dst_input = 1;
+  for (const auto& i : inputs) {
+    graph->AddEdge(i.node, i.index, concat_node, dst_input);
+    ++dst_input;
+  }
+  return concat_node;
+}
+
+// Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
+// nodes into a single output. Sharded outputs are concatenated along row major
+// order. That is, tiled output along 0th dimension will be concatenated last.
+xla::StatusOr<Node*> CreateConcatNodesForRetval(
+    const xla::OpSharding& sharding, DataType dtype, int replica_id,
+    const std::vector<NodeOut>& orig_inputs, Graph* graph,
+    absl::string_view device) {
+  std::map<int, int> split_dimension_map;
+  TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
+      sharding, &split_dimension_map));
+
+  std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
+
+  for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
+       it++) {
+    auto dim = it->first;
+    auto num_splits = it->second;
+
+    int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
+    int input_index_to_concat_node = 0;
+
+    std::vector<NodeOut> new_concat_nodes;
+    for (int i = 0; i < num_concat_nodes; ++i) {
+      auto concat_input_it =
+          inputs_to_sharded_retval.begin() + input_index_to_concat_node;
+      std::vector<NodeOut> inputs(concat_input_it,
+                                  concat_input_it + num_splits);
+      input_index_to_concat_node += num_splits;
+
+      TF_ASSIGN_OR_RETURN(
+          Node * concat_node,
+          CreateConcatNode(
+              dim, num_splits, dtype,
+              absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
+              inputs, graph, device));
+      new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
+    }
+    inputs_to_sharded_retval = new_concat_nodes;
+  }
+
+  TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
+  return inputs_to_sharded_retval.at(0).node;
+}
+
+absl::optional<int> GetCoreIndexInSharding(const xla::OpSharding& sharding,
+                                           int64 core) {
+  absl::optional<int> output_index;
+  for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) {
+    int64 assigned_core = sharding.tile_assignment_devices(i);
+    if (assigned_core == core) {
+      output_index = i;
+      break;
+    }
+  }
+  return output_index;
+}
+
+// Set the padding ops the same devices as the original inputs. If the original
+// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
+// mode will be triggered, so we don't need to copy the data back to the host
+// to do the padding.
+Status SetPaddingNodesDevices(Graph* graph) {
+  for (Node* n : graph->op_nodes()) {
+    bool tpu_padding_attr;
+    if (n->type_string() == "Pad" &&
+        GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
+            .ok()) {
+      Node* unpadded_input;
+      TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
+
+      const string& requested_device = unpadded_input->requested_device();
+      const string& assigned_device = unpadded_input->assigned_device_name();
+      if (!requested_device.empty() || !assigned_device.empty()) {
+        // The output nodes of the original unpadded inputs include the padded
+        // inputs and real shapes of inputs, we assign those to the same device
+        // as the original inputs.
+        for (Node* out : unpadded_input->out_nodes()) {
+          if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
+                          &tpu_padding_attr)
+                  .ok()) {
+            out->set_requested_device(requested_device);
+            out->set_assigned_device_name(assigned_device);
+          }
+        }
+        // There might be a tf.shape node added before TPUCompileOp, we need to
+        // set its device as well.
+        for (Node* out : n->out_nodes()) {
+          if (n->type_string() == "Shape") {
+            out->set_requested_device(requested_device);
+            out->set_assigned_device_name(assigned_device);
+          }
+        }
+      }
+    }
+  }
+  return Status::OK();
+}
+
+const string& AssignedOrRequestedDevice(const Node* node) {
+  if (!node->assigned_device_name().empty()) {
+    return node->assigned_device_name();
+  }
+  return node->requested_device();
+}
+
+bool IsTpuDevice(const string& device_string) {
+  DeviceNameUtils::ParsedName device;
+  return DeviceNameUtils::ParseFullName(device_string, &device) &&
+         device.type == DEVICE_TPU_NODE;
+}
+
+// Returns a set of device ops can be placed on TPU. There is no strict rule of
+// thumb to decide which ops should be in the list, but empirically they are
+// mostly dummy ops like Identity-like ops or control flow related ops. However
+// people can add also add other ops like Pad to allow data stay on TPU.
+const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
+  static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
+      {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
+       "NextIteration", "Shape"});
+  return *place_on_tpu_ops;
+}
+
+// If an op satisfies the following conditions, it will be placed on the same
+// TPU device as its inputs:
+//   (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
+//   (2) The op itself has no requested or assigned devices.
+//   (3) All the data inputs of this op are placed on the same device on TPUs.
+//       There are exceptions like the NextIterations input of Switch node can
+//       be placed on CPU as it is just a boolean.
+//
+// Returns true if the node device has been changed, otherwise returns false.
+bool PlaceOpsOnTPU(Node* node) {
+  if (!AssignedOrRequestedDevice(node).empty() ||
+      !PlaceOnTPUOpList().contains(node->type_string())) {
+    return false;
+  }
+  string src_tpu_device = "";
+  Node* src_node;
+  for (const Edge* e : node->in_edges()) {
+    if (e->IsControlEdge()) {
+      continue;
+    }
+    Node* src = e->src();
+    const string& src_device = AssignedOrRequestedDevice(src);
+
+    // Make exceptions that we don't force the some inputs to place on TPUs.
+    if (node->IsSwitch() && src->IsLoopCond()) {
+      continue;
+    }
+
+    if (!IsTpuDevice(src_device) ||
+        (!src_tpu_device.empty() && src_device != src_tpu_device)) {
+      return false;
+    }
+    if (src_tpu_device.empty()) {
+      src_tpu_device = src_device;
+      src_node = src;
+    }
+  }
+  node->set_assigned_device_name(src_node->assigned_device_name());
+  node->set_requested_device(src_node->requested_device());
+  return true;
+}
+
+// Validate sharding configuration derived from XlaSharding attribute.
+// Infer the core id from the OpSharding, if necessary.
+Status ParseAndValidateSharding(const xla::OpSharding& sharding,
+                                const int num_cores_per_replica,
+                                int64* inferred_core_id,
+                                absl::optional<xla::OpSharding>* result) {
+  if (sharding.type() == xla::OpSharding::MAXIMAL) {
+    int64 core_annotation = sharding.tile_assignment_devices(0);
+    TF_RETURN_IF_ERROR(
+        ValidateCoreNumber(core_annotation, num_cores_per_replica));
+    if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
+      *inferred_core_id = core_annotation;
+      result->emplace(sharding);
+    }
+  } else {
+    if (sharding.type() == xla::OpSharding::OTHER) {
+      for (int64 core : sharding.tile_assignment_devices()) {
+        TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
+      }
+    }
+
+    if (!result->has_value()) {
+      *result = sharding;
+    } else {
+      std::string result_value_serialized;
+      std::string sharding_serialized;
+      SerializeToStringDeterministic(result->value(), &result_value_serialized);
+      SerializeToStringDeterministic(sharding, &sharding_serialized);
+
+      if (result_value_serialized != sharding_serialized) {
+        // We see different shardings, assign to core 0.
+        result->emplace(xla::sharding_builder::AssignDevice(0));
+      }
+    }
+  }
+  return Status::OK();
+}
+
+// As XlaSharding node may be followed by Cast op or an Identity op,
+// recursively walk the graph and aggregate nodes connectd to
+// |input_node| or Cast/Identity op following the |input_node|.
+void FindNodesMaybeContainingShardingInfo(const Node& input_node,
+                                          std::vector<const Node*>* nodes) {
+  if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
+    for (const Node* connected_node : input_node.out_nodes())
+      FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
+  }
+  nodes->emplace_back(&input_node);
+}
+
+// Parse sharding configuration from |node| or it's adjacent nodes.
+// XlaSharding configuration may be derived from
+//   a) Connected Identity op node.
+//   b) Connected Cast op node.
+xla::StatusOr<absl::optional<xla::OpSharding>>
+ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
+                                   const Node& node) {
+  // If |node| has `device` attribute or is a XlaSharding op,
+  // return the parsed OpSharding.
+  TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
+                      ParseShardingFromDevice(node, num_cores_per_replica));
+  if (sharding.has_value()) return sharding;
+
+  // XlaShardingOp may be followed by an identity or followed by identity
+  // and a Cast op.
+  std::vector<const Node*> potential_nodes_with_input_sharding;
+  FindNodesMaybeContainingShardingInfo(node,
+                                       &potential_nodes_with_input_sharding);
+  for (const Node* maybe_node_with_sharding_info :
+       potential_nodes_with_input_sharding) {
+    if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
+
+    TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding_config,
+                        ParseShardingFromDevice(*maybe_node_with_sharding_info,
+                                                num_cores_per_replica));
+    if (sharding_config.has_value()) return sharding_config;
+  }
+  return sharding;
+}
+
+// Walk the graph from an argument node to find OpSharding configuration
+// from its neighbor nodes. Sharding configuration may be inferred from
+//  1) Parsing XlaSharding attribute from neighboring node.
+//  2) If argument node is a resource, then by parsing adjacent nodes
+//     of the connected ReadVariable op.
+Status ParseAndValidateShardingFromNeighbors(
+    const int num_cores_per_replica, const std::string& arg_node_name,
+    const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
+    absl::optional<xla::OpSharding>* result) {
+  if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
+    *is_fast_mem = true;
+    VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
+            << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
+  }
+
+  // XlaSharding information may be encoded on node directly connected to the
+  // argument node.
+  TF_ASSIGN_OR_RETURN(
+      absl::optional<xla::OpSharding> sharding,
+      ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
+  if (sharding.has_value()) {
+    TF_RETURN_IF_ERROR(ParseAndValidateSharding(
+        *sharding, num_cores_per_replica, inferred_core_id, result));
+    return Status::OK();
+  }
+
+  // When we use variable in TPU computation, we always have a
+  // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
+  // the users of ReadVariableOp for potential sharding configuration.
+  if (neighbor_node.type_string() == "ReadVariableOp") {
+    for (const Edge* e : neighbor_node.out_edges()) {
+      if (e->IsControlEdge()) continue;
+
+      if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
+        *is_fast_mem = true;
+        VLOG(2) << "place " << arg_node_name << " on fast memory because "
+                << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
+      }
+
+      TF_ASSIGN_OR_RETURN(
+          absl::optional<xla::OpSharding> sharding,
+          ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
+      if (sharding.has_value()) {
+        TF_RETURN_IF_ERROR(ParseAndValidateSharding(
+            *sharding, num_cores_per_replica, inferred_core_id, result));
+        return Status::OK();
+      }
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+// Inputs:
+//   replication_spec_string: the device to which the TPUReplicate node was
+//     assigned.
+//   device_set: the set of TF devices.
+// Outputs:
+//   tpu_compilation_device: the name of the TPU compilation device.
+//   num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
+//     have the same number of TPU devices.
+//   tpu_devices: the TPU devices, indexed by [task][device].
+static Status GetTPUDeviceNames(
+    const string& replication_spec_string, const DeviceSet& device_set,
+    string* tpu_compilation_device, int* num_tpus_per_task,
+    std::vector<std::vector<Device*>>* tpu_devices) {
+  // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
+  // the tpu_system device, which we replace by the cpu device. We do this
+  // replacement because we want to place the TPUCompileOp (and the compile
+  // assert op) explicitly on cpu devices on the same job as the tpu_system
+  // device.
+  DeviceNameUtils::ParsedName replication_spec;
+  Device* replication_device;
+  TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
+      replication_spec_string, device_set, &replication_spec,
+      &replication_device));
+  *tpu_compilation_device =
+      str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
+                              DEVICE_CPU, /*replace_all=*/true);
+
+  // Finds the set of TPU devices attached to the tasks in the job.
+  TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
+      replication_spec, device_set, num_tpus_per_task, tpu_devices));
+
+  return Status::OK();
+}
+
+// Parses the topology attribute of TPUReplicate, and populates *topology with
+// a physical mesh coordinate to (task, device) mapping.
+static Status ParseTopologyAttr(const string& topology_attr,
+                                const tpu::TpuTopologyExternal& tpu_topology,
+                                int num_tasks, int num_tpus_per_task,
+                                xla::Array4D<std::pair<int, int>>* topology) {
+  static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
+  tpu::TopologyProto proto;
+  proto.ParseFromString(topology_attr);
+  if (proto.mesh_shape_size() != kTPUTopologyRank) {
+    return errors::InvalidArgument("TPU topology must be rank ",
+                                   kTPUTopologyRank);
+  }
+  if (proto.num_tasks() != num_tasks) {
+    return errors::InvalidArgument("Mismatched number of TPU tasks");
+  }
+  if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
+    return errors::InvalidArgument("Mismatched number of TPUs per task (",
+                                   proto.num_tpu_devices_per_task(),
+                                   " != ", num_tpus_per_task, ").");
+  }
+  if (proto.device_coordinates_size() !=
+      num_tasks * num_tpus_per_task * kTPUTopologyRank) {
+    return errors::InvalidArgument(
+        "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
+        kTPUTopologyRank, "; got ", proto.device_coordinates_size());
+  }
+
+  int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
+  *topology = xla::Array4D<std::pair<int, int>>(
+      tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
+      tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
+  int pos = 0;
+  for (int task = 0; task < num_tasks; ++task) {
+    for (int device = 0; device < num_tpus_per_task; ++device) {
+      int32 x = proto.device_coordinates(pos++);
+      int32 y = proto.device_coordinates(pos++);
+      int32 z = proto.device_coordinates(pos++);
+      int32 core = proto.device_coordinates(pos++);
+
+      if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
+          core >= devices_per_chip) {
+        return errors::InvalidArgument(
+            "Mesh coordinates (", x, ",", y, ",", z, ",", core,
+            ") are not valid for the current TPU topology");
+      }
+      if ((*topology)(x, y, z, core).first != -1) {
+        return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
+                                       ",", z, ",", core, ") in TPU topology");
+      }
+      (*topology)(x, y, z, core) = {task, device};
+    }
+  }
+  return Status::OK();
+}
+
+// Parses the value of the device_assignment attribute to TPUReplicate.
+// Populates *device_assignment; *device_assignment must be a 2D array with
+// shape (num_replicas, num_cores_per_replica).
+static Status ParseDeviceAssignmentAttr(
+    absl::Span<const int> device_assignment_attr,
+    const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
+    int num_cores_per_replica,
+    xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
+  static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
+
+  const int64 device_assignment_attr_size =
+      num_replicas * num_cores_per_replica * kTPUTopologyRank;
+  if (device_assignment_attr.size() != device_assignment_attr_size) {
+    return errors::InvalidArgument(
+        "Length of device_assignment attribute must be equal to num_replicas (",
+        num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
+        ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
+  }
+  for (int core : device_assignment_attr) {
+    if (core < 0 || core >= kTPUMaxTopologySize) {
+      return errors::InvalidArgument(
+          "Invalid core number in device assignment: ", core);
+    }
+  }
+
+  *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
+      num_replicas, num_cores_per_replica);
+  int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
+  xla::Array4D<int> replica_assignment(
+      tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
+      tpu_topology.chip_bounds().z, devices_per_chip, -1);
+  int pos = 0;
+  for (int replica = 0; replica < num_replicas; ++replica) {
+    for (int logical_core = 0; logical_core < num_cores_per_replica;
+         ++logical_core) {
+      int32 x = device_assignment_attr[pos++];
+      int32 y = device_assignment_attr[pos++];
+      int32 z = device_assignment_attr[pos++];
+      int32 core = device_assignment_attr[pos++];
+
+      if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
+          core >= devices_per_chip) {
+        return errors::InvalidArgument(
+            "Mesh coordinates (", x, ",", y, ",", core,
+            ") are not valid for the current TPU topology");
+      }
+      tpu::TpuCoreLocationExternal core_location =
+          tpu_topology.Core(x, y, z, kTensorCore, core);
+
+      if (replica_assignment(x, y, z, core) != -1) {
+        return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
+                                       ",", z, ",", core,
+                                       ") in TPU device assignment");
+      }
+      replica_assignment(x, y, z, core) = replica;
+      (*device_assignment)(replica, logical_core) = core_location;
+    }
+  }
+  return Status::OK();
+}
+
+// Builds TensorFlow device assignments for the special case of a single core
+// computation that is replicated to every core in the mesh.
+// LINT.IfChange
+static Status BuildFullMeshDeviceAssignment(
+    int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
+    int num_tasks, int num_tpus_per_task,
+    std::vector<std::vector<string>>* tf_device_assignment) {
+  // Assign TensorFlow devices to replicas arbitrarily.
+  for (int i = 0; i < num_replicas; ++i) {
+    int task = i / num_tpus_per_task;
+    int device = i % num_tpus_per_task;
+    TF_RET_CHECK(task >= 0 && task < num_tasks);
+    TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
+
+    // We don't actually know which TF device corresponds to which physical
+    // device, but it doesn't matter—they're all identical.
+    (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
+  }
+  return Status::OK();
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+
+// Builds TensorFlow device assignments for a replicated computation and convert
+// device_assignment into xla_device_assignment.
+static Status BuildGeneralDeviceAssignment(
+    int num_replicas, int num_cores_per_replica,
+    const std::vector<std::vector<Device*>>& tpu_devices,
+    const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
+    const xla::Array4D<std::pair<int, int>>& topology,
+    std::vector<std::vector<string>>* tf_device_assignment,
+    std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
+  // Assign TensorFlow devices to each computation's replicas according to
+  // device_assignment and 'topology'.
+  *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
+      num_replicas, num_cores_per_replica);
+  for (int replica = 0; replica < num_replicas; ++replica) {
+    for (int computation = 0; computation < num_cores_per_replica;
+         ++computation) {
+      const tpu::TpuCoreLocationExternal& core_location =
+          device_assignment(replica, computation);
+
+      int task;
+      int device;
+      std::tie(task, device) =
+          topology(core_location.chip_coordinates().x,
+                   core_location.chip_coordinates().y,
+                   core_location.chip_coordinates().z, core_location.index());
+
+      CHECK_LT(computation, num_cores_per_replica);
+      (**xla_device_assignment)(replica, computation) = core_location.Id();
+
+      // The communication pattern between replicas will be determined later by
+      // BuildAllReduceRing.
+      TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
+      TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
+      (*tf_device_assignment)[replica].push_back(
+          tpu_devices[task][device]->name());
+    }
+  }
+  return Status::OK();
+}
+
+/*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
+    const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
+    const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
+    int num_cores_per_replica, const string& topology_attr,
+    absl::Span<const int> device_assignment_attr,
+    std::vector<std::vector<string>>* tf_device_assignment,
+    std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
+  const int num_tasks = tpu_devices.size();
+  const int num_tpu_devices = num_tasks * num_tpus_per_task;
+  VLOG(2) << "num_tasks=" << num_tasks
+          << " num_tpus_per_task=" << num_tpus_per_task;
+
+  // Checks num_replicas is sane first to avoid integer overflow.
+  if (num_replicas > num_tpu_devices) {
+#ifdef PLATFORM_CLOUD_TPU
+    return errors::InvalidArgument("Requested num_replicas=", num_replicas,
+                                   " but there are only ", num_tpu_devices,
+                                   " cores in the TPU topology.");
+#else
+    return errors::InvalidArgument("Requested num_replicas=", num_replicas,
+                                   " but there are only ", num_tpu_devices,
+                                   " cores in the TPU topology.");
+#endif
+  }
+  if (num_replicas * num_cores_per_replica > num_tpu_devices) {
+    return errors::InvalidArgument(
+        "Requested num_replicas=", num_replicas, " with ",
+        num_cores_per_replica, " cores per replica, but there are only ",
+        num_tpu_devices, " cores in the TPU topology");
+  }
+
+  tf_device_assignment->clear();
+  tf_device_assignment->resize(num_replicas);
+
+  // Special case: we allow the user to omit the topology and device assignment
+  // information in two cases:
+  // * there is only one replica and one core per replica. In this case, we
+  //   don't need to know topology information because we don't communicate with
+  //   other cores.
+  // * the number of replicas is equal to the number of cores in the slice. In
+  //   this case, all cores are running the same program so we don't need to
+  //   know which is which.
+  if (topology_attr.empty()) {
+    // LINT.IfChange
+    if (num_replicas != 1 && num_replicas != num_tpu_devices) {
+      return errors::InvalidArgument(
+          "TPUReplicate asked to create ", num_replicas,
+          " replicas, but the number of cores in the TPU topology is ",
+          num_tpu_devices,
+          " and no TPU device assignment was supplied. "
+          "A TPU device assignment is required if the number of replicas is "
+          "not 1 or the number of cores in the topology (",
+          num_tpu_devices, ")");
+    }
+
+    if (num_cores_per_replica != 1) {
+      return errors::InvalidArgument(
+          "A TPU topology must be provided if num_cores_per_replica != 1");
+    }
+
+    if (!device_assignment_attr.empty()) {
+      return errors::InvalidArgument(
+          "A TPU topology must be provided if device_assignment_attr is "
+          "non-empty");
+    }
+
+    // If there is only one replica, assign the Tensorflow computation to task 0
+    // device 0, and leave the XLA device assignment empty. We don't know which
+    // core this is in the TPU topology, but it doesn't matter—we don't need to
+    // communicate with any other cores.
+    if (num_replicas == 1) {
+      (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
+      return Status::OK();
+    }
+
+    // Otherwise, num_replicas is equal to the number of cores, and we build a
+    // device assignment that covers the entire mesh. We do not need to know
+    // the topology to do so because all cores are identical.
+    return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
+                                         num_tpus_per_task,
+                                         tf_device_assignment);
+    // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
+  }
+
+  // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
+  xla::Array4D<std::pair<int, int>> topology;
+  TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
+                                       num_tpus_per_task, &topology));
+
+  // Array that maps logical (replica, core) pairs to physical mesh coordinates.
+  xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
+  TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
+      device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
+      &device_assignment));
+
+  return BuildGeneralDeviceAssignment(
+      num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
+      topology, tf_device_assignment, xla_device_assignment);
+}
+
+Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
+    const NameAttrList& function, FunctionLibraryRuntime* flr,
+    Graph* computation, DataTypeVector* arg_types,
+    DataTypeVector* retval_types) {
+  FunctionLibraryRuntime::Handle handle;
+
+  TF_RETURN_IF_ERROR(
+      flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
+
+  const FunctionBody* fbody = flr->GetFunctionBody(handle);
+
+  CopyGraph(*fbody->graph, computation);
+  *arg_types = fbody->arg_types;
+  *retval_types = fbody->ret_types;
+  return Status::OK();
+}
+
+// Grab the InferredShape corresponding to an edge input.
+static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
+                           const InferredShape** info) {
+  auto it = shape_info.find(edge.src()->name());
+  if (it == shape_info.end()) {
+    return errors::InvalidArgument(
+        "Input to replicated TPU computation is missing InferredShape: ",
+        edge.src()->name());
+  }
+  TF_RET_CHECK(it->second.size() > edge.src_output());
+  *info = &it->second[edge.src_output()];
+  return Status::OK();
+}
+
+Status DistributedTPURewritePass::GetArgAndRetvalShapes(
+    const GraphShapeInfo& shape_info, const Node& node,
+    const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
+    std::vector<InferredShape>* retval_shapes) {
+  std::vector<const Edge*> input_edges;
+  TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
+
+  // If any replica's arg shape is unknown, we will mark the computation's arg
+  // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
+  // runtime error.
+  std::vector<bool> any_replica_shape_unknown(
+      params_info.NumInputsToEachReplica());
+  arg_shapes->clear();
+  arg_shapes->resize(params_info.NumInputsToEachReplica());
+  TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
+  // Determines the shapes of the per-replica arguments and checks that all
+  // replicas have identical shapes.
+  int64 edge_pos = 0;
+  auto check_shape = [&](int input_index) -> Status {
+    const InferredShape* info;
+    TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
+    ++edge_pos;
+
+    if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
+        (info->handle_type != DT_INVALID &&
+         !info->handle_shape.IsFullyDefined())) {
+      any_replica_shape_unknown[input_index] = true;
+    }
+    xla::StatusOr<InferredShape> status =
+        MergeInferredShapes((*arg_shapes)[input_index], *info);
+    if (!status.ok()) {
+      return errors::InvalidArgument(
+          "Mismatched shapes for input ", input_index, ": ",
+          (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
+          info->shape.DebugString());
+    }
+    (*arg_shapes)[input_index] = status.ValueOrDie();
+    return Status::OK();
+  };
+
+  for (int64 i = 0; i < params_info.NumReplicas(); ++i) {
+    for (int64 j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
+      TF_RETURN_IF_ERROR(check_shape(j));
+    }
+  }
+
+  for (int64 i = 0; i < params_info.NumDistributedArgs(); ++i) {
+    TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
+  }
+
+  for (int64 i = 0;
+       i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
+       ++i) {
+    if (any_replica_shape_unknown[i]) {
+      (*arg_shapes)[i].shape = PartialTensorShape();
+      (*arg_shapes)[i].handle_shape = PartialTensorShape();
+    }
+  }
+
+  // Determines the shape of the broadcast arguments.
+  for (int64 i = 0; i < params_info.NumBroadcastArgs(); ++i) {
+    TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
+    const InferredShape* info;
+    TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
+    (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
+                  params_info.NumDistributedArgs()]
+        .shape = info->shape;
+    ++edge_pos;
+  }
+
+  // Determines the handle shape and handle type of the resource variable
+  // arguments.
+  for (int64 i = 0; i < params_info.NumVariables(); ++i) {
+    TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
+    const InferredShape* info;
+    TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
+    InferredShape& arg_shape =
+        (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
+                      params_info.NumDistributedArgs() +
+                      params_info.NumBroadcastArgs()];
+    arg_shape.shape = TensorShape();  // Variables are always scalars.
+    arg_shape.handle_shape = info->handle_shape;
+    arg_shape.handle_type = info->handle_type;
+    TF_RET_CHECK(arg_shape.handle_type != DT_INVALID);
+    ++edge_pos;
+  }
+
+  // Determines the shape of the guaranteed constants.
+  // TODO(vinuraja): Can be removed because they are not required for any
+  // calculations. Leaving them here for symmetry with other structures like
+  // arg_types, arg_sharding, etc.
+  for (int64 i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
+    TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
+    const InferredShape* info;
+    TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
+    (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
+                  params_info.NumDistributedArgs() +
+                  params_info.NumBroadcastArgs() + params_info.NumVariables()]
+        .shape = info->shape;
+    ++edge_pos;
+  }
+
+  // Extract the return value shapes.
+  auto it = shape_info.find(node.name());
+  retval_shapes->clear();
+  if (it != shape_info.end()) {
+    TF_RET_CHECK(it->second.size() >= node.num_outputs());
+    retval_shapes->resize(node.num_outputs());
+    for (int i = 0; i < node.num_outputs(); ++i) {
+      (*retval_shapes)[i].shape = it->second[i].shape;
+    }
+  } else if (node.num_outputs() > 0) {
+    return errors::InvalidArgument(
+        "Replicated TPU computation is missing InferredShape: ",
+        FormatNodeForError(node));
+  }
+  return Status::OK();
+}
+
+// Verifies that all nodes have legal sharding.
+static Status ValidateCoreNumbers(const Graph& graph,
+                                  int num_cores_per_replica) {
+  for (Node* n : graph.nodes()) {
+    TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
+                        ParseShardingFromDevice(*n, num_cores_per_replica));
+  }
+  return Status::OK();
+}
+
+static Status InferXlaShardingFromNeighbors(
+    const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
+    CachedFunctionHandles* cached_function_handles,
+    absl::optional<xla::OpSharding>* output_sharding, bool* is_fast_mem) {
+  int64 core = -1;
+  absl::optional<xla::OpSharding> result;
+  // We assume the variable has been allocated on fast memory if any consuming
+  // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
+  // compiler.
+  *is_fast_mem = false;
+  for (const Edge* edge : n.out_edges()) {
+    if (edge->IsControlEdge()) continue;
+
+    TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
+        num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
+        &result));
+
+    if (!flr) continue;
+
+    // The nodes deciding this arg's device assignment might be in
+    // FunctionDef. Instantiate FunctionDefs associated with this node
+    // and check nodes using this arg.
+    std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
+        [&](const Edge* call_edge) {
+          auto associated_functions = GetAssociatedFunctions(
+              *call_edge->dst(), flr->GetFunctionLibraryDefinition());
+          for (auto& associated_function : associated_functions) {
+            FunctionLibraryRuntime::Handle handle;
+            TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
+                associated_function.func_name(),
+                AttrSlice(&associated_function.attrs()), &handle));
+            const FunctionBody* body = flr->GetFunctionBody(handle);
+            Graph* g = body->graph;
+
+            for (Node* body_node : g->nodes()) {
+              if (!body_node->IsArg()) continue;
+
+              int index;
+              TF_RETURN_IF_ERROR(
+                  GetNodeAttr(body_node->attrs(), "index", &index));
+              if (index != call_edge->dst_input()) continue;
+
+              for (const Edge* out_edge : body_node->out_edges()) {
+                if (out_edge->IsControlEdge()) continue;
+
+                TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
+                    num_cores_per_replica, n.name(), *out_edge->dst(), &core,
+                    is_fast_mem, &result));
+
+                TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
+              }
+            }
+          }
+          return Status::OK();
+        };
+    TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
+  }
+  *output_sharding = result;
+  return Status::OK();
+}
+
+bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
+  bool spmd_attr;
+  if (!replicate_node ||
+      !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
+                      &spmd_attr)) {
+    spmd_attr = false;
+  }
+  return spmd_attr;
+}
+
+Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
+    int num_cores_per_replica, const ParameterInfo& params_info,
+    const DataTypeVector& arg_types,
+    const std::vector<InferredShape>& arg_shapes,
+    const DataTypeVector& retval_types,
+    const std::vector<InferredShape>& retval_shapes, const Graph& graph,
+    const Node* replicate_node, FunctionLibraryRuntime* flr,
+    std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
+    std::vector<xla::OpSharding>* retval_sharding) {
+  // Builds vectors of the argument and return nodes.
+  std::vector<Node*> args(arg_types.size());
+  std::vector<Node*> retvals(retval_types.size());
+  absl::flat_hash_map<int, Node*> partitioned_output_nodes;
+  for (Node* node : graph.op_nodes()) {
+    if (node->IsArg()) {
+      int index;
+      TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
+      TF_RET_CHECK(index >= 0 && index < args.size());
+      args[index] = node;
+    } else if (node->IsRetval()) {
+      int index;
+      TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
+      TF_RET_CHECK(index >= 0 && index < retvals.size());
+      retvals[index] = node;
+    }
+  }
+  for (const Edge* edge : replicate_node->out_edges()) {
+    int num_partitioned_outputs = 0;
+    for (const Edge* out_edge : edge->dst()->out_edges()) {
+      if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
+        partitioned_output_nodes[edge->src_output()] = out_edge->dst();
+        num_partitioned_outputs++;
+      }
+    }
+    if (num_partitioned_outputs > 1) {
+      return errors::InvalidArgument(
+          "More than one TPUPartitionedOutput per replciated output.");
+    }
+  }
+
+  // Verifies there are no missing arguments/return values.
+  for (int i = 0; i < args.size(); ++i) {
+    if (args[i] == nullptr) {
+      return errors::Internal("Missing function argument: ", i);
+    }
+  }
+  for (int i = 0; i < retvals.size(); ++i) {
+    if (retvals[i] == nullptr) {
+      return errors::Internal("Missing function return value: ", i);
+    }
+  }
+
+  // Assigns a core to each _Arg. Chooses the lowest-numbered core that
+  // consumes the argument. We choose the lowest-numbered core so the
+  // assignment is deterministic.
+  TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
+                                          arg_shapes);
+  arg_sharding->resize(args.size());
+  arg_fast_mem->resize(args.size());
+  CachedFunctionHandles cached_function_handles(flr);
+  const bool use_spmd = UseSpmdForXlaPartitioning(replicate_node) ||
+                        replicate_inputs_outputs_by_default_for_xla_spmd_;
+  for (int i = 0; i < args.size(); ++i) {
+    const Node* n = args[i];
+    absl::optional<int64> assigned_core;
+    absl::optional<xla::OpSharding> sharding;
+    bool is_fast_mem;
+    TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
+        *n, num_cores_per_replica, flr, &cached_function_handles, &sharding,
+        &is_fast_mem));
+
+    if (params_info.IsPerReplicaArg(i) || params_info.IsDistributedArg(i)) {
+      Node* input_node;
+      TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node));
+      if (input_node->type_string() == kTPUPartitionedInput) {
+        TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
+                            GetShardingFromNodeDef(input_node->def()));
+        if (!parsed_sharding.has_value())
+          return errors::InvalidArgument("Missing _XlaSharding attr from: ",
+                                         input_node->DebugString());
+        sharding = parsed_sharding;
+        VLOG(1) << "Arg " << i << " parsed sharding information from "
+                << input_node->name() << " : "
+                << parsed_sharding->DebugString();
+      }
+    }
+
+    if (sharding.has_value() && enable_automatic_model_parallelism_) {
+      return tensorflow::errors::InvalidArgument(
+          "Specifying manual sharding is not allowed when automatic "
+          "model parallelism is enabled.",
+          sharding->DebugString());
+    }
+
+    if (!sharding.has_value()) {
+      if (use_spmd &&
+          (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
+           ((params_info.IsPerReplicaArg(i) ||
+             params_info.IsDistributedArg(i)) &&
+            arg_types[i] != DT_RESOURCE))) {
+        // Use replication for host variables or non-variable per-replica
+        // inputs.
+        sharding = xla::sharding_builder::Replicate();
+      } else {
+        // TODO(dlibenzi): Distributing variables to cores other than 0 makes
+        // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
+        // For now distribute only per replica arguments, unless
+        // tf_jf_distribute_vars is set, to allow debugging the issue.
+        if (((params_info.IsPerReplicaArg(i) ||
+              params_info.IsDistributedArg(i)) &&
+             arg_types[i] != DT_RESOURCE) ||
+            (distribute_vars_ && params_info.IsVariableArg(i))) {
+          assigned_core = args_device_selector.RetrieveAssignment(i);
+        } else {
+          assigned_core = 0;
+        }
+        sharding = xla::sharding_builder::AssignDevice(*assigned_core);
+      }
+    } else if (sharding->type() == xla::OpSharding::MAXIMAL) {
+      assigned_core = sharding->tile_assignment_devices(0);
+    } else if (sharding->type() != xla::OpSharding::REPLICATED &&
+               sharding->type() != xla::OpSharding::OTHER) {
+      return tensorflow::errors::InvalidArgument(
+          "Unsupported argument sharding: ", sharding->DebugString());
+    }
+    if (assigned_core.has_value()) {
+      args_device_selector.ReportDeviceAssigned(*assigned_core, i);
+      VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
+              << ") to core " << *assigned_core;
+      args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
+    } else if (sharding->type() == xla::OpSharding::OTHER) {
+      for (int64 core : sharding->tile_assignment_devices()) {
+        args_device_selector.ReportDeviceAssigned(core, i);
+        VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
+                << ") with tiled sharding to core " << core;
+      }
+    } else {
+      CHECK_EQ(sharding->type(), xla::OpSharding::REPLICATED);
+      for (int64 core = 0; core < num_cores_per_replica; ++core) {
+        args_device_selector.ReportDeviceAssigned(core, i);
+      }
+      VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
+              << ") to all cores";
+    }
+    (*arg_sharding)[i] = *sharding;
+    (*arg_fast_mem)[i] = is_fast_mem;
+    if (is_fast_mem) {
+      VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
+              << args[i]->name();
+    }
+    args[i]->AddAttr(kShardingAttribute, sharding->SerializeAsString());
+  }
+  TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
+
+  // Assigns each _Retval node to the core that produces its value.
+  TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
+                                             retval_types, retval_shapes);
+  retval_sharding->resize(retvals.size());
+  for (int i = 0; i < retvals.size(); ++i) {
+    const Edge* edge;
+    TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
+
+    TF_ASSIGN_OR_RETURN(
+        absl::optional<xla::OpSharding> sharding,
+        ParseShardingFromDevice(*edge->src(), num_cores_per_replica));
+
+    if (partitioned_output_nodes.contains(i)) {
+      Node* output_node = partitioned_output_nodes[i];
+      TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding,
+                          GetShardingFromNodeDef(output_node->def()));
+      if (parsed_sharding.has_value()) {
+        sharding = parsed_sharding;
+        VLOG(1) << "Retval " << i << " parsed sharding information from "
+                << output_node->name() << " : " << sharding->DebugString();
+      }
+    }
+    absl::optional<int64> assigned_core;
+    if (sharding.has_value()) {
+      if (enable_automatic_model_parallelism_) {
+        return tensorflow::errors::InvalidArgument(
+            "Specifying manual sharding is not allowed when automatic "
+            "model parallelism is enabled.",
+            sharding->DebugString());
+      }
+
+      if (sharding.value().type() == xla::OpSharding::MAXIMAL) {
+        assigned_core = sharding.value().tile_assignment_devices(0);
+        TF_RETURN_IF_ERROR(
+            ValidateCoreNumber(*assigned_core, num_cores_per_replica));
+      } else if (sharding.value().type() != xla::OpSharding::REPLICATED &&
+                 sharding.value().type() != xla::OpSharding::OTHER) {
+        return tensorflow::errors::InvalidArgument(
+            "Unsupported argument sharding: ", sharding->DebugString());
+      }
+    } else {
+      if (use_spmd) {
+        sharding = xla::sharding_builder::Replicate();
+      } else {
+        if (distribute_vars_) {
+          assigned_core = retvals_device_selector.RetrieveAssignment(i);
+        } else {
+          assigned_core = 0;
+        }
+        sharding = xla::sharding_builder::AssignDevice(*assigned_core);
+      }
+    }
+    if (assigned_core.has_value()) {
+      retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
+      retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
+      VLOG(3) << "Assigning return value " << i << " ("
+              << retvals[i]->DebugString() << ") to core " << *assigned_core;
+    } else if (sharding->type() == xla::OpSharding::OTHER) {
+      for (int64 core : sharding->tile_assignment_devices()) {
+        retvals_device_selector.ReportDeviceAssigned(core, i);
+        VLOG(3) << "Assigning return value " << i << " ("
+                << retvals[i]->DebugString() << ") with tiled sharding to core "
+                << core;
+      }
+    } else {
+      CHECK_EQ(sharding->type(), xla::OpSharding::REPLICATED);
+      for (int64 core = 0; core < num_cores_per_replica; ++core) {
+        retvals_device_selector.ReportDeviceAssigned(core, i);
+      }
+      VLOG(3) << "Assigning return value " << i << " ("
+              << retvals[i]->DebugString() << ") to all cores.";
+    }
+    retvals[i]->AddAttr(kShardingAttribute, sharding->SerializeAsString());
+    (*retval_sharding)[i] = *sharding;
+  }
+  return Status::OK();
+}
+
+// Builds Shape nodes that compute the shapes of arguments whose shapes are not
+// statically known.
+/* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
+    const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
+    const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
+    Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
+  dynamic_shape_nodes->clear();
+
+  std::vector<const Edge*> replicate_input_edges;
+  TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
+
+  // The compiler determines the shape of each constant by inspecting the value
+  // of its corresponding host-memory tensor; this happens when a step is run.
+  // As a result, the shapes of constants are not needed at graph rewrite time.
+  const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
+  TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
+                               params_info.NumDistributedArgs() +
+                               params_info.NumBroadcastArgs() +
+                               params_info.NumVariables());
+
+  for (int i = 0; i < num_args; ++i) {
+    const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
+                                          ? &arg_shapes[i].shape
+                                          : &arg_shapes[i].handle_shape;
+    if (!shape->IsFullyDefined()) {
+      Node* src;
+      int src_output;
+      if (params_info.IsPerReplicaArg(i)) {
+        TF_RET_CHECK(i < replicate_input_edges.size());
+        // All replicas must have the same input shapes. Uses the shape of the
+        // inputs from the first replica.
+        src = replicate_input_edges[i]->src();
+        src_output = replicate_input_edges[i]->src_output();
+      } else if (params_info.IsDistributedArg(i) ||
+                 params_info.IsBroadcastArg(i)) {
+        int64 input_num =
+            params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
+            params_info.NumPerReplicaArgs();
+        TF_RET_CHECK(0 <= input_num &&
+                     input_num < replicate_input_edges.size());
+        src = replicate_input_edges[input_num]->src();
+        src_output = replicate_input_edges[input_num]->src_output();
+      } else {
+        int64 var_num = i - params_info.NumPerReplicaArgs() -
+                        params_info.NumDistributedArgs() -
+                        params_info.NumBroadcastArgs();
+        TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
+        src = variable_reads[var_num];
+        src_output = 0;
+      }
+
+      NodeDef def;
+      def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
+      def.set_op("Shape");
+      def.set_device(src->assigned_device_name());
+      AddNodeAttr("T", src->output_type(src_output), &def);
+      AddNodeAttr("out_type", DT_INT64, &def);
+      MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
+
+      Status status;
+      Node* shape_node = graph->AddNode(def, &status);
+      if (!status.ok()) return status;
+      dynamic_shape_nodes->push_back(shape_node);
+
+      shape_node->set_assigned_device_name(src->assigned_device_name());
+      graph->AddEdge(src, src_output, shape_node, 0);
+    }
+  }
+  return Status::OK();
+}
+
+// Builds a TPUCompile node that compiles the bodies of the function call
+// `nodes`.
+Status DistributedTPURewritePass::BuildCompileNode(
+    const Node* replicate_node, const NameAttrList& function,
+    uint64 library_fingerprint, const ParameterInfo& params_info,
+    const std::vector<InferredShape>& arg_shapes,
+    const DataTypeVector& arg_types,
+    const std::vector<Node*>& guaranteed_constant_nodes,
+    const string& session_handle,
+    const std::vector<xla::OpSharding>& arg_sharding,
+    const std::vector<bool>& arg_fast_mem,
+    const std::vector<xla::OpSharding>& retval_sharding,
+    int num_cores_per_replica, const string& compile_device,
+    const xla::DeviceAssignment* xla_device_assignment,
+    const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
+    Node** compile_node, int64 autotuner_thresh) {
+  VLOG(1) << "BuildCompileNode";
+
+  tpu::TPUCompileMetadataProto proto;
+  proto.set_num_replicas(params_info.NumReplicas());
+  proto.set_num_cores_per_replica(num_cores_per_replica);
+  proto.set_function_library_fingerprint(library_fingerprint);
+  proto.set_enable_automatic_model_parallelism(
+      enable_cross_replica_sharding_mirrored_variables_);
+  const bool use_spmd = UseSpmdForXlaPartitioning(replicate_node);
+  proto.set_use_spmd_for_xla_partitioning(use_spmd);
+
+  // Get and fill padding map.
+  if (replicate_node != nullptr) {
+    TF_RETURN_IF_ERROR(
+        FillPaddingMap(*replicate_node, proto.mutable_padding_maps()));
+    xla::DebugOptions::StepMarkerLocation location;
+    TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
+    proto.set_step_marker_location(location);
+  }
+
+  if (xla_device_assignment != nullptr) {
+    TF_RETURN_IF_ERROR(
+        xla_device_assignment->Serialize(proto.mutable_device_assignment()));
+  }
+
+  const int num_args = arg_types.size();
+  const int num_guaranteed_constants = guaranteed_constant_nodes.size();
+  const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
+  TF_RET_CHECK(num_args == arg_shapes.size());
+  TF_RET_CHECK(num_args == arg_sharding.size())
+      << num_args << " != " << arg_sharding.size();
+
+  for (int i = 0; i < num_args; ++i) {
+    tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
+    DataType type = arg_types[i];
+    const InferredShape& arg_shape = arg_shapes[i];
+    if (type == DT_RESOURCE) {
+      TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
+      arg->set_dtype(arg_shape.handle_type);
+      arg_shape.handle_shape.AsProto(arg->mutable_shape());
+      arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
+      arg->set_fast_mem(arg_fast_mem[i]);
+    } else {
+      arg->set_dtype(type);
+      arg_shape.shape.AsProto(arg->mutable_shape());
+      if (i >= guaranteed_const_start_index) {
+        const DataType edge_type =
+            guaranteed_constant_nodes[i - guaranteed_const_start_index]
+                ->output_type(0);
+        TF_RET_CHECK(type == edge_type)
+            << "Arg type: " << type << " but edge type: " << edge_type;
+        arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
+      } else {
+        arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
+      }
+    }
+    // As long as the argument is not a per-replica one, it should have the same
+    // value for all replicas. For clarity, we keep the (redundant) checks for
+    // variable, broadcast and constant types, to prevent bugs in case new types
+    // with different semantics are introduced in the future.
+    arg->set_is_same_data_across_replicas(
+        !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
+        (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
+         params_info.IsConstantArg(i)));
+    if (params_info.mirrored_variable_indices().count(i) > 0) {
+      CHECK_EQ(type, DT_RESOURCE);
+      arg->set_is_same_data_across_replicas(true);
+      // 64-bit type is not shardable by XLA:TPU yet.
+      bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
+                               arg_shape.handle_type != DT_INT64 &&
+                               arg_shape.handle_type != DT_UINT64 &&
+                               arg_shape.handle_type != DT_DOUBLE);
+      arg->set_enable_xla_sharding(
+          sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
+                           : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
+    }
+    *arg->mutable_sharding() = arg_sharding[i];
+  }
+
+  const int num_retvals = retval_sharding.size();
+  for (int i = 0; i < num_retvals; ++i) {
+    *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
+  }
+  proto.set_session_handle(session_handle);
+
+  DataTypeVector constant_arg_types;
+  constant_arg_types.reserve(num_guaranteed_constants);
+  for (int i = 0; i < num_guaranteed_constants; ++i) {
+    constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
+  }
+  proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
+
+  string metadata;
+  proto.SerializeToString(&metadata);
+
+  NodeDef def;
+  def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
+  def.set_op("TPUCompile");
+  def.set_device(compile_device);
+  if (replicate_node) {
+    MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
+  }
+
+  AddNodeAttr("function", function, &def);
+  AddNodeAttr("num_computations", num_cores_per_replica, &def);
+  AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
+              &def);
+  AddNodeAttr("metadata", metadata, &def);
+  AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
+
+  Status status;
+  *compile_node = graph->AddNode(def, &status);
+  TF_RETURN_IF_ERROR(status);
+
+  (*compile_node)->set_assigned_device_name(compile_device);
+
+  for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
+    graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
+  }
+
+  for (int i = 0; i < num_guaranteed_constants; ++i) {
+    graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
+                   dynamic_shape_nodes.size() + i);
+  }
+  VLOG(1) << "BuildCompileNode(): " << status;
+  return status;
+}
+
+Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
+    const Node& node, const NameRangeMap& input_range_map,
+    std::vector<Node*>* guaranteed_constants) {
+  std::vector<const Edge*> input_edges;
+  TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
+  std::pair<int, int> variables_limits =
+      input_range_map.at("guaranteed_constants");
+  for (int i = variables_limits.first; i < variables_limits.second; ++i) {
+    guaranteed_constants->push_back(input_edges[i]->src());
+  }
+  return Status::OK();
+}
+
+Status DistributedTPURewritePass::FindVariableInputs(
+    const Node& node, const NameRangeMap& input_range_map,
+    std::vector<VariableInput>* variables) {
+  std::vector<const Edge*> input_edges;
+  TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
+  std::pair<int, int> variables_limits = input_range_map.at("variables");
+  for (int i = variables_limits.first; i < variables_limits.second; ++i) {
+    Node* node = input_edges[i]->src();
+
+    // Find the type of the VarHandleOp that feeds this node, looking through
+    // any wrapping Enter or Switch nodes.
+    while (node->IsEnter() || node->IsSwitch()) {
+      TF_RETURN_IF_ERROR(node->input_node(0, &node));
+    }
+    // Fix the variable device assignment if it is requested with a full name.
+    if (!node->has_assigned_device_name() &&
+        !node->requested_device().empty()) {
+      DeviceNameUtils::ParsedName var_device;
+      TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
+                                                  &var_device));
+      if (var_device.has_job && var_device.has_replica && var_device.has_task &&
+          var_device.has_type && var_device.has_id) {
+        node->set_assigned_device_name(node->requested_device());
+        if (node != input_edges[i]->src() &&
+            !input_edges[i]->src()->has_assigned_device_name()) {
+          input_edges[i]->src()->set_assigned_device_name(
+              node->requested_device());
+        }
+      }
+    }
+    if (node->type_string() == "VarHandleOp") {
+      DataType dtype;
+      TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
+      variables->push_back(VariableInput{input_edges[i]->src(),
+                                         input_edges[i]->src_output(), dtype});
+    } else if (node->type_string() == "_Arg") {
+      std::vector<DataType> dtypes;
+      TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
+      if (dtypes.empty()) {
+        return errors::Internal(
+            "_Arg node with resource output must have non-empty _handle_dtypes "
+            "attribute: ",
+            node->DebugString());
+      }
+      variables->push_back(VariableInput{
+          input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
+    } else {
+      return errors::Internal(
+          "Cannot handle variable input with node type other than VarHandleOp "
+          "and _Arg: ",
+          node->DebugString());
+    }
+  }
+  return Status::OK();
+}
+
+// Builds a NoOp node, used for building control dependencies.
+static Status BuildNoopNode(const Node& source, StringPiece name,
+                            const string& device, Graph* graph, Node** node) {
+  NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
+  if (!device.empty()) {
+    builder.Device(device);
+  }
+  NodeDef def;
+  TF_RETURN_IF_ERROR(builder.Finalize(&def));
+
+  Status status;
+  *node = graph->AddNode(def, &status);
+  if (!device.empty()) {
+    (*node)->set_assigned_device_name(device);
+  }
+  return status;
+}
+
+Status DistributedTPURewritePass::ConnectHostComputeNodes(
+    Node* compile_node, Node* key_placeholder_node, Graph* graph) {
+  // First find all the downstream nodes of the key placeholder node, since we
+  // want to delete the connecting edges from key_placeholder_node which would
+  // invalidate the out_nodes iterator.
+  std::vector<Node*> host_transfer_nodes;
+  for (Node* node : key_placeholder_node->out_nodes()) {
+    host_transfer_nodes.push_back(node);
+  }
+  for (Node* node : host_transfer_nodes) {
+    int input_index = -1;
+    for (int i = 0; i < node->num_inputs(); i++) {
+      const Edge* e;
+      TF_RETURN_IF_ERROR(node->input_edge(i, &e));
+      if (e->src() == key_placeholder_node) {
+        if (input_index != -1) {
+          return errors::Internal(
+              "Node ", node->name(),
+              " has multiple input edges from key placeholder node");
+        }
+        input_index = e->dst_input();
+      }
+    }
+    if (input_index == -1) {
+      return errors::Internal("Node ", node->name(),
+                              " has no input edge from key placeholder node");
+    }
+    const Edge* key_edge;
+    TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
+    graph->RemoveEdge(key_edge);
+    graph->AddEdge(compile_node, 1, node, input_index);
+  }
+  graph->RemoveNode(key_placeholder_node);
+  return Status::OK();
+}
+
+Status DistributedTPURewritePass::BuildVariableReads(
+    absl::Span<const VariableInput> variables, Node* control_predecessor,
+    Graph* graph, std::vector<Node*>* variable_reads) {
+  variable_reads->resize(variables.size());
+  for (int i = 0; i < variables.size(); ++i) {
+    string name =
+        graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
+    NodeDefBuilder builder(name, "ReadVariableOp",
+                           NodeDebugInfo(*variables[i].node));
+
+    builder.Attr("dtype", variables[i].dtype);
+    builder.Device(variables[i].node->assigned_device_name());
+    builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
+    NodeDef def;
+    TF_RETURN_IF_ERROR(builder.Finalize(&def));
+
+    Status status;
+    Node* read_node;
+    (*variable_reads)[i] = read_node = graph->AddNode(def, &status);
+    if (!status.ok()) return status;
+
+    read_node->set_requested_device(variables[i].node->requested_device());
+    read_node->set_assigned_device_name(
+        variables[i].node->assigned_device_name());
+    graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
+
+    graph->AddControlEdge(control_predecessor, read_node);
+  }
+  return Status::OK();
+}
+
+bool DistributedTPURewritePass::ContainsResourceWriteOp(
+    const Graph& graph, const FunctionLibraryDefinition& fld) {
+  for (const Node* n : graph.nodes()) {
+    const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
+    if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
+      VLOG(2) << "Found write resource op inside computation";
+      return true;
+    }
+  }
+  for (const string& func_name : fld.ListFunctionNames()) {
+    const FunctionDef* func_def = fld.Find(func_name);
+    for (const NodeDef& n : func_def->node_def()) {
+      const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
+      if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
+        VLOG(2) << "Found write resource op inside " << func_name;
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+Status DistributedTPURewritePass::BuildVariableWrites(
+    absl::Span<const VariableInput> variables, Node* control_successor,
+    absl::Span<const VariableWrite> variable_writes, Graph* graph) {
+  CHECK_EQ(variables.size(), variable_writes.size());
+  for (int i = 0; i < variables.size(); ++i) {
+    const VariableWrite& write = variable_writes[i];
+    NodeDebugInfo debug_info(*variables[i].node);
+
+    auto name = [&](string suffix) {
+      return graph->NewName(
+          strings::StrCat(variables[i].node->name(), "/", suffix));
+    };
+
+    Node* write_node;
+    TF_RETURN_IF_ERROR(
+        IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
+            .AddAttr("dtype", variables[i].dtype)
+            .Device(variables[i].node->assigned_device_name())
+            .Build(graph, &write_node));
+
+    // Colocate the control flow with the variable.
+    CondBuilder cb(variables[i].node->name(),
+                   variables[i].node->assigned_device_name(), debug_info,
+                   graph);
+
+    // Inputs to conditional.
+    Node* switch_val;
+    TF_RETURN_IF_ERROR(
+        cb.AddInput("switch_val", variables[i].dtype,
+                    /*device=*/write.value->assigned_device_name(), debug_info,
+                    &switch_val));
+    Node* switch_var;
+    TF_RETURN_IF_ERROR(
+        cb.AddInput("switch_var", DT_RESOURCE,
+                    /*device=*/variables[i].node->assigned_device_name(),
+                    debug_info, &switch_var));
+    // Conditionally write the value back.
+    graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
+    graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
+    graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
+    // Add control edge from the write to value that will be merged. There is no
+    // output from the write so this control edge ensures the write completes.
+    graph->AddControlEdge(write_node, cb.switch_t());
+
+    graph->AddControlEdge(cb.control_successor(), control_successor);
+
+    graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
+    graph->AddEdge(write.value, write.value_output, switch_val, 0);
+  }
+  return Status::OK();
+}
+
+namespace {
+
+// Helper that creates an IdentityN node containing all of the variables
+// values on CPU device 'device', except for those that will be split across
+// cores. (For split variables, this may cause additional cross-host data
+// transfers if more than 1 devices share the same variable partition on a
+// remote host.)
+//
+// A previous iteration of this code built one Identity node per TPU core per
+// variable, but this can rapidly become hundreds of thousands of nodes. This
+// formulation creates a single IdentityN node containing all of the variables
+// on each host. This may cause some unnecessary variable copies if only a
+// subset of hosts consume a given variable, but has the virtue of being
+// simple, and most models use pure replication where all cores want all the
+// variables.
+//
+// Returns the node and its output index to be consumed by TPUExecute for the
+// requested variable index.
+xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
+    const string& host_cpu_device, int64 var_index,
+    const std::vector<Node*>& variable_reads,
+    const DistributedTPURewritePass::ParameterInfo& params_info,
+    const std::vector<xla::OpSharding>& arg_shardings,
+    const Node& replicate_node,
+    absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
+    Graph* graph) {
+  auto it = per_host_var_copies->find(host_cpu_device);
+  if (it != per_host_var_copies->end()) {
+    return it->second[var_index];
+  }
+
+  DataTypeVector dtypes;
+  // Per-variable data source for TPUExecute.
+  std::vector<NodeOut> index_mapping;
+  index_mapping.reserve(variable_reads.size());
+  dtypes.reserve(variable_reads.size());
+  for (int64 i = 0; i < variable_reads.size(); ++i) {
+    Node* read = variable_reads[i];
+    int64 orig_arg_num =
+        i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs();
+    if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
+      // We haven't built the IdentityN node yet, so temporarily use nullptr.
+      index_mapping.push_back(
+          NodeOut{nullptr, static_cast<int>(dtypes.size())});
+      dtypes.push_back(read->output_type(0));
+    } else {
+      // Do not copy the full tensor of partitioned variables.
+      index_mapping.push_back(NodeOut{read, 0});
+    }
+  }
+  NodeDef ndef;
+  ndef.set_name(
+      graph->NewName(absl::StrCat(replicate_node.name(), "/_variable_copy")));
+  ndef.set_op("IdentityN");
+  ndef.set_device(host_cpu_device);
+  AddNodeAttr("T", dtypes, &ndef);
+  Status s;
+  Node* id_node = graph->AddNode(ndef, &s);
+  TF_RETURN_IF_ERROR(s);
+  id_node->set_assigned_device_name(host_cpu_device);
+
+  for (int64 i = 0; i < variable_reads.size(); ++i) {
+    if (index_mapping[i].node == nullptr) {
+      // Fill index_mapping with the actual IdentityN node.
+      index_mapping[i].node = id_node;
+      // Add the edge to id_node.
+      graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
+    }
+  }
+
+  auto result = index_mapping[var_index];
+  (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
+  return result;
+}
+
+}  // namespace
+
+Status DistributedTPURewritePass::BuildExecuteNodes(
+    const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
+    const Node& replicate_node, const DataTypeVector& arg_types,
+    const std::vector<InferredShape>& arg_shapes,
+    const DataTypeVector& retval_types,
+    const std::vector<xla::OpSharding>& arg_shardings,
+    const std::vector<xla::OpSharding>& retval_shardings,
+    const std::vector<std::vector<string>>& tpu_device_names,
+    Node* compile_node, const std::vector<Node*>& variable_reads,
+    Node* control_predecessor, Node* control_successor,
+    std::vector<VariableWrite>* variable_writes, Graph* graph) {
+  VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
+  TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
+
+  const int num_variables = variable_reads.size();
+  const int num_retvals_per_replica = retval_types.size();
+
+  variable_writes->resize(num_variables);
+
+  std::vector<const Edge*> replicate_input_edges;
+  TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
+
+  // Map from replicate input index to the fan_in node;
+  absl::flat_hash_map<int, std::vector<Node*>> replicate_input_fan_in_nodes;
+  absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
+  absl::flat_hash_map<int, std::vector<int>>
+      replicate_output_fan_out_dst_inputs;
+  std::vector<Node*> to_be_removed_nodes;
+
+  for (const Edge* e : replicate_input_edges) {
+    if (e->src()->type_string() == kTPUPartitionedInput) {
+      int num_users = 0;
+      for (const auto& ue : e->src()->out_edges()) {
+        if (!ue->IsControlEdge()) ++num_users;
+      }
+      if (num_users != 1) {
+        return tensorflow::errors::InvalidArgument(
+            e->src()->name(), " must only have one user. Found ", num_users);
+      }
+      to_be_removed_nodes.push_back(e->src());
+      std::vector<Node*>& nodes = replicate_input_fan_in_nodes[e->dst_input()];
+      nodes.resize(num_cores_per_replica, nullptr);
+      VLOG(2) << "allocate " << num_cores_per_replica
+              << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
+      std::vector<const Edge*> fan_in_edges;
+      TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
+      TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
+
+      for (const Edge* fe : fan_in_edges) {
+        nodes[fe->dst_input()] = fe->src();
+        VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
+                << fe->dst_input() << "] = " << fe->src()->name();
+      }
+    }
+  }
+
+  // Replicate output edges are sorted by replica id and then by outputs for
+  // each replica. For example, if TPU Computation has outputs (output_1,
+  // output_2, and output_3) and number of replicas is 2, then
+  // replicate_output_edges order would be:
+  // output_1_replica_1, output_2_replica_1, output_3_replica_1,
+  // output_1_replica_2, output_2_replica_2, output_3_replica_2.
+  std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
+                                                  nullptr);
+  for (const Edge* edge : replicate_node.out_edges()) {
+    if (edge->IsControlEdge()) continue;
+
+    int num_partitioned_outputs = 0;
+
+    for (const Edge* out_edge : edge->dst()->out_edges()) {
+      if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
+        num_partitioned_outputs++;
+        // Paths between replicate_node and replicate_output_fan_out_nodes:
+        // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
+        TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
+        to_be_removed_nodes.push_back(edge->dst());
+        to_be_removed_nodes.push_back(out_edge->dst());
+        // Get the right replicated id from the replicate_output_edge.
+        std::vector<Node*>& nodes =
+            replicate_output_fan_out_nodes[edge->src_output()];
+        std::vector<int>& dst_inputs =
+            replicate_output_fan_out_dst_inputs[edge->src_output()];
+        nodes.resize(num_cores_per_replica, nullptr);
+        dst_inputs.resize(num_cores_per_replica, 0);
+        TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
+                     num_cores_per_replica);
+
+        for (const Edge* fe : out_edge->dst()->out_edges()) {
+          nodes[fe->src_output()] = fe->dst();
+          dst_inputs[fe->src_output()] = fe->dst_input();
+          VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
+                  << "][" << fe->src_output()
+                  << "] = " << fe->dst()->DebugString() << " with dst_input "
+                  << fe->dst_input();
+        }
+      }
+    }
+    replicate_output_edges[edge->src_output()] = edge;
+    if (num_partitioned_outputs > 1) {
+      return errors::InvalidArgument(
+          "More than one TPUPartitionedOutput per replciated output.");
+    }
+  }
+
+  const int num_execute_args =
+      arg_shardings.size() - params_info.NumGuaranteedConstants();
+  // Inverts the arg_shardings and retval_shardings mappings to
+  // form core -> {argument number} maps.
+  std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
+  for (int i = 0; i < num_execute_args; ++i) {
+    const auto& sharding = arg_shardings[i];
+    if (sharding.type() == xla::OpSharding::MAXIMAL) {
+      int core = sharding.tile_assignment_devices(0);
+      TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
+      core_arg_nums[core].push_back(i);
+    } else if (sharding.type() == xla::OpSharding::OTHER) {
+      for (int64 core : sharding.tile_assignment_devices()) {
+        core_arg_nums[core].push_back(i);
+      }
+    } else if (sharding.type() == xla::OpSharding::REPLICATED) {
+      for (int core = 0; core < num_cores_per_replica; ++core) {
+        core_arg_nums[core].push_back(i);
+      }
+    } else {
+      return tensorflow::errors::InvalidArgument(
+          "Unsupported argument sharding: ", sharding.DebugString());
+    }
+  }
+  std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
+  for (int i = 0; i < retval_shardings.size(); ++i) {
+    const auto& sharding = retval_shardings[i];
+    if (sharding.type() == xla::OpSharding::MAXIMAL) {
+      int core = sharding.tile_assignment_devices(0);
+      TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
+      core_retval_nums[core].push_back(i);
+    } else if (sharding.type() == xla::OpSharding::REPLICATED) {
+      for (int core = 0; core < num_cores_per_replica; ++core) {
+        core_retval_nums[core].push_back(i);
+      }
+    } else if (sharding.type() == xla::OpSharding::OTHER) {
+      for (int64 core : sharding.tile_assignment_devices()) {
+        core_retval_nums[core].push_back(i);
+      }
+    } else {
+      return tensorflow::errors::InvalidArgument(
+          "Unsupported argument sharding: ", sharding.DebugString());
+    }
+  }
+
+  // Maps host device name to a list of per-variable pairs (variable_copy_node,
+  // output_index_of_copy_node).
+  absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
+
+  // Mapping from original resource arg number to a second level map. Second
+  // level map is from core id to output index of updated variable value.
+  absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
+      orig_arg_num_to_output_index_mapping;
+  // Mapping from retval index to a second level map. Second level map is from
+  // core id to output index of sharded output value.
+  std::unordered_map<int, std::unordered_map<int, int>>
+      retval_index_to_output_index_mapping;
+
+  // Represents mapping of argument index of sharded input to each
+  // TPUExecute node to its corresponding Split node and its output index
+  // from which sharded input will be fed into TPUExecute node.
+  std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
+
+  // Builds one TPUExecute node per core per replica.
+  std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
+  for (int core = 0; core < num_cores_per_replica; ++core) {
+    DataTypeVector core_retval_types;
+    for (int output : core_retval_nums[core]) {
+      core_retval_types.push_back(retval_types[output]);
+    }
+    DataTypeVector core_arg_types;
+    std::vector<int> core_variable_writes;
+    for (int input : core_arg_nums[core]) {
+      // Resource variables can be passed either by reference (as a DT_RESOURCE)
+      // tensor or by value (as the variable's current value). Per-replica or
+      // distributed resource arguments are always passed by reference and
+      // broadcast variables are always passed by value.
+      if (arg_types[input] == DT_RESOURCE &&
+          !params_info.IsPerReplicaArg(input) &&
+          !params_info.IsDistributedArg(input)) {
+        DataType handle_type = arg_shapes[input].handle_type;
+        TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
+        core_arg_types.push_back(handle_type);
+        int base = input - params_info.NumPerReplicaArgs() -
+                   params_info.NumDistributedArgs() -
+                   params_info.NumBroadcastArgs();
+        // Variables passed by value will have a corresponding additional output
+        // containing an updated value for the variable.
+        core_variable_writes.push_back(base);
+        core_retval_types.push_back(handle_type);
+      } else {
+        core_arg_types.push_back(arg_types[input]);
+      }
+    }
+
+    NodeDef def;
+    def.set_op("TPUExecute");
+    MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
+    AddNodeAttr("Targs", core_arg_types, &def);
+    AddNodeAttr("Tresults", core_retval_types, &def);
+
+    for (int64 replica = 0; replica < params_info.NumReplicas(); ++replica) {
+      def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
+                                   "_", core));
+
+      Status status;
+      Node* node = graph->AddNode(def, &status);
+      if (!status.ok()) return status;
+      execute_nodes[replica].push_back(node);
+
+      node->set_assigned_device_name(tpu_device_names[replica][core]);
+
+      // Add control edges to ensure that execution happens after
+      // `control_predecessor`, happens before `control_successor`, and is
+      // triggered by evaluating any operator that depends on the original
+      // TPUReplicate operator. See the comment at the top of the header file
+      // for more details.
+      graph->AddControlEdge(control_predecessor, node);
+      graph->AddControlEdge(node, control_successor);
+
+      // Add data input edges.
+      for (int64 i = 0; i < core_arg_nums[core].size(); ++i) {
+        int64 orig_arg_num = core_arg_nums[core][i];
+        VLOG(2) << " replica " << replica << " core " << core << " i " << i
+                << " orig_arg_num " << orig_arg_num;
+        if (params_info.IsPerReplicaArg(orig_arg_num) ||
+            params_info.IsDistributedArg(orig_arg_num)) {
+          // Per-replica input and distributed input
+          int64 input_num = params_info.IsPerReplicaArg(orig_arg_num)
+                                ? replica * params_info.NumPerReplicaArgs() +
+                                      core_arg_nums[core][i]
+                                : params_info.NumReplicas() *
+                                          params_info.NumPerReplicaArgs() +
+                                      core_arg_nums[core][i] -
+                                      params_info.NumPerReplicaArgs();
+
+          const Edge* edge = replicate_input_edges[input_num];
+          VLOG(2) << "replicate_input_edges[" << input_num << "]";
+          DataType dtype = edge->src()->output_type(edge->src_output());
+          if (dtype == DT_RESOURCE) {
+            DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
+            if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
+                          handle_dtype) == kTpuAllTypes.end()) {
+              return errors::InvalidArgument(
+                  "Unsupported resource variable data type for TPU: ",
+                  DataTypeString(handle_dtype), ", caused by output ",
+                  edge->src()->name(), ":", edge->src_output());
+            }
+          } else {
+            if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
+                kTpuAllTypes.end()) {
+              return errors::InvalidArgument(
+                  "Unsupported data type for TPU: ", DataTypeString(dtype),
+                  ", caused by output ", edge->src()->name(), ":",
+                  edge->src_output());
+            }
+          }
+          if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
+            // Don't automatically add a split node when input node is
+            // kTPUPartitionedInput
+            if (edge->src()->type_string() == kTPUPartitionedInput) {
+              VLOG(2) << "Connect "
+                      << replicate_input_fan_in_nodes[input_num][core]->name()
+                      << " to " << node->name() << " at " << i;
+              graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
+                             node, i);
+            } else {
+              if (dtype == DT_RESOURCE) {
+                return errors::InvalidArgument(
+                    "Tiled sharding for per-replica DT_RESOURCE input must",
+                    "be TPUPartitionedInput. Here got ",
+                    edge->src()->type_string());
+              }
+              const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
+
+              // Create or get the Split node.
+              TF_ASSIGN_OR_RETURN(
+                  ShardedInputInfo sharded_input_info,
+                  CreateOrGetSplitNodesForInputSharding(
+                      sharding, orig_arg_num, dtype, replica,
+                      edge->src_output(), edge->src(), control_predecessor,
+                      graph, &input_index_to_sharded_inputs));
+
+              // Calculate which output we should receive from the Split node.
+              absl::optional<int> output_index =
+                  GetCoreIndexInSharding(sharding, core);
+              TF_RET_CHECK(output_index);
+
+              NodeOut split_node_and_index =
+                  sharded_input_info.sharded_inputs.at(output_index.value());
+              // Connect with Split node output.
+              graph->AddEdge(split_node_and_index.node,
+                             split_node_and_index.index, node, i);
+            }
+          } else if (edge->src()->type_string() == kTPUPartitionedInput &&
+                     arg_shardings[orig_arg_num].type() ==
+                         xla::OpSharding::REPLICATED) {
+            graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0,
+                           node, i);
+          } else {
+            graph->AddEdge(edge->src(), edge->src_output(), node, i);
+          }
+        } else if (params_info.IsBroadcastArg(orig_arg_num)) {
+          // Broadcast input.
+          int64 input_num = params_info.FirstBroadcastArgFromHost() +
+                            core_arg_nums[core][i] -
+                            params_info.NumPerReplicaArgs() -
+                            params_info.NumDistributedArgs();
+          const Edge* edge = replicate_input_edges[input_num];
+          DataType dtype = edge->src()->output_type(edge->src_output());
+          if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
+              kTpuAllTypes.end()) {
+            return errors::InvalidArgument(
+                "Unsupported data type for TPU: ", DataTypeString(dtype),
+                ", caused by output ", edge->src()->name(), ":",
+                edge->src_output());
+          }
+          graph->AddEdge(edge->src(), edge->src_output(), node, i);
+        } else {
+          // Variable input.
+          int64 variable_num = orig_arg_num - params_info.NumPerReplicaArgs() -
+                               params_info.NumDistributedArgs() -
+                               params_info.NumBroadcastArgs();
+          TF_RET_CHECK(variable_num < num_variables);
+
+          Node* variable_read = variable_reads[variable_num];
+          DataType dtype = variable_read->output_type(0);
+          if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
+              kTpuAllTypes.end()) {
+            return errors::InvalidArgument(
+                "Unsupported resource variable data type for TPU: ",
+                DataTypeString(dtype), ", caused by ReadVariableOp ",
+                variable_read->DebugString());
+          }
+          DeviceNameUtils::ParsedName requested_device;
+          string requested = variable_read->requested_device();
+          TF_RET_CHECK(
+              DeviceNameUtils::ParseFullName(requested, &requested_device));
+          if (requested_device.type != "TPU") {
+            // Stage the value via the CPU device on the remote host. The graph
+            // partitioner will introduce an intermediate copy rather than
+            // copying the same tensor multiple times across the network, and we
+            // would prefer that intermediate copy to be in host memory to avoid
+            // running out of memory if the TPUExecute op on the staging device
+            // starts running before the _Send ops to the other TPU devices on
+            // the same host complete. We don't do this if the variables are
+            // already placed on TPU, otherwise it will cause an unnecessary
+            // round trip copy.
+            // TODO(b/79580121): give each replica its own on-device variable
+            // replica and then delete this code.
+            string device;
+            TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
+                tpu_device_names[replica][core], &device));
+            TF_ASSIGN_OR_RETURN(auto var_data,
+                                CreateOrGetPerHostVariableCopy(
+                                    device, variable_num, variable_reads,
+                                    params_info, arg_shardings, replicate_node,
+                                    &per_host_var_copies, graph));
+
+            if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
+              const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
+              // Create or get the Split node.
+              TF_ASSIGN_OR_RETURN(
+                  ShardedInputInfo sharded_input_info,
+                  CreateOrGetSplitNodesForInputSharding(
+                      sharding, orig_arg_num,
+                      arg_shapes[orig_arg_num].handle_type, replica,
+                      var_data.index, var_data.node, control_predecessor, graph,
+                      &input_index_to_sharded_inputs));
+
+              // Calculate which output we should receive from the Split node.
+              absl::optional<int> output_index =
+                  GetCoreIndexInSharding(sharding, core);
+              TF_RET_CHECK(output_index);
+              NodeOut split_node_and_index =
+                  sharded_input_info.sharded_inputs[output_index.value()];
+              // Connect with Split node output.
+              graph->AddEdge(split_node_and_index.node,
+                             split_node_and_index.index, node, i);
+
+            } else {
+              graph->AddEdge(var_data.node, var_data.index, node, i);
+            }
+          } else {
+            graph->AddEdge(variable_reads[variable_num], 0, node, i);
+          }
+        }
+      }
+
+      // Adds a program input edge from the compiler.
+      graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
+
+      // Add data output edges.
+      int num_outputs = core_retval_nums[core].size();
+      for (int i = 0; i < num_outputs; ++i) {
+        int output_num =
+            replica * num_retvals_per_replica + core_retval_nums[core][i];
+        const auto& sharding = retval_shardings[core_retval_nums[core][i]];
+        if (sharding.type() == xla::OpSharding::OTHER) {
+          int retval_index = core_retval_nums[core][i];
+          retval_index_to_output_index_mapping[retval_index][core] = i;
+          bool is_last_core =
+              core ==
+              *std::max_element(sharding.tile_assignment_devices().begin(),
+                                sharding.tile_assignment_devices().end());
+          bool isPartitionOutNode = false;
+
+          const Edge* e = replicate_output_edges[output_num];
+          const Edge* e_out;
+          for (const Edge* out_edge : e->dst()->out_edges()) {
+            if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
+              isPartitionOutNode = true;
+              e_out = out_edge;
+            }
+          }
+          if (isPartitionOutNode) {
+            graph->AddEdge(
+                node, i, replicate_output_fan_out_nodes[output_num][core],
+                replicate_output_fan_out_dst_inputs[output_num][core]);
+            VLOG(2) << "Connect " << node->name() << " at " << i << " to "
+                    << replicate_output_fan_out_nodes[output_num][core]->name()
+                    << " at "
+                    << replicate_output_fan_out_dst_inputs[output_num][core];
+            if (is_last_core) {
+              graph->RemoveEdge(e);
+              graph->RemoveEdge(e_out);
+            }
+            continue;
+          }
+
+          // Do this in the iteration of last core in tile assignment, so all
+          // TPUExecute nodes have been created.
+          if (!is_last_core) {
+            continue;
+          }
+
+          // Add a Concat node.
+          std::vector<NodeOut> orig_inputs;
+          for (int64 core_id : sharding.tile_assignment_devices()) {
+            int core_retval_index =
+                retval_index_to_output_index_mapping[retval_index][core_id];
+            orig_inputs.push_back(
+                NodeOut{execute_nodes[replica][core_id],
+                        static_cast<int>(
+                            core_retval_nums[core_id][core_retval_index])});
+          }
+          DataType dtype = e->src()->output_type(e->src_output());
+          TF_ASSIGN_OR_RETURN(
+              Node * concat_node,
+              CreateConcatNodesForRetval(sharding, dtype, replica, orig_inputs,
+                                         graph, /*device=*/""));
+
+          const Edge* edge = replicate_output_edges[output_num];
+          Node* dst = edge->dst();
+          int dst_input = edge->dst_input();
+          graph->RemoveEdge(edge);
+          graph->AddEdge(concat_node, 0, dst, dst_input);
+
+          continue;
+        }
+
+        // If this is a replicated output, outputs on all cores will be the
+        // same, and we only take the output from core 0.
+        if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
+          continue;
+        }
+
+        // If output has maximal sharding, make sure we only use output from
+        // TPUExecute node with logical core id equal to core id defined by the
+        // xla sharding.
+        if (sharding.type() == xla::OpSharding::MAXIMAL &&
+            core != sharding.tile_assignment_devices(0)) {
+          continue;
+        }
+
+        const Edge* replicate_edge_to_replace =
+            replicate_output_edges[output_num];
+        Node* dst = replicate_edge_to_replace->dst();
+        int dst_input = replicate_edge_to_replace->dst_input();
+        graph->RemoveEdge(replicate_edge_to_replace);
+        graph->AddEdge(node, i, dst, dst_input);
+      }
+
+      // Feed the updated variable values from the first replica to the
+      // variable write nodes.
+      if (replica == 0) {
+        for (int i = 0; i < core_variable_writes.size(); ++i) {
+          int orig_arg_num =
+              core_variable_writes[i] + params_info.NumPerReplicaArgs() +
+              params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
+          const auto& sharding = arg_shardings[orig_arg_num];
+          // If this is a tiling sharded variable, concat variable updates from
+          // all cores.
+          if (sharding.type() == xla::OpSharding::OTHER) {
+            orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
+
+            // Do this in the iteration of last core in tile assignment, so all
+            // TPUExecute nodes have been created.
+            if (core !=
+                *std::max_element(sharding.tile_assignment_devices().begin(),
+                                  sharding.tile_assignment_devices().end())) {
+              continue;
+            }
+
+            // Add a Concat node.
+            std::vector<NodeOut> orig_inputs;
+            for (int64 core_id : sharding.tile_assignment_devices()) {
+              int core_retval_num =
+                  orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
+              orig_inputs.push_back(
+                  NodeOut{execute_nodes[0][core_id],
+                          static_cast<int>(core_retval_nums[core_id].size() +
+                                           core_retval_num)});
+            }
+
+            // Use the variable read's device for the concat. They should both
+            // be collocated with the variable.
+            absl::string_view device =
+                variable_reads[core_variable_writes[i]]->assigned_device_name();
+            TF_ASSIGN_OR_RETURN(
+                Node * concat_node,
+                CreateConcatNodesForRetval(
+                    sharding, arg_shapes[orig_arg_num].handle_type, replica,
+                    orig_inputs, graph, device));
+            // Populate VariableWrite.
+            VariableWrite& write = variable_writes->at(core_variable_writes[i]);
+            write.value = concat_node;
+            write.value_output = 0;
+            write.predicate = compile_node;
+            write.predicate_output = num_cores_per_replica + core + 1;
+
+            continue;
+          }
+
+          // If this is a replicated variable, outputs on all cores will be the
+          // same, and we only take the output from core 0 for the varialbe
+          // update.
+          if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
+            continue;
+          }
+          VariableWrite& write = variable_writes->at(core_variable_writes[i]);
+          write.value = node;
+          write.value_output = num_outputs + i;
+          write.predicate = compile_node;
+          write.predicate_output = num_cores_per_replica + core + 1;
+        }
+      }
+    }
+  }
+
+  for (Node* node : to_be_removed_nodes) {
+    graph->RemoveNode(node);
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
+    int replica_index, const std::vector<Node*>& outside_compilation_nodes,
+    const DeviceNameUtils::ParsedName& tpu_device,
+    const DeviceNameUtils::ParsedName& partial_device,
+    NodeToNodeReplicasMap* node_images, Graph* graph) {
+  for (Node* node : outside_compilation_nodes) {
+    NodeDef image_def = node->def();
+    MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
+    const string suffix = strings::StrCat("/R", replica_index);
+    // In addition to node name, make the frame name unique to avoid multiple
+    // LoopCond nodes in one frame.
+    TF_RETURN_IF_ERROR(
+        AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
+    Status status;
+    Node* image = graph->AddNode(image_def, &status);
+    image->AddAttr(kXlaReplicaIdAttrName, replica_index);
+    TF_RETURN_IF_ERROR(status);
+    if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
+      TF_RETURN_IF_ERROR(
+          SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
+    } else {
+      const string& original_device_string =
+          node->assigned_device_name().empty() ? node->requested_device()
+                                               : node->assigned_device_name();
+      DeviceNameUtils::ParsedName device;
+      TF_RET_CHECK(
+          DeviceNameUtils::ParseFullName(original_device_string, &device));
+      // If the requested device can be merged with the replica's host device,
+      // then do so. For example, if the requested device is "/CPU:0" or
+      // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
+      // replica is running. But if the requested device is
+      // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
+      if (DeviceNameUtils::IsSpecification(device, partial_device)) {
+        TF_RETURN_IF_ERROR(
+            DeviceNameUtils::MergeDevNames(&device, partial_device));
+      }
+      image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
+    }
+    std::vector<Node*>& node_image_vector = (*node_images)[node];
+    node_image_vector.resize(replica_index + 1);
+    node_image_vector[replica_index] = image;
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
+    const std::vector<std::vector<string>>& tf_device_assignment,
+    const HostComputeCoreMap& host_compute_core,
+    const OutsideCompilationNodeMap& outside_compilation_nodes,
+    NodeToNodeReplicasMap* node_images, Graph* graph) {
+  // Iterate over replicas.
+  for (int i = 0; i < tf_device_assignment.size(); ++i) {
+    const auto& core_devices = tf_device_assignment[i];
+    for (const auto& oc_cluster_iter : outside_compilation_nodes) {
+      const string& oc_cluster_name = oc_cluster_iter.first;
+      const auto& oc_cluster_nodes = oc_cluster_iter.second;
+      // We previously validated that host_compute_core contains an entry for
+      // each cluster.
+      int core = host_compute_core.at(oc_cluster_name);
+      TF_RET_CHECK(core >= 0 && core < core_devices.size());
+      // tpu_device is the device the HostCompute XLA Op for this cluster runs
+      // on.
+      DeviceNameUtils::ParsedName tpu_device;
+      TF_RET_CHECK(
+          DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
+      // partial_device contains the replica and task but not the type.
+      DeviceNameUtils::ParsedName partial_device = tpu_device;
+      partial_device.has_type = false;
+      partial_device.has_id = false;
+
+      if (tf_device_assignment.size() == 1) {
+        // With a single replica don't copy any nodes just put the original
+        // nodes into the image map. We leave the device placement alone, except
+        // that we have to fill in the correct core for the host send and
+        // receive nodes.
+        for (Node* node : oc_cluster_nodes) {
+          (*node_images)[node] = {node};
+          node->AddAttr(kXlaReplicaIdAttrName, 0);
+          if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
+            TF_RETURN_IF_ERROR(
+                SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
+          }
+        }
+      } else {
+        // Iterate over outside_compilation clusters in this computation, adding
+        // all the nodes with appropriate device assignments.
+        TF_RETURN_IF_ERROR(
+            CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
+                                        partial_device, node_images, graph));
+      }
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
+    const std::vector<Node*>& outside_compilation_nodes,
+    const NodeToNodeReplicasMap& node_images,
+    const std::unordered_map<string, Node*> outside_compilation_inputs,
+    Graph* graph) {
+  for (Node* node : outside_compilation_nodes) {
+    const auto& images = node_images.at(node);
+    // Make a copy of all edges and iterate on "in_edges", because we might
+    // remove edges when iteratating through them.
+    std::vector<const Edge*> in_edges(node->in_edges().begin(),
+                                      node->in_edges().end());
+    for (const Edge* edge : in_edges) {
+      Node* src = edge->src();
+      const auto iter = node_images.find(src);
+      if (iter == node_images.end()) {
+        if (images.size() > 1) {
+          // The source node is a 'normal' node not part of any
+          // rewrite. Broadcast the value to all replicas. (If images.size() ==
+          // 1 the cluster is not replicated and we can leave the original edge
+          // in place.)
+          for (Node* dst : images) {
+            graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
+          }
+        }
+        continue;
+      }
+
+      // The source node is a replicated outside_compilation node.
+      const auto& src_images = iter->second;
+      if (src_images.size() != images.size()) {
+        return errors::InvalidArgument(
+            "Graph contains an edge from node ", src->name(),
+            " in an outside_compilation block replicated ", src_images.size(),
+            " ways to node ", node->name(),
+            " in an outside_compilation block replicated ", images.size(),
+            " ways. Replication factors must match. Leave a comment on "
+            "tracking bug b/76419636 if you need this to be supported.");
+      }
+      bool is_lifted_arg;
+      string outside_compilation_cluster;
+      if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
+              .ok() &&
+          GetNodeAttr(src->def(), kOutsideCompilationAttr,
+                      &outside_compilation_cluster)
+              .ok()) {
+        const auto input_iter =
+            outside_compilation_inputs.find(outside_compilation_cluster);
+        TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
+        TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
+        int dst_input = edge->dst_input();
+        if (src_images.size() == 1) {
+          graph->RemoveEdge(edge);
+        }
+        for (int i = 0; i < src_images.size(); ++i) {
+          graph->AddEdge(input_iter->second, i, images[i], dst_input);
+        }
+        continue;
+      }
+
+      bool is_placeholder_for_arg;
+      string outside_compilation_input_attr;
+      if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
+                      &is_placeholder_for_arg)
+              .ok() &&
+          GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
+                      &outside_compilation_input_attr)
+              .ok()) {
+        const auto input_iter =
+            outside_compilation_inputs.find(outside_compilation_input_attr);
+        TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
+        TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
+        int dst_input = edge->dst_input();
+        if (src_images.size() == 1) {
+          graph->RemoveEdge(edge);
+        }
+        for (int i = 0; i < src_images.size(); ++i) {
+          graph->AddEdge(input_iter->second, i, images[i], dst_input);
+        }
+        continue;
+      }
+
+      if (images.size() > 1) {
+        // If images.size() == 1 neither cluster is replicated and we can
+        // leave the original edges in place.
+        for (int i = 0; i < src_images.size(); ++i) {
+          graph->AddEdge(src_images[i], edge->src_output(), images[i],
+                         edge->dst_input());
+        }
+      }
+    }
+    for (const Edge* edge : node->out_edges()) {
+      Node* dst = edge->dst();
+      const auto iter = node_images.find(dst);
+      if (iter == node_images.end()) {
+        // The source node is a 'normal' node not part of any rewrite.
+        if (edge->IsControlEdge()) {
+          // Make the dst node have a control dependency on every replica.
+          if (images.size() > 1) {
+            for (int i = 0; i < images.size(); ++i) {
+              graph->AddControlEdge(images[i], dst);
+            }
+          }
+          // else the cluster is not replicated so we can leave the original
+          // edge in place.
+        } else {
+          // The edge
+          // is only valid if the outside_compilation block is not replicated.
+          if (images.size() > 1) {
+            return errors::InvalidArgument(
+                "Graph contains an edge from node ", node->name(),
+                " in an outside_compilation block replicated ", images.size(),
+                " ways to node ", dst->name(),
+                " that is not part of an outside_compilation block. Edges from "
+                "outside_compilation to regular graph nodes are only supported "
+                "for replication factors of 1. Leave a comment on tracking bug "
+                "b/76419636 if you need this to be supported.");
+          }
+          // else the cluster is not replicated so we can leave the original
+          // edge in place.
+        }
+      }
+      // The case where src and dst are both in node_images is covered elsewhere
+      // when iterating over in_edges of dst.
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
+    const OutsideCompilationNodeMap& outside_compilation_nodes,
+    const NodeToNodeReplicasMap& node_images,
+    const std::unordered_map<string, Node*> outside_compilation_inputs,
+    Graph* graph) {
+  for (const auto& oc_cluster_iter : outside_compilation_nodes) {
+    TF_RETURN_IF_ERROR(
+        CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
+                                    outside_compilation_inputs, graph));
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
+    const NodeToNodeReplicasMap& node_images, Graph* graph) {
+  for (const auto& iter : node_images) {
+    if (iter.second.size() > 1) {
+      // The cluster was replicated so remove the original node.
+      Node* node = iter.first;
+      graph->RemoveNode(node);
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status
+DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
+    Graph* g, const FunctionLibraryDefinition& flib_def,
+    const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
+  bool modified = false;
+  do {
+    std::vector<Node*> nodes_to_lower;
+    for (Node* n : g->op_nodes()) {
+      if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
+        continue;
+      }
+
+      if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
+        // Only lower functional ops with DT_RESOURCE input, because otherwise
+        // placer will complain. For normal cases, lowering will cause slowdown
+        // when related functions are huge (b/139037679).
+        bool has_resource_input = false;
+        for (const Edge* e : n->in_edges()) {
+          if (!e->IsControlEdge() &&
+              e->src()->output_type(e->src_output()) == DT_RESOURCE) {
+            has_resource_input = true;
+            break;
+          }
+        }
+        if (has_resource_input) {
+          nodes_to_lower.push_back(n);
+        }
+      }
+    }
+
+    modified = !nodes_to_lower.empty();
+
+    auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
+      // Clear device assignment. Otherwise all lowered nodes will have
+      // device assignment, which is not what we want.
+      n->set_requested_device("");
+
+      int replica_id;
+      TF_RETURN_IF_ERROR(
+          GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
+
+      string outside_compilation_attr;
+      TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
+                                     &outside_compilation_attr));
+
+      // There are two different kinds of functional outside compilation nodes:
+      // 1. Nodes that are in outside compilation blocks already. They are
+      //    generated by FunctionalizeControlFlowForXlaPass, and only have
+      //    attribute kOutsideCompilationAttr.
+      // 2. Mirrored control flow built for outside compilation in functional
+      //    nodes. They are generated by ExtractOutsideCompilationPass, and have
+      //    both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
+      // When lowering them, they need to be treated differently.
+      // For 1), their body functions are always V1 functions written by users,
+      // and their "control outputs" are control inputs of _Retval nodes. They
+      // should be lowered as V1 functions.
+      // For 2), we always add necessary "control outputs"
+      // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
+      // FunctionDef's. They should be lowered as V2 functions.
+      bool is_host_side_mirrored_control_flow =
+          HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
+
+      int num_node_ids = g->num_node_ids();
+      bool is_call_node = IsFunctionCall(flib_def, *n);
+      if (n->IsWhileNode()) {
+        TF_RETURN_IF_ERROR(RewriteWhileNode(n, g,
+                                            /*keep_node_fetchable=*/false));
+      } else if (n->IsIfNode()) {
+        TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
+      } else {
+        TF_RET_CHECK(is_call_node);
+        // See comments for "is_host_side_mirrored_control_flow" above.
+        // If this is a node that's in outside compilation block, lower it as
+        // V1 function. This is controlled by removing
+        // kLowerAsMultiDeviceFunctionAttr from the node.
+        if (!is_host_side_mirrored_control_flow) {
+          n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
+        } else {
+          n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
+          n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
+                     true);
+        }
+        TF_RETURN_IF_ERROR(
+            RewriteFunctionCallNode(n, g, flib_def,
+                                    /*keep_caller_fetchable=*/false));
+      }
+
+      for (int i = num_node_ids; i < g->num_node_ids(); i++) {
+        Node* node = g->FindNodeId(i);
+        if (!node) {
+          continue;
+        }
+
+        if (!is_call_node && is_host_side_mirrored_control_flow &&
+            IsFunctionCall(flib_def, *node)) {
+          // For If/While nodes, if they are host side mirrored control flow,
+          // mark their body function calls with kXlaHasHostTransferAttrName
+          // attribute to make sure we lower them as V2 function.
+          node->AddAttr(kXlaHasHostTransferAttrName, true);
+        }
+
+        if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
+            node->IsIfNode()) {
+          // Set kOutsideCompilationAttr attribute so we lower these
+          // nested function call nodes later.
+          node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
+          // Set kXlaReplicaIdAttrName attribute so we know replica id when we
+          // lower this function call node.
+          node->AddAttr(kXlaReplicaIdAttrName, replica_id);
+        } else if (node->type_string() == "_XlaRecvAtHost" ||
+                   node->type_string() == "_XlaSendFromHost") {
+          // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
+          // have kXlaReplicaIdAttrName attribute so later we know which host
+          // device to assign.
+          node->AddAttr(kXlaReplicaIdAttrName, replica_id);
+        }
+      }
+      return Status::OK();
+    };
+
+    for (Node* n : nodes_to_lower) {
+      TF_RETURN_IF_ERROR(lower_functional_node(n));
+    }
+  } while (modified);
+
+  // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
+  for (Node* n : g->op_nodes()) {
+    if (n->type_string() != "_XlaRecvAtHost" &&
+        n->type_string() != "_XlaSendFromHost") {
+      continue;
+    }
+
+    string replicate;
+    TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
+    auto iter = tpu_replicate_device_names_mapping.find(replicate);
+    TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
+    const auto& tpu_device_names = iter->second;
+
+    int replica_id;
+    TF_RETURN_IF_ERROR(
+        GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
+    TF_RET_CHECK(replica_id < tpu_device_names.size());
+    const string& tpu_device_name = tpu_device_names[replica_id][0];
+    string host_device_name;
+    TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
+        tpu_device_name, &host_device_name));
+    n->set_assigned_device_name(host_device_name);
+    // We may run TPU rewrite passes again on the subgraphs of the resulting
+    // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
+    // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
+    // sure that TPU rewrite passes take no effect on host-side subgraphs for
+    // outside compilation.
+    n->ClearAttr(kTPUReplicateAttr);
+    n->ClearAttr(kOutsideCompilationAttr);
+  }
+
+  // Remove IdentityN nodes generated for outside compilation. IdentityN is
+  // exempt from resource edge colocation, but here we do need input and output
+  // for these IdentityN nodes to be colocated.
+  std::vector<Node*> identityn_nodes;
+  for (Node* n : g->op_nodes()) {
+    if (n->type_string() == "IdentityN" &&
+        HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
+      identityn_nodes.push_back(n);
+    }
+  }
+  for (Node* n : identityn_nodes) {
+    std::vector<const Edge*> out_edges(n->out_edges().begin(),
+                                       n->out_edges().end());
+    for (const Edge* e : out_edges) {
+      if (e->IsControlEdge()) {
+        continue;
+      }
+
+      int src_output = e->src_output();
+      const Edge* input_edge;
+      TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
+      Node* dst = e->dst();
+      int dst_input = e->dst_input();
+      g->RemoveEdge(e);
+      g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
+    }
+    g->RemoveNode(n);
+  }
+
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
+    const Node& replicate_node,
+    const OutsideCompilationNodeMap& outside_compilation_nodes,
+    HostComputeCoreMap* host_compute_core) {
+  std::vector<string> hc_core_string;
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
+                                 &hc_core_string));
+  TF_RETURN_IF_ERROR(
+      ParseHostComputeCoreList(hc_core_string, host_compute_core));
+  for (const auto& iter : outside_compilation_nodes) {
+    const string& oc_cluster_name = iter.first;
+    if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
+      // By default put host compute Ops on replicated core 0.
+      (*host_compute_core)[oc_cluster_name] = 0;
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::GetDeviceTopology(
+    const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
+    int* num_cores_per_replica, int* num_tasks,
+    std::vector<std::vector<string>>* tf_device_assignment,
+    std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
+    string* tpu_compilation_device) {
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
+  if (*num_replicas < 1) {
+    return errors::InvalidArgument("num_replicas must be >= 1, got ",
+                                   *num_replicas);
+  }
+
+  // Find the set of TPU devices in the TF job.
+  // Indexed by [task number][tpu device number].
+  std::vector<std::vector<Device*>> tpu_devices;
+  int num_tpus_per_task;
+  TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
+                                       device_set, tpu_compilation_device,
+                                       &num_tpus_per_task, &tpu_devices));
+
+  string topology;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "topology", &topology));
+  TF_RETURN_IF_ERROR(GetNodeAttr(
+      replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
+  std::vector<int> device_assignment;
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
+                                 &device_assignment));
+
+  // TODO(cwhipkey): since we can control multiple pods of different shapes
+  // from a single worker, it may be desirable to propagate the remote device
+  // information around (e.g., in DeviceAttributes). This can lead to the mesh
+  // topology proto being leaked to cloud TPU users (e.g. through GetStatus
+  // calls); this may be okay, but to be conservative, just assume that the
+  // master session has the proper flags set.
+
+  // We do not initialize platform right now, but we can still retrieve the
+  // TPU topology even with an uninitialized platform.
+  auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
+      /*initialize_platform=*/false);
+  TF_RET_CHECK(tpu_platform);
+  tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
+  TF_RET_CHECK(num_tpus_per_task ==
+               tpu_topology.LogicalDevicesPerHost(kTensorCore));
+  TF_RETURN_IF_ERROR(BuildDeviceAssignment(
+      tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
+      *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
+      xla_device_assignment));
+
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::GetIOTypes(
+    int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
+    Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
+    std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
+    DataTypeVector* retval_types, ParameterInfo* params_info) {
+  DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
+                                 &broadcast_input_types));
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
+                                 "Tguaranteed_constants",
+                                 &guaranteed_constant_types));
+  int num_distributed_vars;
+  TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
+                                 "num_distributed_variables",
+                                 &num_distributed_vars));
+  const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
+
+  if (num_per_replica_inputs % num_replicas != 0) {
+    return errors::InvalidArgument(
+        "Number of inputs to TPUReplicate (", num_per_replica_inputs,
+        ") is not divisible by the number of replicas (", num_replicas, ").");
+  }
+
+  int num_variables;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
+
+  NameRangeMap output_name_map;
+  TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
+                                       input_name_map, &output_name_map));
+
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node.attrs(), "computation", function));
+
+  *computation = absl::make_unique<Graph>(graph->op_registry());
+  TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
+      **function, flr, computation->get(), arg_types, retval_types));
+
+  *params_info = ParameterInfo(
+      num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
+      broadcast_input_types.size(), num_variables,
+      guaranteed_constant_types.size(), retval_types->size());
+
+  if (arg_types->size() != params_info->NumInputsToEachReplica()) {
+    return errors::InvalidArgument(
+        "Computation argument to TPUReplicate has wrong number of "
+        "arguments. Expected ",
+        params_info->NumInputsToEachReplica(), " inputs, got ",
+        arg_types->size());
+  }
+  if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
+    return errors::InvalidArgument(
+        "Wrong number of outputs from TPUReplicate. Expected ",
+        params_info->NumOutputsToHost(), " outputs, got ",
+        replicate_node.num_outputs());
+  }
+  if (enable_cross_replica_sharding_mirrored_variables_) {
+    std::vector<int> mirrored_variable_indices;
+    TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
+                                   TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
+                                   &mirrored_variable_indices));
+    for (int index : mirrored_variable_indices) {
+      TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
+                   params_info->IsDistributedArg(index))
+          << "Mirrored variables not categorized as per-replica arguments, "
+             "index: "
+          << index;
+      params_info->mutable_mirrored_variable_indices()->insert(index);
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
+    const string& tpu_compilation_device, const Node& replicate_node,
+    Graph* graph, Node** host_transfer_sequencer, Node** control_before,
+    Node** control_after) {
+  *host_transfer_sequencer = nullptr;
+
+  TF_RETURN_IF_ERROR(
+      BuildNoopNode(replicate_node,
+                    graph->NewName(strings::StrCat(replicate_node.name(), "/",
+                                                   "control_before")),
+                    /*device=*/"", graph, control_before));
+  for (const Edge* e : replicate_node.in_edges()) {
+    if (!e->IsControlEdge()) {
+      continue;
+    }
+    Node* predecessor = e->src();
+    if (predecessor->IsSource()) continue;
+    if (predecessor->type_string() == "NoOp" &&
+        predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
+      // The node is the sequencer for host transfer operations. Its control
+      // dependency needs to be placed after the execute node, not before.
+      if (*host_transfer_sequencer != nullptr) {
+        return errors::Internal("Replicate node ", replicate_node.name(),
+                                " has two transfer sequencer nodes: ",
+                                (*host_transfer_sequencer)->name(), " and ",
+                                predecessor->name());
+      }
+      // Set the correct device to match the other sequencing nodes.
+      predecessor->set_assigned_device_name(tpu_compilation_device);
+      *host_transfer_sequencer = predecessor;
+    } else {
+      graph->AddControlEdge(predecessor, *control_before);
+    }
+  }
+
+  TF_RETURN_IF_ERROR(
+      BuildNoopNode(replicate_node,
+                    graph->NewName(strings::StrCat(replicate_node.name(), "/",
+                                                   "control_after")),
+                    /*device=*/tpu_compilation_device, graph, control_after));
+  for (Node* successor : replicate_node.out_nodes()) {
+    if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
+      graph->AddControlEdge(successor, *control_after);
+    } else {
+      graph->AddControlEdge(*control_after, successor);
+    }
+  }
+  return Status::OK();
+}
+
+/* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
+    const Node& replicate_node, const NameRangeMap& input_name_map,
+    Graph* graph, Node* host_transfer_sequencer, Node* control_before,
+    Node* control_after, absl::Span<const VariableInput> variable_nodes,
+    std::vector<Node*>* guaranteed_constant_nodes,
+    std::vector<Node*>* variable_reads) {
+  TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
+      replicate_node, input_name_map, guaranteed_constant_nodes));
+
+  TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
+                                        variable_reads));
+  // Add the control dependency from host transfer nodes.
+  if (host_transfer_sequencer != nullptr) {
+    graph->AddControlEdge(host_transfer_sequencer, control_after);
+  }
+  return Status::OK();
+}
+
+/* static */ Status
+DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
+    Node* replicate_node, Node* compile_node, Node** control_after_compilation,
+    Graph* graph) {
+  const Edge* compilation_edge = nullptr;
+  for (const auto* e : replicate_node->out_edges()) {
+    if (e->IsControlEdge() &&
+        e->dst()->type_string() == "TPUCompilationResult") {
+      TF_RET_CHECK(compilation_edge == nullptr)
+          << "Multiple compilation result nodes attached to the same replicate "
+             "cluster.";
+      compilation_edge = e;
+    }
+  }
+
+  // TODO(jpienaar): This should be checked by default, current tests not using
+  // this are ones that use the "abort upon successful compile flag" which will
+  // be removed. Leaving this in until then.
+  if (compilation_edge != nullptr) {
+    Node* compilation_status = compilation_edge->dst();
+    const AttrValue* compile_status_cluster_attr =
+        compilation_status->attrs().Find(kTPUCompilationResultAttr);
+    TF_RET_CHECK(compile_status_cluster_attr != nullptr);
+    const string& compile_status_cluster = compile_status_cluster_attr->s();
+    TF_RET_CHECK(!compile_status_cluster.empty());
+    const AttrValue* replicate_cluster_attr =
+        replicate_node->attrs().Find(kTPUReplicateAttr);
+    TF_RET_CHECK(replicate_cluster_attr != nullptr);
+    const string& replicate_cluster = replicate_cluster_attr->s();
+    TF_RET_CHECK(!replicate_cluster.empty());
+    TF_RET_CHECK(compile_status_cluster == replicate_cluster);
+
+    TF_RETURN_IF_ERROR(
+        ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
+    graph->AddEdge(compile_node, 0, compilation_status, 0);
+  }
+
+  NodeDef def;
+  def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
+  // Create an op to assert that compilation succeeded. The alternative would
+  // have been to have each execute op check and return an error.
+  def.set_op("TPUCompileSucceededAssert");
+  MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
+  Status status;
+  Node* compile_succeeded = graph->AddNode(def, &status);
+  compile_succeeded->set_assigned_device_name(
+      compile_node->assigned_device_name());
+  TF_RETURN_IF_ERROR(status);
+  graph->AddEdge(compile_node, 0, compile_succeeded, 0);
+
+  // Build a sequencing node for when compilation has completed.
+  TF_RETURN_IF_ERROR(
+      BuildNoopNode(*replicate_node,
+                    graph->NewName(strings::StrCat(compile_node->name(), "/",
+                                                   "after_compilation")),
+                    /*device=*/"", graph, control_after_compilation));
+  graph->AddControlEdge(compile_succeeded, *control_after_compilation);
+
+  return Status::OK();
+}
+
+// Updates the head and tail outside compiled nodes so that nodes have the
+// correct device and removes the replication and outside compilation attributes
+// so that these nodes do not trigger further graph optimization passes.
+/* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
+    const std::vector<std::vector<string>>& tf_device_assignment,
+    const std::vector<Node*>& head_tail_outside_compilation_nodes) {
+  for (Node* node : head_tail_outside_compilation_nodes) {
+    int replica_id;
+    TF_RETURN_IF_ERROR(
+        GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
+    // Since we set the device, this will now run on a task other than 0. We
+    // clear the two following attributes so that we don't trigger encapsulation
+    // again on the remote host (which will fail due to a missing
+    // _TPUReplicateMetadata node for the cluster).
+    for (const Edge* e : node->in_edges()) {
+      // Resource consuming ops should colocate with its resource input.
+      if (e->src()->IsArg() &&
+          e->src()->output_type(e->src_output()) == DT_RESOURCE) {
+        node->set_requested_device(tf_device_assignment[replica_id][0]);
+      }
+    }
+    if (node->requested_device().empty()) {
+      string cpu_device;
+      TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
+          tf_device_assignment[replica_id][0], &cpu_device));
+      node->set_requested_device(cpu_device);
+    }
+    node->ClearAttr(kTPUReplicateAttr);
+    node->ClearAttr(kOutsideCompilationAttr);
+  }
+  return Status::OK();
+}
+
+/* static */
+Status DistributedTPURewritePass::FingerprintFunctionLibrary(
+    const FunctionLibraryDefinition& library, uint64* fingerprint) {
+  // TODO(phawkins): rather than fingerprinting the entire function library,
+  // consider fingerprinting just the transitive dependencies of a
+  // computation.
+  std::string serialized;
+  FunctionDefLibrary library_proto = library.ToProto();
+  if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) {
+    LOG(WARNING) << "Serializing large proto, size: "
+                 << library_proto.ByteSizeLong();
+  }
+  TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized));
+  *fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized);
+  return Status::OK();
+}
+
+// Performs the rewrite on a single TPUReplicate node.
+/* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
+    const string& session_handle, const DeviceSet& device_set,
+    Node* replicate_node, FunctionLibraryDefinition* flib_def,
+    FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
+    const OutsideCompilationNodeMap& outside_compilation_nodes,
+    const std::vector<Node*>& head_tail_outside_compilation_nodes,
+    NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
+    const GraphShapeInfo& shape_info,
+    TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
+    int64 autotuner_thresh) {
+  VLOG(2) << "Rewriting node " << replicate_node->name();
+
+  // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
+  // of the computation) and cores (virtual cores within computations) specified
+  // by the user. They will be mapped to physical TPU cores below.
+  int num_replicas;
+  int num_cores_per_replica;
+  int num_tasks;  // Number of tasks.
+  std::vector<std::vector<string>> tf_device_assignment;
+  std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
+  string tpu_compilation_device;
+  TF_RETURN_IF_ERROR(GetDeviceTopology(
+      device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
+      &num_tasks, &tf_device_assignment, &xla_device_assignment,
+      &tpu_compilation_device));
+
+  TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
+      tf_device_assignment, head_tail_outside_compilation_nodes));
+
+  string replicate;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
+  tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
+
+  NameRangeMap input_name_map;
+  const NameAttrList* function;
+  std::unique_ptr<Graph> computation;
+  DataTypeVector arg_types, retval_types;
+  ParameterInfo params_info;
+  TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
+                                &input_name_map, &function, &computation,
+                                &arg_types, &retval_types, &params_info));
+
+  std::vector<InferredShape> arg_shapes, retval_shapes;
+  TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
+      shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
+
+  TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
+
+  std::vector<xla::OpSharding> arg_sharding;
+  std::vector<bool> arg_fast_mem;
+  std::vector<xla::OpSharding> retval_sharding;
+  TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
+      num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
+      retval_shapes, *computation, replicate_node, flr, &arg_sharding,
+      &arg_fast_mem, &retval_sharding));
+
+  VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
+                             flib_def);
+
+  GraphDef graph_def;
+  graph->ToGraphDef(&graph_def);
+  FunctionLibraryDefinition reachable_functions =
+      flib_def->ReachableDefinitions(graph_def);
+  uint64 library_fingerprint;
+
+  TF_RETURN_IF_ERROR(
+      FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
+  VLOG(1) << "Fingerprint functions: "
+          << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
+  VLOG(1) << "library_fingerprint: " << library_fingerprint;
+
+  // Builds trigger nodes that put barriers around the expansion of
+  // TPUReplicate. In particular, we must guarantee:
+  // a) variable reads happen after all predecessors of the original
+  //    TPUReplicate.
+  // b) variable writes happen before all successors of the original
+  //    TPUReplicate.
+  // c) all replicas execute, even if output tensors are only requested from
+  //    a subset of replicas. This is necessary both to ensure that variable
+  //    updates happen, but also Send/Recv will deadlock if only one half of
+  //    the communicating pair runs.
+  Node* host_transfer_sequencer;
+  Node* control_before;
+  Node* control_after;
+  TF_RETURN_IF_ERROR(BuildSequencingNodes(
+      tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
+      &control_before, &control_after));
+
+  // Build a vector of variable nodes that are inputs.
+  std::vector<VariableInput> variable_inputs;
+  TF_RETURN_IF_ERROR(
+      FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
+
+  std::vector<Node*> guaranteed_constant_nodes;
+  std::vector<Node*> variable_reads;
+  TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
+      *replicate_node, input_name_map, graph, host_transfer_sequencer,
+      control_before, control_after, variable_inputs,
+      &guaranteed_constant_nodes, &variable_reads));
+
+  // Builds Shape nodes that compute the dynamic shapes of arguments whose
+  // shapes are not statically known.
+  std::vector<Node*> dynamic_shape_nodes;
+  TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
+                                            params_info, variable_reads, graph,
+                                            &dynamic_shape_nodes));
+
+  // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
+  Node* compile_node;
+  TF_RETURN_IF_ERROR(BuildCompileNode(
+      replicate_node, *function, library_fingerprint, params_info, arg_shapes,
+      arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
+      arg_fast_mem, retval_sharding, num_cores_per_replica,
+      /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
+      dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
+
+  // Compilation must be sequenced after the control node if the TPU computation
+  // in a control-flow construct, such as a loop.
+  graph->AddControlEdge(control_before, compile_node);
+
+  Node* control_after_compilation;
+  TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
+      replicate_node, compile_node, &control_after_compilation, graph));
+
+  std::vector<VariableWrite> variable_writes;
+  TF_RETURN_IF_ERROR(BuildExecuteNodes(
+      params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types,
+      arg_shapes, retval_types, arg_sharding, retval_sharding,
+      tf_device_assignment, compile_node, variable_reads,
+      control_after_compilation, control_after, &variable_writes, graph));
+  bool contains_resource_write_op =
+      ContainsResourceWriteOp(*graph, reachable_functions);
+
+  VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
+  // Skip conditional write if there is no resource writing op inside TPU
+  // computation.
+  if (contains_resource_write_op) {
+    TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
+                                           variable_writes, graph));
+  }
+
+  if (host_compute_key_placeholder_node != nullptr) {
+    TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
+        compile_node, host_compute_key_placeholder_node, graph));
+  }
+
+  HostComputeCoreMap host_compute_core;
+  TF_RETURN_IF_ERROR(ParseHostComputeCores(
+      *replicate_node, outside_compilation_nodes, &host_compute_core));
+  TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
+      tf_device_assignment, host_compute_core, outside_compilation_nodes,
+      outside_compilation_node_images, graph));
+
+  graph->RemoveNode(replicate_node);
+  return Status::OK();
+}
+
+// Adds sharded weight update optimization for each host training loop.
+//
+// For any host training loop found in the graph, TPUVariableReshard ops
+// are inserted to match the best layout chosen by the XLA.
+/* static */ Status
+DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
+    Graph* graph, FunctionLibraryDefinition* flib_def,
+    FunctionLibraryRuntime* flr) {
+  std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
+  Status s = tpu::DetectHostTrainingLoop(
+      /*current_function_name=*/nullptr,
+      /*current_function_attr=*/nullptr, flib_def, graph, flr,
+      &host_training_loops_info);
+  if (!s.ok()) {
+    VLOG(2) << "No valid host training loop found. Skipping sharded weight "
+            << "update optimization.";
+    return Status::OK();
+  }
+
+  for (const auto& host_loop : host_training_loops_info) {
+    const auto& function_name = host_loop.encapsulating_function_name;
+    // `function_name` has value when host training loop is inside a
+    // function call node. When host training loop is found inside a function
+    // call node, then, in addition to adding TPUVariableReshard ops, function
+    // library definition needs to be updated as well.
+    if (function_name.has_value()) {
+      const auto& function_attr = host_loop.encapsulating_function_attrs;
+      TF_RET_CHECK(function_attr.has_value())
+          << "Unable to find function attribute for function: "
+          << *function_name;
+
+      const FunctionDef* function_def = flib_def->Find(*function_name);
+      TF_RET_CHECK(function_def)
+          << "Unable to find function : " << *function_name;
+
+      std::unique_ptr<FunctionBody> fbody;
+      TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
+          *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
+      Graph* function_graph = fbody->graph;
+      TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
+      TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
+                                                     *function_name, flib_def));
+    } else {
+      TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
+    }
+  }
+  return Status::OK();
+}
+
+Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
+    Graph* graph) {
+  ReverseDFS(*graph, {}, PlaceOpsOnTPU);
+  return Status::OK();
+}
+
+Status DistributedTPURewritePass::Run(
+    const GraphOptimizationPassOptions& options) {
+  VLOG(1) << "DistributedTPURewritePass::Run";
+
+  Graph* graph = options.graph->get();
+
+  VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
+                             options.flib_def);
+
+  const auto* config = &options.session_options->config;
+  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+      new ProcessFunctionLibraryRuntime(
+          nullptr, options.session_options->env, config,
+          graph->versions().producer(), options.flib_def,
+          config ? config->graph_options().optimizer_options()
+                 : OptimizerOptions()));
+
+  FunctionLibraryRuntime* flr =
+      pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+  // This pass can only run in the session master, which should fill
+  // in the device_set field to the options.
+  TF_RET_CHECK(options.device_set != nullptr);
+
+  // Find all the replicate nodes before mutating the graph.
+  std::vector<Node*> replicate_nodes;
+  // Map from compiled subgraph cluster name to the outside_compilation nodes in
+  // that cluster.
+  std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
+  std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
+  TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
+                                     &outside_compilation_nodes,
+                                     &head_tail_outside_compilation_nodes));
+
+  if (replicate_nodes.empty()) {
+    // Remove unused TPUPartitionedInput nodes.
+    for (Node* n : graph->nodes()) {
+      if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
+    }
+    return Status::OK();
+  }
+
+  std::unordered_map<string, Node*> host_compute_key_placeholder_map;
+  TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
+      graph, replicate_nodes, &host_compute_key_placeholder_map));
+
+  GraphShapeInfo shape_info;
+  TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
+                                 flr->GetFunctionLibraryDefinition(),
+                                 &shape_info));
+  int64 autotuner_thresh = options.session_options->config.experimental()
+                               .xla_fusion_autotuner_thresh();
+
+  NodeToNodeReplicasMap outside_compilation_node_images;
+  TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
+  for (Node* node : replicate_nodes) {
+    TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
+        options.session_handle, *options.device_set, node, options.flib_def,
+        flr, host_compute_key_placeholder_map[node->name()],
+        outside_compilation_nodes[node->name()],
+        head_tail_outside_compilation_nodes[node->name()],
+        &outside_compilation_node_images, graph, shape_info,
+        &tpu_replicate_device_names_mapping, autotuner_thresh));
+  }
+
+  // Place the padding nodes generated by dynamic padder on the correct devices.
+  // TODO(rxsang): Place padding ops on TPUs in
+  // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
+  TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
+
+  std::unordered_map<string, Node*> outside_compilation_inputs;
+  for (Node* n : graph->op_nodes()) {
+    string lifted_arg_inputs_attr;
+    if (n->type_string() == "IdentityN" &&
+        GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
+                    &lifted_arg_inputs_attr)
+            .ok()) {
+      outside_compilation_inputs[lifted_arg_inputs_attr] = n;
+    }
+  }
+  for (const auto& iter : outside_compilation_nodes) {
+    TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
+        iter.second, outside_compilation_node_images,
+        outside_compilation_inputs, graph));
+  }
+  TF_RETURN_IF_ERROR(
+      RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
+  TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
+      graph, *options.flib_def, tpu_replicate_device_names_mapping));
+
+  TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
+  VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
+                             options.flib_def);
+  VLOG(1) << "DistributedTPURewritePass::Run() finished";
+
+  if (enable_cross_replica_sharding_mirrored_variables_) {
+    VLOG(1) << "Starting host training loop optimization.";
+    VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
+                               options.flib_def);
+    TF_RETURN_IF_ERROR(
+        PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
+    VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
+                               options.flib_def);
+    VLOG(1) << "Host training loop optimization finished.";
+  }
+
+  return Status::OK();
+}
+
+bool DistributedTPURewritePass::distribute_vars_ = false;
+bool DistributedTPURewritePass::
+    replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
+bool DistributedTPURewritePass::
+    enable_cross_replica_sharding_mirrored_variables_ = true;
+bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
+
+/*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
+    bool distribute_vars, bool replicate_inputs_outputs_by_default_for_xla_spmd,
+    bool enable_cross_replica_sharding_mirrored_variables,
+    bool enable_automatic_model_parallelism) {
+  distribute_vars_ = distribute_vars;
+  replicate_inputs_outputs_by_default_for_xla_spmd_ =
+      replicate_inputs_outputs_by_default_for_xla_spmd;
+  enable_cross_replica_sharding_mirrored_variables_ =
+      enable_cross_replica_sharding_mirrored_variables;
+  enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
new file mode 100644
index 00000000000..52fae7a7c13
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
@@ -0,0 +1,589 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Rewrites TPUReplicate nodes into replicated computations on TPU.
+//
+// To represent a distributed TPU computation, we use the
+// TPUReplicate operator, that describes a subgraph (represented as a
+// Tensorflow function) to replicate across a TPU pod.
+//
+// Model parallelism and data parallelism:
+// ---------------------------------------
+// We support two different kinds of parallelism on TPU:
+// * data parallelism (replication), or parallelization across batches, and
+// * model parallelism, or parallelization within a batch.
+//
+// The function passed to a TPUReplicate operator is replicated many
+// times across a TPU pod (data parallelism). The `num_replicas` attribute
+// controls how many replicas of the computation to create. Replicas are mostly
+// independent; replicas can only communicate using the CrossReplicaSum
+// operator, which is typically used to communicate gradients during training.
+//
+// Each replica may optionally use more than one TPU core (model
+// parallelism). The `num_cores_per_replica` attribute controls how many cores
+// there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE
+// device that is only valid within replicated TPU computations (e.g.,
+// TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE
+// device corresponds to one TPU core in every replica.
+// Each replica has runs its own copy of the computation assigned to each
+// TPU_REPLICATED_CORE device.
+//
+// The Python code is responsible for providing a device_assignment that
+// describes how the replicated logical cores map to physical cores on the TPU
+// topology.
+//
+// Inputs to TPUReplicate:
+// ------------------------------
+// The TPUReplicate operator takes three kinds of inputs, in the
+// following order:
+// * per-replica inputs. If there are three per-replica inputs (A, B, C) and two
+//   replicas, the first six arguments to TPUReplicate will be:
+//   A0 B0 C0 A1 B1 C1
+//   where Ai is the A input to the i-th replica.
+// * distributed inputs. These inputs follow the per-replica inputs.
+//   If there are two distributed inputs (E, F) and two replicas, the following
+//   arguments to TPUReplicate will be: E F.
+//   But there is local E and F on each replica.
+// * broadcast inputs. These inputs follow the distributed inputs. All
+//   replicas receive a copy of each of these inputs.
+// * variables. Resource variables accessed by the computation follow the
+//   broadcast inputs.
+//
+// For example, for a computation with two replicas, three per-replica inputs
+// (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two
+// variables (V, W), the arguments to TPUReplicate will be:
+// A0 B0 C0 A1 B1 C1 E F X Y V W
+// and each replica will receive the following arguments:
+// A B C E F X Y V W
+//
+// Distributed TPU compilation requires that the shapes of all operators
+// be known statically at compilation time, before any nodes have executed.
+// Shapes are determined using shape information emitted by InferShapes. It
+// is not possible to replicate Tensorflow operators with unknown or dynamic
+// shapes for TPU at present.
+//
+// Graph rewrite:
+// --------------
+// Compilation replaces TPUReplicate operators with:
+// * a single TPUCompile node that compiles the computations,
+// * one TPUExecute node for each TPU device in the system that
+//   executes the relevant computation,
+// * one ReadVariableOp for each variable accessed by the replicated
+//   computation,
+// * one AssignVariableOp for each variable accessed by the replicated
+//   computation. An assignment is built even if a variable is only read by the
+//   computation. We do not know which variables are written until we apply the
+//   XlaCompiler to the computation, but that does not happen until after the
+//   rewrite. Conservatively, we write back the values of all variables after
+//   the computation completes.
+//   TODO(phawkins): only write back variables that the computation may write.
+// * one Shape node for each Tensor or Variable input to the computation whose
+//   shape is not statically known at rewrite time. The input shapes are fed
+//   to the TPUCompile node.
+//
+// To ensure that the reads and writes seem to happen at the right time in the
+// graph execution, we add control edges from all predecessors of the original
+// TPUReplicate operator to each of the ReadVariableOp operators.
+// Similarly, we add control edges from all of the AssignVariableOp operators to
+// all of the successors of the TPUReplicate operator.
+//
+// The TPUReplicate rewrite must run before placement, since resource
+// variable inputs will have DT_RESOURCE, which cannot be sent across devices,
+// leading to objections from the placer. The rewrite rewrites the resource
+// accesses into explicit ReadVariableOp and AssignVariableOp operators that the
+// placer is free to colocate with the variables.
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/jit/shape_inference.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/stream_executor/tpu/tpu_topology.h"
+
+namespace tensorflow {
+
+// Replaces clusters assigned to TPU_SYSTEM devices with
+// TPUCompile and TPUExecute nodes assigned to the corresponding
+// TPU devices.
+class DistributedTPURewritePass : public GraphOptimizationPass {
+ public:
+  static void SetDistributedTpuRewritePassOptions(
+      bool distribute_vars,
+      bool replicate_inputs_outputs_by_default_for_xla_spmd,
+      bool enable_cross_replica_sharding_mirrored_variables,
+      bool enable_automatic_model_parallelism);
+
+  Status Run(const GraphOptimizationPassOptions& options) override;
+
+  // The following methods are public only for the use of unit tests.
+
+  // See comment at the top of the file for how the inputs are ordered.
+  // Encapsulates the different TPU replicated node input and output
+  // information, and provide common APIs over them.
+  class ParameterInfo {
+   public:
+    ParameterInfo() {}
+    ParameterInfo(int64 num_replicas, int64 num_per_replica_args,
+                  int64 num_distributed_args, int64 num_broadcast_args,
+                  int64 num_variables, int64 num_guaranteed_constants,
+                  int64 num_retvals_per_replica)
+        : num_replicas_(num_replicas),
+          num_per_replica_args_(num_per_replica_args),
+          num_distributed_args_(num_distributed_args),
+          num_broadcast_args_(num_broadcast_args),
+          num_variables_(num_variables),
+          num_guaranteed_constants_(num_guaranteed_constants),
+          num_retvals_per_replica_(num_retvals_per_replica) {}
+
+    int64 NumReplicas() const { return num_replicas_; }
+
+    int64 NumPerReplicaArgs() const { return num_per_replica_args_; }
+
+    int64 NumDistributedArgs() const { return num_distributed_args_; }
+
+    int64 NumBroadcastArgs() const { return num_broadcast_args_; }
+
+    int64 NumVariables() const { return num_variables_; }
+
+    int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; }
+
+    int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; }
+
+    bool IsPerReplicaArg(int64 index) const {
+      return index < num_per_replica_args_;
+    }
+
+    bool IsDistributedArg(int64 index) const {
+      return index >= num_per_replica_args_ &&
+             index < (num_per_replica_args_ + num_distributed_args_);
+    }
+
+    bool IsBroadcastArg(int64 index) const {
+      return index >= num_per_replica_args_ &&
+             index < (num_per_replica_args_ + num_distributed_args_ +
+                      num_broadcast_args_);
+    }
+
+    bool IsVariableArg(int64 index) const {
+      return index >= (num_per_replica_args_ + num_broadcast_args_) &&
+             index < (num_per_replica_args_ + num_distributed_args_ +
+                      num_broadcast_args_ + num_variables_);
+    }
+
+    bool IsConstantArg(int64 index) const {
+      return index >= (num_per_replica_args_ + num_distributed_args_ +
+                       num_broadcast_args_ + num_variables_) &&
+             index < (num_per_replica_args_ + num_distributed_args_ +
+                      num_broadcast_args_ + num_variables_ +
+                      num_guaranteed_constants_);
+    }
+
+    // Returns the number of inputs which has been received by the host.
+    int64 NumInputsFromHost() const {
+      return num_replicas_ * num_per_replica_args_ + num_distributed_args_ +
+             num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
+    }
+
+    // Returns the number of inputs which will be sent to each replica.
+    int64 NumInputsToEachReplica() const {
+      return num_per_replica_args_ + num_distributed_args_ +
+             num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
+    }
+
+    // Returns the total number of output values returned to the host (for all
+    // replicas).
+    int64 NumOutputsToHost() const {
+      return num_replicas_ * num_retvals_per_replica_;
+    }
+
+    // Returns the position of the first per-replica argument, within the set
+    // of all hosts arguments.
+    // Broadcast arguments follow the distributed arguments.
+    int64 FirstBroadcastArgFromHost() const {
+      return num_replicas_ * num_per_replica_args_ + num_distributed_args_;
+    }
+
+    // Indices of mirrored variables across replicas, which should be
+    // categorized as per_replica_args.
+    const std::set<int64>& mirrored_variable_indices() const {
+      return mirrored_variable_indices_;
+    }
+    std::set<int64>* mutable_mirrored_variable_indices() {
+      return &mirrored_variable_indices_;
+    }
+
+   private:
+    int64 num_replicas_ = 1;
+    int64 num_per_replica_args_ = 0;
+    int64 num_distributed_args_ = 0;
+    int64 num_broadcast_args_ = 0;
+    int64 num_variables_ = 0;
+    int64 num_guaranteed_constants_ = 0;
+    int64 num_retvals_per_replica_ = 0;
+    std::set<int64> mirrored_variable_indices_;
+  };
+
+  // Mapping from TPUReplicate cluster name to tpu device names. Value is a
+  // mapping from [replica][core] to a TF device name.
+  typedef absl::flat_hash_map<string, std::vector<std::vector<string>>>
+      TPUReplicateDeviceNamesMapping;
+
+  // Determines which devices to use to run the computation.
+  // Inputs:
+  // * num_tpus_per_task: the number of TPU devices attached to each task
+  // * tpu_devices: a [task][device] collection of TPU devices
+  // * num_replicas: the number of replicas requested
+  // * num_cores_per_replica: the number of cores in each computation instance
+  // * topology_attr: the topology TPUReplicate attribute
+  // * device_assignment_attr: the device_assignment TPUReplicate attribute
+  // Outputs:
+  // * tf_device_assignment: a mapping from [replica][core] to a TF device name
+  // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU
+  //   coordinate.
+  // TODO(phawkins): change tf_device_assignment to an xla::Array2D.
+  static Status BuildDeviceAssignment(
+      const tpu::TpuTopologyExternal& topology, int num_tpus_per_task,
+      const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
+      int num_cores_per_replica, const string& topology_attr,
+      absl::Span<const int> device_assignment_attr,
+      std::vector<std::vector<string>>* tf_device_assignment,
+      std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment);
+
+  // Returns the `computation` graph attached to TPUReplicate operator
+  // `node`. `flr` is a FunctionLibraryRuntime to use when
+  // instantiating the function body. Sets `*arg_types` and
+  // `*retval_types` to the argument/return types of the function.
+  static Status GetComputationForTPUReplicateOp(const NameAttrList& function,
+                                                FunctionLibraryRuntime* flr,
+                                                Graph* computation,
+                                                DataTypeVector* arg_types,
+                                                DataTypeVector* retval_types);
+
+  // Returns the shapes of the argument tensors and return values of the
+  // TPUReplicate operator `node` using the _output_shapes,
+  // _output_handle_shapes, and _output_handle_types annotations on the input
+  // nodes. Expects inputs in the following order (see comment at top of file):
+  // * num_replicas * num_per_replica_args per-replica inputs,
+  // * num_broadcast_args broadcast inputs,
+  // * num_variables variable inputs.
+  // Returns an error if the input shapes to `node` are not statically known.
+  // Also verifies that all replicas have identical input shapes for their
+  // per-replica inputs.
+  static Status GetArgAndRetvalShapes(
+      const GraphShapeInfo& shape_info, const Node& node,
+      const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
+      std::vector<InferredShape>* retval_shapes);
+
+  // Assigns arguments and return values to cores. The assignment is represented
+  // as an XLA op sharding, so that an argument can be replicated across cores.
+  // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by
+  // argument/retval number.
+  // `arg_fast_mem` is vector of fast_mem indication which is indexed by
+  // argument number.
+  static Status AssignArgsAndRetvalsToCores(
+      int num_cores_per_replica, const ParameterInfo& params_info,
+      const DataTypeVector& arg_types,
+      const std::vector<InferredShape>& arg_shapes,
+      const DataTypeVector& retval_types,
+      const std::vector<InferredShape>& retval_shapes, const Graph& graph,
+      const Node* replicate_node, FunctionLibraryRuntime* flr,
+      std::vector<::xla::OpSharding>* arg_sharding,
+      std::vector<bool>* arg_fast_mem,
+      std::vector<::xla::OpSharding>* retval_sharding);
+
+  // Computes a fingerprint of the contents of `library`.
+  static Status FingerprintFunctionLibrary(
+      const FunctionLibraryDefinition& library, uint64* fingerprint);
+
+  // Populates `*variables` with the "variables" inputs to `index`-th output of
+  // `node`.
+  struct VariableInput {
+    Node* node;
+    int index;
+
+    // Type of the variable's value. Note that this is different to the type of
+    // the output of 'variable', which is always DT_RESOURCE.
+    DataType dtype;
+  };
+  static Status FindVariableInputs(const Node& node,
+                                   const NameRangeMap& input_range_map,
+                                   std::vector<VariableInput>* variables);
+
+  // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs
+  // to 'node'.
+  static Status FindGuaranteedConstantInputs(
+      const Node& node, const NameRangeMap& input_range_map,
+      std::vector<Node*>* guaranteed_constants);
+
+  // Builds Shape nodes that compute the shapes of arguments whose shapes are
+  // not statically known.
+  static Status BuildDynamicShapeNodes(
+      const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
+      const ParameterInfo& params_info,
+      const std::vector<Node*>& variable_reads, Graph* graph,
+      std::vector<Node*>* dynamic_shape_nodes);
+
+  // Builds a TPUCompile node that compiles the computation in
+  // `function_names`. calls `nodes`.
+  // TODO(b/33943292): at present, for model parallelism with Send/Recv to work
+  // the `nodes` must correspond to the computations assigned to TPU:0,
+  // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated
+  // executables.
+  static Status BuildCompileNode(
+      const Node* replicate_node, const NameAttrList& function,
+      uint64 library_fingerprint, const ParameterInfo& params_info,
+      const std::vector<InferredShape>& arg_shapes,
+      const DataTypeVector& arg_types,
+      const std::vector<Node*>& guaranteed_constant_nodes,
+      const string& session_handle,
+      const std::vector<::xla::OpSharding>& arg_sharding,
+      const std::vector<bool>& arg_fast_mem,
+      const std::vector<::xla::OpSharding>& retval_sharding,
+      int num_cores_per_replica, const string& compile_device,
+      const xla::DeviceAssignment* xla_device_assignment,
+      const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
+      Node** compile_node, int64 autotuner_thresh);
+
+  // Builds a TPUCompileSucceededAssert node that verifies that compilation
+  // succeeded and replaces the TPUCompilationStatus node in the graph.
+  static Status BuildCompilationStatusReturnNodes(
+      Node* replicate_node, Node* compile_node,
+      Node** control_after_compilation, Graph* graph);
+
+  // Builds ReadVariableOp nodes that read `variables`, with a control
+  // edges that ensure they happen after `control_predecessor`.
+  static Status BuildVariableReads(absl::Span<const VariableInput> variables,
+                                   Node* control_predecessor, Graph* graph,
+                                   std::vector<Node*>* variable_reads);
+
+  // Returns true if graph or functions contain resource write op, otherwise
+  // return false.
+  // TODO(b/137048563): Recognize unused resource rewrite op.
+  static bool ContainsResourceWriteOp(const Graph& graph,
+                                      const FunctionLibraryDefinition& fld);
+  // Struct that describes a variable value to be written back from TPUExecute.
+  struct VariableWrite {
+    // A node:output pair containing a boolean tensor that determines whether
+    // the value should be written back.
+    Node* predicate;
+    int predicate_output;
+
+    // A node:output pair containing the value to be written back.
+    Node* value;
+    int value_output;
+  };
+
+  // Builds AssignVariableOp nodes that write `variables` with the values from
+  // `variable_writes`, with control edges that ensure the writes happen before
+  // `control_successor`.
+  static Status BuildVariableWrites(
+      absl::Span<const VariableInput> variables, Node* control_successor,
+      absl::Span<const VariableWrite> variable_writes, Graph* graph);
+
+  // Builds TPUExecute operators assigned to each TPU device
+  // involved in the computation.
+  // Arguments:
+  // * `params_info` is the structure containing the information about the
+  //    TPUReplicate node inputs and outputs.
+  // * `num_tasks` is the number of TensorFlow tasks in the slice.
+  // * `num_cores_per_replica` is the number of cores which are dedicated to
+  //    each replica.
+  // * `replicate_node` is the original TPUReplicate node.
+  // * `arg_types` are the types of the arguments to the computation function
+  //    passed as argument to TPUReplicate, including per-replica,
+  //    broadcast, and variable arguments.
+  // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
+  //    applicable).
+  // * `arg_shardings` and `retval_shardings` are mappings from
+  //    arguments/return indices to shardings, as returned by
+  //    `AssignArgsAndRetvalsToCores`.
+  // * `pod_devices` lists the devices to assign to each core of each replica.
+  // * `variable_reads` is a vectors of ReadVariableOp operators, one for each
+  //    variable argument to the computation.
+  // * The execute operators will have a control edge from
+  //   `control_predecessor` and another control edge to `control_successor`.
+  // Populates '*variable_writes' with information about variable values to
+  // write back.
+  static Status BuildExecuteNodes(
+      const ParameterInfo& params_info, int num_tasks,
+      int num_cores_per_replica, const Node& replicate_node,
+      const DataTypeVector& arg_types,
+      const std::vector<InferredShape>& arg_shapes,
+      const DataTypeVector& retval_types,
+      const std::vector<::xla::OpSharding>& arg_shardings,
+      const std::vector<::xla::OpSharding>& retval_shardings,
+      const std::vector<std::vector<string>>& tpu_device_names,
+      Node* compile_node, const std::vector<Node*>& variable_reads,
+      Node* control_predecessor, Node* control_successor,
+      std::vector<VariableWrite>* variable_writes, Graph* graph);
+
+  // Connects the compile node to all the host transfer nodes, and removes the
+  // key placeholder node that was previously standing in for it.
+  // Arguments:
+  // * `compile_node` is the TPUCompile node that has been added to the graph.
+  // * `key_placeholder_node` is the placeholder node to send the key to all the
+  // host
+  // * transfer nodes in the original graph.
+  // * `graph` is the graph being rewritten.
+  static Status ConnectHostComputeNodes(Node* compile_node,
+                                        Node* key_placeholder_node,
+                                        Graph* graph);
+
+  // Map from a Node in an outside_compilation cluster in the original graph to
+  // the list of Nodes, one for each replica, that it is expanded into during
+  // replication.
+  typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap;
+
+  // Map from the name of an outside_compilation cluster to the model-parallel
+  // core index that the HostCompute Op should be placed on in that cluster.
+  typedef std::map<string, int> HostComputeCoreMap;
+
+  // Map from the name of an outside_compilation cluster to the list of Nodes
+  // that should run on the host for that cluster.
+  typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap;
+
+  // Copies the outside_compilation nodes in a cluster to create replica
+  // replica_index.
+  static Status CopyOutsideCompilationNodes(
+      int replica_index, const std::vector<Node*>& outside_compilation_nodes,
+      const DeviceNameUtils::ParsedName& tpu_device,
+      const DeviceNameUtils::ParsedName& partial_device,
+      NodeToNodeReplicasMap* node_images, Graph* graph);
+
+  // Replicates all the nodes in outside_compilation clusters in a compiled
+  // computation.
+  static Status ReplicateOutsideCompilationNodes(
+      const std::vector<std::vector<string>>& tf_device_assignment,
+      const HostComputeCoreMap& host_compute_core,
+      const OutsideCompilationNodeMap& outside_compilation_nodes,
+      NodeToNodeReplicasMap* node_images, Graph* graph);
+
+  // Lifts the edges between original outside_compilation nodes in a cluster
+  // onto their replicas.
+  static Status CopyOutsideCompilationEdges(
+      const std::vector<Node*>& outside_compilation_nodes,
+      const NodeToNodeReplicasMap& node_images,
+      const std::unordered_map<string, Node*> outside_compilation_inputs,
+      Graph* graph);
+
+  // Lifts all the edges in outside_compilation clusters in a compiled
+  // computation to their replicas.
+  static Status ReplicateOutsideCompilationEdges(
+      const OutsideCompilationNodeMap& outside_compilation_nodes,
+      const NodeToNodeReplicasMap& node_images,
+      const std::unordered_map<string, Node*> outside_compilation_inputs,
+      Graph* graph);
+
+  // Removes all the original outside_compilation nodes from the graph,
+  // following replication.
+  static Status RemoveOutsideCompilationNodes(
+      const NodeToNodeReplicasMap& node_images, Graph* graph);
+
+  // Lowers outside compilation functional nodes (If/While/function call).
+  // Otherwise, when we have multiple workers, device placer will not be able to
+  // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a
+  // DT_RESOURCE input fed into multiple While nodes on different devices).
+  static Status LowerOutsideCompilationFunctionalNodes(
+      Graph* g, const FunctionLibraryDefinition& flib_def,
+      const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping);
+
+  // Parses the 'host_compute_core' attribute on replicate_node to get the
+  // replicated core id of each outside_compilation cluster.
+  static Status ParseHostComputeCores(
+      const Node& replicate_node,
+      const OutsideCompilationNodeMap& outside_compilation_nodes,
+      HostComputeCoreMap* host_compute_core);
+
+  // Gets the physical topology information about the TPU system.
+  static Status GetDeviceTopology(
+      const DeviceSet& device_set, const Node& replicate_node,
+      int* num_replicas, int* num_cores_per_replica, int* num_tasks,
+      std::vector<std::vector<string>>* tf_device_assignment,
+      std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
+      string* tpu_compilation_device);
+
+  // Gets the types of args, retvals, and parameters.
+  static Status GetIOTypes(
+      int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
+      Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
+      std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
+      DataTypeVector* retval_types, ParameterInfo* params_info);
+
+  // Find known constants and deals with variable reads.
+  static Status DealWithConstantsAndVariables(
+      const Node& replicate_node, const NameRangeMap& input_name_map,
+      Graph* graph, Node* host_transfer_sequencer, Node* control_before,
+      Node* control_after, absl::Span<const VariableInput> variable_nodes,
+      std::vector<Node*>* guaranteed_constant_nodes,
+      std::vector<Node*>* variable_reads);
+
+  // Adds NoOp nodes for sequencing computation and variable reads/writes.
+  static Status BuildSequencingNodes(const string& tpu_compilation_device,
+                                     const Node& replicate_node, Graph* graph,
+                                     Node** host_transfer_sequencer,
+                                     Node** control_before,
+                                     Node** control_after);
+
+  // Performs the pass's rewrite on a TPUReplicate node `node`.
+  static Status RewriteTPUReplicateNode(
+      const string& session_handle, const DeviceSet& device_set,
+      Node* replicate_node, FunctionLibraryDefinition* flib_def,
+      FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
+      const OutsideCompilationNodeMap& outside_compilation_nodes,
+      const std::vector<Node*>& head_tail_outside_compilation_nodes,
+      NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
+      const GraphShapeInfo& shape_info,
+      TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
+      int64 autotuner_thresh);
+
+  // Performs host training loop optimization. For example, when TPUExecute
+  // node is inside a while loop, then model weight variables can be sharded
+  // in XLA preferred layout and then unsharded only at the very last iteration
+  // to reduce the number of all_gather.
+  static Status PerformHostTrainingLoopOptimization(
+      Graph* graph, FunctionLibraryDefinition* flib_def,
+      FunctionLibraryRuntime* flr);
+
+  // Heuristically place some nodes with unassigned devices on TPUs for
+  // performance reasons.
+  static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph);
+
+  // Updates the head and tail outside compiled nodes so that nodes have the
+  // correct device and removes the replication and outside compilation
+  // attributes so that these nodes do not trigger further graph optimization
+  // passes.
+  static Status UpdateHeadTailOutsideCompilation(
+      const std::vector<std::vector<string>>& tf_device_assignment,
+      const std::vector<Node*>& head_tail_outside_compilation_nodes);
+
+ private:
+  static bool distribute_vars_;
+  static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
+  static bool enable_cross_replica_sharding_mirrored_variables_;
+  static bool enable_automatic_model_parallelism_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc
new file mode 100644
index 00000000000..18b158c0335
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc
@@ -0,0 +1,45 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
+
+#include <limits>
+
+#include "absl/random/random.h"
+
+namespace tensorflow {
+namespace {
+
+static int64 overridden_node_id = -1;
+
+}  // namespace
+
+namespace internal {
+
+void OverrideNodeIdForTesting(const int64 node_id) {
+  overridden_node_id = node_id;
+}
+
+uint64 GetNodeId() {
+  if (overridden_node_id > -1) {
+    return overridden_node_id;
+  } else {
+    return absl::Uniform(absl::SharedBitGen(), uint64{0},
+                         std::numeric_limits<uint64>::max());
+  }
+}
+
+}  // namespace internal
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h
new file mode 100644
index 00000000000..ce80249c30f
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
+
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT
+// depend on these.
+namespace internal {
+
+// When set to a value >= 0, overrides the node_id. Used for getting
+// deterministic node_ids during testing.
+void OverrideNodeIdForTesting(int64 node_id);
+
+// Retrieves the node id, used to make some node names unique in the rewrite
+// pass.
+uint64 GetNodeId();
+
+}  // namespace internal
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
new file mode 100644
index 00000000000..fad8e22399c
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
@@ -0,0 +1,629 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
+
+#include <deque>
+#include <map>
+#include <unordered_map>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/container/node_hash_set.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
+
+namespace tensorflow {
+namespace tpu {
+
+namespace {
+
+constexpr char kDefaultShardingValue[] = "";
+
+const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
+  for (const auto e : src->out_edges()) {
+    if (e->dst()->name() == dst->name()) return &(*e);
+  }
+  return nullptr;
+}
+
+// Contains TPUExecute node and its DT_RESOURCE input nodes that
+// correspond to model weights.
+struct ExecuteNodeInfo {
+  Node* execute_node;
+  std::vector<const Edge*> var_inputs;
+};
+
+// Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
+bool IsExecuteNodeOrIdentityToExecuteNode(
+    const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
+    const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
+  if (execute_nodes.find(node) != execute_nodes.end()) return true;
+  if (loop_nodes.find(node) == loop_nodes.end()) return false;
+  if (node->IsNextIteration()) return true;
+  if (!node->IsIdentity()) return false;
+
+  for (const Edge* e : node->out_edges()) {
+    if (e->IsControlEdge()) continue;
+
+    Node* node = e->dst();
+    if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
+                                              node)) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+// From input node to the TPUExecute op, finds the corresponding Enter node
+// by searching/traversing nodes in below pattern of nodes:
+// Enter ----> (identity) --->  While body input
+// Returns nullptr if the Enter node is not found.
+xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
+  Node* node = input_node;
+  while (node->IsIdentity()) {
+    TF_RETURN_IF_ERROR(node->input_node(0, &node));
+  }
+
+  if (node->IsEnter()) {
+    return node;
+  }
+  return nullptr;
+}
+
+xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
+    const Graph& graph, const std::unordered_set<Node*>& loop_nodes,  // NOLINT
+    const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
+  for (const Edge* output_edge : enter_node->out_edges()) {
+    Node* output_node = output_edge->dst();
+    if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
+
+    // If output node is not execute node, it must be output node
+    // to the while loop body.
+    if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
+                                              output_node)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
+// program and its model weight variable inputs as well.
+// TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
+// if new reshard ops are added.
+Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
+                              const std::unordered_set<Node*>& loop_nodes,  // NOLINT
+                              std::vector<ExecuteNodeInfo>* execute_node_info,
+                              TPUCompileMetadataProto* new_metadata) {
+  string metadata_string;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
+  new_metadata->ParsePartialFromString(metadata_string);
+  if (new_metadata->num_cores_per_replica() != 1) {
+    // We do not support model parallelism yet.
+    return Status::OK();
+  }
+
+  execute_node_info->clear();
+  for (Node* node : compile_node->out_nodes()) {
+    if (node->type_string() == "TPUExecute") {
+      execute_node_info->push_back({node});
+    }
+  }
+  if (execute_node_info->empty()) {
+    return Status::OK();
+  }
+  TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
+      << "Number of replicas does not equal number of execute nodes: "
+      << new_metadata->num_replicas() << " vs " << execute_node_info->size();
+  DataTypeVector arg_types;
+  TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
+                                 "Targs", &arg_types));
+  for (int64 i = 0; i < arg_types.size(); ++i) {
+    if (arg_types[i] != DT_RESOURCE) {
+      continue;
+    }
+    const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
+    if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
+        sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
+      continue;
+    }
+    std::vector<const Edge*> edges(execute_node_info->size());
+    bool is_supported = true;
+    std::unordered_map<Node*, absl::flat_hash_set<Node*>>
+        enter_to_execute_nodes;
+    for (int64 j = 0; j < edges.size(); ++j) {
+      auto execute = (*execute_node_info)[j].execute_node;
+      TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
+      TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
+                   arg_types[i])
+          << "Execute op has an unexpected input type.";
+      // Traverse backwards to find the Enter node from which the input is
+      // passed.
+      // This makes sure that we are checking the usages of all potential
+      // aliases of the input node as well.
+      TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
+                                               edges[j]->src()));
+      if (enter_node == nullptr) {
+        is_supported = false;
+        enter_to_execute_nodes.clear();
+        break;
+      }
+      enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
+    }
+
+    for (const auto& it : enter_to_execute_nodes) {
+      // Size of execute nodes should be either 1 (per-replica variables) or
+      // num_replicas (distributed variables).
+      if ((it.second.size() != 1) &&
+          (it.second.size() != new_metadata->num_replicas())) {
+        is_supported = false;
+        break;
+      }
+      TF_ASSIGN_OR_RETURN(bool no_other_use,
+                          ResourceOnlyUsedForTPUExecuteInLoop(
+                              graph, loop_nodes, it.first, it.second));
+      if (!no_other_use) {
+        is_supported = false;
+        break;
+      }
+    }
+
+    // Add the variable input edges only when they are supported for all
+    // executes.
+    if (is_supported) {
+      for (int64 j = 0; j < edges.size(); ++j) {
+        (*execute_node_info)[j].var_inputs.push_back(edges[j]);
+      }
+      new_metadata->mutable_args(i)->set_enable_xla_sharding(
+          TPUCompileMetadataProto::Arg::ALLOWED);
+    }
+  }
+
+  int64 total = 0;
+  for (const auto& a : new_metadata->args()) {
+    if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
+      total++;
+    }
+  }
+  TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
+      << " total " << total << " var_inputs "
+      << (*execute_node_info)[0].var_inputs.size();
+  if (total == 0) {
+    // We don't need to process anything if no input is added.
+    execute_node_info->clear();
+  }
+  return Status::OK();
+}
+
+bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
+
+void FindTPUCompileNodes(
+    const std::string* current_function_name,
+    const AttrValueMap* current_function_attr,
+    const std::unordered_map<string, WhileLoopFrame>& frames,
+    std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
+  // Adds frames with no children (i.e., the innermost frames) to a worklist.
+  std::deque<const WhileLoopFrame*> worklist;
+
+  for (auto& frame : frames) {
+    if (frame.second.num_children == 0) {
+      worklist.push_back(&frame.second);
+    }
+  }
+
+  // Check TPUCompile node from the innermost while loop to the outermost
+  // while loop.
+  while (!worklist.empty()) {
+    const WhileLoopFrame* frame = worklist.front();
+    worklist.pop_front();
+
+    for (const auto& n : frame->nodes) {
+      if (!IsTPUCompileOp(*n)) continue;
+
+      HostTrainingLoopInfo host_training_loop_info;
+      host_training_loop_info.compile_node_name = n->name();
+      host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
+      host_training_loop_info.while_loop_name = frame->name;
+
+      for (const auto arg : frame->args) {
+        LoopArgInfo arg_info;
+        arg_info.enter_node_name = arg.enter->name();
+        if (arg.exit) arg_info.exit_node_name = arg.exit->name();
+
+        host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
+      }
+      host_training_loop_info.loop_nodes = frame->nodes;
+
+      if (current_function_name) {
+        host_training_loop_info.encapsulating_function_name =
+            *current_function_name;
+      }
+      if (current_function_attr) {
+        host_training_loop_info.encapsulating_function_attrs =
+            *current_function_attr;
+      }
+
+      host_training_loops_info->emplace_back(
+          std::move(host_training_loop_info));
+    }
+
+    // If the parent has no remaining children, add it to the worklist.
+    --frame->parent->num_children;
+    if (frame->parent->num_children == 0) {
+      worklist.push_back(frame->parent);
+    }
+  }
+}
+
+// From while loop cond node, finds all loop exit nodes by searching/traversing
+// nodes in below pattern of nodes:
+// LoopCond -----> Switch -----> Exit
+std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
+  std::vector<Node*> loop_exit_nodes;
+  for (const auto e_cond : loop_cond.out_edges()) {
+    if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
+    auto switch_node = e_cond->dst();
+
+    for (const auto e_switch : switch_node->out_edges()) {
+      if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
+
+      loop_exit_nodes.push_back(e_switch->dst());
+    }
+  }
+  return loop_exit_nodes;
+}
+
+// Find any one of switch nodes in the while loop by traversing the graph
+// from while loop condition node.
+xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
+  Node* loop_switch_node;
+  for (auto n : loop_cond_node.out_nodes()) {
+    if (n->IsSwitch()) {
+      loop_switch_node = n;
+      break;
+    }
+  }
+
+  TF_RET_CHECK(loop_switch_node->IsSwitch())
+      << "Unable to find any switch nodes.";
+  return loop_switch_node;
+}
+
+// Returns or creates a node in that is executed before each loop iteration
+// in the while loop.
+Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
+                                          Node** node_out) {
+  // If while loop switch node already has a outgoing data to true brach
+  // of the switch op, then reuse that node.
+  for (const auto out_edge : loop_switch_node->out_edges()) {
+    if (out_edge->src_output() == 1) {
+      *node_out = out_edge->dst();
+      return Status::OK();
+    }
+  }
+
+  // Create Identity node that represents execution at every loop iteration.
+  NodeDef at_loop_iteration_nodedef;
+  at_loop_iteration_nodedef.set_op("Identity");
+  DataType dtype;
+  TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
+
+  AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
+  at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
+      "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
+
+  Status status;
+  Node* at_loop_iteration_node =
+      graph->AddNode(at_loop_iteration_nodedef, &status);
+  TF_RETURN_IF_ERROR(status);
+
+  graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
+  *node_out = at_loop_iteration_node;
+  return Status::OK();
+}
+
+// Injects NoOp node in that is executed after the very last iteration
+// of the while loop but before the while loop exit node.
+Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
+                                 Node** node_out) {
+  // Find the exit node from loop switch node.
+  Node* exit_node;
+  for (const auto out_node : loop_switch_node->out_nodes()) {
+    if (out_node->IsExit()) {
+      exit_node = out_node;
+      break;
+    }
+  }
+
+  TF_RET_CHECK(exit_node != nullptr)
+      << "Cannot find exit node connected to switch node :"
+      << loop_switch_node->name();
+
+  // Create NoOp that represents execution at the end of while loop
+  // last iteration.
+  NodeDef after_last_loop_iteration;
+  after_last_loop_iteration.set_op("Identity");
+  DataType dtype;
+  TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
+
+  AddNodeAttr("T", dtype, &after_last_loop_iteration);
+  after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
+      "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
+
+  Status status;
+  Node* after_last_iteration_node =
+      graph->AddNode(after_last_loop_iteration, &status);
+  TF_RETURN_IF_ERROR(status);
+
+  // Newly created node must be executed once after last iteration of the while
+  // loop and before while loop exits.
+  graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
+  graph->AddControlEdge(after_last_iteration_node, exit_node);
+  *node_out = after_last_iteration_node;
+  return Status::OK();
+}
+
+}  // namespace
+
+Status DetectHostTrainingLoop(
+    const std::string* current_function_name,
+    const AttrValueMap* current_function_attr,
+    const FunctionLibraryDefinition* library, Graph* graph,
+    FunctionLibraryRuntime* flr,
+    std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
+  std::vector<AssociatedFunctionInfo> associated_function_list;
+  for (const auto* n : graph->nodes()) {
+    const auto associated_functions = GetAssociatedFunctions(*n, library);
+    if (associated_functions.empty()) continue;
+
+    associated_function_list.insert(associated_function_list.end(),
+                                    associated_functions.begin(),
+                                    associated_functions.end());
+  }
+
+  Status ret_status = Status::OK();
+  for (const auto& function : associated_function_list) {
+    if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
+
+    // Convert the function to Graph.
+    FunctionLibraryRuntime::Handle handle;
+    TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
+                                        AttrSlice(&function.attrs()), &handle));
+    auto cleanup_handle = gtl::MakeCleanup([&]() {
+      auto s = flr->ReleaseHandle(handle);
+      if (!s.ok()) {
+        ret_status.Update(s);
+      }
+    });
+    const FunctionBody* body = flr->GetFunctionBody(handle);
+    Graph* function_graph = body->graph;
+    TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
+        &function.func_name(), &function.attrs(), library, function_graph, flr,
+        host_training_loops_info));
+  }
+
+  // BuildControlFlowInfo() requires that the graph's source node is connected
+  // to all source nodes in the graph. Many graphs violate this invariant.
+  // As so, add edges to source/sink nodes so that this invariant is kept.
+  FixupSourceAndSinkEdges(graph);
+  std::vector<ControlFlowInfo> cf_info;
+  TF_RETURN_IF_ERROR(
+      BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
+
+  std::unordered_map<string, WhileLoopFrame> frames;
+  TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
+  FindTPUCompileNodes(current_function_name, current_function_attr, frames,
+                      host_training_loops_info);
+  return ret_status;
+}
+
+Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
+  const auto& compile_node_name = host_loop_info.compile_node_name;
+  const auto node_name_map = graph->BuildNodeNameIndex();
+  const auto node_it = node_name_map.find(compile_node_name);
+  TF_RET_CHECK(node_it != node_name_map.end())
+      << "Unable to find compile node : " << compile_node_name;
+
+  const auto compile_node = node_it->second;
+  std::vector<ExecuteNodeInfo> execute_nodes_info;
+
+  Status status;
+  TPUCompileMetadataProto metadata;
+  status =
+      ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
+                             &execute_nodes_info, &metadata);
+  if (!status.ok()) {
+    LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
+                  "skipping host loop optimization. Status: "
+               << status.ToString();
+    return Status::OK();
+  }
+
+  if (execute_nodes_info.empty()) {
+    return Status::OK();
+  }
+
+  // Update the TPUCompileMetadata such that sharding config of the
+  // sharded resource variable inputs is set to ALLOWED instead of
+  // TENTATIVE.
+  string new_metadata_string;
+  metadata.SerializeToString(&new_metadata_string);
+  compile_node->ClearAttr("metadata");
+  compile_node->AddAttr("metadata", new_metadata_string);
+
+  // Unsharding of the model weight variables must happen only at the very
+  // last loop iteration. As so, add while loop condition predicate as an
+  // input to the sharding switch node. If loop condition is true, we do not
+  // unshard.
+  const auto& cond_node_name = host_loop_info.loop_cond_node_name;
+  auto loop_cond_node_it = node_name_map.find(cond_node_name);
+  TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
+      << "Cannot find loop condition node : " << cond_node_name;
+  auto* loop_condition_node = loop_cond_node_it->second;
+
+  // In order to make sure that shard/unshard operations are invoked
+  // at the start of every loop body and at the end of last iteration
+  // of the loop, respectively, traverse the graph and find a switch node
+  // of the host training loop.
+  TF_ASSIGN_OR_RETURN(Node * switch_node,
+                      GetLoopSwitchNode(*loop_condition_node));
+
+  Node* after_last_iteration_node;
+  TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
+                                               &after_last_iteration_node));
+
+  Node* before_loop_iteration_node;
+  TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
+      graph, switch_node, &before_loop_iteration_node));
+
+  // Create const op that represents default sharding value
+  // (i.e. no-op sharding).
+  NodeDef default_sharding;
+  default_sharding.set_op("Const");
+  default_sharding.set_name(graph->NewName(strings::StrCat(
+      "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
+  AddNodeAttr("dtype", DT_STRING, &default_sharding);
+
+  Tensor t(DT_STRING, {2});
+  t.vec<tstring>()(0) = kDefaultShardingValue;
+  t.vec<tstring>()(1) = kDefaultShardingValue;
+  t.AsProtoTensorContent(
+      (*default_sharding.mutable_attr())["value"].mutable_tensor());
+
+  Node* default_sharding_node = graph->AddNode(default_sharding, &status);
+  TF_RETURN_IF_ERROR(status);
+  // Add control edge between loop condition to make sure that
+  // default_sharding_node node is inside the while loop frame.
+  graph->AddControlEdge(loop_condition_node, default_sharding_node);
+
+  // Build a no-op node used to add control edges after unshard nodes.
+  NodeDef after_unshard;
+  after_unshard.set_op("NoOp");
+  after_unshard.set_name(graph->NewName(strings::StrCat(
+      "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
+  auto after_unshard_node = graph->AddNode(after_unshard, &status);
+  TF_RETURN_IF_ERROR(status);
+
+  for (auto info : execute_nodes_info) {
+    auto execute_node = info.execute_node;
+    // Create Reshard op that optionally shards model weight variables
+    // prior to program execution.
+    NodeDef reshard_node_def;
+    reshard_node_def.set_name(graph->NewName(strings::StrCat(
+        "TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
+    reshard_node_def.set_op("TPUReshardVariables");
+    AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
+                &reshard_node_def);
+    Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
+    if (!status.ok()) return status;
+
+    reshard_op_node->set_assigned_device_name(
+        execute_node->assigned_device_name());
+
+    // Reshard op must execute at every loop iteration prior to
+    // TPUExecute node.
+    graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
+    graph->AddControlEdge(reshard_op_node, execute_node);
+
+    for (int i = 0; i < info.var_inputs.size(); ++i) {
+      const auto variable_edge = info.var_inputs[i];
+      graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
+                     reshard_op_node, i);
+    }
+
+    const int new_key_input = info.var_inputs.size();
+    // Add program input edge from the compiler(i.e. compilation key).
+    const auto compilation_key_edge =
+        FindEdgeConnecting(compile_node, execute_node);
+    graph->AddEdge(compile_node, compilation_key_edge->src_output(),
+                   reshard_op_node, new_key_input);
+
+    // Create VarHandleOp to store sharding state. Sharding state holds string
+    // compilation key that identifies whether the graph is re-compiled and the
+    // variables need to be sharded again.
+    NodeDef var_handle_def;
+    var_handle_def.set_op("VarHandleOp");
+    var_handle_def.set_name(graph->NewName(strings::StrCat(
+        "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
+    AddNodeAttr("dtype", DT_STRING, &var_handle_def);
+    AddNodeAttr("shape", TensorShape({}), &var_handle_def);
+    Node* var_handle_node = graph->AddNode(var_handle_def, &status);
+    if (!status.ok()) return status;
+
+    // Add control edge between `var_handle_def` node and while loop
+    // loop condition so that `var_handle_def` is inside the same while loop
+    // frame.
+    // TODO(hongjunchoi): Consider adding control edge from another node--such
+    // as input control node.
+    graph->AddControlEdge(loop_condition_node, var_handle_node);
+
+    // Connect data edge between var handle op and reshard op.
+    const int format_state_input = new_key_input + 1;
+    graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
+
+    // Create Reshard op that represents unsharding after TPUExecute.
+    NodeDef unshard_node_def;
+    unshard_node_def.set_name(graph->NewName(strings::StrCat(
+        "TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
+    unshard_node_def.set_op("TPUReshardVariables");
+    AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
+                &unshard_node_def);
+    Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
+    TF_RETURN_IF_ERROR(status);
+
+    unshard_op_node->set_assigned_device_name(
+        execute_node->assigned_device_name());
+
+    for (int i = 0; i < info.var_inputs.size(); ++i) {
+      const auto variable_edge = info.var_inputs[i];
+      // Connect model weight resource variables to unshard op. Since unshard op
+      // must be only invoked after the very last loop iteration, for each while
+      // loop inputs, we traverse backwards to find the switch node of the host
+      // training loop and connect `output_false` field of the switch node with
+      // unshard op.
+      TF_ASSIGN_OR_RETURN(
+          Node * enter_node,
+          FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
+      graph->AddEdge(enter_node, 0, unshard_op_node, i);
+    }
+
+    // Add control dependency before/after unshard node and the control nodes.
+    graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
+    graph->AddControlEdge(unshard_op_node, after_unshard_node);
+
+    graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
+
+    // Add data edge from sharding state var handle op to unshard op.
+    graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
+  }
+  // Add control dependency from after_unshard_node to all exits nodes. This is
+  // to make sure that the unshard ops will be executed as long as any of the
+  // exits are used.
+  for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
+    graph->AddControlEdge(after_unshard_node, exit);
+  }
+  return Status::OK();
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h
new file mode 100644
index 00000000000..822dc9edd51
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h
@@ -0,0 +1,80 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace tpu {
+
+struct LoopArgInfo {
+  std::string enter_node_name;
+  // Exit nodes are optional for loop invariant while loop args.
+  absl::optional<std::string> exit_node_name;
+};
+
+struct HostTrainingLoopInfo {
+  // Name and attribute information about the function in which
+  // host training loop is included. If host training loop is not
+  // inside a function call, then `function_name` and `function_attrs`
+  // are nullopt.
+  absl::optional<std::string> encapsulating_function_name;
+  absl::optional<AttrValueMap> encapsulating_function_attrs;
+
+  // TPU Compile node as within a host training loop.
+  std::string compile_node_name;
+
+  // Name of the while loop in which TPU compile op is located.
+  std::string while_loop_name;
+
+  // Name of the node that represents loop condition.
+  std::string loop_cond_node_name;
+
+  // Exit and Enter node names for each loop arguments.
+  std::vector<LoopArgInfo> loop_arguments;
+
+  std::unordered_set<Node*> loop_nodes;  // NOLINT
+};
+
+// Walks through the `graph`, recursively if functional nodes exist, and
+// identifies all host training loops. Host training loops are the inner
+// most while loops that encapsulates TPUCompileOp node. This would be
+// later used/analyzed to inroduce host loop specific optimizations such
+// as adding sharded weight update.
+Status DetectHostTrainingLoop(
+    const std::string* current_function_name,
+    const AttrValueMap* current_function_attr,
+    const FunctionLibraryDefinition* library, Graph* graph,
+    FunctionLibraryRuntime* flr,
+    std::vector<HostTrainingLoopInfo>* host_training_loops_info);
+
+// Injects VariableReshardOps to before and after TPUExecute op inside
+// host training loop body. This effectively applies sharded weight update
+// on model weight variables.
+Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info);
+
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc
new file mode 100644
index 00000000000..47187204f69
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc
@@ -0,0 +1,73 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/function.h"
+
+namespace tensorflow {
+
+IncompleteNodeDefBuilder::IncompleteNodeDefBuilder(const string& name,
+                                                   const string& op,
+                                                   const NodeDebugInfo& debug) {
+  nodedef_.set_name(name);
+  nodedef_.set_op(op);
+  MergeDebugInfo(debug, &nodedef_);
+}
+
+IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(
+    const string& attr, const DataType& type) {
+  AddNodeAttr(attr, type, &nodedef_);
+  return *this;
+}
+
+IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(const string& attr,
+                                                            int val) {
+  AddNodeAttr(attr, val, &nodedef_);
+  return *this;
+}
+
+IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::Device(
+    const string& device) {
+  nodedef_.set_device(device);
+  return *this;
+}
+
+Status IncompleteNodeDefBuilder::Build(Graph* graph, Node** n) {
+  Status status;
+  *n = graph->AddNode(nodedef_, &status);
+  return status;
+}
+
+IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Identity(
+    const string& name, const DataType& type, const NodeDebugInfo& debug) {
+  return IncompleteNodeDefBuilder(name, "Identity", debug).AddAttr("T", type);
+}
+
+IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Merge(
+    const string& name, const DataType& type, const NodeDebugInfo& debug,
+    int n) {
+  return IncompleteNodeDefBuilder(name, "Merge", debug)
+      .AddAttr("T", type)
+      .AddAttr("N", n);
+}
+
+IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Switch(
+    const string& name, const DataType& type, const NodeDebugInfo& debug) {
+  return IncompleteNodeDefBuilder(name, "Switch", debug).AddAttr("T", type);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h
new file mode 100644
index 00000000000..88e484f00cf
--- /dev/null
+++ b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
+#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Convenience builder to build NodeDefs without specifying the inputs. This is
+// similar to NodeDefBuilder except inputs are not specified.
+// TODO(jpienaar): Clean up NodeDefBuilder and remove this class.
+class IncompleteNodeDefBuilder {
+ public:
+  IncompleteNodeDefBuilder(const string& name, const string& op,
+                           const NodeDebugInfo& debug);
+
+  IncompleteNodeDefBuilder& AddAttr(const string& attr, const DataType& type);
+  IncompleteNodeDefBuilder& AddAttr(const string& attr, int val);
+
+  IncompleteNodeDefBuilder& Device(const string& device);
+
+  Status Build(Graph* graph, Node** n);
+
+  static IncompleteNodeDefBuilder Identity(const string& name,
+                                           const DataType& type,
+                                           const NodeDebugInfo& debug);
+  static IncompleteNodeDefBuilder Merge(const string& name,
+                                        const DataType& type,
+                                        const NodeDebugInfo& debug, int n);
+  static IncompleteNodeDefBuilder Switch(const string& name,
+                                         const DataType& type,
+                                         const NodeDebugInfo& debug);
+
+ private:
+  NodeDef nodedef_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
diff --git a/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc b/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc
index ef1831464e2..83a652d7aaa 100644
--- a/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc
+++ b/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
+#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
 #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
 #include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
 
@@ -30,8 +31,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34,
                       EncapsulateTPUComputationsPass);
 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39,
                       ExtractOutsideCompilationPass);
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 40,
+                      DistributedTPURewritePass);
 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0,
                       VariableMergerPass);
-
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc
index 6c767d1d66e..120245e34b7 100644
--- a/tensorflow/stream_executor/multi_platform_manager.cc
+++ b/tensorflow/stream_executor/multi_platform_manager.cc
@@ -55,8 +55,8 @@ class MultiPlatformManagerImpl {
       TF_LOCKS_EXCLUDED(mu_);
 
   port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
-      const std::function<bool(const Platform*)>& filter)
-      TF_LOCKS_EXCLUDED(mu_);
+      const std::function<bool(const Platform*)>& filter,
+      bool initialize_platform) TF_LOCKS_EXCLUDED(mu_);
 
   using Listener = MultiPlatformManager::Listener;
   port::Status RegisterListener(std::unique_ptr<Listener> listener)
@@ -188,7 +188,8 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
 
 port::StatusOr<std::vector<Platform*>>
 MultiPlatformManagerImpl::PlatformsWithFilter(
-    const std::function<bool(const Platform*)>& filter) {
+    const std::function<bool(const Platform*)>& filter,
+    bool initialize_platform) {
   absl::MutexLock lock(&mu_);
   CHECK_EQ(id_map_.size(), name_map_.size());
   std::vector<Platform*> platforms;
@@ -196,7 +197,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
   for (const auto& entry : id_map_) {
     Platform* platform = entry.second;
     if (filter(platform)) {
-      if (!platform->Initialized()) {
+      if (initialize_platform && !platform->Initialized()) {
         SE_RETURN_IF_ERROR(platform->Initialize({}));
       }
       platforms.push_back(platform);
@@ -299,7 +300,14 @@ MultiPlatformManager::InitializePlatformWithId(
 /*static*/ port::StatusOr<std::vector<Platform*>>
 MultiPlatformManager::PlatformsWithFilter(
     const std::function<bool(const Platform*)>& filter) {
-  return Impl().PlatformsWithFilter(filter);
+  return PlatformsWithFilter(filter, /*initialize_platform=*/true);
+}
+
+/*static*/ port::StatusOr<std::vector<Platform*>>
+MultiPlatformManager::PlatformsWithFilter(
+    const std::function<bool(const Platform*)>& filter,
+    bool initialize_platform) {
+  return Impl().PlatformsWithFilter(filter, initialize_platform);
 }
 
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h
index fbb6effdf83..4fa2d819520 100644
--- a/tensorflow/stream_executor/multi_platform_manager.h
+++ b/tensorflow/stream_executor/multi_platform_manager.h
@@ -130,6 +130,10 @@ class MultiPlatformManager {
   static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
       const std::function<bool(const Platform*)>& filter);
 
+  static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
+      const std::function<bool(const Platform*)>& filter,
+      bool initialize_platform);
+
   // Although the MultiPlatformManager "owns" its platforms, it holds them as
   // undecorated pointers to prevent races during program exit (between this
   // object's data and the underlying platforms (e.g., CUDA, OpenCL).
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index 7fa46ebd8d1..a8557aada48 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -331,6 +331,7 @@ cc_library(
     name = "tpu_topology_external",
     srcs = ["tpu_topology.cc"],
     hdrs = ["tpu_topology.h"],
+    visibility = ["//visibility:public"],
     deps = [
         ":c_api_decl",
         "//tensorflow/core/platform:types",
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
index 7580e709bdf..fa9062c217c 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -23,10 +23,11 @@ namespace tensorflow {
 namespace tpu {
 
 namespace {
-TpuPlatformInterface* GetRegisteredPlatformStatic() {
+TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform) {
   // Prefer TpuPlatform if it's registered.
   auto status_or_tpu_platform =
-      stream_executor::MultiPlatformManager::PlatformWithName("TPU");
+      stream_executor::MultiPlatformManager::PlatformWithName(
+          "TPU", initialize_platform);
   if (status_or_tpu_platform.ok()) {
     return static_cast<TpuPlatformInterface*>(
         status_or_tpu_platform.ValueOrDie());
@@ -43,7 +44,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
           [](const stream_executor::Platform* platform) {
             return dynamic_cast<const TpuPlatformInterface*>(platform) !=
                    nullptr;
-          });
+          },
+          initialize_platform);
   if (!status_or_other_tpu_platforms.ok()) {
     LOG(WARNING) << "Error when getting other TPU platforms: "
                  << status_or_tpu_platform.status();
@@ -64,9 +66,24 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
 
 /* static */
 TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
-  // Use a local static variable to avoid data races during initialization.
+  return GetRegisteredPlatform(/*initialize_platform=*/true);
+}
+
+/* static */
+TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
+    bool initialize_platform) {
+  static bool requested_initialize_platform = initialize_platform;
   static TpuPlatformInterface* tpu_registered_platform =
-      GetRegisteredPlatformStatic();
+      GetRegisteredPlatformStatic(initialize_platform);
+
+  if (!requested_initialize_platform && initialize_platform) {
+    // If the first time this function is called, we did not request
+    // initializing the platform, but the next caller wants the platform
+    // initialized, we will call GetRegisteredPlatformStatic again to initialize
+    // the platform.
+    tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
+  }
+
   return tpu_registered_platform;
 }
 
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
index da9e91ffc1c..889375245a8 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
@@ -33,6 +33,9 @@ class TpuPlatformInterface : public stream_executor::Platform {
   // is registered or an error occurred.
   static TpuPlatformInterface* GetRegisteredPlatform();
 
+  // Option to not initialize a platform if not necessary.
+  static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform);
+
   virtual Status Reset() { return Reset(false); }
 
   virtual Status Reset(bool only_tear_down) = 0;
diff --git a/tensorflow/stream_executor/tpu/tpu_topology.h b/tensorflow/stream_executor/tpu/tpu_topology.h
index 48371b6e008..b49b1e24386 100644
--- a/tensorflow/stream_executor/tpu/tpu_topology.h
+++ b/tensorflow/stream_executor/tpu/tpu_topology.h
@@ -30,6 +30,7 @@ struct TpuChipCoordinatesExternal {
 
 class TpuCoreLocationExternal {
  public:
+  TpuCoreLocationExternal() : core_location_(nullptr) {}
   explicit TpuCoreLocationExternal(void* core_location)
       : core_location_(core_location) {}
   TpuChipCoordinatesExternal chip_coordinates() const;