diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 9020fe8ea78..6eab3716391 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -37,6 +37,15 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate") .HostMemory("handle"), XRTAllocateOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle"), + XRTAllocateUninitializedOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle"), + XRTAllocateUninitializedOp); + REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") .Device(DEVICE_XLA_GPU) .HostMemory("inputs") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index c3511b1d5d4..2ffde52af06 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -205,6 +205,50 @@ class XRTAllocateOp : public OpKernel { } }; +// Op that allocates uninitialized memory on the device for a tensor of +// a particular shape. +template +class XRTAllocateUninitializedOp : public OpKernel { + public: + explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_)); + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_)); + } + ~XRTAllocateUninitializedOp() override = default; + XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete; + XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTAllocateUninitializedOp::Compute"; + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + + RefPtr memory_manager = XRTMemoryManager::Get(rm); + XRTTupleAllocation* allocation; + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::CreateUninitialized( + xla_shape_, memory_manager.get(), device_ref.backend(), + device_ref.device_ordinal(), &allocation)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = memory_manager->Register(allocation); + ctx->set_output(0, output); + } + + private: + DataType dtype_; + TensorShape tf_shape_; + xla::Shape xla_shape_; +}; + // Op that allocates memory for a tensor (with optional layout) and transfers it // to the device, returning an allocation handle. template diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 6d4e70fad53..49a2656a0f9 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -32,6 +32,21 @@ Reads a literal proto and transfers it to device memory. 'handle' is an id that can be used in other ops to refer to the allocation. )"); +REGISTER_OP("XRTAllocateUninitialized") + .Output("handle: int64") + .Attr("dtype: type") + .Attr("shape: shape") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Allocates a tensor to hold the specified shape in device memory. The values +in the tensor are left uninitialized. + +shape: The shapes which the tensor should have on device. + +handle: An id that can be used in other ops to refer to the allocation. +)"); + REGISTER_OP("XRTAllocateFromTensor") .Input("inputs: dtypes") .Output("handle: int64") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index b5108acff16..f0729251eeb 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -320,6 +320,72 @@ TEST(RawApiTest, AllocFromTensor) { EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } +TEST(RawApiTest, AllocUninitialized) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = + GetAttrLayout(literal.shape().layout().minor_to_major()); + + auto allocate_op = + ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape()); + + Tensor handle; + std::vector outputs; + XrtClientSession session(root); + // Allocate the tensor + { + TF_EXPECT_OK(session.Run({allocate_op}, &outputs)); + handle = outputs[0]; + } + + // Make sure it has the expected shape + { + auto read_back_op = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run({read_back_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + xla::LiteralProto read_back_literal; + EXPECT_TRUE( + read_back_literal.ParseFromString(outputs[0].scalar()())); + Tensor read_back_tensor; + TF_ASSERT_OK(LiteralToHostTensor( + xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT, + &read_back_tensor)); + + // The shape should be the same as 'tensor', but we don't have any + // expectation about the value of the tensors yet since it is uninitialized + EXPECT_EQ(tensor.shape(), read_back_tensor.shape()); + } + + // Make sure we can write to it + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto(); + { + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + } + + // Now read it back + { + auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({read_back_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(response, new_literal)); + } +} + TEST(RawApiTest, AllocFromTensorTuple) { xla::Literal literal0 = xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index b47d2b61b6f..4ad652edb9a 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -194,6 +194,29 @@ void XRTTupleAllocation::ReleaseBuffers() { return Status::OK(); } +/*static*/ Status XRTTupleAllocation::CreateUninitialized( + const xla::Shape& shape, XRTMemoryManager* memory_manager, + xla::Backend* backend, int device_ordinal, + XRTTupleAllocation** allocation) { + std::unique_ptr scoped_buffer; + TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer( + memory_manager, backend, device_ordinal, shape, &scoped_buffer)); + + // By releasing the ScopedShapedBuffer we ensure that the underlying storage + // won't be freed when the buffer goes out of scope at the end of this + // call. To avoid a leak, there must be no error-case returns from here until + // the end of the method. + auto shaped_buffer = scoped_buffer->release(); + *allocation = new XRTTupleAllocation( + device_ordinal, backend->memory_allocator(), + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape()); + (*allocation) + ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(), + device_ordinal); + (*allocation)->SetDeviceMemorySize(); + return Status::OK(); +} + /*static*/ Status XRTTupleAllocation::CreateFromBuffer( const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation) { diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 929c77b3f5c..287015b83e6 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -84,6 +84,14 @@ class XRTTupleAllocation : public core::RefCounted { xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation); + // Allocates new device memory buffers sufficient to store a tensor of + // the specified shape, and returns a XRTTupleAllocation handle to the + // allocated buffers. The allocated buffers are not initialized. + static Status CreateUninitialized(const xla::Shape& shape, + XRTMemoryManager* memory_manager, + xla::Backend* backend, int device_ordinal, + XRTTupleAllocation** allocation); + // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend, int device_ordinal,