Moved the XRT memory allocation logic into a new memory manager class, which is a ResourceBase and hosted within the ResourceMgr.

Refactored the XRT memory compaction API into the memory manager.
Allowed the memory manager API to be able to swap out unpinned tuple allocations, and perform memory compaction, under memory pressure.
Made some clearly-const XRT state APIs as such.
Removed a ResourceMgr API which now has no more uses.

PiperOrigin-RevId: 249115186
This commit is contained in:
Davide Libenzi 2019-05-20 13:34:45 -07:00 committed by TensorFlower Gardener
parent 273981699d
commit ce43d4d7d9
14 changed files with 1324 additions and 497 deletions

View File

@ -44,12 +44,15 @@ cc_library(
srcs = [
"xrt_compilation_cache.cc",
"xrt_device.cc",
"xrt_memory_manager.cc",
"xrt_state.cc",
"xrt_util.cc",
],
hdrs = [
"xrt_compilation_cache.h",
"xrt_device.h",
"xrt_memory_manager.h",
"xrt_refptr.h",
"xrt_state.h",
"xrt_util.h",
],

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xrt/xrt.pb.h"
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include "tensorflow/compiler/xrt/xrt_device.h"
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
#include "tensorflow/compiler/xrt/xrt_state.h"
#include "tensorflow/compiler/xrt/xrt_util.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -72,28 +73,31 @@ uint32 GetXLARandomSeed() {
}
xla::StatusOr<InputBuffers> GetInputBuffers(
ResourceMgr* rm, const std::vector<InputCoords>& input_coords,
bool release_inputs) {
XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
const std::vector<InputCoords>& input_coords, bool release_inputs) {
InputBuffers input_buffers;
input_buffers.input_tuples.reserve(input_coords.size());
input_buffers.input_allocations.reserve(input_coords.size());
input_buffers.input_pointers.reserve(input_coords.size());
for (size_t i = 0; i < input_coords.size(); ++i) {
XRTTupleAllocation* tuple;
TF_RETURN_IF_ERROR(
XRTTupleAllocation::Lookup(rm, input_coords[i].handle, &tuple));
working_set->LookupAndPin(backend, input_coords[i].handle));
auto tuple = working_set->PinnedTuples().back();
input_buffers.input_tuples.emplace_back(tuple);
if (release_inputs) {
// We are holding a reference to the tuple, so we can safely delete it
// from the resource manager here.
TF_RETURN_IF_ERROR(XRTTupleAllocation::DeleteFromResourceManager(
rm, input_coords[i].handle));
TF_RETURN_IF_ERROR(
working_set->MemoryManager()->Release(input_coords[i].handle));
VLOG(2) << "Released allocation handle " << input_coords[i].handle;
}
if (input_coords[i].index.empty()) {
input_buffers.input_allocations.emplace_back(tuple->ToShapedBuffer());
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
tuple->ToShapedBuffer());
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
} else {
xla::ShapedBuffer shaped_buffer = tuple->ToShapedBuffer();
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
tuple->ToShapedBuffer());
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer,
shaped_buffer.SubShapedBuffer(input_coords[i].index));
input_buffers.input_allocations.emplace_back(
@ -107,28 +111,25 @@ xla::StatusOr<InputBuffers> GetInputBuffers(
}
xla::StatusOr<InputBuffers> GetChainedOpInputs(
const xrt::XRTChainedExecuteOp& op, int current_index,
absl::Span<const RefPtr<XRTTupleAllocation>> ops_outputs) {
const xrt::XRTChainedExecuteOp& op,
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs) {
InputBuffers input_buffers;
input_buffers.input_tuples.reserve(op.inputs_size());
input_buffers.input_allocations.reserve(op.inputs_size());
input_buffers.input_pointers.reserve(op.inputs_size());
for (auto& input : op.inputs()) {
if (input.op_index() >= current_index) {
return errors::InvalidArgument(
"Input index ", input.op_index(),
" is above the current position: ", current_index);
}
input_buffers.input_tuples.emplace_back(ops_outputs[input.op_index()]);
for (int i = 0; i < op.inputs_size(); ++i) {
auto& input = op.inputs(i);
input_buffers.input_tuples.emplace_back(op_inputs[i]);
// Thanks to the greatness of proto3, there is no way to query for
// explicitly set fields, so the default for output_index (zero) means no
// sub-index. As consequence, the real index is output_index - 1.
if (input.output_index() == 0) {
input_buffers.input_allocations.emplace_back(
input_buffers.input_tuples.back()->ToShapedBuffer());
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
input_buffers.input_tuples.back()->ToShapedBuffer());
input_buffers.input_allocations.emplace_back(std::move(shaped_buffer));
} else {
xla::ShapedBuffer shaped_buffer =
input_buffers.input_tuples.back()->ToShapedBuffer();
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer,
input_buffers.input_tuples.back()->ToShapedBuffer());
TF_ASSIGN_OR_RETURN(
xla::ShapedBuffer sub_shaped_buffer,
shaped_buffer.SubShapedBuffer({input.output_index() - 1}));
@ -142,7 +143,7 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
return std::move(input_buffers);
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed) {
@ -190,15 +191,35 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
OpKernelContext* context, ResourceMgr* rm,
OpKernelContext* context, XRTMemoryManager* memory_manager,
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed) {
auto runfn = [&]() {
return RunExecutable(context, device_ref, executable, input_buffers, stream,
rng_seed);
};
// We pass zero as requested_free_size as there is no simple way to get the
// peak heap size. Upon zero, the Run() API will try to free chunks of device
// memory, until either the runfn can run, or we run out of freeable memory.
return memory_manager->Run<RefPtr<XRTTupleAllocation>>(
runfn, device_ref->backend(), device_ref->device_ordinal(),
/*requested_free_size=*/0);
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
OpKernelContext* context, const RefPtr<XRTMemoryManager>& memory_manager,
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable,
const std::vector<InputCoords>& input_coords, bool release_inputs,
se::Stream* stream, int rng_seed) {
XRTMemoryManager::WorkingSet working_set(memory_manager);
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
GetInputBuffers(rm, input_coords, release_inputs));
return ExecuteComputation(context, device_ref, executable, input_buffers,
stream, rng_seed);
GetInputBuffers(&working_set, device_ref->backend(),
input_coords, release_inputs));
return ExecuteComputation(context, memory_manager.get(), device_ref,
executable, input_buffers, stream, rng_seed);
}
// XRTExecuteOp
@ -265,8 +286,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
se::Stream* stream = context->op_device_context()
? context->op_device_context()->stream()
: nullptr;
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
TF_ASSIGN_OR_RETURN(std::vector<InputCoords> input_coords,
GetComputationInputs(context, rm, "input_handles"));
GetComputationInputs(context, "input_handles"));
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
@ -279,10 +301,11 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
TF_ASSIGN_OR_RETURN(
RefPtr<XRTTupleAllocation> output_tuple,
ExecuteComputation(context, rm, &device_ref, executable, input_coords,
release_inputs, stream, rng_seed));
ExecuteComputation(context, memory_manager, &device_ref, executable,
input_coords, release_inputs, stream, rng_seed));
return CreateExecuteOutput(context, rm, std::move(output_tuple),
return CreateExecuteOutput(context, memory_manager.get(),
std::move(output_tuple),
config_proto.return_exploded_tuple());
}
@ -346,22 +369,23 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
se::Stream* stream = context->op_device_context()
? context->op_device_context()->stream()
: nullptr;
auto execute_op =
[&](const xrt::XRTChainedExecuteOp& op, int current_index,
absl::Span<const RefPtr<XRTTupleAllocation>> ops_outputs)
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
auto execute_op = [&](const xrt::XRTChainedExecuteOp& op,
absl::Span<const RefPtr<XRTTupleAllocation>> op_inputs)
-> xla::StatusOr<RefPtr<XRTTupleAllocation>> {
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
GetChainedOpInputs(op, current_index, ops_outputs));
GetChainedOpInputs(op, op_inputs));
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry));
xla::LocalExecutable* executable = entry->get().get_executable();
return ExecuteComputation(context, &device_ref, executable, input_buffers,
stream, rng_seed);
return ExecuteComputation(context, memory_manager.get(), &device_ref,
executable, input_buffers, stream, rng_seed);
};
return ExecuteChained(context, rm, plan, config, execute_op);
return ExecuteChained(context, memory_manager, device_ref.backend(),
device_ref.device_ordinal(), plan, config, execute_op);
}
XRTExecuteChainedOp::~XRTExecuteChainedOp() = default;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xrt/xrt.pb.h"
#include "tensorflow/compiler/xrt/xrt_device.h"
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
#include "tensorflow/compiler/xrt/xrt_state.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -103,8 +104,8 @@ class XRTStateHelpers {
TF_RET_CHECK(
TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
int64 key = input_tensor_list[input_index].scalar<int64>()();
TF_RETURN_IF_ERROR(
XRTTupleAllocation::Lookup(rm, key, &input.allocation));
TF_ASSIGN_OR_RETURN(input.allocation,
XRTMemoryManager::Get(rm)->Lookup(key));
input.release_allocation_after_use = release_this_input;
}
}
@ -192,17 +193,14 @@ class XRTAllocateOp : public OpKernel {
class DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
literal, device_ref.backend(),
literal, memory_manager.get(), device_ref.backend(),
device_ref.device_ordinal(), &allocation));
// Intern takes ownership of our reference to allocation.
int64 key;
OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = key;
output.scalar<int64>()() = memory_manager->Register(allocation);
ctx->set_output(0, output);
}
};
@ -291,17 +289,14 @@ class XRTAllocateFromTensorOp : public OpKernel {
class DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
literal, device_ref.backend(),
literal, memory_manager.get(), device_ref.backend(),
device_ref.device_ordinal(), &allocation));
// Intern takes ownership of our reference to allocation.
int64 key;
OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = key;
output.scalar<int64>()() = memory_manager->Register(allocation);
ctx->set_output(0, output);
}
@ -342,28 +337,22 @@ class XRTSubTupleOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
RefPtr<XRTTupleAllocation> allocation;
OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
if (discard_) {
VLOG(2) << "Releasing handle " << allocation_handle;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
rm, allocation_handle));
OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
}
XRTTupleAllocation* suballocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::MakeSubBuffer(allocation, shape_index,
ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index,
&suballocation, !discard_));
// Intern takes ownership of our reference to suballocation.
int64 key;
OP_REQUIRES_OK(ctx, suballocation->Intern(rm, &key));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = key;
output.scalar<int64>()() = memory_manager->Register(suballocation);
ctx->set_output(0, output);
}
};
@ -398,14 +387,6 @@ class XRTMakeTupleOp : public OpKernel {
// exit.
std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
arg_list.size());
auto cleanup = gtl::MakeCleanup([&input_vector] {
for (auto& input : input_vector) {
if (input.allocation != nullptr) {
input.allocation->Unref();
}
}
});
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
@ -425,28 +406,22 @@ class XRTMakeTupleOp : public OpKernel {
OP_REQUIRES_OK(
ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
XRTTupleAllocation* output_allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
device_ref.backend(), device_ref.device_ordinal(),
tuple_shape_tree, &output_allocation));
// Add a ScopedUnref to simplify the error path while calling
// DeleteFromResourceManager.
core::ScopedUnref unref(output_allocation);
memory_manager.get(), device_ref.backend(),
device_ref.device_ordinal(), tuple_shape_tree,
&output_allocation));
RefPtr<XRTTupleAllocation> output_ptr(output_allocation);
for (int i = 0; i < input_vector.size(); ++i) {
if (input_vector[i].release_allocation_after_use) {
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
rm, arg_list[i].scalar<int64>()()));
OP_REQUIRES_OK(ctx,
memory_manager->Release(arg_list[i].scalar<int64>()()));
}
}
// Intern takes ownership of a reference to output_allocation, so add
// another since the ScopedUnref will release one when this method exits.
output_allocation->Ref();
int64 key;
OP_REQUIRES_OK(ctx, output_allocation->Intern(rm, &key));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = key;
output.scalar<int64>()() = memory_manager->Register(std::move(output_ptr));
ctx->set_output(0, output);
}
};
@ -473,15 +448,13 @@ class XRTReadLiteralOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
RefPtr<XRTTupleAllocation> allocation;
OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
if (discard_) {
VLOG(2) << "Releasing handle " << allocation_handle;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
rm, allocation_handle));
OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
}
// We are guaranteed that the underlying device object won't be deleted out
@ -491,9 +464,7 @@ class XRTReadLiteralOp : public OpKernel {
ctx, allocation->device_ordinal(), &device_ref));
xla::Literal literal(allocation->on_host_shape());
OP_REQUIRES_OK(
ctx, allocation->ToLiteral(device_ref.backend(),
device_ref.device_ordinal(), &literal));
OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal));
xla::LiteralProto literal_proto = literal.ToProto();
Tensor output(DT_STRING, TensorShape({}));
@ -529,15 +500,13 @@ class XRTReadToTensorOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
RefPtr<XRTTupleAllocation> allocation;
OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
if (discard_) {
VLOG(2) << "Releasing handle " << allocation_handle;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
rm, allocation_handle));
OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
}
// We are guaranteed that the underlying device object won't be deleted out
@ -573,15 +542,14 @@ class XRTReadToTensorOp : public OpKernel {
XRTTupleAllocation* sub;
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
allocation, index, &sub, /*alias_parent_allocation=*/true));
allocation.get(), index, &sub, /*alias_parent_allocation=*/true));
core::ScopedUnref sub_unref(sub);
xla::MutableBorrowingLiteral literal;
TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
&literal));
TF_RETURN_IF_ERROR(sub->ToLiteral(
device_ref.backend(), device_ref.device_ordinal(), &literal));
TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal));
++output;
return Status::OK();
@ -624,10 +592,10 @@ class XRTWriteLiteralOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(
ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
core::ScopedUnref allocation_unref(allocation);
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
RefPtr<XRTTupleAllocation> allocation;
OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
typename DeviceAccessor::ScopedRef device_ref;
@ -657,12 +625,12 @@ class XRTReleaseAllocationOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
const Tensor& allocation_handle = ctx->input(0);
auto flat_keys = allocation_handle.flat<int64>();
for (int64 i = 0; i < flat_keys.size(); ++i) {
int64 key = flat_keys(i);
OP_REQUIRES_OK(ctx,
XRTTupleAllocation::DeleteFromResourceManager(rm, key));
OP_REQUIRES_OK(ctx, memory_manager->Release(key));
VLOG(2) << "Released allocation handle " << key;
}
}
@ -684,7 +652,7 @@ class XRTReleaseAllAllocationsOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm));
XRTMemoryManager::Get(rm)->ReleaseAllAllocations();
}
};
@ -701,11 +669,11 @@ class XRTCompactAllocationsOp : public OpKernel {
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
class DeviceAccessor::ScopedRef device_ref;
OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
OP_REQUIRES_OK(ctx,
XRTTupleAllocation::CompactAllocations(
rm, device_ref.backend(), device_ref.device_ordinal()));
OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations(
device_ref.backend(), device_ref.device_ordinal()));
}
};

