diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 67f3a6c1df4..33d5b6f1d4d 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { return execution_profile_; } +ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( + DeviceAssignment* device_assignment) { + device_assignment_ = device_assignment; + return *this; +} + +DeviceAssignment* ExecutableRunOptions::device_assignment() const { + return device_assignment_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 03f2d016ad0..deb3ddb203d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -40,6 +40,7 @@ struct ThreadPoolDevice; namespace xla { class DeviceMemoryAllocator; +class DeviceAssignment; class ExecutionProfile; // Class containing options for running a LocalExecutable. @@ -79,9 +80,14 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile() const; ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); + ExecutableRunOptions& set_device_assignment( + DeviceAssignment* device_assignment); + DeviceAssignment* device_assignment() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; + DeviceAssignment* device_assignment_ = nullptr; perftools::gputools::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index acaf9eaafb2..9a041e3f412 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -332,6 +332,7 @@ cc_library( hdrs = ["backend.h"], deps = [ ":compiler", + ":computation_placer", ":device_memory_allocator", ":platform_util", ":pool", @@ -950,6 +951,26 @@ cc_test( ], ) +cc_library( + name = "computation_placer", + srcs = ["computation_placer.cc"], + hdrs = ["computation_placer.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform computation placer registration +) + cc_library( name = "generic_transfer_manager", srcs = ["generic_transfer_manager.cc"], diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 66d54ad3802..83f954789c9 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -96,9 +96,11 @@ struct Backend::EigenThreadPoolWrapper { PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); - std::unique_ptr<Backend> backend( - new Backend(replica_count, platform, compiler, stream_executors, - transfer_manager, options.intra_op_parallelism_threads())); + TF_ASSIGN_OR_RETURN(auto computation_placer, + ComputationPlacer::GetForPlatform(platform)); + std::unique_ptr<Backend> backend(new Backend( + replica_count, platform, compiler, stream_executors, transfer_manager, + computation_placer, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -135,10 +137,12 @@ Backend::Backend( int64 replica_count, perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads) + TransferManager* transfer_manager, ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), + computation_placer_(computation_placer), replica_count_(replica_count) { // The given set of stream executors set may include invalid executors. for (se::StreamExecutor* exec : stream_executors) { @@ -179,36 +183,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas( - int device_ordinal) const { - if (stream_executors_[device_ordinal] == nullptr) { - return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); - } - - // Find replica_count_ stream executors starting from the given device - // ordinal. - std::vector<perftools::gputools::StreamExecutor*> replicas; - for (se::StreamExecutor* exec : stream_executors_) { - CHECK(exec != nullptr); - if (exec->device_ordinal() >= device_ordinal) { - replicas.push_back(exec); - if (replicas.size() >= replica_count_) { - return replicas; - } - } - } - - return InvalidArgument( - "Not enough devices for replicas for the device ordinal %d", - device_ordinal); -} - -std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const { - CHECK_GE(stream_executors_.size(), replica_count_); - return Replicas(default_device_ordinal()).ValueOrDie(); -} - tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { return inter_op_thread_pool_.get(); } diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e0b15dc43f2..3ead274cde8 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -40,6 +41,8 @@ struct ThreadPoolDevice; namespace xla { // Options to configure the backend when it is created. +// +// TODO(b/62588571): Remove the notion of replicas from the backend. class BackendOptions { public: // Set the platform backing the backend, or nullptr for the default platform. @@ -92,11 +95,15 @@ class Backend { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } + ComputationPlacer* computation_placer() const { return computation_placer_; } // Returns the number of devices of the platform type which are visible. Not // all of these devices may be usable by XLA. int device_count() const { return stream_executors_.size(); } + // Returns the number of replicas when replication is enabled. + int replica_count() const { return replica_count_; } + // Returns the device ordinal number of the default device. int default_device_ordinal() const; @@ -107,24 +114,13 @@ class Backend { return stream_executors_; } - // Returns the replicas for the default stream executor. - // - // When the number of replicas is R, the first R stream executors are assigned - // to the replicas of the default stream executor. - std::vector<perftools::gputools::StreamExecutor*> Replicas() const; - - // Returns the replicas for the given device_ordinal. The given device ordinal - // is considered to be the first device ordinal among the replicas. Returns an - // error status if the stream executor for the given given device ordinal does - // not exist or if there are not enough stream executors for the replicas. - StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas( - int device_ordinal) const; - - // Return the stream executor for the given device ordinal. + // Returns the stream executor for the given device ordinal. StatusOr<perftools::gputools::StreamExecutor*> stream_executor( int device_ordinal) const; - // Return the stream executor for the default device ordinal. + // Returns the stream executor for the default device ordinal. This stream + // executor can only be used when the number of computations is 1 (replication + // can be > 1). perftools::gputools::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; @@ -178,13 +174,16 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads); + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; perftools::gputools::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; + ComputationPlacer* computation_placer_; int64 replica_count_ = -1; // Vector of stream executors. stream_executors_[0] is the default executor. diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc new file mode 100644 index 00000000000..cdf277581f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_placer.h" + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { + proto->set_replica_count(replica_count()); + proto->set_computation_count(computation_count()); + for (int computation = 0; computation < computation_count(); ++computation) { + DeviceAssignmentProto::ComputationDevice* computation_device = + proto->add_computation_devices(); + for (int replica = 0; replica < replica_count(); ++replica) { + computation_device->add_replica_device_ids((*this)(replica, computation)); + } + } + return Status::OK(); +} + +/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize( + const DeviceAssignmentProto& proto) { + TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + DeviceAssignment assignment(proto.replica_count(), proto.computation_count()); + for (int computation = 0; computation < proto.computation_count(); + ++computation) { + const auto& computation_device = proto.computation_devices(computation); + TF_RET_CHECK(computation_device.replica_device_ids_size() == + proto.replica_count()); + for (int replica = 0; replica < proto.replica_count(); ++replica) { + assignment(replica, computation) = + computation_device.replica_device_ids(replica); + } + } + return std::move(assignment); +} + +StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { + TF_RET_CHECK(replica < replica_count); + TF_RET_CHECK(computation < computation_count); + + return computation * replica_count + replica; +} + +StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices( + int replica_count, int computation_count) { + DeviceAssignment assignment(replica_count, computation_count); + for (int replica = 0; replica < replica_count; ++replica) { + for (int computation = 0; computation < computation_count; ++computation) { + TF_ASSIGN_OR_RETURN( + int device_id, + DeviceId(replica, computation, replica_count, computation_count)); + assignment(replica, computation) = device_id; + } + } + return std::move(assignment); +} + +/* static */ void ComputationPlacer::RegisterComputationPlacer( + se::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + CHECK(computation_placers->find(platform_id) == computation_placers->end()); + (*computation_placers)[platform_id].creation_function = creation_function; +} + +/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform( + const se::Platform* platform) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + + auto it = computation_placers->find(platform->id()); + if (it == computation_placers->end()) { + return NotFound( + "could not find registered computation placer for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + if (it->second.placer == nullptr) { + // Lazily create the computation placer the first time it is needed. + it->second.placer = (*it->second.creation_function)(); + } + + return it->second.placer.get(); +} + +/* static */ tensorflow::mutex* +ComputationPlacer::platform_computation_placer_mutex() { + static tensorflow::mutex* m = new tensorflow::mutex; + return m; +} + +/* static */ std::map<perftools::gputools::Platform::Id, + ComputationPlacer::State>* +ComputationPlacer::GetPlatformComputationPlacers() { + static auto* r = + new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>; + return r; +} + +} // namespace xla + +static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() { + return xla::MakeUnique<xla::ComputationPlacer>(); +} + +static bool InitModule() { + xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, + &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, + &CreateComputationPlacer); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h new file mode 100644 index 00000000000..4d26d6bb85f --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ + +#include <map> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class that represents the device assignment for a set of XLA replicated +// computations. For R replicas and C computations, R * C devices are required +// execute the computation in parallel. The assigned device ids can be accessed +// by assignment(replica, computation). +class DeviceAssignment : public Array2D<int> { + public: + DeviceAssignment() {} + DeviceAssignment(int replica_count, int computation_count) + : Array2D<int>(replica_count, computation_count, -1) { + CHECK_GT(replica_count, 0); + CHECK_GT(computation_count, 0); + } + + int replica_count() const { return height(); } + int computation_count() const { return width(); } + + // Protocol buffer serialization and deserialization. + Status Serialize(DeviceAssignmentProto* proto) const; + static StatusOr<DeviceAssignment> Deserialize( + const DeviceAssignmentProto& proto); +}; + +// A generic implementation of the XLA computation placer, which assigns device +// ids to a set of replicated computations. +class ComputationPlacer { + public: + ComputationPlacer() {} + virtual ~ComputationPlacer() {} + + // Returns the device id assigned to the given replica and computation + // instance for [replica_count x computation_count] setup. The returned device + // id must match the assignement from PlaceReplicatedComputation(). + virtual StatusOr<int> DeviceId(int replica, int computation, + int replica_count, int computation_count); + + // Returns the device ids assigned to a set of replicated computations, given + // the number of replicas and the number of computations. + virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count, + int computation_count); + + using ComputationPlacerCreationFunction = + std::unique_ptr<ComputationPlacer> (*)(); + + // Registers a computation placer creation function for a particular platform. + static void RegisterComputationPlacer( + perftools::gputools::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function); + + // Returns the computation placer singleton pointer if it is available for the + // given platform, or an error status if it is not. + static StatusOr<ComputationPlacer*> GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Routine that returns the mutex that guards the platform-to-computation + // placer map. Done as a routine to ensure correct initialization ordering, + // since RegisterComputationPlacer can be called during program initialization + // time. + static tensorflow::mutex* platform_computation_placer_mutex(); + + // State kept for each kind of ComputationPlacer. Registration functions set + // up creation_function, and then we use that to lazily create "placer" the + // first time GetForPlatform is invoked for a particular id. + struct State { + std::unique_ptr<ComputationPlacer> placer; + ComputationPlacerCreationFunction creation_function = nullptr; + }; + + // Map from platform kind to computation placer singleton. + static std::map<perftools::gputools::Platform::Id, State>* + GetPlatformComputationPlacers(); + + perftools::gputools::Platform::Id platform_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 131c2ee87b0..748124ebc70 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( // Construct computation layout from the argument layouts. auto module_config = MakeUnique<HloModuleConfig>(*program_shape); module_config->set_has_hybrid_result(has_hybrid_result); - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(execute_backend_->replica_count()); legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); if (flags->xla_hlo_profile) { module_config->enable_hlo_profiling(true); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 79bf679c5fd..5812d3e4874 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -325,7 +325,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(backend->Replicas().size()); + module_config->set_replica_count(backend->replica_count()); module_config->set_seed(execution_options.seed()); module_config->set_debug_options(execution_options.debug_options()); @@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice< std::vector<perftools::gputools::DeviceMemoryBase>> arguments, - Backend* backend, - tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors, + Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles, tensorflow::gtl::ArraySlice<string> result_tags) { - // TODO(b/33943292): Support for replication when using multiple computations. - TF_RET_CHECK(backend->Replicas().size() == 1); - - // Set up streams. + // Streams where the computation are launched, so we can wait on the streams + // to complete. std::vector<Pool<se::Stream>::SmartPtr> streams; - for (se::StreamExecutor* executor : executors) { - TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, - backend->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - // Set up run options. - std::vector<ServiceExecutableRunOptions> run_options; - for (const Pool<se::Stream>::SmartPtr& stream : streams) { - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); - } - - // Asynchronously launch all executables. + // Global data handles for the computation results, one for each computation. std::vector<GlobalDataHandle> result_handles; - for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0; - i < executables.size(); i++) { - TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, - executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); - result_handles.push_back(allocation_tracker_.Register( - backend, executors[i]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + backend->replica_count(), executables.size())); + + for (int64 i = 0; i < executables.size(); i++) { + // Stream executors for the replicas of the current computation. + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, + backend->BorrowStream(replicas[replica])); + streams.push_back(std::move(stream)); + + // Set up run options. + ExecutableRunOptions options; + options.set_stream(streams.back().get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + ServiceExecutableRunOptions run_options(options, + backend->StreamBorrower()); + + // Asynchronously launch the computation. + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase result, + executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + + // All replicas share the same device address for the result allocation, + // so only one of the replicas need to register the result handle. + if (replica == 0) { + result_handles.push_back(allocation_tracker_.Register( + backend, replicas[0]->device_ordinal(), result, + executables[i]->result_shape(), result_tags[i])); + } + } } // Wait for all executions to complete. - for (int64 i = 0; i < result_handles.size(); ++i) { + for (int64 i = 0; i < streams.size(); ++i) { if (!streams[i]->BlockHostUntilDone()) { return InternalError("failed to complete execution for stream %lld", i); } @@ -550,17 +558,22 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult( arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { - TF_RET_CHECK(!backend->Replicas().empty()); - // Set up streams. std::vector<Pool<se::Stream>::SmartPtr> streams; - for (se::StreamExecutor* executor : backend->Replicas()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*backend, SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + backend->replica_count(), /*computation_count=*/1)); + // Set up run options. std::vector<ServiceExecutableRunOptions> run_options; for (const Pool<se::Stream>::SmartPtr& stream : streams) { @@ -570,19 +583,20 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); run_options.emplace_back(options, backend->StreamBorrower(), backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; - if (backend->Replicas().size() == 1) { + if (backend->replica_count() == 1) { TF_ASSIGN_OR_RETURN( result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>( &run_options[0], profile, arguments)); } else { std::vector< tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>> - repeated_arguments(backend->Replicas().size(), arguments); + repeated_arguments(backend->replica_count(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); @@ -610,25 +624,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::vector<VersionedComputationHandle> versioned_handles; std::vector<std::unique_ptr<HloModuleConfig>> module_configs; std::vector<string> computation_names; + std::vector<DeviceHandle> device_handles; - if (arg->requests_size() > execute_backend_->stream_executors().size()) { + if (arg->requests_size() * execute_backend_->replica_count() > + execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", arg->requests_size()); } for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor on which the computation will run. Select the - // specific device if requested, otherwise select the i'th device from the - // list of available stream executors. - se::StreamExecutor* executor; - if (arg->requests(i).has_device_handle()) { - executor = - execute_backend_ - ->stream_executors()[arg->requests(i).device_handle().handle()]; - } else { - executor = execute_backend_->stream_executors()[i]; + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + if (!arg->requests(i).has_device_handle()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); } + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, arg->requests(i).device_handle())); + se::StreamExecutor* executor = replicas[0]; CHECK(executor != nullptr); // Resolve the UserComputation object associated with the requested @@ -673,6 +688,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, module_configs.push_back(std::move(module_config)); computation_names.push_back(user_computation->name()); executors.push_back(executor); + device_handles.push_back(arg->requests(i).device_handle()); } // Build the user computations into HloModules and compile to generate the @@ -692,7 +708,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector<GlobalDataHandle> outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), executors, + execute_backend_.get(), device_handles, computation_names)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; @@ -706,10 +722,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { - const int64 available_device_count = - execute_backend_->stream_executors().size(); - const int64 replicas = execute_backend_->Replicas().size(); - if (available_device_count < arg->device_count() * replicas) { + const int64 available_device_count = execute_backend_->device_count(); + const int64 replica_count = execute_backend_->replica_count(); + if (replica_count <= 0) { + return FailedPrecondition("Replica count must be a positive integer"); + } + if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( "Requested device count (%lld) exceeds the number of available devices " "on the target (%lld)", @@ -718,8 +736,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, for (int64 i = 0; i < arg->device_count(); ++i) { DeviceHandle device_handle; - device_handle.set_handle( - execute_backend_->stream_executors()[i * replicas]->device_ordinal()); + device_handle.set_handle(i); + device_handle.set_device_count(arg->device_count()); *result->add_device_handles() = device_handle; } @@ -841,11 +859,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, execute_backend_->default_stream_executor(), &profile)); - TF_RET_CHECK(!execute_backend_->Replicas().empty()); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + // Set up streams. std::vector<Pool<se::Stream>::SmartPtr> streams; - for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); @@ -927,19 +948,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) { // TODO(b/32990684): Tuple transfers to host end up allocating further // buffers - implement that correctly. return Unimplemented( "Tuple transfers to the device not supported with replication."); } - se::StreamExecutor* stream_executor; + std::vector<se::StreamExecutor*> replicas; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(replicas, + Replicas(*execute_backend_, arg->device_handle())); } else { - stream_executor = execute_backend_->default_stream_executor(); + TF_ASSIGN_OR_RETURN( + replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } // Allocate memory on the device, using the stream executor. The size of the @@ -950,14 +972,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, execute_backend_->memory_allocator()->Allocate( - stream_executor->device_ordinal(), allocation_size)); + replicas[0]->device_ordinal(), allocation_size)); *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, + StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -968,7 +988,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = execute_backend_->replica_count(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "%s", @@ -980,11 +1000,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( @@ -994,7 +1017,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferFromOutfeed( const TransferFromOutfeedRequest* arg, TransferFromOutfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = execute_backend_->replica_count(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " @@ -1004,11 +1027,14 @@ tensorflow::Status Service::TransferFromOutfeed( se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } Literal literal; @@ -1387,4 +1413,28 @@ tensorflow::Status Service::LoadComputationSnapshot( return tensorflow::Status::OK(); } +DeviceHandle Service::SingleComputationDeviceHandle() const { + DeviceHandle device_handle; + device_handle.set_handle(0); + device_handle.set_device_count(1); + return device_handle; +} + +StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas( + const Backend& backend, const DeviceHandle& device_handle) const { + std::vector<perftools::gputools::StreamExecutor*> replicas; + for (int replica = 0; replica < backend.replica_count(); ++replica) { + // From the computation placer, find out the device ids of the replicas for + // the given device handle. + TF_ASSIGN_OR_RETURN( + int device_ordinal, + backend.computation_placer()->DeviceId(replica, device_handle.handle(), + backend.replica_count(), + device_handle.device_count())); + TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); + replicas.push_back(executor); + } + return replicas; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index abd1281bdd0..81aa0d2e8e2 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -319,8 +319,7 @@ class Service : public ServiceInterface { std::vector<perftools::gputools::DeviceMemoryBase>> arguments, Backend* backend, - tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> - executors, + tensorflow::gtl::ArraySlice<DeviceHandle> device_handles, tensorflow::gtl::ArraySlice<string> result_tags); // Returns an HLO dumper for use in the compiler (it refers to flags @@ -346,6 +345,16 @@ class Service : public ServiceInterface { tensorflow::Status ValidateResultShapeWithLayout( const Shape& shape_with_layout, const Shape& result_shape) const; + // Returns the stream executors assigned to the replicas represented by the + // given device handle. Each device_handle is a virtual replicated device that + // represents a set of physical devices for the replicas. + StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas( + const Backend& backend, const DeviceHandle& device_handle) const; + + // Returns the device handle that represents the replicated device for a + // single computation that is not model-parallelized. + DeviceHandle SingleComputationDeviceHandle() const; + // Tracks computations built via the API. ComputationTracker computation_tracker_; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5fa515b26fb..b81116c1001 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -99,6 +99,7 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -196,6 +197,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -746,6 +748,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 633d16c4c32..c8fd31d0ad1 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -255,11 +255,15 @@ message ComputationDataHandle { int64 handle = 1; } -// Handle given to a user that represents a device to execute a computation. -// When replication is enabled, the device handle represents the device for the -// replica id 0. +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. message DeviceHandle { int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; } // Handle given to a user to represent a channel between two computations @@ -269,6 +273,21 @@ message ChannelHandle { int64 handle = 1; } +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int32 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + // Literals are used when the server and client need to exchange materialized // data / results. Literals are also used to describe constants used in // computations.