Introduced a new XRTWriteLiteral op to allow a client to overwrite the values stored in device memory.

PiperOrigin-RevId: 224004631
This commit is contained in:
A. Unique TensorFlower 2018-12-04 10:40:47 -08:00 committed by TensorFlower Gardener
parent b2cad94ff7
commit 08cbdcdc92
6 changed files with 145 additions and 1 deletions

View File

@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
.HostMemory("literal"),
XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
.Device(DEVICE_XLA_GPU)
.HostMemory("handle")
.HostMemory("literal")
.HostMemory("output_handle"),
XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
.Device(DEVICE_XLA_CPU)
.HostMemory("handle")
.HostMemory("literal")
.HostMemory("output_handle"),
XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
.Device(DEVICE_XLA_GPU)
.HostMemory("handle")

View File

@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel {
}
};
// Op that writes a new literal value into device-resident memory.
template <class DeviceAccessor>
class XRTWriteLiteralOp : public OpKernel {
public:
explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~XRTWriteLiteralOp() override = default;
XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "XRTWriteLiteralOp::Compute";
const Tensor& handle_tensor = ctx->input(0);
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
errors::Internal("computation input should be an int64 scalar"));
int64 allocation_handle = handle_tensor.scalar<int64>()();
const Tensor& literal_info = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
errors::Internal("literal input should be a string scalar"));
xla::LiteralProto literal_proto;
OP_REQUIRES(ctx,
literal_proto.ParseFromString(literal_info.scalar<string>()()),
errors::InvalidArgument(
"Unable to parse allocation input to LiteralProto"));
xla::Literal literal;
OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
typename DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
ctx, allocation->device_ordinal(), &device_ref));
OP_REQUIRES_OK(ctx,
allocation->WriteLiteral(device_ref.backend(), literal));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = allocation_handle;
ctx->set_output(0, output);
}
};
// Op that discards a handle to device memory.
template <class DeviceAccessor>
class XRTReleaseAllocationOp : public OpKernel {

View File

@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal.
'literal' is a serialized xla::LiteralProto proto.
)");
REGISTER_OP("XRTWriteLiteral")
.Input("handle: int64")
.Input("literal: string")
.Output("output_handle: int64")
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
.Doc(
R"(
Copies the input literal into the device memory pointed to by handle.
Returns the handle itself.
'handle' is the id returned from the Op that produced the on-device allocation.
'literal' is a serialized xla::LiteralProto proto to be written to device memory.
)");
REGISTER_OP("XRTReadLiteralAndRelease")
.Input("handle: int64")
.Output("literal: string")

View File

@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a,
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
bool equal = l_a == l_b;
if (!equal) {
LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
LOG(INFO) << "LiteralProtos don't match: " << a.DebugString()
<< " != " << b.DebugString();
}
return equal;
@ -215,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape(
->ComputeProgramShape();
}
TEST(RawApiTest, AllocAndRewrite) {
xrt::XLAAllocation alloc;
alloc.set_device_ordinal(0);
*alloc.mutable_value() =
xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
auto value =
ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
auto handle = ops::XRTAllocate(root, value);
auto read_back = ops::XRTReadLiteral(root, handle);
TF_ASSERT_OK(root.status());
tensorflow::ClientSession session(root);
std::vector<tensorflow::Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
EXPECT_EQ(outputs.size(), 2);
int64 allocation_handle = outputs[1].scalar<int64>()();
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
outputs.clear();
xla::LiteralProto new_literal =
xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
new_literal.SerializeAsString());
auto write_op =
ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({write_op}, &outputs));
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
outputs.clear();
auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
EXPECT_EQ(outputs.size(), 1);
xla::LiteralProto new_response;
EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
auto release =
ops::XRTReleaseAllocationHandle(root, Input(allocation_handle));
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
&outputs));
}
TEST(RawApiTest, ReadAndWriteState) {
xrt::XLAAllocation alloc;
alloc.set_device_ordinal(0);

View File

@ -183,6 +183,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
return Status::OK();
}
Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
const xla::Literal& literal) {
if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
return errors::InvalidArgument(
"New literal shape not matching the existing one: literal=",
xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
" device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
}
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
ToShapedBuffer());
}
void XRTTupleAllocation::DiscardAllocation(
const xla::ShapeIndex& buffer_index) {
buffers_.element(buffer_index)->DiscardAllocation();

View File

@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase {
Status ToLiteral(xla::Backend* backend, int device_ordinal,
xla::Literal* literal);
// Write a new literal value to the allocation.
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
// True if none of the buffers in the allocation are aliased by any other live
// handle.
bool IsExclusiveOwner();