Introduced a new XRTWriteLiteral op to allow a client to overwrite the values stored in device memory.
PiperOrigin-RevId: 224004631
This commit is contained in:
parent
b2cad94ff7
commit
08cbdcdc92
tensorflow/compiler/xrt
@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal")
|
||||
.HostMemory("output_handle"),
|
||||
XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.Device(DEVICE_XLA_CPU)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal")
|
||||
.HostMemory("output_handle"),
|
||||
XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
.HostMemory("handle")
|
||||
|
@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
// Op that writes a new literal value into device-resident memory.
|
||||
template <class DeviceAccessor>
|
||||
class XRTWriteLiteralOp : public OpKernel {
|
||||
public:
|
||||
explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~XRTWriteLiteralOp() override = default;
|
||||
XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
|
||||
XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "XRTWriteLiteralOp::Compute";
|
||||
|
||||
const Tensor& handle_tensor = ctx->input(0);
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
|
||||
errors::Internal("computation input should be an int64 scalar"));
|
||||
int64 allocation_handle = handle_tensor.scalar<int64>()();
|
||||
|
||||
const Tensor& literal_info = ctx->input(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
|
||||
errors::Internal("literal input should be a string scalar"));
|
||||
xla::LiteralProto literal_proto;
|
||||
OP_REQUIRES(ctx,
|
||||
literal_proto.ParseFromString(literal_info.scalar<string>()()),
|
||||
errors::InvalidArgument(
|
||||
"Unable to parse allocation input to LiteralProto"));
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
|
||||
|
||||
ResourceMgr* rm;
|
||||
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
|
||||
|
||||
XRTTupleAllocation* allocation;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
|
||||
core::ScopedUnref allocation_unref(allocation);
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
typename DeviceAccessor::ScopedRef device_ref;
|
||||
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
|
||||
ctx, allocation->device_ordinal(), &device_ref));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
allocation->WriteLiteral(device_ref.backend(), literal));
|
||||
|
||||
Tensor output(DT_INT64, TensorShape({}));
|
||||
output.scalar<int64>()() = allocation_handle;
|
||||
ctx->set_output(0, output);
|
||||
}
|
||||
};
|
||||
|
||||
// Op that discards a handle to device memory.
|
||||
template <class DeviceAccessor>
|
||||
class XRTReleaseAllocationOp : public OpKernel {
|
||||
|
@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal.
|
||||
'literal' is a serialized xla::LiteralProto proto.
|
||||
)");
|
||||
|
||||
REGISTER_OP("XRTWriteLiteral")
|
||||
.Input("handle: int64")
|
||||
.Input("literal: string")
|
||||
.Output("output_handle: int64")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||
.Doc(
|
||||
R"(
|
||||
Copies the input literal into the device memory pointed to by handle.
|
||||
Returns the handle itself.
|
||||
|
||||
'handle' is the id returned from the Op that produced the on-device allocation.
|
||||
'literal' is a serialized xla::LiteralProto proto to be written to device memory.
|
||||
)");
|
||||
|
||||
REGISTER_OP("XRTReadLiteralAndRelease")
|
||||
.Input("handle: int64")
|
||||
.Output("literal: string")
|
||||
|
@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a,
|
||||
auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
|
||||
bool equal = l_a == l_b;
|
||||
if (!equal) {
|
||||
LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
|
||||
LOG(INFO) << "LiteralProtos don't match: " << a.DebugString()
|
||||
<< " != " << b.DebugString();
|
||||
}
|
||||
return equal;
|
||||
@ -215,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape(
|
||||
->ComputeProgramShape();
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AllocAndRewrite) {
|
||||
xrt::XLAAllocation alloc;
|
||||
alloc.set_device_ordinal(0);
|
||||
*alloc.mutable_value() =
|
||||
xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
auto value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
|
||||
auto handle = ops::XRTAllocate(root, value);
|
||||
auto read_back = ops::XRTReadLiteral(root, handle);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
tensorflow::ClientSession session(root);
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
|
||||
EXPECT_EQ(outputs.size(), 2);
|
||||
|
||||
int64 allocation_handle = outputs[1].scalar<int64>()();
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
|
||||
outputs.clear();
|
||||
|
||||
xla::LiteralProto new_literal =
|
||||
xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
|
||||
auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||
new_literal.SerializeAsString());
|
||||
auto write_op =
|
||||
ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
|
||||
TF_ASSERT_OK(root.status());
|
||||
TF_EXPECT_OK(session.Run({write_op}, &outputs));
|
||||
EXPECT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
|
||||
outputs.clear();
|
||||
|
||||
auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
|
||||
TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
|
||||
EXPECT_EQ(outputs.size(), 1);
|
||||
|
||||
xla::LiteralProto new_response;
|
||||
EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
|
||||
|
||||
auto release =
|
||||
ops::XRTReleaseAllocationHandle(root, Input(allocation_handle));
|
||||
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
|
||||
&outputs));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, ReadAndWriteState) {
|
||||
xrt::XLAAllocation alloc;
|
||||
alloc.set_device_ordinal(0);
|
||||
|
@ -183,6 +183,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
|
||||
const xla::Literal& literal) {
|
||||
if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"New literal shape not matching the existing one: literal=",
|
||||
xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
|
||||
" device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
|
||||
}
|
||||
auto transfer_manager = backend->transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
|
||||
return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
|
||||
ToShapedBuffer());
|
||||
}
|
||||
|
||||
void XRTTupleAllocation::DiscardAllocation(
|
||||
const xla::ShapeIndex& buffer_index) {
|
||||
buffers_.element(buffer_index)->DiscardAllocation();
|
||||
|
@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase {
|
||||
Status ToLiteral(xla::Backend* backend, int device_ordinal,
|
||||
xla::Literal* literal);
|
||||
|
||||
// Write a new literal value to the allocation.
|
||||
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
|
||||
|
||||
// True if none of the buffers in the allocation are aliased by any other live
|
||||
// handle.
|
||||
bool IsExclusiveOwner();
|
||||
|
Loading…
Reference in New Issue
Block a user