Add ability to allocate an uninitialized tensor of a particular shape
through XRT. PiperOrigin-RevId: 254250869
This commit is contained in:
parent
a88ea5a35c
commit
b57953d5d1
@ -37,6 +37,15 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
|
|||||||
.HostMemory("handle"),
|
.HostMemory("handle"),
|
||||||
XRTAllocateOp<XRTGenericDeviceAccessor>);
|
XRTAllocateOp<XRTGenericDeviceAccessor>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||||
|
.Device(DEVICE_XLA_GPU)
|
||||||
|
.HostMemory("handle"),
|
||||||
|
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||||
|
.Device(DEVICE_XLA_CPU)
|
||||||
|
.HostMemory("handle"),
|
||||||
|
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||||
.Device(DEVICE_XLA_GPU)
|
.Device(DEVICE_XLA_GPU)
|
||||||
.HostMemory("inputs")
|
.HostMemory("inputs")
|
||||||
|
@ -205,6 +205,50 @@ class XRTAllocateOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Op that allocates uninitialized memory on the device for a tensor of
|
||||||
|
// a particular shape.
|
||||||
|
template <class DeviceAccessor>
|
||||||
|
class XRTAllocateUninitializedOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
|
||||||
|
}
|
||||||
|
~XRTAllocateUninitializedOp() override = default;
|
||||||
|
XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
|
||||||
|
XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
|
||||||
|
delete;
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
VLOG(1) << "XRTAllocateUninitializedOp::Compute";
|
||||||
|
ResourceMgr* rm;
|
||||||
|
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
|
||||||
|
|
||||||
|
// We are guaranteed that the underlying device object won't be deleted out
|
||||||
|
// from under us, while the ScopedRef is live.
|
||||||
|
class DeviceAccessor::ScopedRef device_ref;
|
||||||
|
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
|
||||||
|
|
||||||
|
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
|
||||||
|
XRTTupleAllocation* allocation;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
XRTTupleAllocation::CreateUninitialized(
|
||||||
|
xla_shape_, memory_manager.get(), device_ref.backend(),
|
||||||
|
device_ref.device_ordinal(), &allocation));
|
||||||
|
|
||||||
|
Tensor output(DT_INT64, TensorShape({}));
|
||||||
|
output.scalar<int64>()() = memory_manager->Register(allocation);
|
||||||
|
ctx->set_output(0, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
TensorShape tf_shape_;
|
||||||
|
xla::Shape xla_shape_;
|
||||||
|
};
|
||||||
|
|
||||||
// Op that allocates memory for a tensor (with optional layout) and transfers it
|
// Op that allocates memory for a tensor (with optional layout) and transfers it
|
||||||
// to the device, returning an allocation handle.
|
// to the device, returning an allocation handle.
|
||||||
template <class DeviceAccessor>
|
template <class DeviceAccessor>
|
||||||
|
@ -32,6 +32,21 @@ Reads a literal proto and transfers it to device memory.
|
|||||||
'handle' is an id that can be used in other ops to refer to the allocation.
|
'handle' is an id that can be used in other ops to refer to the allocation.
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
REGISTER_OP("XRTAllocateUninitialized")
|
||||||
|
.Output("handle: int64")
|
||||||
|
.Attr("dtype: type")
|
||||||
|
.Attr("shape: shape")
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||||
|
.Doc(
|
||||||
|
R"(
|
||||||
|
Allocates a tensor to hold the specified shape in device memory. The values
|
||||||
|
in the tensor are left uninitialized.
|
||||||
|
|
||||||
|
shape: The shapes which the tensor should have on device.
|
||||||
|
|
||||||
|
handle: An id that can be used in other ops to refer to the allocation.
|
||||||
|
)");
|
||||||
|
|
||||||
REGISTER_OP("XRTAllocateFromTensor")
|
REGISTER_OP("XRTAllocateFromTensor")
|
||||||
.Input("inputs: dtypes")
|
.Input("inputs: dtypes")
|
||||||
.Output("handle: int64")
|
.Output("handle: int64")
|
||||||
|
@ -320,6 +320,72 @@ TEST(RawApiTest, AllocFromTensor) {
|
|||||||
EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
|
EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RawApiTest, AllocUninitialized) {
|
||||||
|
xla::Literal literal =
|
||||||
|
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
|
||||||
|
Tensor tensor;
|
||||||
|
TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||||
|
std::vector<int> layout =
|
||||||
|
GetAttrLayout(literal.shape().layout().minor_to_major());
|
||||||
|
|
||||||
|
auto allocate_op =
|
||||||
|
ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
|
||||||
|
|
||||||
|
Tensor handle;
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
XrtClientSession session(root);
|
||||||
|
// Allocate the tensor
|
||||||
|
{
|
||||||
|
TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
|
||||||
|
handle = outputs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure it has the expected shape
|
||||||
|
{
|
||||||
|
auto read_back_op = ops::XRTReadLiteral(root, handle);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
|
||||||
|
EXPECT_EQ(outputs.size(), 1);
|
||||||
|
xla::LiteralProto read_back_literal;
|
||||||
|
EXPECT_TRUE(
|
||||||
|
read_back_literal.ParseFromString(outputs[0].scalar<string>()()));
|
||||||
|
Tensor read_back_tensor;
|
||||||
|
TF_ASSERT_OK(LiteralToHostTensor(
|
||||||
|
xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
|
||||||
|
&read_back_tensor));
|
||||||
|
|
||||||
|
// The shape should be the same as 'tensor', but we don't have any
|
||||||
|
// expectation about the value of the tensors yet since it is uninitialized
|
||||||
|
EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we can write to it
|
||||||
|
xla::LiteralProto new_literal =
|
||||||
|
xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
|
||||||
|
{
|
||||||
|
auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||||
|
new_literal.SerializeAsString());
|
||||||
|
auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
TF_EXPECT_OK(session.Run({write_op}, &outputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now read it back
|
||||||
|
{
|
||||||
|
auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
|
||||||
|
EXPECT_EQ(outputs.size(), 1);
|
||||||
|
|
||||||
|
xla::LiteralProto response;
|
||||||
|
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||||
|
EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, AllocFromTensorTuple) {
|
TEST(RawApiTest, AllocFromTensorTuple) {
|
||||||
xla::Literal literal0 =
|
xla::Literal literal0 =
|
||||||
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
|
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
|
||||||
|
@ -194,6 +194,29 @@ void XRTTupleAllocation::ReleaseBuffers() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ Status XRTTupleAllocation::CreateUninitialized(
|
||||||
|
const xla::Shape& shape, XRTMemoryManager* memory_manager,
|
||||||
|
xla::Backend* backend, int device_ordinal,
|
||||||
|
XRTTupleAllocation** allocation) {
|
||||||
|
std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
|
||||||
|
TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
|
||||||
|
memory_manager, backend, device_ordinal, shape, &scoped_buffer));
|
||||||
|
|
||||||
|
// By releasing the ScopedShapedBuffer we ensure that the underlying storage
|
||||||
|
// won't be freed when the buffer goes out of scope at the end of this
|
||||||
|
// call. To avoid a leak, there must be no error-case returns from here until
|
||||||
|
// the end of the method.
|
||||||
|
auto shaped_buffer = scoped_buffer->release();
|
||||||
|
*allocation = new XRTTupleAllocation(
|
||||||
|
device_ordinal, backend->memory_allocator(),
|
||||||
|
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
|
||||||
|
(*allocation)
|
||||||
|
->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
|
||||||
|
device_ordinal);
|
||||||
|
(*allocation)->SetDeviceMemorySize();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
|
/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
|
||||||
const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
|
const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
|
||||||
int device_ordinal, XRTTupleAllocation** allocation) {
|
int device_ordinal, XRTTupleAllocation** allocation) {
|
||||||
|
@ -84,6 +84,14 @@ class XRTTupleAllocation : public core::RefCounted {
|
|||||||
xla::Backend* backend, int device_ordinal,
|
xla::Backend* backend, int device_ordinal,
|
||||||
XRTTupleAllocation** allocation);
|
XRTTupleAllocation** allocation);
|
||||||
|
|
||||||
|
// Allocates new device memory buffers sufficient to store a tensor of
|
||||||
|
// the specified shape, and returns a XRTTupleAllocation handle to the
|
||||||
|
// allocated buffers. The allocated buffers are not initialized.
|
||||||
|
static Status CreateUninitialized(const xla::Shape& shape,
|
||||||
|
XRTMemoryManager* memory_manager,
|
||||||
|
xla::Backend* backend, int device_ordinal,
|
||||||
|
XRTTupleAllocation** allocation);
|
||||||
|
|
||||||
// Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
|
// Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
|
||||||
static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
|
static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
|
||||||
xla::Backend* backend, int device_ordinal,
|
xla::Backend* backend, int device_ordinal,
|
||||||
|
Loading…
Reference in New Issue
Block a user