Add ability to allocate an uninitialized tensor of a particular shape
through XRT. PiperOrigin-RevId: 254250869
This commit is contained in:
parent
a88ea5a35c
commit
b57953d5d1
@ -37,6 +37,15 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateOp<XRTGenericDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_XLA_CPU)
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
.HostMemory("inputs")
|
||||
|
@ -205,6 +205,50 @@ class XRTAllocateOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
// Op that allocates uninitialized memory on the device for a tensor of
|
||||
// a particular shape.
|
||||
template <class DeviceAccessor>
|
||||
class XRTAllocateUninitializedOp : public OpKernel {
|
||||
public:
|
||||
explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
|
||||
}
|
||||
~XRTAllocateUninitializedOp() override = default;
|
||||
XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
|
||||
XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
|
||||
delete;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "XRTAllocateUninitializedOp::Compute";
|
||||
ResourceMgr* rm;
|
||||
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
|
||||
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
class DeviceAccessor::ScopedRef device_ref;
|
||||
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
|
||||
|
||||
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
|
||||
XRTTupleAllocation* allocation;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
XRTTupleAllocation::CreateUninitialized(
|
||||
xla_shape_, memory_manager.get(), device_ref.backend(),
|
||||
device_ref.device_ordinal(), &allocation));
|
||||
|
||||
Tensor output(DT_INT64, TensorShape({}));
|
||||
output.scalar<int64>()() = memory_manager->Register(allocation);
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
TensorShape tf_shape_;
|
||||
xla::Shape xla_shape_;
|
||||
};
|
||||
|
||||
// Op that allocates memory for a tensor (with optional layout) and transfers it
|
||||
// to the device, returning an allocation handle.
|
||||
template <class DeviceAccessor>
|
||||
|
@ -32,6 +32,21 @@ Reads a literal proto and transfers it to device memory.
|
||||
'handle' is an id that can be used in other ops to refer to the allocation.
|
||||
)");
|
||||
|
||||
REGISTER_OP("XRTAllocateUninitialized")
|
||||
.Output("handle: int64")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||
.Doc(
|
||||
R"(
|
||||
Allocates a tensor to hold the specified shape in device memory. The values
|
||||
in the tensor are left uninitialized.
|
||||
|
||||
shape: The shapes which the tensor should have on device.
|
||||
|
||||
handle: An id that can be used in other ops to refer to the allocation.
|
||||
)");
|
||||
|
||||
REGISTER_OP("XRTAllocateFromTensor")
|
||||
.Input("inputs: dtypes")
|
||||
.Output("handle: int64")
|
||||
|
@ -320,6 +320,72 @@ TEST(RawApiTest, AllocFromTensor) {
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AllocUninitialized) {
|
||||
xla::Literal literal =
|
||||
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
|
||||
Tensor tensor;
|
||||
TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
std::vector<int> layout =
|
||||
GetAttrLayout(literal.shape().layout().minor_to_major());
|
||||
|
||||
auto allocate_op =
|
||||
ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
|
||||
|
||||
Tensor handle;
|
||||
std::vector<Tensor> outputs;
|
||||
XrtClientSession session(root);
|
||||
// Allocate the tensor
|
||||
{
|
||||
TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
|
||||
handle = outputs[0];
|
||||
}
|
||||
|
||||
// Make sure it has the expected shape
|
||||
{
|
||||
auto read_back_op = ops::XRTReadLiteral(root, handle);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
|
||||
EXPECT_EQ(outputs.size(), 1);
|
||||
xla::LiteralProto read_back_literal;
|
||||
EXPECT_TRUE(
|
||||
read_back_literal.ParseFromString(outputs[0].scalar<string>()()));
|
||||
Tensor read_back_tensor;
|
||||
TF_ASSERT_OK(LiteralToHostTensor(
|
||||
xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
|
||||
&read_back_tensor));
|
||||
|
||||
// The shape should be the same as 'tensor', but we don't have any
|
||||
// expectation about the value of the tensors yet since it is uninitialized
|
||||
EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
|
||||
}
|
||||
|
||||
// Make sure we can write to it
|
||||
xla::LiteralProto new_literal =
|
||||
xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
|
||||
{
|
||||
auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||
new_literal.SerializeAsString());
|
||||
auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
|
||||
TF_ASSERT_OK(root.status());
|
||||
TF_EXPECT_OK(session.Run({write_op}, &outputs));
|
||||
}
|
||||
|
||||
// Now read it back
|
||||
{
|
||||
auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
|
||||
TF_ASSERT_OK(root.status());
|
||||
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
|
||||
EXPECT_EQ(outputs.size(), 1);
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AllocFromTensorTuple) {
|
||||
xla::Literal literal0 =
|
||||
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
|
||||
|
@ -194,6 +194,29 @@ void XRTTupleAllocation::ReleaseBuffers() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTupleAllocation::CreateUninitialized(
|
||||
const xla::Shape& shape, XRTMemoryManager* memory_manager,
|
||||
xla::Backend* backend, int device_ordinal,
|
||||
XRTTupleAllocation** allocation) {
|
||||
std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
|
||||
TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
|
||||
memory_manager, backend, device_ordinal, shape, &scoped_buffer));
|
||||
|
||||
// By releasing the ScopedShapedBuffer we ensure that the underlying storage
|
||||
// won't be freed when the buffer goes out of scope at the end of this
|
||||
// call. To avoid a leak, there must be no error-case returns from here until
|
||||
// the end of the method.
|
||||
auto shaped_buffer = scoped_buffer->release();
|
||||
*allocation = new XRTTupleAllocation(
|
||||
device_ordinal, backend->memory_allocator(),
|
||||
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
|
||||
(*allocation)
|
||||
->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
|
||||
device_ordinal);
|
||||
(*allocation)->SetDeviceMemorySize();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
|
||||
const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
|
||||
int device_ordinal, XRTTupleAllocation** allocation) {
|
||||
|
@ -84,6 +84,14 @@ class XRTTupleAllocation : public core::RefCounted {
|
||||
xla::Backend* backend, int device_ordinal,
|
||||
XRTTupleAllocation** allocation);
|
||||
|
||||
// Allocates new device memory buffers sufficient to store a tensor of
|
||||
// the specified shape, and returns a XRTTupleAllocation handle to the
|
||||
// allocated buffers. The allocated buffers are not initialized.
|
||||
static Status CreateUninitialized(const xla::Shape& shape,
|
||||
XRTMemoryManager* memory_manager,
|
||||
xla::Backend* backend, int device_ordinal,
|
||||
XRTTupleAllocation** allocation);
|
||||
|
||||
// Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
|
||||
static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
|
||||
xla::Backend* backend, int device_ordinal,
|
||||
|
Loading…
Reference in New Issue
Block a user