Add ComputationPlacer to assign device ids for replicated model-parallel computations.

PiperOrigin-RevId: 159056198
This commit is contained in:
HyoukJoong Lee 2017-06-14 18:46:32 -07:00 committed by TensorFlower Gardener
parent 0fa19543aa
commit 7d3497a639
12 changed files with 489 additions and 134 deletions

View File

@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
return execution_profile_; return execution_profile_;
} }
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
DeviceAssignment* device_assignment) {
device_assignment_ = device_assignment;
return *this;
}
DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}
} // namespace xla } // namespace xla

View File

@ -40,6 +40,7 @@ struct ThreadPoolDevice;
namespace xla { namespace xla {
class DeviceMemoryAllocator; class DeviceMemoryAllocator;
class DeviceAssignment;
class ExecutionProfile; class ExecutionProfile;
// Class containing options for running a LocalExecutable. // Class containing options for running a LocalExecutable.
@ -79,9 +80,14 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile() const; ExecutionProfile* execution_profile() const;
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
ExecutableRunOptions& set_device_assignment(
DeviceAssignment* device_assignment);
DeviceAssignment* device_assignment() const;
private: private:
DeviceMemoryAllocator* allocator_ = nullptr; DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1; int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr;
perftools::gputools::Stream* stream_ = nullptr; perftools::gputools::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;

View File

@ -332,6 +332,7 @@ cc_library(
hdrs = ["backend.h"], hdrs = ["backend.h"],
deps = [ deps = [
":compiler", ":compiler",
":computation_placer",
":device_memory_allocator", ":device_memory_allocator",
":platform_util", ":platform_util",
":pool", ":pool",
@ -950,6 +951,26 @@ cc_test(
], ],
) )
cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = True, # Contains per-platform computation placer registration
)
cc_library( cc_library(
name = "generic_transfer_manager", name = "generic_transfer_manager",
srcs = ["generic_transfer_manager.cc"], srcs = ["generic_transfer_manager.cc"],

View File

@ -96,9 +96,11 @@ struct Backend::EigenThreadPoolWrapper {
PlatformUtil::GetStreamExecutors(platform)); PlatformUtil::GetStreamExecutors(platform));
TF_ASSIGN_OR_RETURN(auto transfer_manager, TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform)); TransferManager::GetForPlatform(platform));
std::unique_ptr<Backend> backend( TF_ASSIGN_OR_RETURN(auto computation_placer,
new Backend(replica_count, platform, compiler, stream_executors, ComputationPlacer::GetForPlatform(platform));
transfer_manager, options.intra_op_parallelism_threads())); std::unique_ptr<Backend> backend(new Backend(
replica_count, platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend); return std::move(backend);
} }
@ -135,10 +137,12 @@ Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform, int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler, Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors, tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads) TransferManager* transfer_manager, ComputationPlacer* computation_placer,
int intra_op_parallelism_threads)
: platform_(platform), : platform_(platform),
compiler_(compiler), compiler_(compiler),
transfer_manager_(transfer_manager), transfer_manager_(transfer_manager),
computation_placer_(computation_placer),
replica_count_(replica_count) { replica_count_(replica_count) {
// The given set of stream executors set may include invalid executors. // The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) { for (se::StreamExecutor* exec : stream_executors) {
@ -179,36 +183,6 @@ int Backend::default_device_ordinal() const {
return default_stream_executor()->device_ordinal(); return default_stream_executor()->device_ordinal();
} }
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
int device_ordinal) const {
if (stream_executors_[device_ordinal] == nullptr) {
return InvalidArgument("device %s not supported by XLA service",
device_name(device_ordinal).c_str());
}
// Find replica_count_ stream executors starting from the given device
// ordinal.
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (se::StreamExecutor* exec : stream_executors_) {
CHECK(exec != nullptr);
if (exec->device_ordinal() >= device_ordinal) {
replicas.push_back(exec);
if (replicas.size() >= replica_count_) {
return replicas;
}
}
}
return InvalidArgument(
"Not enough devices for replicas for the device ordinal %d",
device_ordinal);
}
std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
CHECK_GE(stream_executors_.size(), replica_count_);
return Replicas(default_device_ordinal()).ValueOrDie();
}
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
return inter_op_thread_pool_.get(); return inter_op_thread_pool_.get();
} }

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -40,6 +41,8 @@ struct ThreadPoolDevice;
namespace xla { namespace xla {
// Options to configure the backend when it is created. // Options to configure the backend when it is created.
//
// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions { class BackendOptions {
public: public:
// Set the platform backing the backend, or nullptr for the default platform. // Set the platform backing the backend, or nullptr for the default platform.
@ -92,11 +95,15 @@ class Backend {
return memory_allocator_.get(); return memory_allocator_.get();
} }
TransferManager* transfer_manager() const { return transfer_manager_; } TransferManager* transfer_manager() const { return transfer_manager_; }
ComputationPlacer* computation_placer() const { return computation_placer_; }
// Returns the number of devices of the platform type which are visible. Not // Returns the number of devices of the platform type which are visible. Not
// all of these devices may be usable by XLA. // all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); } int device_count() const { return stream_executors_.size(); }
// Returns the number of replicas when replication is enabled.
int replica_count() const { return replica_count_; }
// Returns the device ordinal number of the default device. // Returns the device ordinal number of the default device.
int default_device_ordinal() const; int default_device_ordinal() const;
@ -107,24 +114,13 @@ class Backend {
return stream_executors_; return stream_executors_;
} }
// Returns the replicas for the default stream executor. // Returns the stream executor for the given device ordinal.
//
// When the number of replicas is R, the first R stream executors are assigned
// to the replicas of the default stream executor.
std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
// Returns the replicas for the given device_ordinal. The given device ordinal
// is considered to be the first device ordinal among the replicas. Returns an
// error status if the stream executor for the given given device ordinal does
// not exist or if there are not enough stream executors for the replicas.
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
int device_ordinal) const;
// Return the stream executor for the given device ordinal.
StatusOr<perftools::gputools::StreamExecutor*> stream_executor( StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
int device_ordinal) const; int device_ordinal) const;
// Return the stream executor for the default device ordinal. // Returns the stream executor for the default device ordinal. This stream
// executor can only be used when the number of computations is 1 (replication
// can be > 1).
perftools::gputools::StreamExecutor* default_stream_executor() const { perftools::gputools::StreamExecutor* default_stream_executor() const {
CHECK(!stream_executors_.empty()); CHECK(!stream_executors_.empty());
return stream_executors_[0]; return stream_executors_[0];
@ -178,13 +174,16 @@ class Backend {
Compiler* compiler, Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors, stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads); TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
Backend(const Backend&) = delete; Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete; Backend& operator=(const Backend&) = delete;
perftools::gputools::Platform* platform_; perftools::gputools::Platform* platform_;
Compiler* compiler_; Compiler* compiler_;
TransferManager* transfer_manager_; TransferManager* transfer_manager_;
ComputationPlacer* computation_placer_;
int64 replica_count_ = -1; int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor. // Vector of stream executors. stream_executors_[0] is the default executor.

View File

@ -0,0 +1,151 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace se = ::perftools::gputools;
namespace xla {
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
proto->set_replica_count(replica_count());
proto->set_computation_count(computation_count());
for (int computation = 0; computation < computation_count(); ++computation) {
DeviceAssignmentProto::ComputationDevice* computation_device =
proto->add_computation_devices();
for (int replica = 0; replica < replica_count(); ++replica) {
computation_device->add_replica_device_ids((*this)(replica, computation));
}
}
return Status::OK();
}
/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
const DeviceAssignmentProto& proto) {
TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
TF_RET_CHECK(computation_device.replica_device_ids_size() ==
proto.replica_count());
for (int replica = 0; replica < proto.replica_count(); ++replica) {
assignment(replica, computation) =
computation_device.replica_device_ids(replica);
}
}
return std::move(assignment);
}
StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
int replica_count,
int computation_count) {
TF_RET_CHECK(replica < replica_count);
TF_RET_CHECK(computation < computation_count);
return computation * replica_count + replica;
}
StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
int replica_count, int computation_count) {
DeviceAssignment assignment(replica_count, computation_count);
for (int replica = 0; replica < replica_count; ++replica) {
for (int computation = 0; computation < computation_count; ++computation) {
TF_ASSIGN_OR_RETURN(
int device_id,
DeviceId(replica, computation, replica_count, computation_count));
assignment(replica, computation) = device_id;
}
}
return std::move(assignment);
}
/* static */ void ComputationPlacer::RegisterComputationPlacer(
se::Platform::Id platform_id,
ComputationPlacerCreationFunction creation_function) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
}
/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
const se::Platform* platform) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
auto* computation_placers = GetPlatformComputationPlacers();
auto it = computation_placers->find(platform->id());
if (it == computation_placers->end()) {
return NotFound(
"could not find registered computation placer for platform %s -- check "
"target linkage",
platform->Name().c_str());
}
if (it->second.placer == nullptr) {
// Lazily create the computation placer the first time it is needed.
it->second.placer = (*it->second.creation_function)();
}
return it->second.placer.get();
}
/* static */ tensorflow::mutex*
ComputationPlacer::platform_computation_placer_mutex() {
static tensorflow::mutex* m = new tensorflow::mutex;
return m;
}
/* static */ std::map<perftools::gputools::Platform::Id,
ComputationPlacer::State>*
ComputationPlacer::GetPlatformComputationPlacers() {
static auto* r =
new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>;
return r;
}
} // namespace xla
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
return xla::MakeUnique<xla::ComputationPlacer>();
}
static bool InitModule() {
xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId,
&CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId,
&CreateComputationPlacer);
return true;
}
static bool module_initialized = InitModule();