View File

@ -34,6 +34,8 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xrt:xrt_proto",
"//tensorflow/compiler/xrt:xrt_server",

View File

@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -235,6 +237,26 @@ xla::XlaComputation AddAndSubTuple() {
return builder.Build().ValueOrDie();
}
xla::XlaComputation BroadcastComputation(
const xla::Shape& shape, absl::Span<const xla::int64> dimensions) {
xla::XlaBuilder builder("BroadcastComputation");
auto p0 = xla::Parameter(&builder, 0, shape, "P0");
xla::Broadcast(p0, dimensions);
return builder.Build().ValueOrDie();
}
xla::XlaComputation IsEqualComputation(const xla::Shape& shape) {
xla::XlaBuilder builder("IsEqualComputation");
auto p0 = xla::Parameter(&builder, 0, shape, "P0");
auto p1 = xla::Parameter(&builder, 1, shape, "P1");
auto cmp =
xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type()));
auto icmp = xla::ConvertElementType(cmp, xla::S32);
xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32),
xla::CreateScalarAddComputation(xla::S32, &builder));
return builder.Build().ValueOrDie();
}
void StoreComputationSnapshot(const xla::XlaComputation& computation,
xla::HloSnapshot* dst) {
auto snapshot = computation.Snapshot().ValueOrDie();
@ -1488,6 +1510,95 @@ TEST(RawApiTest, TestDeviceMemoryCompaction) {
}
}
TEST(RawApiTest, TestDeviceMemorySwap) {
const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
// 100MB F32 tensor.
const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000});
const xla::int64 tensor_size = xla::ShapeUtil::ByteSizeOf(shape);
// On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as
// maximum memory.
xla::int64 device_memory_size = 8LL * 1024 * 1024 * 1024;
if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") {
device_memory_size = 16LL * 1024 * 1024 * 1024;
}
xrt::XLAAllocation p0;
*p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(0.90434).ToProto();
// Create a computation which broadcasts a scalar to a big tensor.
xrt::XLAComputation c_bcast;
{
auto shapes = c_bcast.mutable_config()->mutable_program_shape();
*shapes->add_parameters() = scalar_shape.ToProto();
*shapes->mutable_result() = shape.ToProto();
StoreComputationSnapshot(
BroadcastComputation(scalar_shape, shape.dimensions()),
c_bcast.mutable_hlo_snapshot());
}
// Create a computation which compares two tensors.
xrt::XLAComputation c_equal;
{
auto shapes = c_equal.mutable_config()->mutable_program_shape();
*shapes->add_parameters() = shape.ToProto();
*shapes->add_parameters() = shape.ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
StoreComputationSnapshot(IsEqualComputation(shape),
c_equal.mutable_hlo_snapshot());
}
xrt::XRTExecutionConfig e;
e.set_release_input_handles(false);
e.set_release_compilation_handle(false);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
ClientSession session(root);
auto e_config =
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
auto bcast_computation =
ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString());
auto c_bcast_handle = ops::XRTCompile(root, bcast_computation);
auto equal_computation =
ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString());
auto c_equal_handle = ops::XRTCompile(root, equal_computation);
auto p0_value =
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
std::vector<Tensor> outputs;
std::vector<xla::int64> device_handles;
// Create more data the device can take using the broadcast computation.
xla::int64 num_tensors = 8 + device_memory_size / tensor_size;
for (xla::int64 i = 0; i < num_tensors; ++i) {
auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config,
{Output(p0_handle)});
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({result}, &outputs));
EXPECT_EQ(outputs.size(), 1);
device_handles.push_back(outputs[0].scalar<int64>()());
}
// Trigger computations on XRT handles to verify the swap-out/swap-in logic,
// by comparing sequential couple of tensors.
auto zero_literal = xla::LiteralUtil::CreateR0<xla::int32>(0);
for (size_t i = 0; i + 1 < device_handles.size(); ++i) {
auto exec_op = ops::XRTExecute(
root, c_equal_handle.handle, e_config,
{Input(device_handles[i]), Input(device_handles[i + 1])});
auto read_back = ops::XRTReadLiteral(root, exec_op);
TF_ASSERT_OK(root.status());
TF_EXPECT_OK(session.Run({read_back}, &outputs));
EXPECT_EQ(outputs.size(), 1);
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
EXPECT_EQ(literal, zero_literal);
}
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,353 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
#include <algorithm>
#include <list>
#include <unordered_map>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace {
// We use kDeviceBits to store the device ordinal in the handle. We store the
// device in the upper part of the int64 handle to make sure the random bits are
// in the lower part which is better when storing the handle as a key for
// unordered maps.
const int kDeviceBits = 12;
int64 MakeDeviceHandle(int64 device_ordinal, int64 rnd_value) {
const int64 kUidMask = (static_cast<int64>(1) << (64 - kDeviceBits)) - 1;
return (device_ordinal << (64 - kDeviceBits)) | (rnd_value & kUidMask);
}
int GetDeviceFromHandle(int64 handle) {
return (handle >> (64 - kDeviceBits)) & ((1 << kDeviceBits) - 1);
}
} // namespace
class XRTMemoryManager::DeviceContext {
struct Alloc {
explicit Alloc(RefPtr<XRTTupleAllocation> tuple)
: tuple(std::move(tuple)) {}
RefPtr<XRTTupleAllocation> tuple;
};
using AllocList = std::list<Alloc>;
public:
int64 Register(RefPtr<XRTTupleAllocation> tuple) {
while (true) {
int64 handle = MakeDeviceHandle(tuple->device_ordinal(), CreateUid());
mutex_lock lock(lock_);
allocs_.emplace_front(tuple);
if (alloc_map_.emplace(handle, allocs_.begin()).second) {
return handle;
}
// The chances of hitting an existing handle are so remote, it is much
// more convenient to add to the list before, and eventually removing.
allocs_.erase(allocs_.begin());
}
}
bool Release(int64 handle) {
mutex_lock lock(lock_);
auto it = alloc_map_.find(handle);
if (it == alloc_map_.end()) {
return false;
}
allocs_.erase(it->second);
alloc_map_.erase(it);
return true;
}
RefPtr<XRTTupleAllocation> Lookup(int64 handle) {
mutex_lock lock(lock_);
auto it = alloc_map_.find(handle);
if (it == alloc_map_.end()) {
return nullptr;
}
// LRU
allocs_.splice(allocs_.begin(), allocs_, it->second);
return it->second->tuple;
}
void Clear() {
mutex_lock lock(lock_);
alloc_map_.clear();
allocs_.clear();
}
Status CompactAllocations(XRTMemoryManager* memory_manager,
xla::Backend* backend) {
VLOG(4) << "CompactAllocations started";
mutex_lock lock(lock_);
Status status;
std::vector<AllocList::iterator> swapped;
// We are swapping out from the most recently used allocations. This is
// desirable since the most recently used will be finding themselves at the
// bottom of the allocation space. Since these are more likely to be pinned
// allocations, a further trim done by following TryFreeMemory() call will
// eventually drop the higher located allocations, with better chance of
// reducing fragmentation.
// Also, by swapping out the pinned allocations first, those will also be
// the first to be restored, and hence if we will ever find OOM on the way
// out, we would more likely be swapping in not pinned ones.
for (auto it = allocs_.begin(); it != allocs_.end(); ++it) {
// We are compacting all the allocations, so we will temporarily swap out
// even pinned allocations.
auto swap_result_or = it->tuple->SwapOut(backend, /*swap_pinned=*/true);
if (!swap_result_or.ok()) {
status = swap_result_or.status();
break;
}
if (swap_result_or.ValueOrDie()) {
swapped.push_back(it);
}
}
// At this point we have released all the device memory we could release.
// Load back the tuple allocations we have swapped out above.
for (auto& it : swapped) {
auto swap_result_or = it->tuple->SwapIn(memory_manager, backend);
if (!swap_result_or.ok()) {
// If we failed to restored a pinned allocation, better to CHECK here
// than wondering why XRTTupleAllocation calls fail with errors about
// missing buffers.
CHECK(!it->tuple->IsPinned()); // Crash OK
if (status.ok()) {
status = swap_result_or.status();
}
}
}
VLOG(4) << "CompactAllocations finished: " << status;
return status;
}
// Tries to free size bytes by freeing some unpinned device memory. Returns
// the amount of memory which was able to free.
xla::StatusOr<size_t> TryFreeMemory(xla::Backend* backend, size_t size) {
mutex_lock lock(lock_);
size_t swapped_size = 0;
for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
TF_ASSIGN_OR_RETURN(bool swap_result,
it->tuple->SwapOut(backend, /*swap_pinned=*/false));
if (swap_result) {
swapped_size += it->tuple->GetDeviceMemorySize();
if (swapped_size >= size) {
break;
}
}
}
VLOG(3) << "Swapped out " << swapped_size << " bytes";
return swapped_size;
}
private:
static int64 CreateUid() {
int64 uid;
do {
uid = random::New64() & INT64_MAX;
} while (uid == InvalidKey());
return uid;
}
// We store Alloc records inside an std::list<Alloc> so we can LRU it, and
// store the list iterators within the handle map, as list iterators don't get
// invalidated by (other elements) removals or position swaps.
mutex lock_;
AllocList allocs_;
std::unordered_map<int64, AllocList::iterator> alloc_map_;
};
XRTMemoryManager::WorkingSet::WorkingSet(
RefPtr<XRTMemoryManager> memory_manager)
: memory_manager_(std::move(memory_manager)) {}
XRTMemoryManager::WorkingSet::~WorkingSet() {
for (auto& tuple : pinned_tuples_) {
tuple->Unpin();
}
}
Status XRTMemoryManager::WorkingSet::LookupAndPin(xla::Backend* backend,
int64 handle) {
TF_ASSIGN_OR_RETURN(auto tuple, memory_manager_->Lookup(handle));
TF_RETURN_IF_ERROR(
tuple->PinAndSwapIn(memory_manager_.get(), backend).status());
pinned_tuples_.push_back(std::move(tuple));
return Status::OK();
}
/* static */ RefPtr<XRTMemoryManager> XRTMemoryManager::Get(ResourceMgr* rm) {
static string* container = new string("XrtState");
static string* name = new string("MemoryManager");
XRTMemoryManager* memory_manager = nullptr;
TF_CHECK_OK(rm->LookupOrCreate<XRTMemoryManager>(
*container, *name, &memory_manager, [](XRTMemoryManager** ret) {
*ret = new XRTMemoryManager();
return Status::OK();
}));
return memory_manager;
}
int64 XRTMemoryManager::Register(RefPtr<XRTTupleAllocation> tuple) {
DeviceContext* device_context = GetDeviceContext(tuple->device_ordinal(),
/*create_if_missing=*/true);
return device_context->Register(std::move(tuple));
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> XRTMemoryManager::Lookup(
int64 handle) {
int device_ordinal = GetDeviceFromHandle(handle);
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
RefPtr<XRTTupleAllocation> tuple = device_context->Lookup(handle);
if (tuple == nullptr) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
return std::move(tuple);
}
Status XRTMemoryManager::Release(int64 handle) {
int device_ordinal = GetDeviceFromHandle(handle);
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr || !device_context->Release(handle)) {
return errors::NotFound("XRT memory handle not found: ", handle);
}
return Status::OK();
}
Status XRTMemoryManager::CompactAllocations(xla::Backend* backend,
int device_ordinal) {
DeviceContext* device_context = GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
return device_context != nullptr
? device_context->CompactAllocations(this, backend)
: Status::OK();
}
void XRTMemoryManager::ReleaseAllAllocations() {
mutex_lock lock(lock_);
for (auto& device_context : device_contexts_) {
if (device_context != nullptr) {
device_context->Clear();
}
}
}
xla::StatusOr<se::OwningDeviceMemory> XRTMemoryManager::Allocate(
xla::Backend* backend, int device_ordinal, size_t size) {
se::DeviceMemoryAllocator* allocator = backend->memory_allocator();
auto memory_or =
allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false);
if (memory_or.status().code() == error::RESOURCE_EXHAUSTED) {
VLOG(4) << "Allocate of " << size << " bytes failed on device "
<< device_ordinal;
DeviceContext* device_context =
GetDeviceContext(device_ordinal,
/*create_if_missing=*/false);
if (device_context != nullptr) {
Status status = device_context->TryFreeMemory(backend, size).status();
if (status.ok()) {
// As long as there is no error, we still try again the allocation, even
// if the TryFreeMemory() call ended up freeing less memory than the
// required size. Fragmentation could make the memory allocation succeed
// even if the freed memory is indeed lower.
memory_or = allocator->Allocate(device_ordinal, size,
/*retry_on_failure=*/false);
} else if (status.code() != error::RESOURCE_EXHAUSTED) {
VLOG(4) << "Allocate of " << size << " bytes on device "
<< device_ordinal << ": " << status;
return status;
}
}
}
return memory_or;
}
string XRTMemoryManager::DebugString() const {
// We might want to emit more detailed information here, like per device
// memory allocations.
return "XRTMemoryManager";
}
XRTMemoryManager::DeviceContext* XRTMemoryManager::GetDeviceContext(
int device_ordinal, bool create_if_missing) {
mutex_lock lock(lock_);
if (device_ordinal >= device_contexts_.size()) {
if (!create_if_missing) {
return nullptr;
}
device_contexts_.resize(device_ordinal + 1);
}
DeviceContext* device_context = device_contexts_[device_ordinal].get();
if (device_context == nullptr && create_if_missing) {
device_contexts_[device_ordinal] = absl::make_unique<DeviceContext>();
device_context = device_contexts_[device_ordinal].get();
}
return device_context;
}
Status XRTMemoryManager::TryFreeMemoryStep(MemoryReclaimContext* mrctx,
const Status& status) {
DeviceContext* device_context = GetDeviceContext(mrctx->device_ordinal,
/*create_if_missing=*/false);
if (device_context == nullptr) {
return status;
}
if (!mrctx->done_freeing) {
// If the caller passed us a zero requested_free_size, we try to free chunks
// of kMaxFreeSize memory, until either the run function suceeds, or we run
// out of freeable memory.
const size_t kMaxFreeSize = 1000000000;
size_t free_size =
(mrctx->requested_free_size > 0)
? std::min<size_t>(mrctx->requested_free_size - mrctx->free_size,
kMaxFreeSize)
: kMaxFreeSize;
if (free_size > 0) {
auto free_size_or =
device_context->TryFreeMemory(mrctx->backend, free_size);
if (!free_size_or.ok()) {
return status;
}
size_t size = free_size_or.ValueOrDie();
mrctx->free_size += size;
if (size > 0) {
return Status::OK();
}
}
mrctx->done_freeing = true;
}
if (!mrctx->done_compacting) {
mrctx->done_compacting = true;
if (device_context->CompactAllocations(this, mrctx->backend).ok()) {
return Status::OK();
}
}
return status;
}
} // namespace tensorflow

