Add ComputationPlacer to assign device ids for replicated model-parallel computations.

PiperOrigin-RevId: 159056198
This commit is contained in:
HyoukJoong Lee 2017-06-14 18:46:32 -07:00 committed by TensorFlower Gardener
parent 0fa19543aa
commit 7d3497a639
12 changed files with 489 additions and 134 deletions

View File

@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
return execution_profile_;
}
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
DeviceAssignment* device_assignment) {
device_assignment_ = device_assignment;
return *this;
}
DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}
} // namespace xla

View File

@ -40,6 +40,7 @@ struct ThreadPoolDevice;
namespace xla {
class DeviceMemoryAllocator;
class DeviceAssignment;
class ExecutionProfile;
// Class containing options for running a LocalExecutable.
@ -79,9 +80,14 @@ class ExecutableRunOptions {
ExecutionProfile* execution_profile() const;
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
ExecutableRunOptions& set_device_assignment(
DeviceAssignment* device_assignment);
DeviceAssignment* device_assignment() const;
private:
DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr;
perftools::gputools::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;

View File

@ -332,6 +332,7 @@ cc_library(
hdrs = ["backend.h"],
deps = [
":compiler",
":computation_placer",
":device_memory_allocator",
":platform_util",
":pool",
@ -950,6 +951,26 @@ cc_test(
],
)
cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = True, # Contains per-platform computation placer registration
)
cc_library(
name = "generic_transfer_manager",
srcs = ["generic_transfer_manager.cc"],

View File

@ -96,9 +96,11 @@ struct Backend::EigenThreadPoolWrapper {
PlatformUtil::GetStreamExecutors(platform));
TF_ASSIGN_OR_RETURN(auto transfer_manager,
TransferManager::GetForPlatform(platform));
std::unique_ptr<Backend> backend(
new Backend(replica_count, platform, compiler, stream_executors,
transfer_manager, options.intra_op_parallelism_threads()));
TF_ASSIGN_OR_RETURN(auto computation_placer,
ComputationPlacer::GetForPlatform(platform));
std::unique_ptr<Backend> backend(new Backend(
replica_count, platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend);
}
@ -135,10 +137,12 @@ Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads)
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
computation_placer_(computation_placer),
replica_count_(replica_count) {
// The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) {
@ -179,36 +183,6 @@ int Backend::default_device_ordinal() const {
return default_stream_executor()->device_ordinal();
}
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
int device_ordinal) const {
if (stream_executors_[device_ordinal] == nullptr) {
return InvalidArgument("device %s not supported by XLA service",
device_name(device_ordinal).c_str());
}
// Find replica_count_ stream executors starting from the given device
// ordinal.
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (se::StreamExecutor* exec : stream_executors_) {
CHECK(exec != nullptr);
if (exec->device_ordinal() >= device_ordinal) {
replicas.push_back(exec);
if (replicas.size() >= replica_count_) {
return replicas;
}
}
}
return InvalidArgument(
"Not enough devices for replicas for the device ordinal %d",
device_ordinal);
}
std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
CHECK_GE(stream_executors_.size(), replica_count_);
return Replicas(default_device_ordinal()).ValueOrDie();
}
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
return inter_op_thread_pool_.get();
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@ -40,6 +41,8 @@ struct ThreadPoolDevice;
namespace xla {
// Options to configure the backend when it is created.
//
// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions {
public:
// Set the platform backing the backend, or nullptr for the default platform.
@ -92,11 +95,15 @@ class Backend {
return memory_allocator_.get();
}
TransferManager* transfer_manager() const { return transfer_manager_; }
ComputationPlacer* computation_placer() const { return computation_placer_; }
// Returns the number of devices of the platform type which are visible. Not
// all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); }
// Returns the number of replicas when replication is enabled.
int replica_count() const { return replica_count_; }
// Returns the device ordinal number of the default device.
int default_device_ordinal() const;
@ -107,24 +114,13 @@ class Backend {
return stream_executors_;
}
// Returns the replicas for the default stream executor.
//
// When the number of replicas is R, the first R stream executors are assigned
// to the replicas of the default stream executor.
std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
// Returns the replicas for the given device_ordinal. The given device ordinal
// is considered to be the first device ordinal among the replicas. Returns an
// error status if the stream executor for the given given device ordinal does
// not exist or if there are not enough stream executors for the replicas.
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
int device_ordinal) const;
// Return the stream executor for the given device ordinal.
// Returns the stream executor for the given device ordinal.
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
int device_ordinal) const;
// Return the stream executor for the default device ordinal.
// Returns the stream executor for the default device ordinal. This stream
// executor can only be used when the number of computations is 1 (replication
// can be > 1).
perftools::gputools::StreamExecutor* default_stream_executor() const {
CHECK(!stream_executors_.empty());
return stream_executors_[0];
@ -178,13 +174,16 @@ class Backend {
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
TransferManager* transfer_manager, int intra_op_parallelism_threads);
TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
perftools::gputools::Platform* platform_;
Compiler* compiler_;
TransferManager* transfer_manager_;
ComputationPlacer* computation_placer_;
int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor.

View File

@ -0,0 +1,151 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace se = ::perftools::gputools;
namespace xla {
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
proto->set_replica_count(replica_count());
proto->set_computation_count(computation_count());
for (int computation = 0; computation < computation_count(); ++computation) {
DeviceAssignmentProto::ComputationDevice* computation_device =
proto->add_computation_devices();
for (int replica = 0; replica < replica_count(); ++replica) {
computation_device->add_replica_device_ids((*this)(replica, computation));
}
}
return Status::OK();
}
/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
const DeviceAssignmentProto& proto) {
TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
TF_RET_CHECK(computation_device.replica_device_ids_size() ==
proto.replica_count());
for (int replica = 0; replica < proto.replica_count(); ++replica) {
assignment(replica, computation) =
computation_device.replica_device_ids(replica);
}
}
return std::move(assignment);
}
StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
int replica_count,
int computation_count) {
TF_RET_CHECK(replica < replica_count);
TF_RET_CHECK(computation < computation_count);
return computation * replica_count + replica;
}
StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
int replica_count, int computation_count) {
DeviceAssignment assignment(replica_count, computation_count);
for (int replica = 0; replica < replica_count; ++replica) {
for (int computation = 0; computation < computation_count; ++computation) {
TF_ASSIGN_OR_RETURN(
int device_id,
DeviceId(replica, computation, replica_count, computation_count));
assignment(replica, computation) = device_id;
}
}
return std::move(assignment);
}
/* static */ void ComputationPlacer::RegisterComputationPlacer(
se::Platform::Id platform_id,
ComputationPlacerCreationFunction creation_function) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
}
/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
const se::Platform* platform) {
tensorflow::mutex_lock lock(
*ComputationPlacer::platform_computation_placer_mutex());
auto* computation_placers = GetPlatformComputationPlacers();
auto it = computation_placers->find(platform->id());
if (it == computation_placers->end()) {
return NotFound(
"could not find registered computation placer for platform %s -- check "
"target linkage",
platform->Name().c_str());
}
if (it->second.placer == nullptr) {
// Lazily create the computation placer the first time it is needed.
it->second.placer = (*it->second.creation_function)();
}
return it->second.placer.get();
}
/* static */ tensorflow::mutex*
ComputationPlacer::platform_computation_placer_mutex() {
static tensorflow::mutex* m = new tensorflow::mutex;
return m;
}
/* static */ std::map<perftools::gputools::Platform::Id,
ComputationPlacer::State>*
ComputationPlacer::GetPlatformComputationPlacers() {
static auto* r =
new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>;
return r;
}
} // namespace xla
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
return xla::MakeUnique<xla::ComputationPlacer>();
}
static bool InitModule() {
xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId,
&CreateComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId,
&CreateComputationPlacer);
return true;
}
static bool module_initialized = InitModule();

