diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
index ffea592491d..3258286c106 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
@@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
                             .HostMemory("literal"),
                         XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
 
+REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
+                            .Device(DEVICE_XLA_GPU)
+                            .HostMemory("handle")
+                            .HostMemory("literal")
+                            .HostMemory("output_handle"),
+                        XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
+                            .Device(DEVICE_XLA_CPU)
+                            .HostMemory("handle")
+                            .HostMemory("literal")
+                            .HostMemory("output_handle"),
+                        XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
+
 REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
                             .Device(DEVICE_XLA_GPU)
                             .HostMemory("handle")
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
index 54b06558adc..26a58fa42d8 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel {
   }
 };
 
+// Op that writes a new literal value into device-resident memory.
+template <class DeviceAccessor>
+class XRTWriteLiteralOp : public OpKernel {
+ public:
+  explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+  ~XRTWriteLiteralOp() override = default;
+  XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
+  XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
+
+  void Compute(OpKernelContext* ctx) override {
+    VLOG(1) << "XRTWriteLiteralOp::Compute";
+
+    const Tensor& handle_tensor = ctx->input(0);
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+        errors::Internal("computation input should be an int64 scalar"));
+    int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+    const Tensor& literal_info = ctx->input(1);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
+                errors::Internal("literal input should be a string scalar"));
+    xla::LiteralProto literal_proto;
+    OP_REQUIRES(ctx,
+                literal_proto.ParseFromString(literal_info.scalar<string>()()),
+                errors::InvalidArgument(
+                    "Unable to parse allocation input to LiteralProto"));
+    xla::Literal literal;
+    OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
+
+    ResourceMgr* rm;
+    OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+    XRTTupleAllocation* allocation;
+    OP_REQUIRES_OK(
+        ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+    core::ScopedUnref allocation_unref(allocation);
+    // We are guaranteed that the underlying device object won't be deleted out
+    // from under us, while the ScopedRef is live.
+    typename DeviceAccessor::ScopedRef device_ref;
+    OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
+                            ctx, allocation->device_ordinal(), &device_ref));
+    OP_REQUIRES_OK(ctx,
+                   allocation->WriteLiteral(device_ref.backend(), literal));
+
+    Tensor output(DT_INT64, TensorShape({}));
+    output.scalar<int64>()() = allocation_handle;
+    ctx->set_output(0, output);
+  }
+};
+
 // Op that discards a handle to device memory.
 template <class DeviceAccessor>
 class XRTReleaseAllocationOp : public OpKernel {
diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
index 07d025ce343..a3d63106fa1 100644
--- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
+++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
@@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal.
 'literal' is a serialized xla::LiteralProto proto.
 )");
 
+REGISTER_OP("XRTWriteLiteral")
+    .Input("handle: int64")
+    .Input("literal: string")
+    .Output("output_handle: int64")
+    .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+    .Doc(
+        R"(
+Copies the input literal into the device memory pointed to by handle.
+Returns the handle itself.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto to be written to device memory.
+)");
+
 REGISTER_OP("XRTReadLiteralAndRelease")
     .Input("handle: int64")
     .Output("literal: string")
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index aa0eabfce7e..abaa17e50e3 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a,
   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
   bool equal = l_a == l_b;
   if (!equal) {
-    LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
+    LOG(INFO) << "LiteralProtos don't match: " << a.DebugString()
               << " != " << b.DebugString();
   }
   return equal;
@@ -215,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape(
       ->ComputeProgramShape();
 }
 
+TEST(RawApiTest, AllocAndRewrite) {
+  xrt::XLAAllocation alloc;
+  alloc.set_device_ordinal(0);
+  *alloc.mutable_value() =
+      xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
+
+  Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+  auto value =
+      ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+  auto handle = ops::XRTAllocate(root, value);
+  auto read_back = ops::XRTReadLiteral(root, handle);
+  TF_ASSERT_OK(root.status());
+
+  tensorflow::ClientSession session(root);
+  std::vector<tensorflow::Tensor> outputs;
+  TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
+  EXPECT_EQ(outputs.size(), 2);
+
+  int64 allocation_handle = outputs[1].scalar<int64>()();
+  xla::LiteralProto response;
+  EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+  outputs.clear();
+
+  xla::LiteralProto new_literal =
+      xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
+  auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
+                              new_literal.SerializeAsString());
+  auto write_op =
+      ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
+  TF_ASSERT_OK(root.status());
+  TF_EXPECT_OK(session.Run({write_op}, &outputs));
+  EXPECT_EQ(outputs.size(), 1);
+  EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
+  outputs.clear();
+
+  auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
+  TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
+  EXPECT_EQ(outputs.size(), 1);
+
+  xla::LiteralProto new_response;
+  EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
+  EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
+
+  auto release =
+      ops::XRTReleaseAllocationHandle(root, Input(allocation_handle));
+  TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
+                           &outputs));
+}
+
 TEST(RawApiTest, ReadAndWriteState) {
   xrt::XLAAllocation alloc;
   alloc.set_device_ordinal(0);
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 3a99820d7aa..5c7c537c340 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -183,6 +183,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
   return Status::OK();
 }
 
+Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
+                                        const xla::Literal& literal) {
+  if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
+    return errors::InvalidArgument(
+        "New literal shape not matching the existing one: literal=",
+        xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
+        " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
+  }
+  auto transfer_manager = backend->transfer_manager();
+  TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
+  return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
+                                                   ToShapedBuffer());
+}
+
 void XRTTupleAllocation::DiscardAllocation(
     const xla::ShapeIndex& buffer_index) {
   buffers_.element(buffer_index)->DiscardAllocation();
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
index 73b5584e38f..3664c0cd4e6 100644
--- a/tensorflow/compiler/xrt/xrt_state.h
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase {
   Status ToLiteral(xla::Backend* backend, int device_ordinal,
                    xla::Literal* literal);
 
+  // Write a new literal value to the allocation.
+  Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
+
   // True if none of the buffers in the allocation are aliased by any other live
   // handle.
   bool IsExclusiveOwner();