View File

@ -0,0 +1,177 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
#define TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xrt/xrt_refptr.h"
#include "tensorflow/compiler/xrt/xrt_state.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace tensorflow {
// The XRTMemoryManager manages all the XRT allocations. It is a ResourceBase
// object which leaves within the ResourceMgr. This is only one XRT memory
// manager object within the ResourceMgr container.
class XRTMemoryManager : public ResourceBase {
// The DeviceContext class, defined and implemented locally inside the
// xrt_memory_manager.cc file, holds, for each device, all the information
// related to the XRT memory management for such device.
class DeviceContext;
public:
// A working set is a set of tuple allocations which are the input of a given
// operation, and as such they must be pinned on the device memory. The tuple
// allocations added to the WorkingSet will be unpinned at object destruction.
class WorkingSet {
public:
explicit WorkingSet(RefPtr<XRTMemoryManager> memory_manager);
~WorkingSet();
// Looks up the tuple handle within the memory manager, and pins it to the
// device (if not already pinned).
Status LookupAndPin(xla::Backend* backend, int64 handle);
const std::vector<RefPtr<XRTTupleAllocation>>& PinnedTuples() const {
return pinned_tuples_;
}
const RefPtr<XRTMemoryManager>& MemoryManager() const {
return memory_manager_;
}
private:
RefPtr<XRTMemoryManager> memory_manager_;
std::vector<RefPtr<XRTTupleAllocation>> pinned_tuples_;
};
// Retrieves the XRTMemoryManager singleton stored within the ResourceMgr.
static RefPtr<XRTMemoryManager> Get(ResourceMgr* rm);
// Registers an XRTTupleAllocation and returns the unique handle identifying
// it.
int64 Register(RefPtr<XRTTupleAllocation> tuple);
// Looks up an handle returned by the Register() API and returns the
// XRTTupleAllocation behind it.
xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(int64 handle);
Status Lookup(int64 handle, RefPtr<XRTTupleAllocation>* tuple) {
TF_ASSIGN_OR_RETURN(*tuple, Lookup(handle));
return Status::OK();
}
// Releases an handle by dropping the refences count held on the
// XRTTupleAllocation by the XRTMemoryManager. Existing XRTTupleAllocation
// references will continue to be valid.
Status Release(int64 handle);
// Tries to compact all the memory allocations on a given device. This is
// currently done by swapping-out all the existing allocation, and swapping
// them back in.
Status CompactAllocations(xla::Backend* backend, int device_ordinal);
// Releases all the device memory allocated by XRT within the resource
// manager.
void ReleaseAllAllocations();
// Tries to allocate size bytes of device memory from the device_ordinal
// device. Might attempt to free some unpinned device memory, if the underline
// allocator call fails, and try the allocation again.
xla::StatusOr<se::OwningDeviceMemory> Allocate(xla::Backend* backend,
int device_ordinal,
size_t size);
// Runs the specified function and handling the error::RESOURCE_EXHAUSTED
// status code coming out of it. In such cases, we run different memory
// freeing operations trying to make runfn succeed. The requested_free_size
// argument represents an hint of the requested memory size which would make
// runfn succeed.
template <typename T>
xla::StatusOr<T> Run(const std::function<xla::StatusOr<T>()>& runfn,
xla::Backend* backend, int device_ordinal,
size_t requested_free_size);
string DebugString() const override;
// Returns the invalid key value, which will be never generated by the
// Intern() API.
static int64 InvalidKey() { return 0; }
private:
// Structure used to track the progress of a try-to-free operation. It is
// initialized and the passed to the TryFreeMemoryStep() API.
struct MemoryReclaimContext {
MemoryReclaimContext(xla::Backend* backend, int device_ordinal,
size_t requested_free_size)
: backend(backend),
device_ordinal(device_ordinal),
requested_free_size(requested_free_size) {}
xla::Backend* const backend = nullptr;
const int device_ordinal = 0;
const size_t requested_free_size = 0;
size_t free_size = 0;
bool done_freeing = false;
bool done_compacting = false;
};
DeviceContext* GetDeviceContext(int device_ordinal, bool create_if_missing);
// Called multiple times while trying to make a memory consuming function call
// to fit. Performs progressively more expensive memory reduction operations,
// until returning error::RESOURCE_EXHAUSTED when no further reductions are
// possible.
Status TryFreeMemoryStep(MemoryReclaimContext* mrctx, const Status& status);
mutex lock_;
std::vector<std::unique_ptr<DeviceContext>> device_contexts_;
};
template <typename T>
xla::StatusOr<T> XRTMemoryManager::Run(
const std::function<xla::StatusOr<T>()>& runfn, xla::Backend* backend,
int device_ordinal, size_t requested_free_size) {
MemoryReclaimContext mrctx(backend, device_ordinal, requested_free_size);
while (true) {
// We assume that runfn is a relatively fast-fail function compared to the
// operations required to free up the required memory. Here we call into the
// TryFreeMemoryStep() API multiple times, which will run progressively more
// expensive operations.
auto result_or = runfn();
if (result_or.status().code() != error::RESOURCE_EXHAUSTED) {
return result_or;
}
TF_RETURN_IF_ERROR(TryFreeMemoryStep(&mrctx, result_or.status()));
}
}
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_XRT_XRT_MEMORY_MANAGER_H_

