Enable XRT cache to be shared among multiple GPU devices.
Allow XRT GPU work with multi-threaded based replication, where a single process see all the available devices. PiperOrigin-RevId: 310376508 Change-Id: I25715feaf74ceca421ba8939405f58a0bf68ee59
This commit is contained in:
parent
1d4b4a6706
commit
70e9708e23
|
@ -50,6 +50,7 @@ class RunId {
|
|||
public:
|
||||
// Creates a new, unique RunId.
|
||||
RunId();
|
||||
explicit RunId(int64 value) : data_(value) {}
|
||||
|
||||
RunId(const RunId&) = default;
|
||||
RunId& operator=(const RunId&) = default;
|
||||
|
|
|
@ -158,7 +158,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
|
|||
argument_layout_ptrs[i] = &argument_layouts[i];
|
||||
}
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(client->default_device_ordinal());
|
||||
build_options.set_device_ordinal(device_ref.device_ordinal());
|
||||
build_options.set_num_replicas(num_replicas);
|
||||
build_options.set_result_layout(xla::Shape(config.program_shape().result()));
|
||||
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
|
||||
|
@ -206,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
|
|||
OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
|
||||
|
||||
// Process-wide cache of XLA executables.
|
||||
auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0);
|
||||
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
|
||||
ctx, /*max_number_of_entries=*/0);
|
||||
OP_REQUIRES_OK(ctx, cache_or.status());
|
||||
auto cache = cache_or.ConsumeValueOrDie();
|
||||
|
||||
|
@ -259,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
|
|||
VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
|
||||
auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
|
||||
|
||||
ResourceMgr* rm;
|
||||
OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
|
||||
|
||||
// Process-wide cache of XLA executables.
|
||||
XRTCompilationCache* cache;
|
||||
OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
|
||||
rm->default_container(),
|
||||
kXRTCompilationCacheResourceName, &cache));
|
||||
core::ScopedUnref cache_unref(cache);
|
||||
auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
|
||||
ctx, /*max_number_of_entries=*/0);
|
||||
OP_REQUIRES_OK(ctx, cache_or.status());
|
||||
auto cache = cache_or.ConsumeValueOrDie();
|
||||
|
||||
const Tensor& keys_tensor = ctx->input(0);
|
||||
auto flat_keys = keys_tensor.flat<int64>();
|
||||
|
|
|
@ -149,13 +149,17 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
|
|||
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
|
||||
se::Stream* stream, int rng_seed, int replica_id) {
|
||||
se::Stream* stream, int rng_seed,
|
||||
const xrt::CommonExecutionConfig& config) {
|
||||
VLOG(2) << "Executing computation.";
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(device_ref->backend()->memory_allocator());
|
||||
run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
|
||||
run_options.set_rng_seed(rng_seed);
|
||||
if (config.run_id() != 0) {
|
||||
run_options.set_run_id(xla::RunId(config.run_id()));
|
||||
}
|
||||
if (executable->executable()
|
||||
->module_config()
|
||||
.has_static_device_assignment()) {
|
||||
|
@ -164,8 +168,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
|||
}
|
||||
xla::GpuExecutableRunOptions gpu_options;
|
||||
std::vector<xla::GlobalDeviceId> gpu_global_ids;
|
||||
if (replica_id >= 0) {
|
||||
gpu_global_ids.emplace_back(replica_id);
|
||||
if (config.local_replica_mapping_size() > 0) {
|
||||
gpu_global_ids.reserve(config.local_replica_mapping_size());
|
||||
for (auto& gid : config.local_replica_mapping()) {
|
||||
gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid));
|
||||
}
|
||||
gpu_options.set_gpu_global_device_ids(gpu_global_ids);
|
||||
}
|
||||
std::shared_ptr<NcclUniqueIdFactory> nccl_factory = GetNcclUniqueIdFactory();
|
||||
|
@ -222,10 +229,11 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
|
|||
OpKernelContext* context, XRTMemoryManager* memory_manager,
|
||||
XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
|
||||
se::Stream* stream, int rng_seed, int replica_id) {
|
||||
se::Stream* stream, int rng_seed,
|
||||
const xrt::CommonExecutionConfig& config) {
|
||||
auto runfn = [&]() {
|
||||
return RunExecutable(context, device_ref, executable, input_buffers, stream,
|
||||
rng_seed, replica_id);
|
||||
rng_seed, config);
|
||||
};
|
||||
|
||||
// We pass zero as requested_free_size as there is no simple way to get the
|
||||
|
@ -241,14 +249,15 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> ExecuteComputation(
|
|||
XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||
xla::LocalExecutable* executable,
|
||||
const std::vector<InputCoords>& input_coords, bool release_inputs,
|
||||
se::Stream* stream, int rng_seed, int replica_id) {
|
||||
se::Stream* stream, int rng_seed,
|
||||
const xrt::CommonExecutionConfig& config) {
|
||||
XRTMemoryManager::WorkingSet working_set(memory_manager);
|
||||
TF_ASSIGN_OR_RETURN(InputBuffers input_buffers,
|
||||
GetInputBuffers(&working_set, device_ref->backend(),
|
||||
input_coords, release_inputs));
|
||||
return ExecuteComputation(context, memory_manager.get(), device_ref,
|
||||
executable, input_buffers, stream, rng_seed,
|
||||
replica_id);
|
||||
config);
|
||||
}
|
||||
|
||||
// XRTExecuteOp
|
||||
|
@ -297,8 +306,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
|
|||
bool release_inputs = config_proto.release_input_handles();
|
||||
bool release_compilation = config_proto.release_compilation_handle();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
|
||||
TF_ASSIGN_OR_RETURN(auto cache,
|
||||
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
|
||||
context, /*max_number_of_entries=*/0));
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
class XRTGenericDeviceAccessor::ScopedRef device_ref;
|
||||
|
@ -330,7 +340,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) {
|
|||
RefPtr<XRTTupleAllocation> output_tuple,
|
||||
ExecuteComputation(context, memory_manager, &device_ref, executable,
|
||||
input_coords, release_inputs, stream, rng_seed,
|
||||
config_proto.replica_id()));
|
||||
config_proto.common_config()));
|
||||
|
||||
return CreateExecuteOutput(context, memory_manager.get(),
|
||||
std::move(output_tuple),
|
||||
|
@ -379,8 +389,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
|
|||
xrt::XRTChainedExecuteConfig config;
|
||||
TF_RET_CHECK(ParseFromTString(execution_config.scalar<tstring>()(), &config));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0));
|
||||
TF_ASSIGN_OR_RETURN(auto cache,
|
||||
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
|
||||
context, /*max_number_of_entries=*/0));
|
||||
// We are guaranteed that the underlying device object won't be deleted out
|
||||
// from under us, while the ScopedRef is live.
|
||||
class XRTGenericDeviceAccessor::ScopedRef device_ref;
|
||||
|
@ -408,7 +419,7 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) {
|
|||
|
||||
return ExecuteComputation(context, memory_manager.get(), &device_ref,
|
||||
executable, input_buffers, stream, rng_seed,
|
||||
config.replica_id());
|
||||
config.common_config());
|
||||
};
|
||||
|
||||
return ExecuteChained(context, memory_manager, device_ref.backend(),
|
||||
|
|
|
@ -111,6 +111,17 @@ message XLATupleNode {
|
|||
repeated XLATupleNode tuples = 3;
|
||||
}
|
||||
|
||||
message CommonExecutionConfig {
|
||||
// The replica index this execute is driving.
|
||||
int32 replica_id = 1;
|
||||
// Mapping local device ordinals to global replica IDs.
|
||||
// local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID
|
||||
repeated int32 local_replica_mapping = 2;
|
||||
// The execution run ID used to correlate different XRT execute operations
|
||||
// happeining in parallel from different threads.
|
||||
int64 run_id = 3;
|
||||
}
|
||||
|
||||
// Options for an XLA execution.
|
||||
message XRTExecutionConfig {
|
||||
// Local device to run on. This is present because the execute Op
|
||||
|
@ -133,8 +144,9 @@ message XRTExecutionConfig {
|
|||
// a single tuple allocation the execution will return a vector of
|
||||
// allocations, one for each of the first-level elements of the result tuple.
|
||||
bool return_exploded_tuple = 7;
|
||||
// The replica index this execute is driving.
|
||||
int32 replica_id = 8;
|
||||
reserved 8;
|
||||
// The common configuration for XRT execute operations.
|
||||
CommonExecutionConfig common_config = 9;
|
||||
}
|
||||
|
||||
message XRTChainedExecuteConfig {
|
||||
|
@ -145,8 +157,9 @@ message XRTChainedExecuteConfig {
|
|||
// Optional key to disambiguate between executions. This is only needed if
|
||||
// multiple host send/recvs may be outstanding concurrently with executions.
|
||||
string execution_instance_key = 3;
|
||||
// The replica index this execute is driving.
|
||||
int32 replica_id = 4;
|
||||
reserved 4;
|
||||
// The common configuration for XRT execute operations.
|
||||
CommonExecutionConfig common_config = 5;
|
||||
}
|
||||
|
||||
// A single chained execute operation. An operation can either be a device data
|
||||
|
|
|
@ -17,19 +17,56 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xrt/xrt_device.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ResourceMgrArena {
|
||||
public:
|
||||
static ResourceMgrArena* Get() {
|
||||
static ResourceMgrArena* arena = new ResourceMgrArena();
|
||||
return arena;
|
||||
}
|
||||
|
||||
ResourceMgr* GetResourceMgr(const std::string& platform_name) {
|
||||
mutex_lock lock(mutex_);
|
||||
auto it = resource_managers_.find(platform_name);
|
||||
if (it == resource_managers_.end()) {
|
||||
it = resource_managers_.emplace(platform_name, new ResourceMgr()).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mutex_;
|
||||
std::map<std::string, ResourceMgr*> resource_managers_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
|
||||
OpKernelContext* ctx, ResourceMgr** rm) {
|
||||
*rm = ctx->resource_manager();
|
||||
const XlaDevice::Metadata* metadata;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
|
||||
*rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ xla::StatusOr<RefPtr<XRTCompilationCache>>
|
||||
XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
|
||||
OpKernelContext* ctx, int64 max_number_of_entries) {
|
||||
ResourceMgr* rm;
|
||||
TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm));
|
||||
return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries);
|
||||
}
|
||||
|
||||
/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
|
||||
OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
|
||||
const XlaDevice::Metadata* metadata;
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
|
@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor {
|
|||
public:
|
||||
static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
|
||||
|
||||
static xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
|
||||
OpKernelContext* ctx, int64 max_number_of_entries);
|
||||
|
||||
// We use a ScopedRef pattern here even though it's not strictly necessary,
|
||||
// just so that templated uses of this and the TPU accessor class will be as
|
||||
// similar as possible.
|
||||
|
|
Loading…
Reference in New Issue