[XLA] Move replica_count out of Backend.

This is an intermediate step in making replica_count more explicitly
programmable (rather than set by a flag). There is no reason for the Backend to
hold replica_count - it was only holding it as a container with no additional
semantics.

PiperOrigin-RevId: 159626528
This commit is contained in:
Eli Bendersky 2017-06-20 15:48:13 -07:00 committed by TensorFlower Gardener
parent 4c13564c9c
commit 35af7113de
9 changed files with 60 additions and 77 deletions

View File

@ -342,7 +342,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@ -386,6 +385,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:lib",

View File

@ -22,7 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
return platform_;
}
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
number_of_replicas_ = number_of_replicas;
return *this;
}
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
int num_threads) {
intra_op_parallelism_threads_ = num_threads;
@ -85,11 +77,6 @@ struct Backend::EigenThreadPoolWrapper {
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
const BackendOptions& options) {
int64 replica_count = options.number_of_replicas();
if (replica_count == -1) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
replica_count = flags->xla_replicas;
}
perftools::gputools::Platform* platform = options.platform();
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto stream_executors,
@ -98,9 +85,9 @@ struct Backend::EigenThreadPoolWrapper {
TransferManager::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto computation_placer,
ComputationPlacer::GetForPlatform(platform));
std::unique_ptr<Backend> backend(new Backend(
replica_count, platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads()));
std::unique_ptr<Backend> backend(
new Backend(platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend);
}
@ -134,36 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
}
Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
perftools::gputools::Platform* platform, Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
computation_placer_(computation_placer),
replica_count_(replica_count) {
computation_placer_(computation_placer) {
// The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) {
if (exec != nullptr) {
stream_executors_.push_back(exec);
}
}
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
// Create a memory allocator for the valid stream executors.
memory_allocator_ =
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
// First check that there are some non-null stream executors to avoid issuing
// an error mentioning replicas in the common case of requesting just 1
// replica, which means no replication.
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';
CHECK_GE(stream_executors_.size(), replica_count)
<< "Requested more replicas than there are devices for backend "
<< platform_->Name() << '.';
if (platform->id() == se::host::kHostPlatformId) {
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(

View File

@ -41,20 +41,12 @@ struct ThreadPoolDevice;
namespace xla {
// Options to configure the backend when it is created.
//
// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions {
public:
// Set the platform backing the backend, or nullptr for the default platform.
BackendOptions& set_platform(perftools::gputools::Platform* platform);
perftools::gputools::Platform* platform() const;
// Set the number of replicas to use when compiling replicated
// programs. The default is -1 meaning that the value is read from
// the xla_replicas flag.
BackendOptions& set_number_of_replicas(int number_of_replicas);
int number_of_replicas() const;
// Sets the thread pool size for parallel execution of an individual operator.
// The default value of -1 will result in initializing the thread pool with
// the number of threads equal to the number of cores in the system.
@ -63,7 +55,6 @@ class BackendOptions {
private:
perftools::gputools::Platform* platform_ = nullptr;
int number_of_replicas_ = -1;
int intra_op_parallelism_threads_ = -1;
};
@ -77,8 +68,7 @@ class Backend {
public:
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
// Creates a new backend for the given platform with the given number of
// replicas.
// Creates a new backend.
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
const BackendOptions& options);
@ -101,9 +91,6 @@ class Backend {
// all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); }
// Returns the number of replicas when replication is enabled.
int replica_count() const { return replica_count_; }
// Returns the device ordinal number of the default device.
int default_device_ordinal() const;
@ -170,8 +157,7 @@ class Backend {
private:
struct EigenThreadPoolWrapper;
Backend(int64 replica_count, perftools::gputools::Platform* platform,
Compiler* compiler,
Backend(perftools::gputools::Platform* platform, Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors,
TransferManager* transfer_manager,
@ -184,7 +170,6 @@ class Backend {
Compiler* compiler_;
TransferManager* transfer_manager_;
ComputationPlacer* computation_placer_;
int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor.
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;

View File

@ -52,14 +52,16 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<CompileOnlyService> service(
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
options, compiler, std::move(compute_constant_backend)));
return std::move(service);
}
CompileOnlyService::CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
const ServiceOptions& options, Compiler* compiler,
std::unique_ptr<Backend> compute_constant_backend)
: Service(options, /*backend=*/nullptr,
std::move(compute_constant_backend)),
compiler_(compiler) {
runs_in_client_process_ = true;
}

View File

@ -103,7 +103,8 @@ class CompileOnlyService : public Service {
private:
explicit CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
const ServiceOptions& options, Compiler* compiler,
std::unique_ptr<Backend> compute_constant_backend);
CompileOnlyService(const CompileOnlyService&) = delete;
void operator=(const CompileOnlyService&) = delete;

View File

@ -63,7 +63,6 @@ namespace xla {
BackendOptions backend_options;
backend_options.set_platform(platform)
.set_number_of_replicas(options.number_of_replicas())
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
Backend::CreateBackend(backend_options));
@ -71,13 +70,15 @@ namespace xla {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<LocalService> service(new LocalService(
std::move(backend), std::move(compute_constant_backend)));
options, std::move(backend), std::move(compute_constant_backend)));
return std::move(service);
}
LocalService::LocalService(std::unique_ptr<Backend> execute_backend,
LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend)
: Service(std::move(execute_backend), std::move(compute_constant_backend)) {
: Service(options, std::move(execute_backend),
std::move(compute_constant_backend)) {
runs_in_client_process_ = true;
}
@ -153,7 +154,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result);
module_config->set_replica_count(execute_backend_->replica_count());
module_config->set_replica_count(options_.number_of_replicas());
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
module_config->set_intra_op_parallelism_threads(

View File

@ -60,7 +60,8 @@ class LocalService : public Service {
const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
private:
explicit LocalService(std::unique_ptr<Backend> backend,
explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend);
LocalService(const LocalService&) = delete;
void operator=(const LocalService&) = delete;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const {
}
BackendOptions backend_options;
backend_options.set_platform(platform);
backend_options.set_number_of_replicas(options.number_of_replicas());
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<Service> service(new Service(
std::move(execute_backend), std::move(compute_constant_backend)));
std::unique_ptr<Service> service(
new Service(options, std::move(execute_backend),
std::move(compute_constant_backend)));
return std::move(service);
}
@ -158,7 +160,6 @@ Service::CreateComputeConstantBackend() {
if (platform->id() == se::host::kHostPlatformId) {
BackendOptions backend_options;
backend_options.set_platform(platform);
backend_options.set_number_of_replicas(1);
return Backend::CreateBackend(backend_options);
}
}
@ -171,11 +172,24 @@ Service::CreateComputeConstantBackend() {
};
}
Service::Service(std::unique_ptr<Backend> execute_backend,
Service::Service(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend)
: execute_backend_(std::move(execute_backend)),
: options_(options),
execute_backend_(std::move(execute_backend)),
compute_constant_backend_(std::move(compute_constant_backend)) {
// TODO(b/32648682): this flag / options update dance will go away once we
// pass the replica count explicitly to the service.
if (options_.number_of_replicas() < 0) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
options_.set_number_of_replicas(flags->xla_replicas);
}
if (execute_backend_) {
if (execute_backend_->device_count() > 0) {
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
<< "Requested more replicas than there are devices.";
}
LOG(INFO) << Printf(
"XLA service %p executing computations on platform %s. Devices:", this,
execute_backend_->platform()->Name().c_str());
@ -325,7 +339,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true);
}
module_config->set_replica_count(backend->replica_count());
module_config->set_replica_count(options_.number_of_replicas());
module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options());
@ -506,7 +520,7 @@ Service::ExecuteParallelAndRegisterResult(
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices(
backend->replica_count(), executables.size()));
options_.number_of_replicas(), executables.size()));
for (int64 i = 0; i < executables.size(); i++) {
// Stream executors for the replicas of the current computation.
@ -572,7 +586,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices(
backend->replica_count(), /*computation_count=*/1));
options_.number_of_replicas(),
/*computation_count=*/1));
// Set up run options.
std::vector<ServiceExecutableRunOptions> run_options;
@ -589,14 +604,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
}
perftools::gputools::DeviceMemoryBase result;
if (backend->replica_count() == 1) {
if (options_.number_of_replicas() == 1) {
TF_ASSIGN_OR_RETURN(
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
&run_options[0], profile, arguments));
} else {
std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
repeated_arguments(backend->replica_count(), arguments);
repeated_arguments(options_.number_of_replicas(), arguments);
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
run_options, repeated_arguments));
@ -626,7 +641,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::vector<string> computation_names;
std::vector<DeviceHandle> device_handles;
if (arg->requests_size() * execute_backend_->replica_count() >
if (arg->requests_size() * options_.number_of_replicas() >
execute_backend_->device_count()) {
return FailedPrecondition(
"there are not enough stream executors to execute %d computations",
@ -723,7 +738,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) {
const int64 available_device_count = execute_backend_->device_count();
const int64 replica_count = execute_backend_->replica_count();
const int64 replica_count = options_.number_of_replicas();
if (replica_count <= 0) {
return FailedPrecondition("Replica count must be a positive integer");
}
@ -948,7 +963,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) {
// TODO(b/32990684): Tuple transfers to host end up allocating further
// buffers - implement that correctly.
return Unimplemented(
@ -988,7 +1003,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) {
const int64 replica_count = execute_backend_->replica_count();
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"%s",
@ -1017,7 +1032,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::TransferFromOutfeed(
const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) {
const int64 replica_count = execute_backend_->replica_count();
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
@ -1428,13 +1443,13 @@ DeviceHandle Service::SingleComputationDeviceHandle() const {
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
const Backend& backend, const DeviceHandle& device_handle) const {
std::vector<perftools::gputools::StreamExecutor*> replicas;
for (int replica = 0; replica < backend.replica_count(); ++replica) {
for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
// From the computation placer, find out the device ids of the replicas for
// the given device handle.
TF_ASSIGN_OR_RETURN(
int device_ordinal,
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
backend.replica_count(),
options_.number_of_replicas(),
device_handle.device_count()));
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
replicas.push_back(executor);

View File

@ -248,7 +248,7 @@ class Service : public ServiceInterface {
// The constructor is private. Use the NewService factory to create new
// service objects.
Service(std::unique_ptr<Backend> backend,
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend);
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
@ -355,6 +355,8 @@ class Service : public ServiceInterface {
// single computation that is not model-parallelized.
DeviceHandle SingleComputationDeviceHandle() const;
ServiceOptions options_;
// Tracks computations built via the API.
ComputationTracker computation_tracker_;