View File

@ -0,0 +1,108 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility functions in support of the XRT API.
#ifndef TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_
#define TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_
#include <cstddef>
namespace tensorflow {
// Reference counted smart pointer for XRT objects providing the standard
// Ref()/Unref() APIs.
template <typename T>
class RefPtr {
public:
RefPtr() = default;
// Creates a RefPtr from a pointer. This is an ownership transfer operation,
// and the caller has to own a valid reference to ptr (unless ptr is nullptr).
RefPtr(T* ptr) : ptr_(ptr) {} // NOLINT
RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); }
RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; }
~RefPtr() { Release(ptr_); }
RefPtr& operator=(const RefPtr& other) {
if (this != &other) {
Acquire(other.ptr_);
Release(ptr_);
ptr_ = other.ptr_;
}
return *this;
}
RefPtr& operator=(RefPtr&& other) {
if (this != &other) {
Release(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
operator bool() const { return ptr_ != nullptr; } // NOLINT
bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; }
bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; }
bool operator==(const T* ptr) const { return ptr_ == ptr; }
bool operator!=(const T* ptr) const { return ptr_ != ptr; }
bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; }
bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; }
T* get() const { return ptr_; }
T* operator->() const {
CHECK(ptr_ != nullptr); // Crash OK
return ptr_;
}
T& operator*() const {
CHECK(ptr_ != nullptr); // Crash OK
return *ptr_;
}
T* release() {
T* ptr = ptr_;
ptr_ = nullptr;
return ptr;
}
// Resets the RefPtr from a pointer. This is an ownership transfer operation,
// and the caller has to own a valid reference to ptr (unless ptr is nullptr).
void reset(T* ptr = nullptr) {
Release(ptr_);
ptr_ = ptr;
}
private:
static void Release(T* ptr) {
if (ptr != nullptr) {
ptr->Unref();
}
}
static void Acquire(T* ptr) {
if (ptr != nullptr) {
ptr->Ref();
}
}
T* ptr_ = nullptr;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_XRT_XRT_REFPTR_H_

View File