View File

@ -0,0 +1,113 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
#include <map>
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Class that represents the device assignment for a set of XLA replicated
// computations. For R replicas and C computations, R * C devices are required
// execute the computation in parallel. The assigned device ids can be accessed
// by assignment(replica, computation).
class DeviceAssignment : public Array2D<int> {
public:
DeviceAssignment() {}
DeviceAssignment(int replica_count, int computation_count)
: Array2D<int>(replica_count, computation_count, -1) {
CHECK_GT(replica_count, 0);
CHECK_GT(computation_count, 0);
}
int replica_count() const { return height(); }
int computation_count() const { return width(); }
// Protocol buffer serialization and deserialization.
Status Serialize(DeviceAssignmentProto* proto) const;
static StatusOr<DeviceAssignment> Deserialize(
const DeviceAssignmentProto& proto);
};
// A generic implementation of the XLA computation placer, which assigns device
// ids to a set of replicated computations.
class ComputationPlacer {
public:
ComputationPlacer() {}
virtual ~ComputationPlacer() {}
// Returns the device id assigned to the given replica and computation
// instance for [replica_count x computation_count] setup. The returned device
// id must match the assignement from PlaceReplicatedComputation().
virtual StatusOr<int> DeviceId(int replica, int computation,
int replica_count, int computation_count);
// Returns the device ids assigned to a set of replicated computations, given
// the number of replicas and the number of computations.
virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count,
int computation_count);
using ComputationPlacerCreationFunction =
std::unique_ptr<ComputationPlacer> (*)();
// Registers a computation placer creation function for a particular platform.
static void RegisterComputationPlacer(
perftools::gputools::Platform::Id platform_id,
ComputationPlacerCreationFunction creation_function);
// Returns the computation placer singleton pointer if it is available for the
// given platform, or an error status if it is not.
static StatusOr<ComputationPlacer*> GetForPlatform(
const perftools::gputools::Platform* platform);
private:
// Routine that returns the mutex that guards the platform-to-computation
// placer map. Done as a routine to ensure correct initialization ordering,
// since RegisterComputationPlacer can be called during program initialization
// time.
static tensorflow::mutex* platform_computation_placer_mutex();
// State kept for each kind of ComputationPlacer. Registration functions set
// up creation_function, and then we use that to lazily create "placer" the
// first time GetForPlatform is invoked for a particular id.
struct State {
std::unique_ptr<ComputationPlacer> placer;
ComputationPlacerCreationFunction creation_function = nullptr;
};
// Map from platform kind to computation placer singleton.
static std::map<perftools::gputools::Platform::Id, State>*
GetPlatformComputationPlacers();
perftools::gputools::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_

