diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index 9d787d26433..f4dbcc6cead 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -7,10 +7,26 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+# Currently pybind extension shared objects must use only C API headers since
+# the C API has static initializers duplicated in the Python bindings. So we
+# need a second rule that omits .cc files, in
+# tensorflow/python:_pywrap_parallel_device.
+filegroup(
+    name = "headers",
+    srcs = ["parallel_device.h"],
+    visibility = ["//tensorflow/python:__pkg__"],
+)
+
+filegroup(
+    name = "sources",
+    srcs = ["parallel_device.cc"],
+    visibility = ["//tensorflow/python:__pkg__"],
+)
+
 cc_library(
     name = "parallel_device",
-    srcs = ["parallel_device.cc"],
-    hdrs = ["parallel_device.h"],
+    srcs = [":sources"],
+    hdrs = [":headers"],
     deps = [
         "//tensorflow/c:c_api",
         "//tensorflow/c/eager:c_api",
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index bd5d8e777f2..e6846809fcf 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -574,23 +574,21 @@ void DeleteParallelDevice(void* device_info) {
 
 }  // namespace
 
-void RegisterParallelDevice(TFE_Context* context, const char* device_name,
-                            const char** underlying_devices,
-                            int num_underlying_devices, TF_Status* status) {
-  TFE_CustomDevice custom_device;
-  custom_device.copy_tensor_to_device = &CopyToParallelDevice;
-  custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
-  custom_device.delete_device = &DeleteParallelDevice;
-  custom_device.execute = &ParallelDeviceExecute;
+void AllocateParallelDevice(const char* device_name,
+                            const char* const* underlying_devices,
+                            int num_underlying_devices,
+                            TFE_CustomDevice* device, void** device_info) {
+  device->copy_tensor_to_device = &CopyToParallelDevice;
+  device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
+  device->delete_device = &DeleteParallelDevice;
+  device->execute = &ParallelDeviceExecute;
   std::vector<std::string> underlying_devices_vector;
   underlying_devices_vector.reserve(num_underlying_devices);
   for (int device_index = 0; device_index < num_underlying_devices;
        ++device_index) {
     underlying_devices_vector.push_back(underlying_devices[device_index]);
   }
-  ParallelDevice* d =
-      new ParallelDevice(device_name, underlying_devices_vector);
-  TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
+  *device_info = new ParallelDevice(device_name, underlying_devices_vector);
 }
 
 }  // namespace eager
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h
index b106524401f..f448a4c5b83 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device.h
@@ -16,12 +16,14 @@ limitations under the License.
 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
 
+#include "tensorflow/c/c_api.h"
 #include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
 
 namespace tensorflow {
 namespace eager {
 
-// Register a parallel device named `device_name` which forwards operations to
+// Allocate a parallel device named `device_name` which forwards operations to
 // `underlying_devices`, maintaining "parallel tensors" with components placed
 // on each underlying device.
 //
@@ -50,11 +52,12 @@ namespace eager {
 // TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
 // into its components.
 //
-// `context` owns the parallel device. `underlying_devices` must stay valid
-// while the parallel device is in use.
-void RegisterParallelDevice(TFE_Context* context, const char* device_name,
-                            const char** underlying_devices,
-                            int num_underlying_devices, TF_Status* status);
+// The filled `device` struct and the allocated `device_info` struct may be
+// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
+void AllocateParallelDevice(const char* device_name,
+                            const char* const* underlying_devices,
+                            int num_underlying_devices,
+                            TFE_CustomDevice* device, void** device_info);
 
 }  // namespace eager
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index 41c7d64e231..9b0613b0391 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -288,6 +288,19 @@ void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
             *static_cast<float*>(TF_TensorData(value_zero.get())));
 }
 
+template <std::size_t num_devices>
+void RegisterParallelDevice(
+    TFE_Context* context, const char* device_name,
+    const std::array<const char*, num_devices>& underlying_devices,
+    TF_Status* status) {
+  TFE_CustomDevice device;
+  void* device_info;
+  tensorflow::eager::AllocateParallelDevice(
+      device_name, underlying_devices.data(), underlying_devices.size(),
+      &device, &device_info);
+  TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
+}
+
 // Create and modify a variable placed on a parallel device which composes
 // `first_device` and `second_device`.
 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
@@ -297,9 +310,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
       TF_NewStatus(), TF_DeleteStatus);
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
   std::array<const char*, 2> underlying_devices{first_device, second_device};
-  tensorflow::eager::RegisterParallelDevice(
-      context, device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  RegisterParallelDevice(context, device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a variable handle (uninitialized to start) placed on the parallel
@@ -456,16 +468,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> underlying_devices;
   const char* first_device_name =
       "/job:localhost/replica:0/task:0/device:CPU:0";
-  underlying_devices.push_back(first_device_name);
   const char* second_device_name =
       "/job:localhost/replica:0/task:0/device:CPU:1";
-  underlying_devices.push_back(second_device_name);
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  std::array<const char*, 2> underlying_devices{first_device_name,
+                                                second_device_name};
+  RegisterParallelDevice(context.get(), device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
@@ -524,12 +534,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) {
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> underlying_devices;
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  std::array<const char*, 2> underlying_devices{
+      "/job:localhost/replica:0/task:0/device:CPU:0",
+      "/job:localhost/replica:0/task:0/device:CPU:1"};
+  RegisterParallelDevice(context.get(), device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create two vectors with different lengths
@@ -570,24 +579,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
   // Create a parallel device with two CPUs
   const char* first_device_name =
       "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> first_underlying_devices{
+  std::array<const char*, 2> first_underlying_devices{
       "/job:localhost/replica:0/task:0/device:CPU:0",
       "/job:localhost/replica:0/task:0/device:CPU:1"};
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), first_device_name, first_underlying_devices.data(),
-      first_underlying_devices.size(), status.get());
+  RegisterParallelDevice(context.get(), first_device_name,
+                         first_underlying_devices, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a second parallel device with the first parallel device and one
   // additional CPU.
   const char* second_device_name =
       "/job:localhost/replica:0/task:0/device:CUSTOM:1";
-  std::vector<const char*> second_underlying_devices{
+  std::array<const char*, 2> second_underlying_devices{
       "/job:localhost/replica:0/task:0/device:CUSTOM:0",
       "/job:localhost/replica:0/task:0/device:CPU:2"};
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), second_device_name, second_underlying_devices.data(),
-      second_underlying_devices.size(), status.get());
+  RegisterParallelDevice(context.get(), second_device_name,
+                         second_underlying_devices, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a tensor on the first parallel device
@@ -656,11 +663,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) {
   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> underlying_devices;
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  std::array<const char*, 1> underlying_devices{
+      "/job:localhost/replica:0/task:0/device:CPU:0"};
+  RegisterParallelDevice(context.get(), device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
@@ -775,12 +781,11 @@ TEST(PARALLEL_DEVICE, TestCollective) {
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> underlying_devices;
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  std::array<const char*, 2> underlying_devices{
+      "/job:localhost/replica:0/task:0/device:CPU:0",
+      "/job:localhost/replica:0/task:0/device:CPU:1"};
+  RegisterParallelDevice(context.get(), device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a tensor on the parallel device
@@ -867,12 +872,11 @@ TEST(PARALLEL_DEVICE, TestFunction) {
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  std::vector<const char*> underlying_devices;
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
-  underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
-  tensorflow::eager::RegisterParallelDevice(
-      context.get(), device_name, underlying_devices.data(),
-      underlying_devices.size(), status.get());
+  std::array<const char*, 2> underlying_devices{
+      "/job:localhost/replica:0/task:0/device:CPU:0",
+      "/job:localhost/replica:0/task:0/device:CPU:1"};
+  RegisterParallelDevice(context.get(), device_name, underlying_devices,
+                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   const char* function_name = "test_reduce_mul";
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 27a1fd9645c..ee0ee4dd95d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -8032,6 +8032,29 @@ py_binary(
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_parallel_device",
+    srcs = [
+        "lib/core/safe_ptr.h",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+        "//tensorflow/c/eager/parallel_device:headers",
+        "//tensorflow/c/eager/parallel_device:sources",
+        "//tensorflow/python/distribute/parallel_device:pywrap_parallel_device.cc",
+    ],
+    module_name = "_pywrap_parallel_device",
+    visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
+    deps = [
+        "//tensorflow/core:framework_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python:pybind11_lib",
+        "//tensorflow/python:pybind11_status",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
 pyx_library(
     name = "framework_fast_tensor_util",
     srcs = ["framework/fast_tensor_util.pyx"],
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
new file mode 100644
index 00000000000..e7526a56f66
--- /dev/null
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -0,0 +1,45 @@
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+# Pybind rules must live in tensorflow/python due to header rule visibility.
+exports_files(
+    ["pywrap_parallel_device.cc"],
+    visibility = ["//tensorflow/python:__pkg__"],
+)
+
+py_library(
+    name = "parallel_device",
+    srcs = ["parallel_device.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":saving",
+        "//tensorflow/python:_pywrap_parallel_device",
+    ],
+)
+
+py_library(
+    name = "saving",
+    srcs = ["saving.py"],
+    srcs_version = "PY2AND3",
+    deps = ["//tensorflow/python:framework_ops"],
+)
+
+py_test(
+    name = "parallel_device_test",
+    srcs = ["parallel_device_test.py"],
+    python_version = "PY3",
+    tags = [
+        # Dependencies aren't otherwise included in the pip package yet.
+        "no_pip",
+    ],
+    deps = [
+        ":parallel_device",
+        ":saving",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:collective_ops",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python/module",
+        "//tensorflow/python/tpu",
+    ],
+)
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py
new file mode 100644
index 00000000000..982b061cdb7
--- /dev/null
+++ b/tensorflow/python/distribute/parallel_device/parallel_device.py
@@ -0,0 +1,95 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility for eagerly executing operations in parallel on multiple devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import threading
+
+from tensorflow.python import _pywrap_parallel_device
+from tensorflow.python.distribute.parallel_device import saving
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.tpu.ops import tpu_ops
+
+_next_device_number = 0
+_next_device_number_lock = threading.Lock()
+
+
+# TODO(allenl): Expand this docstring once things like getting components on and
+# off the device are stable.
+class ParallelDevice(object):
+  """A device which executes operations in parallel."""
+
+  def __init__(self, components):
+    """Creates a device which executes operations in parallel on `components`.
+
+    Args:
+      components: A list of device names. Each operation executed on the
+        returned device executes on these component devices.
+
+    Returns:
+      A string with the name of the newly created device.
+    """
+    global _next_device_number, _next_device_number_lock
+    self.components = tuple(components)
+    ctx = context.context()
+    with _next_device_number_lock:
+      # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
+      # special-cased).
+      self.name = "{}/device:CUSTOM:{}".format(
+          ctx.host_address_space(), _next_device_number)
+      _next_device_number += 1
+    device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
+        self.name, self.components)
+    context.register_custom_device(device, self.name, device_info)
+
+  def pack(self, tensors):
+    """Create a tensor on the parallel device from a sequence of tensors.
+
+    Args:
+      tensors: A flat list of tensors, one per device in `self.components`.
+
+    Returns:
+      A single tensor placed on `self.name`.
+    """
+    with ops.device(self.name):
+      return tpu_ops.tpu_replicated_input(inputs=tensors)
+
+  def unpack(self, parallel_tensor):
+    """Unpack a parallel tensor into its components.
+
+    Args:
+      parallel_tensor: A tensor placed on `self.name`.
+
+    Returns:
+      A flat list of tensors, one per `self.components`.
+    """
+    with ops.device(self.name):
+      return tpu_ops.tpu_replicated_output(
+          parallel_tensor, num_replicas=len(self.components))
+
+  # TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
+  # to provide a hook for the custom device to create save specs/etc., then call
+  # that hook from the default variable implementation if the variable is on a
+  # custom device. We'll likely want similar hooks for repr() and such.
+  @contextlib.contextmanager
+  def scope(self):
+    """Runs ops in parallel, makes variables which save independent buffers."""
+    with ops.device(self.name), saving.independent_buffers(self):
+      yield
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
new file mode 100644
index 00000000000..d3f3417eca9
--- /dev/null
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -0,0 +1,254 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+
+from tensorflow.python.distribute.parallel_device import parallel_device
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.module import module
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.tracking import util as tracking
+from tensorflow.python.util import nest
+
+# When running collectives asynchronously, we need to give each parallel device
+# execution a unique ID so the collectives don't interfere. Since the op is
+# replicated with group/instance key intact, the replicated nodes will
+# communicate.
+# TODO(allenl): Switch to using a collective manager.
+_COUNTER_LOCK = threading.Lock()
+_COUNTER = 0
+
+
+def _collective_reduce(inputs, operation, num_replicas):
+
+  def _reduce_tensor(tensor):
+    with _COUNTER_LOCK:
+      global _COUNTER
+      keys = _COUNTER
+      _COUNTER += 1
+    return collective_ops.all_reduce(
+        t=tensor,
+        group_size=num_replicas,
+        merge_op=operation,
+        group_key=keys,
+        instance_key=keys,
+        final_op="Id")
+
+  return nest.map_structure(_reduce_tensor, inputs)
+
+
+def _collective_sum(inputs, num_replicas):
+  return _collective_reduce(
+      inputs=inputs, operation="Add", num_replicas=num_replicas)
+
+
+class _Dense(module.Module):
+
+  def __init__(self, output_size):
+    self.output_size = output_size
+    self.kernel = None
+    self.bias = None
+
+  def __call__(self, x):
+    if self.kernel is None:
+      self.kernel = variables.Variable(
+          array_ops.ones(
+              array_ops.stack([self.output_size,
+                               array_ops.shape(x)[-1]])))
+      self.bias = variables.Variable(array_ops.ones([self.output_size]))
+    return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias
+
+
+class _VirtualDeviceTestCase(test.TestCase):
+
+  def setUp(self):
+    super(_VirtualDeviceTestCase, self).setUp()
+    cpus = context.context().list_physical_devices("CPU")
+    # Set 4 virtual CPUs
+    context.context().set_logical_device_configuration(cpus[0], [
+        context.LogicalDeviceConfiguration(),
+        context.LogicalDeviceConfiguration(),
+        context.LogicalDeviceConfiguration(),
+        context.LogicalDeviceConfiguration()
+    ])
+
+    # TODO(allenl): Make CPU:0 and CPU:1 work (right now "CPU:1" soft-places
+    # onto CPU:0, which seems wrong).
+    components = [
+        "/job:localhost/replica:0/task:0/device:CPU:0",
+        "/job:localhost/replica:0/task:0/device:CPU:1"
+    ]
+    self.device = parallel_device.ParallelDevice(components)
+
+
+class ParallelDeviceTests(_VirtualDeviceTestCase):
+
+  def test_register_parallel_device(self):
+    with ops.device(self.device.name):
+      c = constant_op.constant(1.)
+      d = constant_op.constant(2.)
+      e = c + d
+      outputs = self.device.unpack(e)
+    self.assertAllClose([3., 3.], outputs)
+
+    self.assertIn(self.device.components[0], outputs[0].backing_device)
+    self.assertIn(self.device.components[1], outputs[1].backing_device)
+
+  def test_collective_reduce(self):
+    with ops.device(self.device.name):
+      x = self.device.pack(
+          [constant_op.constant(-1.5),
+           constant_op.constant(3.5)])
+      reduced = _collective_sum(x, num_replicas=2)
+      outputs = self.device.unpack(reduced)
+    self.assertAllClose([2., 2.], outputs)
+    self.assertIn(self.device.components[0], outputs[0].backing_device)
+    self.assertIn(self.device.components[1], outputs[1].backing_device)
+
+  def test_checkpointing(self):
+    prefix = os.path.join(self.get_temp_dir(), "ckpt")
+    with self.device.scope():
+      different_values = self.device.pack(
+          [constant_op.constant(-1.),
+           constant_op.constant(3.)])
+      v = variables.Variable(different_values)
+      checkpoint = tracking.Checkpoint(v=v)
+    save_path = checkpoint.save(prefix)
+    with ops.device(self.device.name):
+      v.assign(constant_op.constant(0.))
+    # Make sure the checkpoint is actually written before we try to read it
+    context.async_wait()
+    checkpoint.restore(save_path).assert_consumed()
+    with ops.device(self.device.name):
+      outputs = self.device.unpack(v)
+    self.assertAllClose([-1., 3.], outputs)
+
+
+class LayerTests(_VirtualDeviceTestCase):
+
+  def test_layer_forward(self):
+    with ops.device(self.device.name):
+      layer = _Dense(5)
+      x = constant_op.constant([[2.]])
+      y = layer(x)
+      outputs = self.device.unpack(y)
+    self.assertAllClose([[3.] * 5], outputs[0])
+    self.assertAllClose([[3.] * 5], outputs[1])
+    self.assertIn(self.device.components[0], outputs[0].backing_device)
+    self.assertIn(self.device.components[1], outputs[1].backing_device)
+
+    # With different Layer inputs we get different outputs
+    with ops.device(self.device.name):
+      x = self.device.pack(
+          [constant_op.constant([[-0.5]]),
+           constant_op.constant([[0.5]])])
+      y = layer(x)
+      outputs = self.device.unpack(y)
+    self.assertGreater(
+        math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5)
+    self.assertIn(self.device.components[0], outputs[0].backing_device)
+    self.assertIn(self.device.components[1], outputs[1].backing_device)
+
+  def test_layer_sync_training(self):
+    with ops.device(self.device.name):
+      layer = _Dense(5)
+
+      with backprop.GradientTape() as tape:
+        x = self.device.pack(
+            [constant_op.constant([[-0.5]]),
+             constant_op.constant([[0.5]])])
+        y = layer(x)
+        loss = (y - math_ops.range(5.))**2.
+      parameters = layer.trainable_variables
+      unreduced_gradients = tape.gradient(loss, parameters)
+      reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2)
+      for grad, param in zip(reduced_gradients, parameters):
+        param.assign_sub(0.01 * grad)
+    final_kernels = self.device.unpack(layer.kernel)
+    self.assertAllClose(final_kernels[0], final_kernels[1])
+    final_bias = self.device.unpack(layer.bias)
+    expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) -
+                     0.01 * 2. * (1. - .5 - math_ops.range(5.)))
+    self.assertAllClose(expected_bias, final_bias[0])
+    self.assertAllClose(expected_bias, final_bias[1])
+    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
+    self.assertIn(self.device.components[1], final_kernels[1].backing_device)
+
+  def test_layer_divergent_buffer_training(self):
+    with ops.device(self.device.name):
+      layer = _Dense(5)
+
+      with backprop.GradientTape() as tape:
+        x = self.device.pack(
+            [constant_op.constant([[-0.5]]),
+             constant_op.constant([[0.5]])])
+        y = layer(x)
+        loss = (y - math_ops.range(5.))**2.
+      parameters = layer.trainable_variables
+      unreduced_gradients = tape.gradient(loss, parameters)
+      for grad, param in zip(unreduced_gradients, parameters):
+        param.assign_sub(0.01 * grad)
+    final_kernels = self.device.unpack(layer.kernel)
+    self.assertNotAllClose(final_kernels[0], final_kernels[1])
+    final_bias = self.device.unpack(layer.bias)
+    self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)),
+                        final_bias[0])
+    self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)),
+                        final_bias[1])
+    self.assertIn(self.device.components[0], final_kernels[0].backing_device)
+    self.assertIn(self.device.components[1], final_kernels[1].backing_device)
+
+  def test_training_loop(self):
+    for _ in range(5):
+      layer = _Dense(5)
+      checkpoint = tracking.Checkpoint(layer=layer)
+      manager = checkpoint_management.CheckpointManager(
+          checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
+      manager.restore_or_initialize()
+
+      for _ in range(10):
+        with self.device.scope():
+          with backprop.GradientTape() as tape:
+            x = self.device.pack(
+                [constant_op.constant([[-0.5]]),
+                 constant_op.constant([[0.5]])])
+            y = layer(x)
+            loss = (y - math_ops.range(5.))**2.
+          parameters = layer.trainable_variables
+          unreduced_gradients = tape.gradient(loss, parameters)
+          reduced_gradients = _collective_sum(
+              unreduced_gradients, num_replicas=len(self.device.components))
+          for grad, param in zip(reduced_gradients, parameters):
+            param.assign_sub(0.01 * grad)
+
+        manager.save()
+
+
+if __name__ == "__main__":
+  ops.enable_eager_execution()
+  test.main()
diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
new file mode 100644
index 00000000000..62488cb31e7
--- /dev/null
+++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
@@ -0,0 +1,70 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Python.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/parallel_device/parallel_device.h"
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+#include "tensorflow/python/lib/core/pybind11_lib.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
+
+namespace py = pybind11;
+
+void CallDelete_Device(PyObject* capsule) {
+  delete reinterpret_cast<TFE_CustomDevice*>(
+      PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
+}
+
+void CallDelete_DeviceInfo(PyObject* capsule) {
+  void (*destructor)(void*) =
+      reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
+  destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
+}
+
+PYBIND11_MODULE(_pywrap_parallel_device, m) {
+  m.def("GetParallelDeviceCapsules",
+        [](const char* name, std::vector<std::string> underlying_devices) {
+          std::vector<const char*> underlying_devices_c;
+          underlying_devices_c.reserve(underlying_devices.size());
+          for (const std::string& element : underlying_devices) {
+            underlying_devices_c.push_back(element.c_str());
+          }
+          // `device` is owned by `device_capsule`.
+          TFE_CustomDevice* device = new TFE_CustomDevice;
+          tensorflow::Safe_PyObjectPtr device_capsule(
+              PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
+          void* device_info;
+          tensorflow::eager::AllocateParallelDevice(
+              name, underlying_devices_c.data(), underlying_devices_c.size(),
+              device, &device_info);
+          if (PyErr_Occurred()) throw py::error_already_set();
+          tensorflow::Safe_PyObjectPtr device_info_capsule(
+              PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
+                            &CallDelete_DeviceInfo));
+          if (PyErr_Occurred()) throw py::error_already_set();
+          // The PyCapsule destructor needs a pointer to the destructor for
+          // DeviceInfo.
+          PyCapsule_SetContext(device_info_capsule.get(),
+                               reinterpret_cast<void*>(device->delete_device));
+          return tensorflow::PyoOrThrow(
+              PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
+        });
+}
diff --git a/tensorflow/python/distribute/parallel_device/saving.py b/tensorflow/python/distribute/parallel_device/saving.py
new file mode 100644
index 00000000000..f2e7dadae41
--- /dev/null
+++ b/tensorflow/python/distribute/parallel_device/saving.py
@@ -0,0 +1,131 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special-cased checkpointing for variables on a parallel device."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import functools
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training.saving import saveable_object
+
+
+def _read_component(handle, dtype, replica_id, parallel_device):
+  """Read one component of a parallel variable and discard the rest."""
+  with ops.device(handle.device):
+    read = gen_resource_variable_ops.read_variable_op(
+        resource=handle, dtype=dtype)
+  all_components = parallel_device.unpack(read)
+  # We're pretending that parallel variables have a first axis with length
+  # num_components, so we need to add a dummy first axis to the shape that gets
+  # saved.
+  return all_components[replica_id][None, ...]
+
+
+class _ParallelDeviceSaveable(saveable_object.SaveableObject):
+  """Saves and restores a parallel variable."""
+
+  def __init__(self, name, handle, dtype, component_shape, parallel_device):
+    # Each component device gets one spec with a tensor to save.
+    specs = []
+    for replica_id, device_name in enumerate(parallel_device.components):
+      # TODO(b/151773535): SaveableObjects with SaveSpecs on different devices
+      # will cause extra copying at the moment. We should fix that before doing
+      # anything serious with this code.
+      specs.append(
+          saveable_object.SaveSpec(
+              tensor=functools.partial(
+                  _read_component,
+                  handle=handle,
+                  dtype=dtype,
+                  replica_id=replica_id,
+                  parallel_device=parallel_device),
+              slice_spec=variables.Variable.SaveSliceInfo(
+                  full_shape=([len(parallel_device.components)] +
+                              component_shape),
+                  var_offset=[replica_id] + [0] * len(component_shape),
+                  var_shape=[1] + component_shape).spec,
+              device=device_name,
+              dtype=dtype,
+              name=name))
+    self._handle = handle
+    self._parallel_device = parallel_device
+    self._component_shape = component_shape
+    super(_ParallelDeviceSaveable, self).__init__(None, specs, name)
+
+  def restore(self, tensors, restored_shapes=None):
+    with ops.device(self._handle.device):
+      # Combine the restored tensors into one parallel tensor to assign.
+      bundled = self._parallel_device.pack(tensors)
+      gen_resource_variable_ops.assign_variable_op(
+          resource=self._handle,
+          # Squeeze out the dummy first axis we added when saving.
+          value=array_ops.squeeze(bundled, axis=0))
+
+
+class VariableWithFixedCheckpointing(resource_variable_ops.ResourceVariable):
+  """Overrides checkpointing behavior to save like a partitioned variable."""
+
+  def __init__(self, parallel_device, **kwargs):
+    self._parallel_device = parallel_device
+    kwargs = {k: v for k, v in kwargs.items()
+              if k not in ["use_resource", "expected_shape"]}
+    super(VariableWithFixedCheckpointing, self).__init__(**kwargs)
+
+  def _gather_saveables_for_checkpoint(self):
+    # Note VARIABLE_VALUE is the usual attribute name for variables. Using
+    # something different means (a) the checkpointing infrastructure won't try
+    # doing restore-on-create (which has shape issues), and (b) the saved
+    # variables won't be compatible with regular variables. Both of those are
+    # good in this case.
+    return dict(
+        PARALLEL_VARIABLE_VALUE=functools.partial(
+            _ParallelDeviceSaveable,
+            handle=self.handle,
+            dtype=self.dtype,
+            component_shape=self.shape,
+            parallel_device=self._parallel_device))
+
+
+def _variable_creator(next_creator, parallel_device, **kwargs):
+  del next_creator
+  return VariableWithFixedCheckpointing(
+      parallel_device=parallel_device, **kwargs)
+
+
+@contextlib.contextmanager
+def independent_buffers(parallel_device):
+  """Context manager which saves parallel buffers independently.
+
+  Creates a ParallelDevice-aware variable subclass which saves buffers for each
+  device separately.
+
+  Args:
+    parallel_device: A ParallelDevice object on which variables are placed.
+
+  Yields:
+    Nothing.
+  """
+  with variable_scope.variable_creator_scope(
+      functools.partial(_variable_creator, parallel_device=parallel_device)):
+    yield