@ -18,31 +18,24 @@ limitations under the License.
#include "tensorflow/compiler/xrt/xrt_state.h"
#include <stdint.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
namespace tensorflow {
namespace {
// Helper typedef to make ShapeTree ForEach helper lambda signatures more
// readable. They need a type of const T& where in this case T is the
// following pointer.
typedef XRTBufferAllocation* XRTBufferAllocationPtr;
class BufferAllocStats {
public:
struct Stats {
@ -71,26 +64,15 @@ class BufferAllocStats {
std::map<int64, Stats> stats_;
};
const char* kTupleContainer = "tuples";
int64 get_uid() {
int64 uid;
do {
uid = random::New64() & INT64_MAX;
} while (uid == XRTTupleAllocation::InvalidKey());
return uid;
}
BufferAllocStats* GetAllocStats() {
static BufferAllocStats* stats = new BufferAllocStats();
return stats;
}
Status AllocateScopedShapedBuffer(
xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
const xla::Shape& shape, std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
auto transfer_manager = backend->transfer_manager();
auto allocator = backend->memory_allocator();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
// XLA may use a different representation on device than the representation on
@ -111,14 +93,14 @@ Status AllocateScopedShapedBuffer(
// it goes out of scope. That's useful if we return early as the result of an
// error allocating one of the later buffers.
*buffer = absl::make_unique<xla::ScopedShapedBuffer>(
shape, on_device_shape, allocator, device_ordinal);
shape, on_device_shape, backend->memory_allocator(), device_ordinal);
for (auto& index_to_buffer : (*buffer)->buffers()) {
xla::Shape subshape =
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory buffer,
allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
memory_manager->Allocate(backend, device_ordinal, size));
// Move our buffer into shaped_buffer, which takes ownership of it.
index_to_buffer.second = buffer.Release();
VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
@ -136,8 +118,7 @@ Status AllocateScopedShapedBuffer(
XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
int device_ordinal,
se::DeviceMemoryAllocator* allocator)
: size_(allocation.size()),
allocation_(allocation),
: allocation_(allocation),
device_ordinal_(device_ordinal),
allocator_(allocator) {
if (VLOG_IS_ON(2)) {
@ -163,11 +144,6 @@ const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
return allocation_;
}
void XRTBufferAllocation::DiscardAllocation() {
// Replace the allocation with a null.
allocation_ = se::DeviceMemoryBase();
}
XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
se::DeviceMemoryAllocator* allocator,
const xla::Shape& on_host_shape,
@ -176,23 +152,29 @@ XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
allocator_(allocator),
on_host_shape_(on_host_shape),
on_device_shape_(on_device_shape),
buffers_(&on_device_shape_) {}
buffers_(&on_device_shape_),
pin_count_(0) {}
XRTTupleAllocation::~XRTTupleAllocation() {
for (auto& buffer : buffers_) {
buffer.second->Unref();
XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); }
void XRTTupleAllocation::ReleaseBuffers() {
for (auto& index_buffer : buffers_) {
if (index_buffer.second != nullptr) {
index_buffer.second->Unref();
index_buffer.second = nullptr;
}
}
}
/*static*/ Status XRTTupleAllocation::CreateAndTransfer(
const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal,
const xla::LiteralBase& literal, XRTMemoryManager* memory_manager,
xla::Backend* backend, int device_ordinal,
XRTTupleAllocation** allocation) {
auto transfer_manager = backend->transfer_manager();
auto allocator = backend->memory_allocator();
std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
backend, device_ordinal, literal.shape(), &scoped_buffer));
TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
device_ordinal, literal.shape(),
&scoped_buffer));
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
stream.get(), literal, *scoped_buffer));
@ -202,11 +184,13 @@ XRTTupleAllocation::~XRTTupleAllocation() {
// call. To avoid a leak, there must be no error-case returns from here until
// the end of the method.
auto shaped_buffer = scoped_buffer->release();
*allocation = new XRTTupleAllocation(device_ordinal, allocator,
shaped_buffer.on_host_shape(),
shaped_buffer.on_device_shape());
*allocation = new XRTTupleAllocation(
device_ordinal, backend->memory_allocator(),
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
(*allocation)
->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
device_ordinal);
(*allocation)->SetDeviceMemorySize();
return Status::OK();
}
@ -220,24 +204,22 @@ XRTTupleAllocation::~XRTTupleAllocation() {
shaped_buffer.on_device_shape());
(*allocation)
->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
(*allocation)->SetDeviceMemorySize();
return Status::OK();
}
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
Status XRTTupleAllocation::ToLiteral(xla::Backend* backend,
xla::MutableLiteralBase* literal) {
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
mutex_lock lock(lock_);
return literal_ == nullptr ? StoreToLiteral(backend, literal)
: literal->CopyFrom(*literal_);
}
// Validate the allocation buffers as if nulls gets to
// TransferLiteralFromDevice() a CHECK is issued.
xla::ShapedBuffer shaped_buffer = ToShapedBuffer();
for (auto& index_buffer : shaped_buffer.buffers()) {
if (index_buffer.second.is_null()) {
return errors::InvalidArgument("Literal buffer at index ",
index_buffer.first.ToString(),
" has been released");
}
}
Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend,
xla::MutableLiteralBase* literal) {
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
return transfer_manager->TransferLiteralFromDevice(stream.get(),
shaped_buffer, *literal);
}
@ -250,52 +232,102 @@ Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
" device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
}
mutex_lock lock(lock_);
if (literal_ != nullptr) {
// The allocation is currently swapped out, and we have a host literal for
// its content. Just update the host literal with the new value.
return literal_->CopyFrom(literal);
}
TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
auto transfer_manager = backend->transfer_manager();
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
ToShapedBuffer());
shaped_buffer);
}
xla::StatusOr<bool> XRTTupleAllocation::SwapOut(xla::Backend* backend,
bool swap_pinned) {
mutex_lock lock(lock_);
if (literal_ == nullptr && (!IsPinned() || swap_pinned)) {
xla::Literal literal(on_host_shape());
TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal));
ReleaseBuffers();
literal_ = absl::make_unique<xla::Literal>(std::move(literal));
return true;
}
return false;
}
xla::StatusOr<bool> XRTTupleAllocation::SwapIn(XRTMemoryManager* memory_manager,
xla::Backend* backend) {
// We need to call AllocateScopedShapedBuffer() outside the locks, since the
// XRTMemoryManager might end up calling back into the SwapOut() API.
// So we do a quick check before using the IsSwapped() API, and it can happen
// that the allocation becomes swapped in after the check. This means which we
// will end up doing an allocation, and then releasing it soon after (via its
// scoped variables). This is an unlikely scenario (two threads calling
// SwapIn() on the same allocation) though.
if (!IsSwapped()) {
return false;
}
auto transfer_manager = backend->transfer_manager();
std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
TF_RETURN_IF_ERROR(
AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(),
on_host_shape(), &scoped_buffer));
TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
mutex_lock lock(lock_);
if (literal_ != nullptr) {
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
stream.get(), *literal_, *scoped_buffer));
auto shaped_buffer = scoped_buffer->release();
InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
device_ordinal());
literal_ = nullptr;
return true;
}
return false;
}
xla::StatusOr<bool> XRTTupleAllocation::PinAndSwapIn(
XRTMemoryManager* memory_manager, xla::Backend* backend) {
Pin();
return SwapIn(memory_manager, backend);
}
bool XRTTupleAllocation::IsSwapped() const {
mutex_lock lock(lock_);
return literal_ != nullptr;
}
int64 XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); }
int64 XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); }
bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; }
void XRTTupleAllocation::DiscardAllocation(
const xla::ShapeIndex& buffer_index) {
buffers_.element(buffer_index)->DiscardAllocation();
}
const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
const xla::Shape& XRTTupleAllocation::on_host_shape() const {
return on_host_shape_;
}
const xla::Shape& XRTTupleAllocation::on_device_shape() {
const xla::Shape& XRTTupleAllocation::on_device_shape() const {
return on_device_shape_;
}
int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; }
const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const {
return buffers_.element({})->allocation();
}
/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
XRTTupleAllocation** allocation) {
string key_string = absl::StrCat(key);
TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
return Status::OK();
}
/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
int64 key) {
string key_string = absl::StrCat(key);
return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
}
/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) {
VLOG(1) << "Releasing all XRT held device memory";
return rm->Cleanup(kTupleContainer);
}
// Helper typedef to make ShapeTree ForEach helper lambda signatures more
// readable. They need a type of const T& where in this case T is the
// following pointer.
typedef XRTBufferAllocation* XRTBufferAllocationPtr;
/*static*/ Status XRTTupleAllocation::MakeSubBuffer(
XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
XRTTupleAllocation** allocation, bool alias_parent_allocation) {
@ -330,46 +362,21 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
parent_index.push_back(index[i]);
}
*buffer = parent->buffers_.element(parent_index);
*parent->buffers_.mutable_element(parent_index) =
new XRTBufferAllocation(se::DeviceMemoryBase(),
parent->device_ordinal(),
parent->allocator_);
*parent->buffers_.mutable_element(parent_index) = nullptr;
});
}
(*allocation)->SetDeviceMemorySize();
return Status::OK();
}
/* static */ Status XRTTupleAllocation::CompactAllocations(
ResourceMgr* rm, xla::Backend* backend, int device_ordinal) {
std::vector<ResourceMgr::ResourceEntry> tuples;
rm->GetContainerResources(kTupleContainer, &tuples);
std::vector<std::pair<string, xla::Literal>> host_tuples;
for (auto& rm_tuple : tuples) {
XRTTupleAllocation* tuple =
dynamic_cast<XRTTupleAllocation*>(rm_tuple.resource.get());
if (tuple->device_ordinal() == device_ordinal) {
xla::Literal literal(tuple->on_host_shape());
TF_RETURN_IF_ERROR(tuple->ToLiteral(backend, device_ordinal, &literal));
host_tuples.emplace_back(rm_tuple.name, std::move(literal));
// At this point there are two references held onto the XRTTupleAllocation
// object. One in the ResourceMgr, which we release here, and one held
// within the tuples vector, which we release in the tuples.clear() call
// below.
TF_RETURN_IF_ERROR(
rm->Delete<XRTTupleAllocation>(kTupleContainer, rm_tuple.name));
void XRTTupleAllocation::SetDeviceMemorySize() {
size_t size = 0;
for (auto& index_buffer : buffers_) {
if (index_buffer.second != nullptr) {
size += index_buffer.second->allocation().size();
}
}
tuples.clear();
for (auto& name_literal : host_tuples) {
XRTTupleAllocation* tuple;
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateAndTransfer(
name_literal.second, backend, device_ordinal, &tuple));
TF_RETURN_IF_ERROR(rm->Create(kTupleContainer, name_literal.first, tuple));
}
return Status::OK();
device_memory_size_ = size;
}
/* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
@ -414,7 +421,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
}
/*static*/ Status XRTTupleAllocation::MakeTuple(
xla::Backend* backend, int device_ordinal,
XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
const xla::ShapeTree<ExpandedTupleInput>& elements,
XRTTupleAllocation** allocation) {
auto transfer_manager = backend->transfer_manager();
@ -429,8 +436,8 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
// The aliasing is determined below based on whether or not all the inputs are
// released while being transferred. allocation_tmp is a local pointer that is
// copied to *allocation at the end only if the method succeeds.
auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
host_shape, device_shape);
XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation(
device_ordinal, allocator, host_shape, device_shape);
core::ScopedUnref allocation_unref(allocation_tmp);
// First allocate device memory for the new tuple index tables, one at each
// internal node of the elements tree. Do this in a separate pass into a
@ -444,12 +451,12 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
[&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
if (!elements.IsLeaf(index)) {
xla::Shape subshape =
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(device_shape, index);
uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
allocator->Allocate(device_ordinal, size,
/*retry_on_failure=*/false));
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory buffer,
memory_manager->Allocate(backend, device_ordinal, size));
VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index "
<< index.ToString();
// Move the new buffer into new_tuple_buffers, which takes ownership
@ -487,10 +494,8 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
// validated that release_allocation_after_use is false if
// element.allocation appears in more than one leaf.
element.allocation->buffers_.ForEachMutableElement(
[&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
*buffer = new XRTBufferAllocation(
se::DeviceMemoryBase(), element.allocation->device_ordinal(),
element.allocation->allocator_);
[&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) {
*buffer = nullptr;
});
} else {
// Increment the refcount on each newly-aliased buffer.
@ -506,6 +511,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
allocator);
}
});
allocation_tmp->SetDeviceMemorySize();
// Because the internal nodes of tuple_buffers are exactly the new index
// tables, WriteTupleIndexTables will write only the new index tables and not
// rewrite the index tables for the existing allocations.
@ -519,36 +525,47 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr;
return Status::OK();
}
Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
*key = get_uid();
string key_string = absl::StrCat(*key);
return rm->Create(kTupleContainer, key_string, this);
}
bool XRTTupleAllocation::IsExclusiveOwner() {
for (const auto& buffer : buffers_) {
if (!buffer.second->RefCountIsOne()) return false;
bool XRTTupleAllocation::IsExclusiveOwner() const {
for (const auto& index_buffer : buffers_) {
if (index_buffer.second != nullptr &&
!index_buffer.second->RefCountIsOne()) {
return false;
}
}
return true;
}
size_t XRTTupleAllocation::GetDeviceMemorySize() const {
return device_memory_size_;
}
void XRTTupleAllocation::InitializeFromShapedBuffer(
const xla::ShapedBuffer& shaped_buffer,
se::DeviceMemoryAllocator* allocator, int device_ordinal) {
for (auto& buffer : buffers_) {
for (auto& index_buffer : buffers_) {
if (index_buffer.second != nullptr) {
index_buffer.second->Unref();
}
// Make a reference-counted version of the allocated buffer.
buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
device_ordinal, allocator);
index_buffer.second = new XRTBufferAllocation(
shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator);
}
}
xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
allocator_->platform(), device_ordinal_);
for (const auto& buffer : buffers_) {
shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
index_buffer.second->allocation().is_null()) {
return errors::InvalidArgument("Literal buffer at index ",
index_buffer.first.ToString(),
" has been released");
}
shaped_buffer.set_buffer(index_buffer.second->allocation(),
index_buffer.first);
}
return shaped_buffer;
return std::move(shaped_buffer);
}
Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
@ -556,37 +573,69 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
const xla::ShapeIndex& dest_index) {
XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
// We allow the destination size being zero, because there are cases where we
// are coming in later filling in null/uninitialized device buffers.
// In all other cases, the size of the new buffer must match.
if (source_buffer->size() != dest_buffer->size() &&
dest_buffer->size() != 0) {
return errors::InvalidArgument(
"Source buffer at index ", source_index.ToString(),
" does not match the size of destination buffer at index ",
dest_index.ToString(), ": ", source_buffer->size(), " vs ",
dest_buffer->size());
if (dest_buffer != nullptr) {
// We allow the destination size being zero, because there are cases where
// we are coming in later filling in null/uninitialized device buffers. In
// all other cases, the size of the new buffer must match.
if (source_buffer->allocation().size() !=
dest_buffer->allocation().size() &&
dest_buffer->allocation().size() != 0) {
return errors::InvalidArgument(
"Source buffer at index ", source_index.ToString(),
" does not match the size of destination buffer at index ",
dest_index.ToString(), ": ", source_buffer->allocation().size(),
" vs ", dest_buffer->allocation().size());
}
} else {
const xla::Shape& source_subshape =
xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index);
const xla::Shape& dest_subshape =
xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index);
if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) {
return errors::InvalidArgument(
"Source and destination subshapes do not match: source=",
xla::ShapeUtil::HumanStringWithLayout(source_subshape),
" dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape));
}
}
*buffers_.mutable_element(dest_index) = source_buffer;
source_buffer->Ref();
dest_buffer->Unref();
if (dest_buffer != nullptr) {
// If we handed over the ownership of a buffer in ToDeviceMemoryTree(), we
// will be called here on the way back from execution, to alias back the
// buffer at that index. In that case the buffers will be the same. So we
// need to discard the memory at the destination buffer, before releasing
// the reference.
if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) &&
dest_buffer != source_buffer) {
dest_buffer->DiscardAllocation();
}
dest_buffer->Unref();
}
return Status::OK();
}
xla::ShapeTree<xla::MaybeOwningDeviceMemory>
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
XRTTupleAllocation::ToDeviceMemoryTree(
const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
for (const auto& buffer : buffers_) {
if (!release_checker(buffer.first)) {
*shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
index_buffer.second->allocation().is_null()) {
return errors::InvalidArgument("Literal buffer at index ",
index_buffer.first.ToString(),
" has been released");
}
if (!release_checker(index_buffer.first)) {
*shaped_tree.mutable_element(index_buffer.first) =
index_buffer.second->allocation();
} else {
*shaped_tree.mutable_element(buffer.first) = se::OwningDeviceMemory(
buffer.second->allocation(), device_ordinal_, allocator_);
DiscardAllocation(buffer.first);
// We keep the ownership of the device memory here.
*shaped_tree.mutable_element(index_buffer.first) = se::OwningDeviceMemory(
index_buffer.second->allocation(), device_ordinal_, allocator_);
}
}
return shaped_tree;
return std::move(shaped_tree);
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
#define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
#include <atomic>
#include <functional>
#include <memory>
#include <string>
@ -27,17 +28,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/compiler/xrt/xrt_refptr.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace tensorflow {
// Cannot include xrt_memory_manager.h here, as it needs to include this file.
class XRTMemoryManager;
// TODO(misard) make this a Tensor if and when that makes sense.
// A reference-counted wrapper around a buffer allocation. This maps an XLA
// tuple index or a non-tuple XLA shape to a region of device memory. The device
@ -51,36 +56,23 @@ class XRTBufferAllocation : public core::RefCounted {
// The region of device memory being wrapped.
const se::DeviceMemoryBase& allocation();
// Sets the DeviceMemoryBase to be null. DiscardAllocation should be called
// when ownership of the underlying buffer has been transferred, e.g., to an
// output buffer when input and output buffers are aliased during
// execution. The call to DiscardAllocation prevents any device buffer being
// freed when the reference count drops to zero.
void DiscardAllocation();
// Returns the expected size of the allocation. Since DiscardAllocation() will
// set allocation_ to {null,0}, and since later we might want to replace the
// discarded buffer with a new one, we need to be able to verify the size
// compatibility.
uint64 size() const { return size_; }
void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); }
private:
uint64 size_ = 0;
se::DeviceMemoryBase allocation_;
int device_ordinal_;
se::DeviceMemoryAllocator* allocator_;
};
// Entry in the resource manager corresponding to an allocation handle returned
// to a client. The handle identifies an immutable tuple of data in device
// memory. New handles can be created in three ways: by passing a literal in
// which case device memory is allocated and the literal is transferred to that
// memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by
// aliasing a vector of existing handles to create a new tuple. The underlying
// storage is reference-counted. When a handle is released, the reference count
// of each storage buffer is decremented, and buffers with no outstanding
// references are freed.
class XRTTupleAllocation : public ResourceBase {
// A XRTTupleAllocation represents an allocated memory area on the device.
// New tuples can be created in three ways: by passing a literal in which case
// device memory is allocated and the literal is transferred to that memory; by
// aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a
// vector of existing handles to create a new tuple. The underlying storage is
// reference-counted. When a handle is released, the reference count of each
// storage buffer is decremented, and buffers with no outstanding references are
// freed.
class XRTTupleAllocation : public core::RefCounted {
public:
~XRTTupleAllocation() override;
@ -88,6 +80,7 @@ class XRTTupleAllocation : public ResourceBase {
// literal to that memory, and returns a XRTTupleAllocation handle to the
// allocated buffers.
static Status CreateAndTransfer(const xla::LiteralBase& literal,
XRTMemoryManager* memory_manager,
xla::Backend* backend, int device_ordinal,
XRTTupleAllocation** allocation);
@ -106,16 +99,11 @@ class XRTTupleAllocation : public ResourceBase {
XRTTupleAllocation** allocation,
bool alias_parent_allocation);
// Runs a compaction cycle which copies the device data to host, frees the
// device data, and then reallocate and send back the data.
static Status CompactAllocations(ResourceMgr* rm, xla::Backend* backend,
int device_ordinal);
// A structure describing a leaf of a tree of tuples to expand. Each leaf
// contains an allocation and indicates whether or not the allocation's handle
// should be freed after incorporating its buffers into the expanded tree.
struct ExpandedTupleInput {
XRTTupleAllocation* allocation;
RefPtr<XRTTupleAllocation> allocation;
bool release_allocation_after_use;
};
@ -129,52 +117,70 @@ class XRTTupleAllocation : public ResourceBase {
// an input is repeated, release_input_handle must be false for every leaf
// where that input appears. The latter property is not validated by MakeTuple
// and must be enforced by the caller.
static Status MakeTuple(xla::Backend* backend, int device_ordinal,
static Status MakeTuple(XRTMemoryManager* memory_manager,
xla::Backend* backend, int device_ordinal,
const xla::ShapeTree<ExpandedTupleInput>& elements,
XRTTupleAllocation** allocation);
// Retrieves the allocation interned under key from rm. The caller owns a
// reference to allocation after looking it up.
static Status Lookup(ResourceMgr* rm, int64 key,
XRTTupleAllocation** allocation);
// Deletes the reference in the rm to an allocation interned under key.
static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key);
// Releases all the device memory allocated by XRT within the resource
// manager.
static Status ReleaseAllAllocations(ResourceMgr* rm);
// Returns the invalid key value, which will be never generated by the
// Intern() API.
static int64 InvalidKey() { return 0; }
// Adds the allocation to a ResourceMgr and returns the key that will be used
// to retrieve it. Transfers a reference on *this to rm.
Status Intern(ResourceMgr* rm, int64* key);
// Copies the allocation from device to host and returns it in literal.
Status ToLiteral(xla::Backend* backend, int device_ordinal,
xla::MutableLiteralBase* literal);
Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal);
// Write a new literal value to the allocation.
Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
// Stores the content of the tuple allocation into the internal literal, and
// releases all the device buffers. The swap_pinned flag tells whether a
// pinned allocation should be swapped out. It should be false on all cases,
// but during the memory compaction operation from the XRTMemoryManager.
// Returns a boolean telling whether the allocation was swapped out.
xla::StatusOr<bool> SwapOut(xla::Backend* backend, bool swap_pinned);
// Allocates the device memory required to store the tuple value held within
// the internal literal, and transfer the literal value into the device
// memory. Returns a boolean telling whether the allocation was swapped in.
xla::StatusOr<bool> SwapIn(XRTMemoryManager* memory_manager,
xla::Backend* backend);
// Pins the allocation first, then swap it in (if it is not already). After
// this API returns, the allocation is pinned and its content on device
// memory. The caller is responsible for releasing the pin-count using the
// Unpin() API.
xla::StatusOr<bool> PinAndSwapIn(XRTMemoryManager* memory_manager,
xla::Backend* backend);
// Checks whether the allocation is currently swapped out.
bool IsSwapped() const;
// Increases the pin-count of this allocation. If the pin-count is greater
// than 0, the allocation cannot be swapped. Returned the pin-count value
// before the increase.
int64 Pin();
// Decreases the pin-count of this allocation. Returned the pin-count value
// before the decrease.
int64 Unpin();
// Checks whether the allocation is currently pinned.
bool IsPinned() const;
// True if none of the buffers in the allocation are aliased by any other live
// handle.
bool IsExclusiveOwner();
bool IsExclusiveOwner() const;
// Retrieves the footprint in terms of device memory, of this allocation.
size_t GetDeviceMemorySize() const;
// The ordinal of the device holding this tuple.
int device_ordinal();
int device_ordinal() const;
// Returns the shape of the tuple as seen by the host.
const xla::Shape& on_host_shape();
const xla::Shape& on_host_shape() const;
// Returns the shape of the tuple as stored on the device.
const xla::Shape& on_device_shape();
const xla::Shape& on_device_shape() const;
// Returns the buffer pointed to by the root of the tuple.
const se::DeviceMemoryBase& root_allocation();
const se::DeviceMemoryBase& root_allocation() const;
// Stops managing the storage for the allocation at buffer_index, e.g.,
// because it has been aliased to the output buffer of a computation.
@ -182,7 +188,7 @@ class XRTTupleAllocation : public ResourceBase {
// Returns the tree of allocations as a ShapedBuffer. This tree may not have
// the same shape as on_host_shape.
xla::ShapedBuffer ToShapedBuffer();
xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer();
// Aliases the source buffer at source_index into the current tuple allocation
// dest_index.
@ -191,14 +197,22 @@ class XRTTupleAllocation : public ResourceBase {
const xla::ShapeIndex& dest_index);
// Returns the device memory tree of this allocation. If the release_checker
// function returns true for a given index, the ownership of the device memory
// at that index is transferred to the result. Every attempt to read the value
// at that index will fail.
xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree(
// function returns true for a given index, an owned device memory is returned
// to the caller. But the tuple allocation cannot release the ownership in
// full, as the execute operation might fail. So we rely on a call to
// AliasBufferFrom() to re-alias back the buffers. This is not great (to say
// the least), but the current aliasing logic relies on
// MaybeOwningDeviceMemory being owned, to detect the fact that the user may
// want to alias a buffer. Unfortunately to do that, it needs to release the
// ownership, which is a problem if the execute will fail.
// This calls for a refactoring of the whole owning/maybe-owning interface to
// introduce a sharing concept (IOW shared_ptr model vs. unique_ptr).
// We'd need something similar to XRTTupleAllocation instead of
// ScopedShapedBuffer, which wants ownership and does not allow sharing.
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
ToDeviceMemoryTree(
const std::function<bool(const xla::ShapeIndex&)>& release_checker);
string DebugString() const override { return "XLA allocation handle"; }
private:
// Creates a new handle with (tuple) shape.
XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator,
@ -211,6 +225,21 @@ class XRTTupleAllocation : public ResourceBase {
se::DeviceMemoryAllocator* allocator,
int device_ordinal);
// Releases all the XRTBufferAllocation buffer references and set the
// corresponding shape tree entry to nullptr.
void ReleaseBuffers();
// Stores the content of the allocation from device memory to the target host
// literal.
Status StoreToLiteral(xla::Backend* backend,
xla::MutableLiteralBase* literal);
// Sets the total size of the buffers held within this allocation buffers.
// This API should be called once when an XRTTupleAllocation object is
// created, as the XRTTupleAllocation shapes never change, and hence the
// device memory size.
void SetDeviceMemorySize();
// Takes a tree 'elements' where each leaf is an allocation, validates that
// they are all on device_ordinal managed by allocator, and returns in
// host_shape and device_shape the host/device shapes of the expanded tree,
@ -221,9 +250,13 @@ class XRTTupleAllocation : public ResourceBase {
se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
xla::Shape* device_shape);
// The lock which protects the internal operations of the tuple allocation. Is
// mutable to allow const-like operations to be declared as such.
mutable mutex lock_;
// Location of the memory that is being managed.
int device_ordinal_;
se::DeviceMemoryAllocator* allocator_;
const int device_ordinal_;
se::DeviceMemoryAllocator* const allocator_;
// The shape that the caller thinks the tuple has.
const xla::Shape on_host_shape_;
@ -233,6 +266,13 @@ class XRTTupleAllocation : public ResourceBase {
// The tree of reference-counted buffers, which uses on_device_shape_ as its
// shape.
xla::ShapeTree<XRTBufferAllocation*> buffers_;
// The footprint of the allocation, when residing on device memory.
size_t device_memory_size_ = 0;
// If the allocation is swapped out, this is the literal storing its content.
std::unique_ptr<xla::Literal> literal_;
// A pinned allocation is one which cannot be swapped out. If pin_count_ > 0
// then the allocation is pinned.
std::atomic<int64> pin_count_;
};
} // namespace tensorflow

View File

@ -25,6 +25,88 @@ limitations under the License.
namespace tensorflow {
namespace {
// The ScopedHandles data structure is used in the ExecuteChained() API and its
// task is to track tuple allocation registrations. It is used both the track
// intermediate results of a chained computation, or its final results. Anything
// which is marked to be released, will be released using the XRTMemoryManager
// once the object is destroyed (unless an explicit call to Drop() or Release()
// is made).
class ScopedHandles {
public:
explicit ScopedHandles(RefPtr<XRTMemoryManager> memory_manager)
: memory_manager_(std::move(memory_manager)) {}
~ScopedHandles() {
for (size_t i = 0; i < handles_.size(); ++i) {
if (handles_release_[i]) {
memory_manager_->Release(handles_[i]).IgnoreError();
}
}
}
int64 operator[](size_t index) const { return handles_.at(index); }
size_t size() const { return handles_.size(); }
// Adds the given handle at the index position, by marking it releasable
// according to the release argument. If an existing, and to-be-released
// handle already exists at the same index, it will be released.
Status Add(size_t index, int64 handle, bool release) {
if (index >= handles_.size()) {
handles_.resize(index + 1, XRTMemoryManager::InvalidKey());
handles_release_.resize(index + 1, false);
}
if (handles_release_[index]) {
Status status = memory_manager_->Release(handles_[index]);
if (!status.ok()) {
if (release) {
memory_manager_->Release(handle).IgnoreError();
}
return status;
}
}
handles_[index] = handle;
handles_release_[index] = release;
return Status::OK();
}
// Adds a to-be-released tuple allocation at the given index.
Status Add(size_t index, RefPtr<XRTTupleAllocation> tuple) {
return Add(index, memory_manager_->Register(std::move(tuple)),
/*release=*/true);
}
// Drops the handle at the given index, and releases it using the
// XRTMemoryManager::Release() if marked as to-be-released.
Status Drop(size_t index) {
if (handles_release_.at(index)) {
TF_RETURN_IF_ERROR(memory_manager_->Release(handles_[index]));
}
Release(index);
return Status::OK();
}
// Releases the handle at the given index. The destructor will not use that
// XRTMemoryManager::Release() API on such handle.
int64 Release(size_t index) {
int64 handle = handles_.at(index);
handles_[index] = XRTMemoryManager::InvalidKey();
handles_release_[index] = false;
return handle;
}
// Looks up the handle stored at the given index, and returns the matching
// tuple allocation.
xla::StatusOr<RefPtr<XRTTupleAllocation>> Lookup(size_t index) const {
return memory_manager_->Lookup(handles_.at(index));
}
private:
RefPtr<XRTMemoryManager> memory_manager_;
std::vector<int64> handles_;
std::vector<bool> handles_release_;
};
bool DebugOptionsPassThroughEnabled() {
const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
bool enabled =
@ -61,6 +143,23 @@ Status MakeOutput(const RefPtr<XRTTupleAllocation>& output, int64 index,
return Status::OK();
}
Status PopulateOpWorkingSet(xla::Backend* backend,
const xrt::XRTChainedExecuteOp& op,
int current_index, const ScopedHandles& outputs,
XRTMemoryManager::WorkingSet* working_set) {
for (int i = 0; i < op.inputs_size(); ++i) {
auto& input = op.inputs(i);
if (input.op_index() >= current_index) {
return errors::InvalidArgument(
"Input index ", input.op_index(),
" is above the current position: ", current_index);
}
TF_RETURN_IF_ERROR(
working_set->LookupAndPin(backend, outputs[input.op_index()]));
}
return Status::OK();
}
} // namespace
xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
@ -81,7 +180,7 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
}
xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
OpKernelContext* context, ResourceMgr* rm, const char* input_name) {
OpKernelContext* context, const char* input_name) {
OpInputList arg_list;
TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list));
// Concatenate all input uids from list of scalars-or-vectors carrying them.
@ -102,7 +201,8 @@ xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
return std::move(input_coords);
}
Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm,
Status CreateExecuteOutput(OpKernelContext* context,
XRTMemoryManager* memory_manager,
RefPtr<XRTTupleAllocation> output_tuple,
bool return_exploded_tuple) {
if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) {
@ -117,23 +217,21 @@ Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm,
TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
output_tuple.get(), {i}, &suballocation,
/*alias_parent_allocation=*/false));
int64 key;
TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key));
output_tensor->vec<int64>()(i) = key;
output_tensor->vec<int64>()(i) = memory_manager->Register(suballocation);
}
} else {
Tensor* output_tensor;
TF_RETURN_IF_ERROR(
context->allocate_output(0, TensorShape({}), &output_tensor));
int64 key;
TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
output_tuple.release();
output_tensor->scalar<int64>()() = key;
output_tensor->scalar<int64>()() =
memory_manager->Register(std::move(output_tuple));
}
return Status::OK();
}
Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm,
Status ExecuteChained(OpKernelContext* context,
const RefPtr<XRTMemoryManager>& memory_manager,
xla::Backend* backend, int device_ordinal,
const xrt::XRTChainedExecutePlan& plan,
const xrt::XRTChainedExecuteConfig& config,
const ChainedExecuteFn& execute_op) {
@ -145,41 +243,43 @@ Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm,
uses[input.op_index()] += 1;
}
}
std::vector<RefPtr<XRTTupleAllocation>> ops_outputs(plan.ops_size());
std::vector<RefPtr<XRTTupleAllocation>> results;
ScopedHandles outputs(memory_manager);
ScopedHandles results(memory_manager);
for (int i = 0; i < plan.ops_size(); ++i) {
auto& op = plan.ops(i);
if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) {
// This operation is a device data load. Fetch the proper
// XRTTupleAllocation behind the user handle and fill up the op output at
// the current position.
XRTTupleAllocation* tuple;
TF_RETURN_IF_ERROR(
XRTTupleAllocation::Lookup(rm, op.data_handle(), &tuple));
ops_outputs[i].reset(tuple);
// This operation is a device data load. Set the handle as output and
// leave the release flag off, since this is not an intermediate output.
TF_RETURN_IF_ERROR(outputs.Add(i, op.data_handle(), /*release=*/false));
} else if (op.op_oneof_case() ==
xrt::XRTChainedExecuteOp::kComputationHandle) {
// This is an XRT execute operation, forward to the device specific
// handler.
TF_ASSIGN_OR_RETURN(ops_outputs[i], execute_op(op, i, ops_outputs));
// handler. Populating the working set makes sure the input allocations
// for this execute operations are pinned to device memory.
XRTMemoryManager::WorkingSet working_set(memory_manager);
TF_RETURN_IF_ERROR(
PopulateOpWorkingSet(backend, op, i, outputs, &working_set));
TF_ASSIGN_OR_RETURN(auto tuple,
execute_op(op, working_set.PinnedTuples()));
TF_RETURN_IF_ERROR(outputs.Add(i, std::move(tuple)));
} else {
return errors::InvalidArgument(
"Undefined operation kind at post-order position ", i);
}
// If the result of this chained operation is an output result, feed the
// results vector at the desired position.
// results at the desired position.
for (auto& output : op.outputs()) {
if (output.result_index() >= results.size()) {
results.resize(output.result_index() + 1);
}
TF_RETURN_IF_ERROR(MakeOutput(ops_outputs[i], output.output_index(),
&results[output.result_index()]));
TF_ASSIGN_OR_RETURN(auto tuple, outputs.Lookup(i));
RefPtr<XRTTupleAllocation> result;
TF_RETURN_IF_ERROR(MakeOutput(tuple, output.output_index(), &result));
TF_RETURN_IF_ERROR(results.Add(output.result_index(), std::move(result)));
}
// Drop intermediate results which have no more users.
for (auto& input : op.inputs()) {
uses[input.op_index()] -= 1;
if (uses[input.op_index()] == 0) {
ops_outputs[input.op_index()].reset();
TF_RETURN_IF_ERROR(outputs.Drop(input.op_index()));
}
}
}
@ -188,12 +288,7 @@ Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm,
TF_RETURN_IF_ERROR(context->allocate_output(
0, TensorShape({static_cast<int64>(results.size())}), &output_tensor));
for (size_t i = 0; i < results.size(); ++i) {
int64 key = XRTTupleAllocation::InvalidKey();
if (results[i] != nullptr) {
TF_RETURN_IF_ERROR(results[i]->Intern(rm, &key));
results[i].release();
}
output_tensor->vec<int64>()(i) = key;
output_tensor->vec<int64>()(i) = results.Release(i);
}
return Status::OK();
}