View File

@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result);
module_config->set_replica_count(execute_backend_->Replicas().size());
module_config->set_replica_count(execute_backend_->replica_count());
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
if (flags->xla_hlo_profile) {
module_config->enable_hlo_profiling(true);

View File

@ -325,7 +325,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true);
}
module_config->set_replica_count(backend->Replicas().size());
module_config->set_replica_count(backend->replica_count());
module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options());
@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice<
std::vector<perftools::gputools::DeviceMemoryBase>>
arguments,
Backend* backend,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors,
Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
tensorflow::gtl::ArraySlice<string> result_tags) {
// TODO(b/33943292): Support for replication when using multiple computations.
TF_RET_CHECK(backend->Replicas().size() == 1);
// Set up streams.
// Streams where the computation are launched, so we can wait on the streams
// to complete.
std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : executors) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
ExecutableRunOptions options;
options.set_stream(stream.get());
options.set_allocator(backend->memory_allocator());
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
run_options.emplace_back(options, backend->StreamBorrower());
}
// Asynchronously launch all executables.
// Global data handles for the computation results, one for each computation.
std::vector<GlobalDataHandle> result_handles;
for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0;
i < executables.size(); i++) {
TF_ASSIGN_OR_RETURN(
perftools::gputools::DeviceMemoryBase result,
executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i]));
result_handles.push_back(allocation_tracker_.Register(
backend, executors[i]->device_ordinal(), result,
executables[i]->result_shape(), result_tags[i]));
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices(
backend->replica_count(), executables.size()));
for (int64 i = 0; i < executables.size(); i++) {
// Stream executors for the replicas of the current computation.
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
for (int64 replica = 0; replica < replicas.size(); ++replica) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(replicas[replica]));
streams.push_back(std::move(stream));
// Set up run options.
ExecutableRunOptions options;
options.set_stream(streams.back().get());
options.set_allocator(backend->memory_allocator());
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
ServiceExecutableRunOptions run_options(options,
backend->StreamBorrower());
// Asynchronously launch the computation.
TF_ASSIGN_OR_RETURN(
perftools::gputools::DeviceMemoryBase result,
executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
// All replicas share the same device address for the result allocation,
// so only one of the replicas need to register the result handle.
if (replica == 0) {
result_handles.push_back(allocation_tracker_.Register(
backend, replicas[0]->device_ordinal(), result,
executables[i]->result_shape(), result_tags[i]));
}
}
}
// Wait for all executions to complete.
for (int64 i = 0; i < result_handles.size(); ++i) {
for (int64 i = 0; i < streams.size(); ++i) {
if (!streams[i]->BlockHostUntilDone()) {
return InternalError("failed to complete execution for stream %lld", i);
}
@ -550,17 +558,22 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
arguments,
Backend* backend, perftools::gputools::StreamExecutor* executor,
const string& result_tag, ExecutionProfile* profile) {
TF_RET_CHECK(!backend->Replicas().empty());
// Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : backend->Replicas()) {
TF_ASSIGN_OR_RETURN(auto replicas,
Replicas(*backend, SingleComputationDeviceHandle()));
TF_RET_CHECK(!replicas.empty());
for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices(
backend->replica_count(), /*computation_count=*/1));
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
@ -570,19 +583,20 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
options.set_intra_op_thread_pool(
backend->eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
run_options.emplace_back(options, backend->StreamBorrower(),
backend->inter_op_thread_pool());
}
perftools::gputools::DeviceMemoryBase result;
if (backend->Replicas().size() == 1) {
if (backend->replica_count() == 1) {
TF_ASSIGN_OR_RETURN(
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
&run_options[0], profile, arguments));
} else {
std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
repeated_arguments(backend->Replicas().size(), arguments);
repeated_arguments(backend->replica_count(), arguments);
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
run_options, repeated_arguments));
@ -610,25 +624,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::vector<VersionedComputationHandle> versioned_handles;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
std::vector<string> computation_names;
std::vector<DeviceHandle> device_handles;
if (arg->requests_size() > execute_backend_->stream_executors().size()) {
if (arg->requests_size() * execute_backend_->replica_count() >
execute_backend_->device_count()) {
return FailedPrecondition(
"there are not enough stream executors to execute %d computations",
arg->requests_size());
}
for (int64 i = 0; i < arg->requests_size(); ++i) {
// Get the stream executor on which the computation will run. Select the
// specific device if requested, otherwise select the i'th device from the
// list of available stream executors.
se::StreamExecutor* executor;
if (arg->requests(i).has_device_handle()) {
executor =
execute_backend_
->stream_executors()[arg->requests(i).device_handle().handle()];
} else {
executor = execute_backend_->stream_executors()[i];
// Get the stream executor for the i'th computation. This stream executor
// is one of the executors to run the replicated computation.
if (!arg->requests(i).has_device_handle()) {
return FailedPrecondition(
"device handles must be given to execute parallel computations");
}
TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, arg->requests(i).device_handle()));
se::StreamExecutor* executor = replicas[0];
CHECK(executor != nullptr);
// Resolve the UserComputation object associated with the requested
@ -673,6 +688,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
module_configs.push_back(std::move(module_config));
computation_names.push_back(user_computation->name());
executors.push_back(executor);
device_handles.push_back(arg->requests(i).device_handle());
}
// Build the user computations into HloModules and compile to generate the
@ -692,7 +708,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDataHandle> outputs,
ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
execute_backend_.get(), executors,
execute_backend_.get(), device_handles,
computation_names));
for (const GlobalDataHandle& output : outputs) {
ExecuteResponse response;
@ -706,10 +722,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) {
const int64 available_device_count =
execute_backend_->stream_executors().size();
const int64 replicas = execute_backend_->Replicas().size();
if (available_device_count < arg->device_count() * replicas) {
const int64 available_device_count = execute_backend_->device_count();
const int64 replica_count = execute_backend_->replica_count();
if (replica_count <= 0) {
return FailedPrecondition("Replica count must be a positive integer");
}
if (available_device_count < arg->device_count() * replica_count) {
return ResourceExhausted(
"Requested device count (%lld) exceeds the number of available devices "
"on the target (%lld)",
@ -718,8 +736,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
for (int64 i = 0; i < arg->device_count(); ++i) {
DeviceHandle device_handle;
device_handle.set_handle(
execute_backend_->stream_executors()[i * replicas]->device_ordinal());
device_handle.set_handle(i);
device_handle.set_device_count(arg->device_count());
*result->add_device_handles() = device_handle;
}
@ -841,11 +859,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
execute_backend_->default_stream_executor(),
&profile));
TF_RET_CHECK(!execute_backend_->Replicas().empty());
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
SingleComputationDeviceHandle()));
TF_RET_CHECK(!replicas.empty());
// Set up streams.
std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
execute_backend_->BorrowStream(executor));
streams.push_back(std::move(stream));
@ -927,19 +948,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
// TODO(b/32990684): Tuple transfers to host end up allocating further
// buffers - implement that correctly.
return Unimplemented(
"Tuple transfers to the device not supported with replication.");
}
se::StreamExecutor* stream_executor;
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
arg->device_handle().handle()));
TF_ASSIGN_OR_RETURN(replicas,
Replicas(*execute_backend_, arg->device_handle()));
} else {
stream_executor = execute_backend_->default_stream_executor();
TF_ASSIGN_OR_RETURN(
replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
}
// Allocate memory on the device, using the stream executor. The size of the
@ -950,14 +972,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
execute_backend_->memory_allocator()->Allocate(
stream_executor->device_ordinal(), allocation_size));
replicas[0]->device_ordinal(), allocation_size));
*result->mutable_data() = allocation_tracker_.Register(
execute_backend_.get(), stream_executor->device_ordinal(), allocation,
shape, StrCat("TransferToServer literal of size ", allocation_size));
execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape,
StrCat("TransferToServer literal of size ", allocation_size));
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
stream_executor->device_ordinal()));
for (se::StreamExecutor* executor : replicas) {
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
@ -968,7 +988,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) {
const int64 replica_count = execute_backend_->Replicas().size();
const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"%s",
@ -980,11 +1000,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
se::StreamExecutor* executor;
if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
arg->device_handle().handle()));
TF_ASSIGN_OR_RETURN(auto replicas,
Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()];
} else {
executor = execute_backend_->Replicas()[arg->replica_id()];
TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
executor = replicas[arg->replica_id()];
}
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
@ -994,7 +1017,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::TransferFromOutfeed(
const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) {
const int64 replica_count = execute_backend_->Replicas().size();
const int64 replica_count = execute_backend_->replica_count();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
@ -1004,11 +1027,14 @@ tensorflow::Status Service::TransferFromOutfeed(
se::StreamExecutor* executor;
if (arg->has_device_handle()) {
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
arg->device_handle().handle()));
TF_ASSIGN_OR_RETURN(auto replicas,
Replicas(*execute_backend_, arg->device_handle()));
executor = replicas[arg->replica_id()];
} else {
executor = execute_backend_->Replicas()[arg->replica_id()];
TF_ASSIGN_OR_RETURN(
auto replicas,
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
executor = replicas[arg->replica_id()];
}
Literal literal;
@ -1387,4 +1413,28 @@ tensorflow::Status Service::LoadComputationSnapshot(
return tensorflow::Status::OK();
}
DeviceHandle Service::SingleComputationDeviceHandle() const {
DeviceHandle device_handle;
device_handle.set_handle(0);
device_handle.set_device_count(1);
return device_handle;
}
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
const Backend& backend, const DeviceHandle& device_handle) const {
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (int replica = 0; replica < backend.replica_count(); ++replica) {
// From the computation placer, find out the device ids of the replicas for
// the given device handle.
TF_ASSIGN_OR_RETURN(
int device_ordinal,
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
backend.replica_count(),
device_handle.device_count()));
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
replicas.push_back(executor);
}
return replicas;
}
} // namespace xla