View File

@ -0,0 +1,113 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
#include <map>
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Class that represents the device assignment for a set of XLA replicated
// computations. For R replicas and C computations, R * C devices are required
// execute the computation in parallel. The assigned device ids can be accessed
// by assignment(replica, computation).
class DeviceAssignment : public Array2D<int> {
public:
DeviceAssignment() {}
DeviceAssignment(int replica_count, int computation_count)
: Array2D<int>(replica_count, computation_count, -1) {
CHECK_GT(replica_count, 0);
CHECK_GT(computation_count, 0);
}
int replica_count() const { return height(); }
int computation_count() const { return width(); }
// Protocol buffer serialization and deserialization.
Status Serialize(DeviceAssignmentProto* proto) const;
static StatusOr<DeviceAssignment> Deserialize(
const DeviceAssignmentProto& proto);
};
// A generic implementation of the XLA computation placer, which assigns device
// ids to a set of replicated computations.
class ComputationPlacer {
public:
ComputationPlacer() {}
virtual ~ComputationPlacer() {}
// Returns the device id assigned to the given replica and computation
// instance for [replica_count x computation_count] setup. The returned device
// id must match the assignement from PlaceReplicatedComputation().
virtual StatusOr<int> DeviceId(int replica, int computation,
int replica_count, int computation_count);
// Returns the device ids assigned to a set of replicated computations, given
// the number of replicas and the number of computations.
virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count,
int computation_count);
using ComputationPlacerCreationFunction =
std::unique_ptr<ComputationPlacer> (*)();
// Registers a computation placer creation function for a particular platform.
static void RegisterComputationPlacer(
perftools::gputools::Platform::Id platform_id,
ComputationPlacerCreationFunction creation_function);
// Returns the computation placer singleton pointer if it is available for the
// given platform, or an error status if it is not.
static StatusOr<ComputationPlacer*> GetForPlatform(
const perftools::gputools::Platform* platform);
private:
// Routine that returns the mutex that guards the platform-to-computation
// placer map. Done as a routine to ensure correct initialization ordering,
// since RegisterComputationPlacer can be called during program initialization
// time.
static tensorflow::mutex* platform_computation_placer_mutex();
// State kept for each kind of ComputationPlacer. Registration functions set
// up creation_function, and then we use that to lazily create "placer" the
// first time GetForPlatform is invoked for a particular id.
struct State {
std::unique_ptr<ComputationPlacer> placer;
ComputationPlacerCreationFunction creation_function = nullptr;
};
// Map from platform kind to computation placer singleton.
static std::map<perftools::gputools::Platform::Id, State>*
GetPlatformComputationPlacers();
perftools::gputools::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_

