Enable XRT cache to be shared among multiple GPU devices.

Allow XRT GPU work with multi-threaded based replication, where a single process see all the available devices.

PiperOrigin-RevId: 310376508
Change-Id: I25715feaf74ceca421ba8939405f58a0bf68ee59
This commit is contained in:
Davide Libenzi 2020-05-07 09:32:16 -07:00 committed by TensorFlower Gardener
parent 1d4b4a6706
commit 70e9708e23
6 changed files with 91 additions and 28 deletions

View File

@ -50,6 +50,7 @@ class RunId {
public:
// Creates a new, unique RunId.
RunId();
explicit RunId(int64 value) : data_(value) {}
RunId(const RunId&) = default;
RunId& operator=(const RunId&) = default;

View File

@ -158,7 +158,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
argument_layout_ptrs[i] = &argument_layouts[i];
}
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client->default_device_ordinal());
build_options.set_device_ordinal(device_ref.device_ordinal());
build_options.set_num_replicas(num_replicas);
build_options.set_result_layout(xla::Shape(config.program_shape().result()));
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
@ -206,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
// Process-wide cache of XLA executables.
auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0);
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
ctx, /*max_number_of_entries=*/0);
OP_REQUIRES_OK(ctx, cache_or.status());
auto cache = cache_or.ConsumeValueOrDie();
@ -259,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
ResourceMgr* rm;
OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
// Process-wide cache of XLA executables.
XRTCompilationCache* cache;
OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
rm->default_container(),
kXRTCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
ctx, /*max_number_of_entries=*/0);
OP_REQUIRES_OK(ctx, cache_or.status());
auto cache = cache_or.ConsumeValueOrDie();
const Tensor& keys_tensor = ctx->input(0);
auto flat_keys = keys_tensor.flat<int64>();

View File

