Add ability to allocate an uninitialized tensor of a particular shape

through XRT.

PiperOrigin-RevId: 254250869
This commit is contained in:
Jeffrey A. Dean 2019-06-20 12:38:59 -07:00 committed by TensorFlower Gardener
parent a88ea5a35c
commit b57953d5d1
6 changed files with 165 additions and 0 deletions

View File

@ -37,6 +37,15 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
.HostMemory("handle"), .HostMemory("handle"),
XRTAllocateOp<XRTGenericDeviceAccessor>); XRTAllocateOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
.Device(DEVICE_XLA_GPU)
.HostMemory("handle"),
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
.Device(DEVICE_XLA_CPU)
.HostMemory("handle"),
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
.Device(DEVICE_XLA_GPU) .Device(DEVICE_XLA_GPU)
.HostMemory("inputs") .HostMemory("inputs")

View File

@ -205,6 +205,50 @@ class XRTAllocateOp : public OpKernel {
} }
}; };
// Op that allocates uninitialized memory on the device for a tensor of
// a particular shape.
template <class DeviceAccessor>
class XRTAllocateUninitializedOp : public OpKernel {
public:
explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
}
~XRTAllocateUninitializedOp() override = default;
XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
delete;
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "XRTAllocateUninitializedOp::Compute";
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx,
XRTTupleAllocation::CreateUninitialized(
xla_shape_, memory_manager.get(), device_ref.backend(),
device_ref.device_ordinal(), &allocation));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = memory_manager->Register(allocation);
ctx->set_output(0, output);
}
private:
DataType dtype_;
TensorShape tf_shape_;
xla::Shape xla_shape_;
};
// Op that allocates memory for a tensor (with optional layout) and transfers it // Op that allocates memory for a tensor (with optional layout) and transfers it
// to the device, returning an allocation handle. // to the device, returning an allocation handle.
template <class DeviceAccessor> template <class DeviceAccessor>

View File

@ -32,6 +32,21 @@ Reads a literal proto and transfers it to device memory.
'handle' is an id that can be used in other ops to refer to the allocation. 'handle' is an id that can be used in other ops to refer to the allocation.
)"); )");
REGISTER_OP("XRTAllocateUninitialized")
.Output("handle: int64")
.Attr("dtype: type")
.Attr("shape: shape")
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
.Doc(
R"(
Allocates a tensor to hold the specified shape in device memory. The values
in the tensor are left uninitialized.
shape: The shapes which the tensor should have on device.
handle: An id that can be used in other ops to refer to the allocation.
)");
REGISTER_OP("XRTAllocateFromTensor") REGISTER_OP("XRTAllocateFromTensor")
.Input("inputs: dtypes") .Input("inputs: dtypes")
.Output("handle: int64") .Output("handle: int64")

View File

@ -320,6 +320,72 @@ TEST(RawApiTest, AllocFromTensor) {
EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
} }
TEST(RawApiTest, AllocUninitialized) {
xla::Literal literal =
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
Tensor tensor;
TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
std::vector<int> layout =
GetAttrLayout(literal.shape().layout().minor_to_major());
auto allocate_op =
ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
Tensor handle;
std::vector<Tensor> outputs;
XrtClientSession session(root);
// Allocate the tensor
{
TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
handle = outputs[0];
}
// Make sure it has the expected shape
{
auto read_back_op = ops::XRTReadLiteral(root, handle);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
EXPECT_EQ(outputs.size(), 1);
xla::LiteralProto read_back_literal;
EXPECT_TRUE(
read_back_literal.ParseFromString(outputs[0].scalar<string>()()));
Tensor read_back_tensor;
TF_ASSERT_OK(LiteralToHostTensor(
xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
&read_back_tensor));
// The shape should be the same as 'tensor', but we don't have any
// expectation about the value of the tensors yet since it is uninitialized
EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
}
// Make sure we can write to it
xla::LiteralProto new_literal =
xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
{
auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
new_literal.SerializeAsString());
auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({write_op}, &outputs));
}
// Now read it back
{
auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
EXPECT_EQ(outputs.size(), 1);
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
}
}
TEST(RawApiTest, AllocFromTensorTuple) { TEST(RawApiTest, AllocFromTensorTuple) {
xla::Literal literal0 = xla::Literal literal0 =
xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}}); xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});

View File

@ -194,6 +194,29 @@ void XRTTupleAllocation::ReleaseBuffers() {
return Status::OK(); return Status::OK();
} }
/*static*/ Status XRTTupleAllocation::CreateUninitialized(
const xla::Shape& shape, XRTMemoryManager* memory_manager,
xla::Backend* backend, int device_ordinal,
XRTTupleAllocation** allocation) {
std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
memory_manager, backend, device_ordinal, shape, &scoped_buffer));
// By releasing the ScopedShapedBuffer we ensure that the underlying storage
// won't be freed when the buffer goes out of scope at the end of this
// call. To avoid a leak, there must be no error-case returns from here until
// the end of the method.
auto shaped_buffer = scoped_buffer->release();
*allocation = new XRTTupleAllocation(
device_ordinal, backend->memory_allocator(),
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
(*allocation)
->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
device_ordinal);
(*allocation)->SetDeviceMemorySize();
return Status::OK();
}
/*static*/ Status XRTTupleAllocation::CreateFromBuffer( /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend, const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
int device_ordinal, XRTTupleAllocation** allocation) { int device_ordinal, XRTTupleAllocation** allocation) {

View File

@ -84,6 +84,14 @@ class XRTTupleAllocation : public core::RefCounted {
xla::Backend* backend, int device_ordinal, xla::Backend* backend, int device_ordinal,
XRTTupleAllocation** allocation); XRTTupleAllocation** allocation);
// Allocates new device memory buffers sufficient to store a tensor of
// the specified shape, and returns a XRTTupleAllocation handle to the
// allocated buffers. The allocated buffers are not initialized.
static Status CreateUninitialized(const xla::Shape& shape,
XRTMemoryManager* memory_manager,
xla::Backend* backend, int device_ordinal,
XRTTupleAllocation** allocation);
// Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
xla::Backend* backend, int device_ordinal, xla::Backend* backend, int device_ordinal,