View File

@ -18,97 +18,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xrt/xrt.pb.h"
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
#include "tensorflow/compiler/xrt/xrt_refptr.h"
#include "tensorflow/compiler/xrt/xrt_state.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Reference counted smart pointer for XRT objects providing the standard
// Ref()/Unref() APIs.
template <typename T>
class RefPtr {
public:
RefPtr() = default;
// Creates a RefPtr from a pointer. This is an ownership transfer operation,
// and the caller has to own a valid reference to ptr (unless ptr is nullptr).
RefPtr(T* ptr) : ptr_(ptr) {}
RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); }
RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; }
~RefPtr() { Release(ptr_); }
RefPtr& operator=(const RefPtr& other) {
if (this != &other) {
Acquire(other.ptr_);
Release(ptr_);
ptr_ = other.ptr_;
}
return *this;
}
RefPtr& operator=(RefPtr&& other) {
if (this != &other) {
Release(ptr_);
ptr_ = other.ptr_;
other.ptr_ = nullptr;
}
return *this;
}
operator bool() const { return ptr_ != nullptr; }
bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; }
bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; }
bool operator==(const T* ptr) const { return ptr_ == ptr; }
bool operator!=(const T* ptr) const { return ptr_ != ptr; }
bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; }
bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; }
T* get() const { return ptr_; }
T* operator->() const {
CHECK(ptr_ != nullptr); // Crash OK
return ptr_;
}
T& operator*() const {
CHECK(ptr_ != nullptr); // Crash OK
return *ptr_;
}
T* release() {
T* ptr = ptr_;
ptr_ = nullptr;
return ptr;
}
// Resets the RefPtr from a pointer. This is an ownership transfer operation,
// and the caller has to own a valid reference to ptr (unless ptr is nullptr).
void reset(T* ptr = nullptr) {
Release(ptr_);
ptr_ = ptr;
}
private:
static void Release(T* ptr) {
if (ptr != nullptr) {
ptr->Unref();
}
}
static void Acquire(T* ptr) {
if (ptr != nullptr) {
ptr->Ref();
}
}
T* ptr_ = nullptr;
};
struct InputCoords {
explicit InputCoords(int64 handle) : handle(handle) {}
InputCoords(int64 handle, xla::ShapeIndex index)
@ -128,12 +50,13 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options);
// Populates the input_coords with a list of input coordinates from a input_name
// op argument.
xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
OpKernelContext* context, ResourceMgr* rm, const char* input_name);
OpKernelContext* context, const char* input_name);
// Create the XRT execute output tensor given the computation result
// (output_tuple). The return_exploded_tuple tells whether a tuple result should
// be returned as vector of handles representing each tuple child.
Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm,
Status CreateExecuteOutput(OpKernelContext* context,
XRTMemoryManager* memory_manager,
RefPtr<XRTTupleAllocation> output_tuple,
bool return_exploded_tuple);
@ -141,9 +64,11 @@ Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm,
// function.
using ChainedExecuteFn =
std::function<xla::StatusOr<RefPtr<XRTTupleAllocation>>(
const xrt::XRTChainedExecuteOp&, int,
const xrt::XRTChainedExecuteOp&,
absl::Span<const RefPtr<XRTTupleAllocation>>)>;
Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm,
Status ExecuteChained(OpKernelContext* context,
const RefPtr<XRTMemoryManager>& memory_manager,
xla::Backend* backend, int device_ordinal,
const xrt::XRTChainedExecutePlan& plan,
const xrt::XRTChainedExecuteConfig& config,
const ChainedExecuteFn& execute_op);

