[XLA] Move replica_count out of Backend.
This is an intermediate step in making replica_count more explicitly programmable (rather than set by a flag). There is no reason for the Backend to hold replica_count - it was only holding it as a container with no additional semantics. PiperOrigin-RevId: 159626528
This commit is contained in:
parent
4c13564c9c
commit
35af7113de
@ -342,7 +342,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
@ -386,6 +385,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
|
||||
return platform_;
|
||||
}
|
||||
|
||||
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
|
||||
number_of_replicas_ = number_of_replicas;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
|
||||
|
||||
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
|
||||
int num_threads) {
|
||||
intra_op_parallelism_threads_ = num_threads;
|
||||
@ -85,11 +77,6 @@ struct Backend::EigenThreadPoolWrapper {
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
|
||||
const BackendOptions& options) {
|
||||
int64 replica_count = options.number_of_replicas();
|
||||
if (replica_count == -1) {
|
||||
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
|
||||
replica_count = flags->xla_replicas;
|
||||
}
|
||||
perftools::gputools::Platform* platform = options.platform();
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto stream_executors,
|
||||
@ -98,9 +85,9 @@ struct Backend::EigenThreadPoolWrapper {
|
||||
TransferManager::GetForPlatform(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto computation_placer,
|
||||
ComputationPlacer::GetForPlatform(platform));
|
||||
std::unique_ptr<Backend> backend(new Backend(
|
||||
replica_count, platform, compiler, stream_executors, transfer_manager,
|
||||
computation_placer, options.intra_op_parallelism_threads()));
|
||||
std::unique_ptr<Backend> backend(
|
||||
new Backend(platform, compiler, stream_executors, transfer_manager,
|
||||
computation_placer, options.intra_op_parallelism_threads()));
|
||||
return std::move(backend);
|
||||
}
|
||||
|
||||
@ -134,36 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
|
||||
}
|
||||
|
||||
Backend::Backend(
|
||||
int64 replica_count, perftools::gputools::Platform* platform,
|
||||
Compiler* compiler,
|
||||
perftools::gputools::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads)
|
||||
: platform_(platform),
|
||||
compiler_(compiler),
|
||||
transfer_manager_(transfer_manager),
|
||||
computation_placer_(computation_placer),
|
||||
replica_count_(replica_count) {
|
||||
computation_placer_(computation_placer) {
|
||||
// The given set of stream executors set may include invalid executors.
|
||||
for (se::StreamExecutor* exec : stream_executors) {
|
||||
if (exec != nullptr) {
|
||||
stream_executors_.push_back(exec);
|
||||
}
|
||||
}
|
||||
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
|
||||
|
||||
// Create a memory allocator for the valid stream executors.
|
||||
memory_allocator_ =
|
||||
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
|
||||
|
||||
// First check that there are some non-null stream executors to avoid issuing
|
||||
// an error mentioning replicas in the common case of requesting just 1
|
||||
// replica, which means no replication.
|
||||
CHECK(!stream_executors_.empty())
|
||||
<< "Service found no devices for backend " << platform_->Name() << '.';
|
||||
CHECK_GE(stream_executors_.size(), replica_count)
|
||||
<< "Requested more replicas than there are devices for backend "
|
||||
<< platform_->Name() << '.';
|
||||
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
|
||||
|
@ -41,20 +41,12 @@ struct ThreadPoolDevice;
|
||||
namespace xla {
|
||||
|
||||
// Options to configure the backend when it is created.
|
||||
//
|
||||
// TODO(b/62588571): Remove the notion of replicas from the backend.
|
||||
class BackendOptions {
|
||||
public:
|
||||
// Set the platform backing the backend, or nullptr for the default platform.
|
||||
BackendOptions& set_platform(perftools::gputools::Platform* platform);
|
||||
perftools::gputools::Platform* platform() const;
|
||||
|
||||
// Set the number of replicas to use when compiling replicated
|
||||
// programs. The default is -1 meaning that the value is read from
|
||||
// the xla_replicas flag.
|
||||
BackendOptions& set_number_of_replicas(int number_of_replicas);
|
||||
int number_of_replicas() const;
|
||||
|
||||
// Sets the thread pool size for parallel execution of an individual operator.
|
||||
// The default value of -1 will result in initializing the thread pool with
|
||||
// the number of threads equal to the number of cores in the system.
|
||||
@ -63,7 +55,6 @@ class BackendOptions {
|
||||
|
||||
private:
|
||||
perftools::gputools::Platform* platform_ = nullptr;
|
||||
int number_of_replicas_ = -1;
|
||||
int intra_op_parallelism_threads_ = -1;
|
||||
};
|
||||
|
||||
@ -77,8 +68,7 @@ class Backend {
|
||||
public:
|
||||
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
|
||||
|
||||
// Creates a new backend for the given platform with the given number of
|
||||
// replicas.
|
||||
// Creates a new backend.
|
||||
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
|
||||
const BackendOptions& options);
|
||||
|
||||
@ -101,9 +91,6 @@ class Backend {
|
||||
// all of these devices may be usable by XLA.
|
||||
int device_count() const { return stream_executors_.size(); }
|
||||
|
||||
// Returns the number of replicas when replication is enabled.
|
||||
int replica_count() const { return replica_count_; }
|
||||
|
||||
// Returns the device ordinal number of the default device.
|
||||
int default_device_ordinal() const;
|
||||
|
||||
@ -170,8 +157,7 @@ class Backend {
|
||||
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
||||
Compiler* compiler,
|
||||
Backend(perftools::gputools::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
stream_executors,
|
||||
TransferManager* transfer_manager,
|
||||
@ -184,7 +170,6 @@ class Backend {
|
||||
Compiler* compiler_;
|
||||
TransferManager* transfer_manager_;
|
||||
ComputationPlacer* computation_placer_;
|
||||
int64 replica_count_ = -1;
|
||||
|
||||
// Vector of stream executors. stream_executors_[0] is the default executor.
|
||||
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
|
||||
|
@ -52,14 +52,16 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<CompileOnlyService> service(
|
||||
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
|
||||
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
|
||||
options, compiler, std::move(compute_constant_backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
CompileOnlyService::CompileOnlyService(
|
||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
|
||||
const ServiceOptions& options, Compiler* compiler,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(options, /*backend=*/nullptr,
|
||||
std::move(compute_constant_backend)),
|
||||
compiler_(compiler) {
|
||||
runs_in_client_process_ = true;
|
||||
}
|
||||
|
@ -103,7 +103,8 @@ class CompileOnlyService : public Service {
|
||||
|
||||
private:
|
||||
explicit CompileOnlyService(
|
||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
|
||||
const ServiceOptions& options, Compiler* compiler,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||
void operator=(const CompileOnlyService&) = delete;
|
||||
|
||||
|
@ -63,7 +63,6 @@ namespace xla {
|
||||
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform)
|
||||
.set_number_of_replicas(options.number_of_replicas())
|
||||
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
||||
Backend::CreateBackend(backend_options));
|
||||
@ -71,13 +70,15 @@ namespace xla {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<LocalService> service(new LocalService(
|
||||
std::move(backend), std::move(compute_constant_backend)));
|
||||
options, std::move(backend), std::move(compute_constant_backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
LocalService::LocalService(std::unique_ptr<Backend> execute_backend,
|
||||
LocalService::LocalService(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> execute_backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(std::move(execute_backend), std::move(compute_constant_backend)) {
|
||||
: Service(options, std::move(execute_backend),
|
||||
std::move(compute_constant_backend)) {
|
||||
runs_in_client_process_ = true;
|
||||
}
|
||||
|
||||
@ -153,7 +154,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
// Construct computation layout from the argument layouts.
|
||||
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
||||
module_config->set_has_hybrid_result(has_hybrid_result);
|
||||
module_config->set_replica_count(execute_backend_->replica_count());
|
||||
module_config->set_replica_count(options_.number_of_replicas());
|
||||
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
||||
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
||||
module_config->set_intra_op_parallelism_threads(
|
||||
|
@ -60,7 +60,8 @@ class LocalService : public Service {
|
||||
const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
|
||||
|
||||
private:
|
||||
explicit LocalService(std::unique_ptr<Backend> backend,
|
||||
explicit LocalService(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
LocalService(const LocalService&) = delete;
|
||||
void operator=(const LocalService&) = delete;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const {
|
||||
}
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform);
|
||||
backend_options.set_number_of_replicas(options.number_of_replicas());
|
||||
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<Service> service(new Service(
|
||||
std::move(execute_backend), std::move(compute_constant_backend)));
|
||||
std::unique_ptr<Service> service(
|
||||
new Service(options, std::move(execute_backend),
|
||||
std::move(compute_constant_backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
@ -158,7 +160,6 @@ Service::CreateComputeConstantBackend() {
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform);
|
||||
backend_options.set_number_of_replicas(1);
|
||||
return Backend::CreateBackend(backend_options);
|
||||
}
|
||||
}
|
||||
@ -171,11 +172,24 @@ Service::CreateComputeConstantBackend() {
|
||||
};
|
||||
}
|
||||
|
||||
Service::Service(std::unique_ptr<Backend> execute_backend,
|
||||
Service::Service(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> execute_backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: execute_backend_(std::move(execute_backend)),
|
||||
: options_(options),
|
||||
execute_backend_(std::move(execute_backend)),
|
||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
||||
// TODO(b/32648682): this flag / options update dance will go away once we
|
||||
// pass the replica count explicitly to the service.
|
||||
if (options_.number_of_replicas() < 0) {
|
||||
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
|
||||
options_.set_number_of_replicas(flags->xla_replicas);
|
||||
}
|
||||
|
||||
if (execute_backend_) {
|
||||
if (execute_backend_->device_count() > 0) {
|
||||
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
|
||||
<< "Requested more replicas than there are devices.";
|
||||
}
|
||||
LOG(INFO) << Printf(
|
||||
"XLA service %p executing computations on platform %s. Devices:", this,
|
||||
execute_backend_->platform()->Name().c_str());
|
||||
@ -325,7 +339,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
|
||||
module_config->set_replica_count(backend->replica_count());
|
||||
module_config->set_replica_count(options_.number_of_replicas());
|
||||
module_config->set_seed(execution_options.seed());
|
||||
module_config->set_debug_options(execution_options.debug_options());
|
||||
|
||||
@ -506,7 +520,7 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
backend->computation_placer()->AssignDevices(
|
||||
backend->replica_count(), executables.size()));
|
||||
options_.number_of_replicas(), executables.size()));
|
||||
|
||||
for (int64 i = 0; i < executables.size(); i++) {
|
||||
// Stream executors for the replicas of the current computation.
|
||||
@ -572,7 +586,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
backend->computation_placer()->AssignDevices(
|
||||
backend->replica_count(), /*computation_count=*/1));
|
||||
options_.number_of_replicas(),
|
||||
/*computation_count=*/1));
|
||||
|
||||
// Set up run options.
|
||||
std::vector<ServiceExecutableRunOptions> run_options;
|
||||
@ -589,14 +604,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
||||
}
|
||||
|
||||
perftools::gputools::DeviceMemoryBase result;
|
||||
if (backend->replica_count() == 1) {
|
||||
if (options_.number_of_replicas() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
|
||||
&run_options[0], profile, arguments));
|
||||
} else {
|
||||
std::vector<
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
||||
repeated_arguments(backend->replica_count(), arguments);
|
||||
repeated_arguments(options_.number_of_replicas(), arguments);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
|
||||
run_options, repeated_arguments));
|
||||
@ -626,7 +641,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
std::vector<string> computation_names;
|
||||
std::vector<DeviceHandle> device_handles;
|
||||
|
||||
if (arg->requests_size() * execute_backend_->replica_count() >
|
||||
if (arg->requests_size() * options_.number_of_replicas() >
|
||||
execute_backend_->device_count()) {
|
||||
return FailedPrecondition(
|
||||
"there are not enough stream executors to execute %d computations",
|
||||
@ -723,7 +738,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
GetDeviceHandlesResponse* result) {
|
||||
const int64 available_device_count = execute_backend_->device_count();
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
const int64 replica_count = options_.number_of_replicas();
|
||||
if (replica_count <= 0) {
|
||||
return FailedPrecondition("Replica count must be a positive integer");
|
||||
}
|
||||
@ -948,7 +963,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
Literal literal = Literal(arg->literal());
|
||||
const Shape& shape = literal.shape();
|
||||
|
||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
|
||||
if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) {
|
||||
// TODO(b/32990684): Tuple transfers to host end up allocating further
|
||||
// buffers - implement that correctly.
|
||||
return Unimplemented(
|
||||
@ -988,7 +1003,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
|
||||
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
TransferToInfeedResponse* result) {
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
const int64 replica_count = options_.number_of_replicas();
|
||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||
return FailedPrecondition(
|
||||
"%s",
|
||||
@ -1017,7 +1032,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
tensorflow::Status Service::TransferFromOutfeed(
|
||||
const TransferFromOutfeedRequest* arg,
|
||||
TransferFromOutfeedResponse* result) {
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
const int64 replica_count = options_.number_of_replicas();
|
||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||
return FailedPrecondition(
|
||||
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
|
||||
@ -1428,13 +1443,13 @@ DeviceHandle Service::SingleComputationDeviceHandle() const {
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
|
||||
const Backend& backend, const DeviceHandle& device_handle) const {
|
||||
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
||||
for (int replica = 0; replica < backend.replica_count(); ++replica) {
|
||||
for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
|
||||
// From the computation placer, find out the device ids of the replicas for
|
||||
// the given device handle.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
int device_ordinal,
|
||||
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
|
||||
backend.replica_count(),
|
||||
options_.number_of_replicas(),
|
||||
device_handle.device_count()));
|
||||
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
|
||||
replicas.push_back(executor);
|
||||
|
@ -248,7 +248,7 @@ class Service : public ServiceInterface {
|
||||
|
||||
// The constructor is private. Use the NewService factory to create new
|
||||
// service objects.
|
||||
Service(std::unique_ptr<Backend> backend,
|
||||
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
|
||||
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
||||
@ -355,6 +355,8 @@ class Service : public ServiceInterface {
|
||||
// single computation that is not model-parallelized.
|
||||
DeviceHandle SingleComputationDeviceHandle() const;
|
||||
|
||||
ServiceOptions options_;
|
||||
|
||||
// Tracks computations built via the API.
|
||||
ComputationTracker computation_tracker_;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user