View File

@ -319,8 +319,7 @@ class Service : public ServiceInterface {
std::vector<perftools::gputools::DeviceMemoryBase>>
arguments,
Backend* backend,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors,
tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
tensorflow::gtl::ArraySlice<string> result_tags);
// Returns an HLO dumper for use in the compiler (it refers to flags
@ -346,6 +345,16 @@ class Service : public ServiceInterface {
tensorflow::Status ValidateResultShapeWithLayout(
const Shape& shape_with_layout, const Shape& result_shape) const;
// Returns the stream executors assigned to the replicas represented by the
// given device handle. Each device_handle is a virtual replicated device that
// represents a set of physical devices for the replicas.
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
DeviceHandle SingleComputationDeviceHandle() const;
// Tracks computations built via the API.
ComputationTracker computation_tracker_;

View File

@ -99,6 +99,7 @@ cc_library(
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
@ -196,6 +197,7 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
@ -746,6 +748,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",

View File

@ -255,11 +255,15 @@ message ComputationDataHandle {
int64 handle = 1;
}
// Handle given to a user that represents a device to execute a computation.
// When replication is enabled, the device handle represents the device for the
// replica id 0.
// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
message DeviceHandle {
int64 handle = 1;
// The number of model-parallel virtual devices that communicate via XLA
// Send/Recv instructions.
int64 device_count = 2;
}
// Handle given to a user to represent a channel between two computations
@ -269,6 +273,21 @@ message ChannelHandle {
int64 handle = 1;
}
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
int32 replica_count = 1;
int32 computation_count = 2;
// Each logical computation runs on replica_count physical devices.
// ComputationDevice represents the device ids assinged to the replicas.
message ComputationDevice {
repeated int32 replica_device_ids = 1;
}
repeated ComputationDevice computation_devices = 3;
}
// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.