View File

@ -104,21 +104,6 @@ ResourceMgr::ResourceMgr(const string& default_container)
ResourceMgr::~ResourceMgr() { Clear(); }
void ResourceMgr::GetContainerResources(
const string& container, std::vector<ResourceEntry>* resources) const {
resources->clear();
mutex_lock l(mu_);
Container* b = gtl::FindPtrOrNull(containers_, container);
if (b != nullptr) {
resources->reserve(b->size());
for (auto& key_resource : *b) {
ResourceBase* resource = key_resource.second;
resource->Ref();
resources->emplace_back(key_resource.first.second, resource);
}
}
}
void ResourceMgr::Clear() {
mutex_lock l(mu_);
for (const auto& p : containers_) {

View File

@ -145,19 +145,6 @@ class ResourceMgr {
std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
resources) const TF_MUST_USE_RESULT;
// Retrieves all the resources within a container. If the container does not
// exist, it will not be created and the result vector will be empty. The
// resource member of the returned ResourceEntry data structures will own
// a reference to the ResourceBase object(s).
struct ResourceEntry {
ResourceEntry(string name, ResourceBase* resource)
: name(std::move(name)), resource(resource) {}
string name;
std::unique_ptr<ResourceBase, core::RefCountDeleter> resource;
};
void GetContainerResources(const string& container,
std::vector<ResourceEntry>* resources) const;
// If "container" has a resource "name", returns it in
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".