From 08cbdcdc920053e3937183a456132c5ac4814a7c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Dec 2018 10:40:47 -0800 Subject: [PATCH] Introduced a new XRTWriteLiteral op to allow a client to overwrite the values stored in device memory. PiperOrigin-RevId: 224004631 --- .../compiler/xrt/kernels/xrt_state_ops.cc | 13 +++++ .../compiler/xrt/kernels/xrt_state_ops.h | 50 ++++++++++++++++++ tensorflow/compiler/xrt/ops/xrt_state_ops.cc | 14 +++++ tensorflow/compiler/xrt/tests/raw_api_test.cc | 52 ++++++++++++++++++- tensorflow/compiler/xrt/xrt_state.cc | 14 +++++ tensorflow/compiler/xrt/xrt_state.h | 3 ++ 6 files changed, 145 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d..3258286c106 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adc..26a58fa42d8 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343..a3d63106fa1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index aa0eabfce7e..abaa17e50e3 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -215,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa..5c7c537c340 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -183,6 +183,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f..3664c0cd4e6 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner();