diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 43ee0fdd820..8ae8c418d5d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -50,6 +50,7 @@ class RunId { public: // Creates a new, unique RunId. RunId(); + explicit RunId(int64 value) : data_(value) {} RunId(const RunId&) = default; RunId& operator=(const RunId&) = default; diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 1bcd8561e61..ba6e6a093d6 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -158,7 +158,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client->default_device_ordinal()); + build_options.set_device_ordinal(device_ref.device_ordinal()); build_options.set_num_replicas(num_replicas); build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); @@ -206,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); // Process-wide cache of XLA executables. - auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); OP_REQUIRES_OK(ctx, cache_or.status()); auto cache = cache_or.ConsumeValueOrDie(); @@ -259,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); - // Process-wide cache of XLA executables. - XRTCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); + OP_REQUIRES_OK(ctx, cache_or.status()); + auto cache = cache_or.ConsumeValueOrDie(); const Tensor& keys_tensor = ctx->input(0); auto flat_keys = keys_tensor.flat(); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index b641f333e8b..d39b37387f2 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -149,13 +149,17 @@ xla::StatusOr GetChainedOpInputs( xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(device_ref->backend()->memory_allocator()); run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); run_options.set_rng_seed(rng_seed); + if (config.run_id() != 0) { + run_options.set_run_id(xla::RunId(config.run_id())); + } if (executable->executable() ->module_config() .has_static_device_assignment()) { @@ -164,8 +168,11 @@ xla::StatusOr> RunExecutable( } xla::GpuExecutableRunOptions gpu_options; std::vector gpu_global_ids; - if (replica_id >= 0) { - gpu_global_ids.emplace_back(replica_id); + if (config.local_replica_mapping_size() > 0) { + gpu_global_ids.reserve(config.local_replica_mapping_size()); + for (auto& gid : config.local_replica_mapping()) { + gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid)); + } gpu_options.set_gpu_global_device_ids(gpu_global_ids); } std::shared_ptr nccl_factory = GetNcclUniqueIdFactory(); @@ -222,10 +229,11 @@ xla::StatusOr> ExecuteComputation( OpKernelContext* context, XRTMemoryManager* memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { auto runfn = [&]() { return RunExecutable(context, device_ref, executable, input_buffers, stream, - rng_seed, replica_id); + rng_seed, config); }; // We pass zero as requested_free_size as there is no simple way to get the @@ -241,14 +249,15 @@ xla::StatusOr> ExecuteComputation( XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const std::vector& input_coords, bool release_inputs, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { XRTMemoryManager::WorkingSet working_set(memory_manager); TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, GetInputBuffers(&working_set, device_ref->backend(), input_coords, release_inputs)); return ExecuteComputation(context, memory_manager.get(), device_ref, executable, input_buffers, stream, rng_seed, - replica_id); + config); } // XRTExecuteOp @@ -297,8 +306,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { bool release_inputs = config_proto.release_input_handles(); bool release_compilation = config_proto.release_compilation_handle(); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -330,7 +340,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { RefPtr output_tuple, ExecuteComputation(context, memory_manager, &device_ref, executable, input_coords, release_inputs, stream, rng_seed, - config_proto.replica_id())); + config_proto.common_config())); return CreateExecuteOutput(context, memory_manager.get(), std::move(output_tuple), @@ -379,8 +389,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xrt::XRTChainedExecuteConfig config; TF_RET_CHECK(ParseFromTString(execution_config.scalar()(), &config)); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -408,7 +419,7 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { return ExecuteComputation(context, memory_manager.get(), &device_ref, executable, input_buffers, stream, rng_seed, - config.replica_id()); + config.common_config()); }; return ExecuteChained(context, memory_manager, device_ref.backend(), diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 1cbd851f7ef..9a351732c4b 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -111,6 +111,17 @@ message XLATupleNode { repeated XLATupleNode tuples = 3; } +message CommonExecutionConfig { + // The replica index this execute is driving. + int32 replica_id = 1; + // Mapping local device ordinals to global replica IDs. + // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID + repeated int32 local_replica_mapping = 2; + // The execution run ID used to correlate different XRT execute operations + // happeining in parallel from different threads. + int64 run_id = 3; +} + // Options for an XLA execution. message XRTExecutionConfig { // Local device to run on. This is present because the execute Op @@ -133,8 +144,9 @@ message XRTExecutionConfig { // a single tuple allocation the execution will return a vector of // allocations, one for each of the first-level elements of the result tuple. bool return_exploded_tuple = 7; - // The replica index this execute is driving. - int32 replica_id = 8; + reserved 8; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 9; } message XRTChainedExecuteConfig { @@ -145,8 +157,9 @@ message XRTChainedExecuteConfig { // Optional key to disambiguate between executions. This is only needed if // multiple host send/recvs may be outstanding concurrently with executions. string execution_instance_key = 3; - // The replica index this execute is driving. - int32 replica_id = 4; + reserved 4; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 5; } // A single chained execute operation. An operation can either be a device data diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index 1b5557d556d..46954572c5d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -17,19 +17,56 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_device.h" +#include + #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { +namespace { + +class ResourceMgrArena { + public: + static ResourceMgrArena* Get() { + static ResourceMgrArena* arena = new ResourceMgrArena(); + return arena; + } + + ResourceMgr* GetResourceMgr(const std::string& platform_name) { + mutex_lock lock(mutex_); + auto it = resource_managers_.find(platform_name); + if (it == resource_managers_.end()) { + it = resource_managers_.emplace(platform_name, new ResourceMgr()).first; + } + return it->second; + } + + private: + mutex mutex_; + std::map resource_managers_; +}; + +} // namespace /*static*/ Status XRTGenericDeviceAccessor::GetResourceManager( OpKernelContext* ctx, ResourceMgr** rm) { - *rm = ctx->resource_manager(); + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + *rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name()); return Status::OK(); } +/* static */ xla::StatusOr> +XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries) { + ResourceMgr* rm; + TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm)); + return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries); +} + /*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) { const XlaDevice::Metadata* metadata; diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 5ebee7641f0..02fab315830 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor { public: static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); + static xla::StatusOr> GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries); + // We use a ScopedRef pattern here even though it's not strictly necessary, // just so that templated uses of this and the TPU accessor class will be as // similar as possible.