View File

@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Construct computation layout from the argument layouts. // Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape); auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result); module_config->set_has_hybrid_result(has_hybrid_result);
module_config->set_replica_count(execute_backend_->Replicas().size()); module_config->set_replica_count(execute_backend_->replica_count());
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
if (flags->xla_hlo_profile) { if (flags->xla_hlo_profile) {
module_config->enable_hlo_profiling(true); module_config->enable_hlo_profiling(true);

View File

@ -325,7 +325,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true); module_config->enable_hlo_profiling(true);
} }
module_config->set_replica_count(backend->Replicas().size()); module_config->set_replica_count(backend->replica_count());
module_config->set_seed(execution_options.seed()); module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options()); module_config->set_debug_options(execution_options.debug_options());
@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice< tensorflow::gtl::ArraySlice<
std::vector<perftools::gputools::DeviceMemoryBase>> std::vector<perftools::gputools::DeviceMemoryBase>>
arguments, arguments,
Backend* backend, Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors,
tensorflow::gtl::ArraySlice<string> result_tags) { tensorflow::gtl::ArraySlice<string> result_tags) {
// TODO(b/33943292): Support for replication when using multiple computations. // Streams where the computation are launched, so we can wait on the streams
TF_RET_CHECK(backend->Replicas().size() == 1); // to complete.
// Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams; std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : executors) { // Global data handles for the computation results, one for each computation.
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
ExecutableRunOptions options;
options.set_stream(stream.get());
options.set_allocator(backend->memory_allocator());
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
run_options.emplace_back(options, backend->StreamBorrower());
}
// Asynchronously launch all executables.
std::vector<GlobalDataHandle> result_handles; std::vector<GlobalDataHandle> result_handles;
for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0;
i < executables.size(); i++) { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
TF_ASSIGN_OR_RETURN( backend->computation_placer()->AssignDevices(
perftools::gputools::DeviceMemoryBase result, backend->replica_count(), executables.size()));
executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i]));
result_handles.push_back(allocation_tracker_.Register( for (int64 i = 0; i < executables.size(); i++) {
backend, executors[i]->device_ordinal(), result, // Stream executors for the replicas of the current computation.
executables[i]->result_shape(), result_tags[i])); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
for (int64 replica = 0; replica < replicas.size(); ++replica) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(replicas[replica]));
streams.push_back(std::move(stream));
// Set up run options.
ExecutableRunOptions options;
options.set_stream(streams.back().get());
options.set_allocator(backend->memory_allocator());
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
ServiceExecutableRunOptions run_options(options,
backend->StreamBorrower());
// Asynchronously launch the computation.
TF_ASSIGN_OR_RETURN(
perftools::gputools::DeviceMemoryBase result,
executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
// All replicas share the same device address for the result allocation,
// so only one of the replicas need to register the result handle.
if (replica == 0) {
result_handles.push_back(allocation_tracker_.Register(
backend, replicas[0]->device_ordinal(), result,
executables[i]->result_shape(), result_tags[i]));
}
}
} }
// Wait for all executions to complete. // Wait for all executions to complete.
for (int64 i = 0; i < result_handles.size(); ++i) { for (int64 i = 0; i < streams.size(); ++i) {
if (!streams[i]->BlockHostUntilDone()) { if (!streams[i]->BlockHostUntilDone()) {
return InternalError("failed to complete execution for stream %lld", i); return InternalError("failed to complete execution for stream %lld", i);
} }
@ -550,17 +558,22 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
arguments, arguments,
Backend* backend, perftools::gputools::StreamExecutor* executor, Backend* backend, perftools::gputools::StreamExecutor* executor,
const string& result_tag, ExecutionProfile* profile) { const string& result_tag, ExecutionProfile* profile) {
TF_RET_CHECK(!backend->Replicas().empty());
// Set up streams. // Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams; std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : backend->Replicas()) { TF_ASSIGN_OR_RETURN(auto replicas,
Replicas(*backend, SingleComputationDeviceHandle()));
TF_RET_CHECK(!replicas.empty());
for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(executor)); backend->BorrowStream(executor));
streams.push_back(std::move(stream)); streams.push_back(std::move(stream));
} }
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices(
backend->replica_count(), /*computation_count=*/1));
// Set up run options. // Set up run options.
std::vector<ServiceExecutableRunOptions> run_options; std::vector<ServiceExecutableRunOptions> run_options;
for (const Pool<se::Stream>::SmartPtr& stream : streams) { for (const Pool<se::Stream>::SmartPtr& stream : streams) {
@ -570,19 +583,20 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool( options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device()); backend->eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
run_options.emplace_back(options, backend->StreamBorrower(), run_options.emplace_back(options, backend->StreamBorrower(),
backend->inter_op_thread_pool()); backend->inter_op_thread_pool());
} }
perftools::gputools::DeviceMemoryBase result; perftools::gputools::DeviceMemoryBase result;
if (backend->Replicas().size() == 1) { if (backend->replica_count() == 1) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>( result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
&run_options[0], profile, arguments)); &run_options[0], profile, arguments));
} else { } else {
std::vector< std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>> tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
repeated_arguments(backend->Replicas().size(), arguments); repeated_arguments(backend->replica_count(), arguments);
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
run_options, repeated_arguments)); run_options, repeated_arguments));
@ -610,25 +624,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::vector<VersionedComputationHandle> versioned_handles; std::vector<VersionedComputationHandle> versioned_handles;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs; std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
std::vector<string> computation_names; std::vector<string> computation_names;
std::vector<DeviceHandle> device_handles;
if (arg->requests_size() > execute_backend_->stream_executors().size()) { if (arg->requests_size() * execute_backend_->replica_count() >
execute_backend_->device_count()) {
return FailedPrecondition( return FailedPrecondition(
"there are not enough stream executors to execute %d computations", "there are not enough stream executors to execute %d computations",
arg->requests_size()); arg->requests_size());
} }
for (int64 i = 0; i < arg->requests_size(); ++i) { for (int64 i = 0; i < arg->requests_size(); ++i) {
// Get the stream executor on which the computation will run. Select the // Get the stream executor for the i'th computation. This stream executor
// specific device if requested, otherwise select the i'th device from the // is one of the executors to run the replicated computation.
// list of available stream executors. if (!arg->requests(i).has_device_handle()) {
se::StreamExecutor* executor; return FailedPrecondition(
if (arg->requests(i).has_device_handle()) { "device handles must be given to execute parallel computations");
executor =
execute_backend_
->stream_executors()[arg->requests(i).device_handle().handle()];
} else {
executor = execute_backend_->stream_executors()[i];
} }
TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, arg->requests(i).device_handle()));
se::StreamExecutor* executor = replicas[0];
CHECK(executor != nullptr); CHECK(executor != nullptr);
// Resolve the UserComputation object associated with the requested // Resolve the UserComputation object associated with the requested
@ -673,6 +688,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
module_configs.push_back(std::move(module_config)); module_configs.push_back(std::move(module_config));
computation_names.push_back(user_computation->name()); computation_names.push_back(user_computation->name());
executors.push_back(executor); executors.push_back(executor);
device_handles.push_back(arg->requests(i).device_handle());
} }
// Build the user computations into HloModules and compile to generate the // Build the user computations into HloModules and compile to generate the
@ -692,7 +708,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<GlobalDataHandle> outputs, std::vector<GlobalDataHandle> outputs,
ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
execute_backend_.get(), executors, execute_backend_.get(), device_handles,
computation_names)); computation_names));
for (const GlobalDataHandle& output : outputs) { for (const GlobalDataHandle& output : outputs) {
ExecuteResponse response; ExecuteResponse response;
@ -706,10 +722,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) { GetDeviceHandlesResponse* result) {
const int64 available_device_count = const int64 available_device_count = execute_backend_->device_count();
execute_backend_->stream_executors().size(); const int64 replica_count = execute_backend_->replica_count();
const int64 replicas = execute_backend_->Replicas().size(); if (replica_count <= 0) {
if (available_device_count < arg->device_count() * replicas) { return FailedPrecondition("Replica count must be a positive integer");
}
if (available_device_count < arg->device_count() * replica_count) {
return ResourceExhausted( return ResourceExhausted(
"Requested device count (%lld) exceeds the number of available devices " "Requested device count (%lld) exceeds the number of available devices "
"on the target (%lld)", "on the target (%lld)",
@ -718,8 +736,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
for (int64 i = 0; i < arg->device_count(); ++i) { for (int64 i = 0; i < arg->device_count(); ++i) {
DeviceHandle device_handle; DeviceHandle device_handle;
device_handle.set_handle( device_handle.set_handle(i);
execute_backend_->stream_executors()[i * replicas]->device_ordinal()); device_handle.set_device_count(arg->device_count());
*result->add_device_handles() = device_handle; *result->add_device_handles() = device_handle;
} }
@ -841,11 +859,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
execute_backend_->default_stream_executor(), execute_backend_->default_stream_executor(),
&profile)); &profile));
TF_RET_CHECK(!execute_backend_->Replicas().empty()); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
SingleComputationDeviceHandle()));
TF_RET_CHECK(!replicas.empty());
// Set up streams. // Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams; std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : execute_backend_->Replicas()) { for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream, TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
execute_backend_->BorrowStream(executor)); execute_backend_->BorrowStream(executor));
streams.push_back(std::move(stream)); streams.push_back(std::move(stream));
@ -927,19 +948,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
Literal literal = Literal(arg->literal()); Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape(); const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
// TODO(b/32990684): Tuple transfers to host end up allocating further // TODO(b/32990684): Tuple transfers to host end up allocating further
// buffers - implement that correctly. // buffers - implement that correctly.
return Unimplemented( return Unimplemented(
"Tuple transfers to the device not supported with replication."); "Tuple transfers to the device not supported with replication.");
} }
se::StreamExecutor* stream_executor; std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) { if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( TF_ASSIGN_OR_RETURN(replicas,
arg->device_handle().handle())); Replicas(*execute_backend_, arg->device_handle()));
} else { } else {
stream_executor = execute_backend_->default_stream_executor(); TF_ASSIGN_OR_RETURN(
replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
} }
// Allocate memory on the device, using the stream executor. The size of the // Allocate memory on the device, using the stream executor. The size of the
@ -950,14 +972,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
execute_backend_->memory_allocator()->Allocate( execute_backend_->memory_allocator()->Allocate(
stream_executor->device_ordinal(), allocation_size)); replicas[0]->device_ordinal(), allocation_size));
*result->mutable_data() = allocation_tracker_.Register( *result->mutable_data() = allocation_tracker_.Register(
execute_backend_.get(), stream_executor->device_ordinal(), allocation, execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape,
shape, StrCat("TransferToServer literal of size ", allocation_size)); StrCat("TransferToServer literal of size ", allocation_size));
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
stream_executor->device_ordinal()));
for (se::StreamExecutor* executor : replicas) { for (se::StreamExecutor* executor : replicas) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice( execute_backend_->transfer_manager()->TransferLiteralToDevice(
@ -968,7 +988,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) { TransferToInfeedResponse* result) {
const int64 replica_count = execute_backend_->Replicas().size(); const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition( return FailedPrecondition(
"%s", "%s",
@ -980,11 +1000,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
se::StreamExecutor* executor; se::StreamExecutor* executor;
if (arg->has_device_handle()) { if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( TF_ASSIGN_OR_RETURN(auto replicas,
arg->device_handle().handle())); Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()]; executor = replicas[arg->replica_id()];
} else { } else {
executor = execute_backend_->Replicas()[arg->replica_id()]; TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
executor = replicas[arg->replica_id()];
} }
return execute_backend_->transfer_manager()->TransferLiteralToInfeed( return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
@ -994,7 +1017,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::TransferFromOutfeed( tensorflow::Status Service::TransferFromOutfeed(
const TransferFromOutfeedRequest* arg, const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) { TransferFromOutfeedResponse* result) {
const int64 replica_count = execute_backend_->Replicas().size(); const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition( return FailedPrecondition(
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
@ -1004,11 +1027,14 @@ tensorflow::Status Service::TransferFromOutfeed(
se::StreamExecutor* executor; se::StreamExecutor* executor;
if (arg->has_device_handle()) { if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( TF_ASSIGN_OR_RETURN(auto replicas,
arg->device_handle().handle())); Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()]; executor = replicas[arg->replica_id()];
} else { } else {
executor = execute_backend_->Replicas()[arg->replica_id()]; TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
executor = replicas[arg->replica_id()];
} }
Literal literal; Literal literal;
@ -1387,4 +1413,28 @@ tensorflow::Status Service::LoadComputationSnapshot(
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
DeviceHandle Service::SingleComputationDeviceHandle() const {
DeviceHandle device_handle;
device_handle.set_handle(0);
device_handle.set_device_count(1);
return device_handle;
}
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
const Backend& backend, const DeviceHandle& device_handle) const {
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (int replica = 0; replica < backend.replica_count(); ++replica) {
// From the computation placer, find out the device ids of the replicas for
// the given device handle.
TF_ASSIGN_OR_RETURN(
int device_ordinal,
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
backend.replica_count(),
device_handle.device_count()));
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
replicas.push_back(executor);
}
return replicas;
}
} // namespace xla } // namespace xla

View File

@ -319,8 +319,7 @@ class Service : public ServiceInterface {
std::vector<perftools::gputools::DeviceMemoryBase>> std::vector<perftools::gputools::DeviceMemoryBase>>
arguments, arguments,
Backend* backend, Backend* backend,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
executors,
tensorflow::gtl::ArraySlice<string> result_tags); tensorflow::gtl::ArraySlice<string> result_tags);
// Returns an HLO dumper for use in the compiler (it refers to flags // Returns an HLO dumper for use in the compiler (it refers to flags
@ -346,6 +345,16 @@ class Service : public ServiceInterface {
tensorflow::Status ValidateResultShapeWithLayout( tensorflow::Status ValidateResultShapeWithLayout(
const Shape& shape_with_layout, const Shape& result_shape) const; const Shape& shape_with_layout, const Shape& result_shape) const;
// Returns the stream executors assigned to the replicas represented by the
// given device handle. Each device_handle is a virtual replicated device that
// represents a set of physical devices for the replicas.
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
DeviceHandle SingleComputationDeviceHandle() const;
// Tracks computations built via the API. // Tracks computations built via the API.
ComputationTracker computation_tracker_; ComputationTracker computation_tracker_;

View File

@ -99,6 +99,7 @@ cc_library(
"//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_execution_profile",
@ -196,6 +197,7 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",
@ -746,6 +748,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",

View File

@ -255,11 +255,15 @@ message ComputationDataHandle {
int64 handle = 1; int64 handle = 1;
} }
// Handle given to a user that represents a device to execute a computation. // Handle given to a user that represents a replicated virtual device. Each
// When replication is enabled, the device handle represents the device for the // replicated device represents N physical devices for execution where N is the
// replica id 0. // number of replicas.
message DeviceHandle { message DeviceHandle {
int64 handle = 1; int64 handle = 1;
// The number of model-parallel virtual devices that communicate via XLA
// Send/Recv instructions.
int64 device_count = 2;
} }
// Handle given to a user to represent a channel between two computations // Handle given to a user to represent a channel between two computations
@ -269,6 +273,21 @@ message ChannelHandle {
int64 handle = 1; int64 handle = 1;
} }
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
int32 replica_count = 1;
int32 computation_count = 2;
// Each logical computation runs on replica_count physical devices.
// ComputationDevice represents the device ids assinged to the replicas.
message ComputationDevice {
repeated int32 replica_device_ids = 1;
}
repeated ComputationDevice computation_devices = 3;
}
// Literals are used when the server and client need to exchange materialized // Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in // data / results. Literals are also used to describe constants used in
// computations. // computations.