[XLA] Move replica_count out of Backend.
This is an intermediate step in making replica_count more explicitly programmable (rather than set by a flag). There is no reason for the Backend to hold replica_count - it was only holding it as a container with no additional semantics. PiperOrigin-RevId: 159626528
This commit is contained in:
parent
4c13564c9c
commit
35af7113de
@ -342,7 +342,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
@ -386,6 +385,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla:xla_proto",
|
"//tensorflow/compiler/xla:xla_proto",
|
||||||
|
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
|
||||||
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
"//tensorflow/compiler/xla/legacy_flags:service_flags",
|
||||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
|
|||||||
return platform_;
|
return platform_;
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
|
|
||||||
number_of_replicas_ = number_of_replicas;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
|
|
||||||
|
|
||||||
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
|
BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
|
||||||
int num_threads) {
|
int num_threads) {
|
||||||
intra_op_parallelism_threads_ = num_threads;
|
intra_op_parallelism_threads_ = num_threads;
|
||||||
@ -85,11 +77,6 @@ struct Backend::EigenThreadPoolWrapper {
|
|||||||
|
|
||||||
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
|
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
|
||||||
const BackendOptions& options) {
|
const BackendOptions& options) {
|
||||||
int64 replica_count = options.number_of_replicas();
|
|
||||||
if (replica_count == -1) {
|
|
||||||
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
|
|
||||||
replica_count = flags->xla_replicas;
|
|
||||||
}
|
|
||||||
perftools::gputools::Platform* platform = options.platform();
|
perftools::gputools::Platform* platform = options.platform();
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||||
TF_ASSIGN_OR_RETURN(auto stream_executors,
|
TF_ASSIGN_OR_RETURN(auto stream_executors,
|
||||||
@ -98,9 +85,9 @@ struct Backend::EigenThreadPoolWrapper {
|
|||||||
TransferManager::GetForPlatform(platform));
|
TransferManager::GetForPlatform(platform));
|
||||||
TF_ASSIGN_OR_RETURN(auto computation_placer,
|
TF_ASSIGN_OR_RETURN(auto computation_placer,
|
||||||
ComputationPlacer::GetForPlatform(platform));
|
ComputationPlacer::GetForPlatform(platform));
|
||||||
std::unique_ptr<Backend> backend(new Backend(
|
std::unique_ptr<Backend> backend(
|
||||||
replica_count, platform, compiler, stream_executors, transfer_manager,
|
new Backend(platform, compiler, stream_executors, transfer_manager,
|
||||||
computation_placer, options.intra_op_parallelism_threads()));
|
computation_placer, options.intra_op_parallelism_threads()));
|
||||||
return std::move(backend);
|
return std::move(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,36 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Backend::Backend(
|
Backend::Backend(
|
||||||
int64 replica_count, perftools::gputools::Platform* platform,
|
perftools::gputools::Platform* platform, Compiler* compiler,
|
||||||
Compiler* compiler,
|
|
||||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||||
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
||||||
int intra_op_parallelism_threads)
|
int intra_op_parallelism_threads)
|
||||||
: platform_(platform),
|
: platform_(platform),
|
||||||
compiler_(compiler),
|
compiler_(compiler),
|
||||||
transfer_manager_(transfer_manager),
|
transfer_manager_(transfer_manager),
|
||||||
computation_placer_(computation_placer),
|
computation_placer_(computation_placer) {
|
||||||
replica_count_(replica_count) {
|
|
||||||
// The given set of stream executors set may include invalid executors.
|
// The given set of stream executors set may include invalid executors.
|
||||||
for (se::StreamExecutor* exec : stream_executors) {
|
for (se::StreamExecutor* exec : stream_executors) {
|
||||||
if (exec != nullptr) {
|
if (exec != nullptr) {
|
||||||
stream_executors_.push_back(exec);
|
stream_executors_.push_back(exec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
|
|
||||||
|
|
||||||
// Create a memory allocator for the valid stream executors.
|
// Create a memory allocator for the valid stream executors.
|
||||||
memory_allocator_ =
|
memory_allocator_ =
|
||||||
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
|
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
|
||||||
|
|
||||||
// First check that there are some non-null stream executors to avoid issuing
|
|
||||||
// an error mentioning replicas in the common case of requesting just 1
|
|
||||||
// replica, which means no replication.
|
|
||||||
CHECK(!stream_executors_.empty())
|
CHECK(!stream_executors_.empty())
|
||||||
<< "Service found no devices for backend " << platform_->Name() << '.';
|
<< "Service found no devices for backend " << platform_->Name() << '.';
|
||||||
CHECK_GE(stream_executors_.size(), replica_count)
|
|
||||||
<< "Requested more replicas than there are devices for backend "
|
|
||||||
<< platform_->Name() << '.';
|
|
||||||
|
|
||||||
if (platform->id() == se::host::kHostPlatformId) {
|
if (platform->id() == se::host::kHostPlatformId) {
|
||||||
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
|
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
|
||||||
|
@ -41,20 +41,12 @@ struct ThreadPoolDevice;
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Options to configure the backend when it is created.
|
// Options to configure the backend when it is created.
|
||||||
//
|
|
||||||
// TODO(b/62588571): Remove the notion of replicas from the backend.
|
|
||||||
class BackendOptions {
|
class BackendOptions {
|
||||||
public:
|
public:
|
||||||
// Set the platform backing the backend, or nullptr for the default platform.
|
// Set the platform backing the backend, or nullptr for the default platform.
|
||||||
BackendOptions& set_platform(perftools::gputools::Platform* platform);
|
BackendOptions& set_platform(perftools::gputools::Platform* platform);
|
||||||
perftools::gputools::Platform* platform() const;
|
perftools::gputools::Platform* platform() const;
|
||||||
|
|
||||||
// Set the number of replicas to use when compiling replicated
|
|
||||||
// programs. The default is -1 meaning that the value is read from
|
|
||||||
// the xla_replicas flag.
|
|
||||||
BackendOptions& set_number_of_replicas(int number_of_replicas);
|
|
||||||
int number_of_replicas() const;
|
|
||||||
|
|
||||||
// Sets the thread pool size for parallel execution of an individual operator.
|
// Sets the thread pool size for parallel execution of an individual operator.
|
||||||
// The default value of -1 will result in initializing the thread pool with
|
// The default value of -1 will result in initializing the thread pool with
|
||||||
// the number of threads equal to the number of cores in the system.
|
// the number of threads equal to the number of cores in the system.
|
||||||
@ -63,7 +55,6 @@ class BackendOptions {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
perftools::gputools::Platform* platform_ = nullptr;
|
perftools::gputools::Platform* platform_ = nullptr;
|
||||||
int number_of_replicas_ = -1;
|
|
||||||
int intra_op_parallelism_threads_ = -1;
|
int intra_op_parallelism_threads_ = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -77,8 +68,7 @@ class Backend {
|
|||||||
public:
|
public:
|
||||||
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
|
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
|
||||||
|
|
||||||
// Creates a new backend for the given platform with the given number of
|
// Creates a new backend.
|
||||||
// replicas.
|
|
||||||
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
|
static StatusOr<std::unique_ptr<Backend>> CreateBackend(
|
||||||
const BackendOptions& options);
|
const BackendOptions& options);
|
||||||
|
|
||||||
@ -101,9 +91,6 @@ class Backend {
|
|||||||
// all of these devices may be usable by XLA.
|
// all of these devices may be usable by XLA.
|
||||||
int device_count() const { return stream_executors_.size(); }
|
int device_count() const { return stream_executors_.size(); }
|
||||||
|
|
||||||
// Returns the number of replicas when replication is enabled.
|
|
||||||
int replica_count() const { return replica_count_; }
|
|
||||||
|
|
||||||
// Returns the device ordinal number of the default device.
|
// Returns the device ordinal number of the default device.
|
||||||
int default_device_ordinal() const;
|
int default_device_ordinal() const;
|
||||||
|
|
||||||
@ -170,8 +157,7 @@ class Backend {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
struct EigenThreadPoolWrapper;
|
struct EigenThreadPoolWrapper;
|
||||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
Backend(perftools::gputools::Platform* platform, Compiler* compiler,
|
||||||
Compiler* compiler,
|
|
||||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
stream_executors,
|
stream_executors,
|
||||||
TransferManager* transfer_manager,
|
TransferManager* transfer_manager,
|
||||||
@ -184,7 +170,6 @@ class Backend {
|
|||||||
Compiler* compiler_;
|
Compiler* compiler_;
|
||||||
TransferManager* transfer_manager_;
|
TransferManager* transfer_manager_;
|
||||||
ComputationPlacer* computation_placer_;
|
ComputationPlacer* computation_placer_;
|
||||||
int64 replica_count_ = -1;
|
|
||||||
|
|
||||||
// Vector of stream executors. stream_executors_[0] is the default executor.
|
// Vector of stream executors. stream_executors_[0] is the default executor.
|
||||||
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
|
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
|
||||||
|
@ -52,14 +52,16 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||||
CreateComputeConstantBackend());
|
CreateComputeConstantBackend());
|
||||||
std::unique_ptr<CompileOnlyService> service(
|
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
|
||||||
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
|
options, compiler, std::move(compute_constant_backend)));
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
CompileOnlyService::CompileOnlyService(
|
CompileOnlyService::CompileOnlyService(
|
||||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
|
const ServiceOptions& options, Compiler* compiler,
|
||||||
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
|
std::unique_ptr<Backend> compute_constant_backend)
|
||||||
|
: Service(options, /*backend=*/nullptr,
|
||||||
|
std::move(compute_constant_backend)),
|
||||||
compiler_(compiler) {
|
compiler_(compiler) {
|
||||||
runs_in_client_process_ = true;
|
runs_in_client_process_ = true;
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,8 @@ class CompileOnlyService : public Service {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
explicit CompileOnlyService(
|
explicit CompileOnlyService(
|
||||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
|
const ServiceOptions& options, Compiler* compiler,
|
||||||
|
std::unique_ptr<Backend> compute_constant_backend);
|
||||||
CompileOnlyService(const CompileOnlyService&) = delete;
|
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||||
void operator=(const CompileOnlyService&) = delete;
|
void operator=(const CompileOnlyService&) = delete;
|
||||||
|
|
||||||
|
@ -63,7 +63,6 @@ namespace xla {
|
|||||||
|
|
||||||
BackendOptions backend_options;
|
BackendOptions backend_options;
|
||||||
backend_options.set_platform(platform)
|
backend_options.set_platform(platform)
|
||||||
.set_number_of_replicas(options.number_of_replicas())
|
|
||||||
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
|
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
||||||
Backend::CreateBackend(backend_options));
|
Backend::CreateBackend(backend_options));
|
||||||
@ -71,13 +70,15 @@ namespace xla {
|
|||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||||
CreateComputeConstantBackend());
|
CreateComputeConstantBackend());
|
||||||
std::unique_ptr<LocalService> service(new LocalService(
|
std::unique_ptr<LocalService> service(new LocalService(
|
||||||
std::move(backend), std::move(compute_constant_backend)));
|
options, std::move(backend), std::move(compute_constant_backend)));
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalService::LocalService(std::unique_ptr<Backend> execute_backend,
|
LocalService::LocalService(const ServiceOptions& options,
|
||||||
|
std::unique_ptr<Backend> execute_backend,
|
||||||
std::unique_ptr<Backend> compute_constant_backend)
|
std::unique_ptr<Backend> compute_constant_backend)
|
||||||
: Service(std::move(execute_backend), std::move(compute_constant_backend)) {
|
: Service(options, std::move(execute_backend),
|
||||||
|
std::move(compute_constant_backend)) {
|
||||||
runs_in_client_process_ = true;
|
runs_in_client_process_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,7 +154,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
|||||||
// Construct computation layout from the argument layouts.
|
// Construct computation layout from the argument layouts.
|
||||||
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
||||||
module_config->set_has_hybrid_result(has_hybrid_result);
|
module_config->set_has_hybrid_result(has_hybrid_result);
|
||||||
module_config->set_replica_count(execute_backend_->replica_count());
|
module_config->set_replica_count(options_.number_of_replicas());
|
||||||
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
|
||||||
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
|
||||||
module_config->set_intra_op_parallelism_threads(
|
module_config->set_intra_op_parallelism_threads(
|
||||||
|
@ -60,7 +60,8 @@ class LocalService : public Service {
|
|||||||
const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
|
const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit LocalService(std::unique_ptr<Backend> backend,
|
explicit LocalService(const ServiceOptions& options,
|
||||||
|
std::unique_ptr<Backend> backend,
|
||||||
std::unique_ptr<Backend> compute_constant_backend);
|
std::unique_ptr<Backend> compute_constant_backend);
|
||||||
LocalService(const LocalService&) = delete;
|
LocalService(const LocalService&) = delete;
|
||||||
void operator=(const LocalService&) = delete;
|
void operator=(const LocalService&) = delete;
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
|
||||||
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
|
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const {
|
|||||||
}
|
}
|
||||||
BackendOptions backend_options;
|
BackendOptions backend_options;
|
||||||
backend_options.set_platform(platform);
|
backend_options.set_platform(platform);
|
||||||
backend_options.set_number_of_replicas(options.number_of_replicas());
|
|
||||||
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||||
CreateComputeConstantBackend());
|
CreateComputeConstantBackend());
|
||||||
std::unique_ptr<Service> service(new Service(
|
std::unique_ptr<Service> service(
|
||||||
std::move(execute_backend), std::move(compute_constant_backend)));
|
new Service(options, std::move(execute_backend),
|
||||||
|
std::move(compute_constant_backend)));
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +160,6 @@ Service::CreateComputeConstantBackend() {
|
|||||||
if (platform->id() == se::host::kHostPlatformId) {
|
if (platform->id() == se::host::kHostPlatformId) {
|
||||||
BackendOptions backend_options;
|
BackendOptions backend_options;
|
||||||
backend_options.set_platform(platform);
|
backend_options.set_platform(platform);
|
||||||
backend_options.set_number_of_replicas(1);
|
|
||||||
return Backend::CreateBackend(backend_options);
|
return Backend::CreateBackend(backend_options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,11 +172,24 @@ Service::CreateComputeConstantBackend() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Service::Service(std::unique_ptr<Backend> execute_backend,
|
Service::Service(const ServiceOptions& options,
|
||||||
|
std::unique_ptr<Backend> execute_backend,
|
||||||
std::unique_ptr<Backend> compute_constant_backend)
|
std::unique_ptr<Backend> compute_constant_backend)
|
||||||
: execute_backend_(std::move(execute_backend)),
|
: options_(options),
|
||||||
|
execute_backend_(std::move(execute_backend)),
|
||||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
compute_constant_backend_(std::move(compute_constant_backend)) {
|
||||||
|
// TODO(b/32648682): this flag / options update dance will go away once we
|
||||||
|
// pass the replica count explicitly to the service.
|
||||||
|
if (options_.number_of_replicas() < 0) {
|
||||||
|
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
|
||||||
|
options_.set_number_of_replicas(flags->xla_replicas);
|
||||||
|
}
|
||||||
|
|
||||||
if (execute_backend_) {
|
if (execute_backend_) {
|
||||||
|
if (execute_backend_->device_count() > 0) {
|
||||||
|
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
|
||||||
|
<< "Requested more replicas than there are devices.";
|
||||||
|
}
|
||||||
LOG(INFO) << Printf(
|
LOG(INFO) << Printf(
|
||||||
"XLA service %p executing computations on platform %s. Devices:", this,
|
"XLA service %p executing computations on platform %s. Devices:", this,
|
||||||
execute_backend_->platform()->Name().c_str());
|
execute_backend_->platform()->Name().c_str());
|
||||||
@ -325,7 +339,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
module_config->enable_hlo_profiling(true);
|
module_config->enable_hlo_profiling(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
module_config->set_replica_count(backend->replica_count());
|
module_config->set_replica_count(options_.number_of_replicas());
|
||||||
module_config->set_seed(execution_options.seed());
|
module_config->set_seed(execution_options.seed());
|
||||||
module_config->set_debug_options(execution_options.debug_options());
|
module_config->set_debug_options(execution_options.debug_options());
|
||||||
|
|
||||||
@ -506,7 +520,7 @@ Service::ExecuteParallelAndRegisterResult(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
backend->computation_placer()->AssignDevices(
|
backend->computation_placer()->AssignDevices(
|
||||||
backend->replica_count(), executables.size()));
|
options_.number_of_replicas(), executables.size()));
|
||||||
|
|
||||||
for (int64 i = 0; i < executables.size(); i++) {
|
for (int64 i = 0; i < executables.size(); i++) {
|
||||||
// Stream executors for the replicas of the current computation.
|
// Stream executors for the replicas of the current computation.
|
||||||
@ -572,7 +586,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
backend->computation_placer()->AssignDevices(
|
backend->computation_placer()->AssignDevices(
|
||||||
backend->replica_count(), /*computation_count=*/1));
|
options_.number_of_replicas(),
|
||||||
|
/*computation_count=*/1));
|
||||||
|
|
||||||
// Set up run options.
|
// Set up run options.
|
||||||
std::vector<ServiceExecutableRunOptions> run_options;
|
std::vector<ServiceExecutableRunOptions> run_options;
|
||||||
@ -589,14 +604,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
|||||||
}
|
}
|
||||||
|
|
||||||
perftools::gputools::DeviceMemoryBase result;
|
perftools::gputools::DeviceMemoryBase result;
|
||||||
if (backend->replica_count() == 1) {
|
if (options_.number_of_replicas() == 1) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
|
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
|
||||||
&run_options[0], profile, arguments));
|
&run_options[0], profile, arguments));
|
||||||
} else {
|
} else {
|
||||||
std::vector<
|
std::vector<
|
||||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
||||||
repeated_arguments(backend->replica_count(), arguments);
|
repeated_arguments(options_.number_of_replicas(), arguments);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
|
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
|
||||||
run_options, repeated_arguments));
|
run_options, repeated_arguments));
|
||||||
@ -626,7 +641,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
|||||||
std::vector<string> computation_names;
|
std::vector<string> computation_names;
|
||||||
std::vector<DeviceHandle> device_handles;
|
std::vector<DeviceHandle> device_handles;
|
||||||
|
|
||||||
if (arg->requests_size() * execute_backend_->replica_count() >
|
if (arg->requests_size() * options_.number_of_replicas() >
|
||||||
execute_backend_->device_count()) {
|
execute_backend_->device_count()) {
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"there are not enough stream executors to execute %d computations",
|
"there are not enough stream executors to execute %d computations",
|
||||||
@ -723,7 +738,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
|||||||
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||||
GetDeviceHandlesResponse* result) {
|
GetDeviceHandlesResponse* result) {
|
||||||
const int64 available_device_count = execute_backend_->device_count();
|
const int64 available_device_count = execute_backend_->device_count();
|
||||||
const int64 replica_count = execute_backend_->replica_count();
|
const int64 replica_count = options_.number_of_replicas();
|
||||||
if (replica_count <= 0) {
|
if (replica_count <= 0) {
|
||||||
return FailedPrecondition("Replica count must be a positive integer");
|
return FailedPrecondition("Replica count must be a positive integer");
|
||||||
}
|
}
|
||||||
@ -948,7 +963,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
|||||||
Literal literal = Literal(arg->literal());
|
Literal literal = Literal(arg->literal());
|
||||||
const Shape& shape = literal.shape();
|
const Shape& shape = literal.shape();
|
||||||
|
|
||||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
|
if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) {
|
||||||
// TODO(b/32990684): Tuple transfers to host end up allocating further
|
// TODO(b/32990684): Tuple transfers to host end up allocating further
|
||||||
// buffers - implement that correctly.
|
// buffers - implement that correctly.
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
@ -988,7 +1003,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
|||||||
|
|
||||||
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||||
TransferToInfeedResponse* result) {
|
TransferToInfeedResponse* result) {
|
||||||
const int64 replica_count = execute_backend_->replica_count();
|
const int64 replica_count = options_.number_of_replicas();
|
||||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"%s",
|
"%s",
|
||||||
@ -1017,7 +1032,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
|||||||
tensorflow::Status Service::TransferFromOutfeed(
|
tensorflow::Status Service::TransferFromOutfeed(
|
||||||
const TransferFromOutfeedRequest* arg,
|
const TransferFromOutfeedRequest* arg,
|
||||||
TransferFromOutfeedResponse* result) {
|
TransferFromOutfeedResponse* result) {
|
||||||
const int64 replica_count = execute_backend_->replica_count();
|
const int64 replica_count = options_.number_of_replicas();
|
||||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
|
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
|
||||||
@ -1428,13 +1443,13 @@ DeviceHandle Service::SingleComputationDeviceHandle() const {
|
|||||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
|
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
|
||||||
const Backend& backend, const DeviceHandle& device_handle) const {
|
const Backend& backend, const DeviceHandle& device_handle) const {
|
||||||
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
||||||
for (int replica = 0; replica < backend.replica_count(); ++replica) {
|
for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
|
||||||
// From the computation placer, find out the device ids of the replicas for
|
// From the computation placer, find out the device ids of the replicas for
|
||||||
// the given device handle.
|
// the given device handle.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
int device_ordinal,
|
int device_ordinal,
|
||||||
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
|
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
|
||||||
backend.replica_count(),
|
options_.number_of_replicas(),
|
||||||
device_handle.device_count()));
|
device_handle.device_count()));
|
||||||
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
|
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
|
||||||
replicas.push_back(executor);
|
replicas.push_back(executor);
|
||||||
|
@ -248,7 +248,7 @@ class Service : public ServiceInterface {
|
|||||||
|
|
||||||
// The constructor is private. Use the NewService factory to create new
|
// The constructor is private. Use the NewService factory to create new
|
||||||
// service objects.
|
// service objects.
|
||||||
Service(std::unique_ptr<Backend> backend,
|
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
|
||||||
std::unique_ptr<Backend> compute_constant_backend);
|
std::unique_ptr<Backend> compute_constant_backend);
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
||||||
@ -355,6 +355,8 @@ class Service : public ServiceInterface {
|
|||||||
// single computation that is not model-parallelized.
|
// single computation that is not model-parallelized.
|
||||||
DeviceHandle SingleComputationDeviceHandle() const;
|
DeviceHandle SingleComputationDeviceHandle() const;
|
||||||
|
|
||||||
|
ServiceOptions options_;
|
||||||
|
|
||||||
// Tracks computations built via the API.
|
// Tracks computations built via the API.
|
||||||
ComputationTracker computation_tracker_;
|
ComputationTracker computation_tracker_;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user