@ -149,13 +149,17 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
VLOG(2) << "Executing computation.";
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(device_ref->backend()->memory_allocator());
run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
run_options.set_rng_seed(rng_seed);
if (config.run_id() != 0) {
run_options.set_run_id(xla::RunId(config.run_id()));
}
if (executable->executable()
->module_config()
.has_static_device_assignment()) {
@ -164,8 +168,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
}
xla::GpuExecutableRunOptions gpu_options;
std::vector<xla::GlobalDeviceId> gpu_global_ids;
if (replica_id >= 0) {
gpu_global_ids.emplace_back(replica_id);
if (config.local_replica_mapping_size() > 0) {
gpu_global_ids.reserve(config.local_replica_mapping_size());
for (auto& gid : config.local_replica_mapping()) {
gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid));
}
gpu_options.set_gpu_global_device_ids(gpu_global_ids);
}
std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
@ -222,10 +229,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
OpKernelContext* context, XRTMemoryManager* memory_manager,
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
auto runfn = [&]() {
return RunExecutable(context, device_ref, executable, input_buffers, stream,
rng_seed, replica_id);
rng_seed, config);
};
// We pass zero as requested_free_size as there is no simple way to get the
@ -241,14 +249,15 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable,
const std::vector<InputCoords>& input_coords, bool release_inputs,
se::Stream* stream, int rng_seed, int replica_id) {
se::Stream* stream, int rng_seed,
const xrt::CommonExecutionConfig& config) {
XRTMemoryManager::WorkingSet working_set(memory_manager);
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
GetInputBuffers(&working_set, device_ref->backend(),
input_coords, release_inputs));
return ExecuteComputation(context, memory_manager.get(), device_ref,
executable, input_buffers, stream, rng_seed,
replica_id);
config);
}
// XRTExecuteOp
@ -297,8 +306,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
bool release_inputs = config_proto.release_input_handles();
bool release_compilation = config_proto.release_compilation_handle();
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
TF_ASSIGN_OR_RETURN(auto cache,
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
context, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;
@ -330,7 +340,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
RefPtr<XRTTupleAllocation> output_tuple,
ExecuteComputation(context, memory_manager, &device_ref, executable,
input_coords, release_inputs, stream, rng_seed,
config_proto.replica_id()));
config_proto.common_config()));
return CreateExecuteOutput(context, memory_manager.get(),
std::move(output_tuple),
@ -379,8 +389,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
xrt::XRTChainedExecuteConfig config;
TF_RET_CHECK(ParseFromTString(execution_config.scalar<tstring>()(), &config));
TF_ASSIGN_OR_RETURN(
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
TF_ASSIGN_OR_RETURN(auto cache,
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
context, /*max_number_of_entries=*/0));
// We are guaranteed that the underlying device object won't be deleted out
// from under us, while the ScopedRef is live.
class XRTGenericDeviceAccessor::ScopedRef device_ref;
@ -408,7 +419,7 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
return ExecuteComputation(context, memory_manager.get(), &device_ref,
executable, input_buffers, stream, rng_seed,
config.replica_id());
config.common_config());
};
return ExecuteChained(context, memory_manager, device_ref.backend(),

View File

@ -111,6 +111,17 @@ message XLATupleNode {
repeated XLATupleNode tuples = 3;
}
message CommonExecutionConfig {
// The replica index this execute is driving.
int32 replica_id = 1;
// Mapping local device ordinals to global replica IDs.
// local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID
repeated int32 local_replica_mapping = 2;
// The execution run ID used to correlate different XRT execute operations
// happeining in parallel from different threads.
int64 run_id = 3;
}
// Options for an XLA execution.
message XRTExecutionConfig {
// Local device to run on. This is present because the execute Op
@ -133,8 +144,9 @@ message XRTExecutionConfig {
// a single tuple allocation the execution will return a vector of
// allocations, one for each of the first-level elements of the result tuple.
bool return_exploded_tuple = 7;
// The replica index this execute is driving.
int32 replica_id = 8;
reserved 8;
// The common configuration for XRT execute operations.
CommonExecutionConfig common_config = 9;
}
message XRTChainedExecuteConfig {
@ -145,8 +157,9 @@ message XRTChainedExecuteConfig {
// Optional key to disambiguate between executions. This is only needed if
// multiple host send/recvs may be outstanding concurrently with executions.
string execution_instance_key = 3;
// The replica index this execute is driving.
int32 replica_id = 4;
reserved 4;
// The common configuration for XRT execute operations.
CommonExecutionConfig common_config = 5;
}
// A single chained execute operation. An operation can either be a device data

View File

@ -17,19 +17,56 @@ limitations under the License.
#include "tensorflow/compiler/xrt/xrt_device.h"
#include <map>
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace {
class ResourceMgrArena {
public:
static ResourceMgrArena* Get() {
static ResourceMgrArena* arena = new ResourceMgrArena();
return arena;
}
ResourceMgr* GetResourceMgr(const std::string& platform_name) {
mutex_lock lock(mutex_);
auto it = resource_managers_.find(platform_name);
if (it == resource_managers_.end()) {
it = resource_managers_.emplace(platform_name, new ResourceMgr()).first;
}
return it->second;
}
private:
mutex mutex_;
std::map<std::string, ResourceMgr*> resource_managers_;
};
} // namespace
/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
OpKernelContext* ctx, ResourceMgr** rm) {
*rm = ctx->resource_manager();
const XlaDevice::Metadata* metadata;
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
*rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name());
return Status::OK();
}
/* static */ xla::StatusOr<RefPtr<XRTCompilationCache>>
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
OpKernelContext* ctx, int64 max_number_of_entries) {
ResourceMgr* rm;
TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm));
return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries);
}
/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
const XlaDevice::Metadata* metadata;

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor {
public:
static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
static xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
OpKernelContext* ctx, int64 max_number_of_entries);
// We use a ScopedRef pattern here even though it's not strictly necessary,
// just so that templated uses of this and the TPU accessor class will be as
// similar as possible.