[XLA] Move replica_count out of Backend.

This is an intermediate step in making replica_count more explicitly
programmable (rather than set by a flag). There is no reason for the Backend to
hold replica_count - it was only holding it as a container with no additional
semantics.

PiperOrigin-RevId: 159626528
This commit is contained in:
Eli Bendersky 2017-06-20 15:48:13 -07:00 committed by TensorFlower Gardener
parent 4c13564c9c
commit 35af7113de
9 changed files with 60 additions and 77 deletions

View File

@ -342,7 +342,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
@ -386,6 +385,7 @@ cc_library(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:backend_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -22,7 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const {
return platform_; return platform_;
} }
BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) {
number_of_replicas_ = number_of_replicas;
return *this;
}
int BackendOptions::number_of_replicas() const { return number_of_replicas_; }
BackendOptions& BackendOptions::set_intra_op_parallelism_threads( BackendOptions& BackendOptions::set_intra_op_parallelism_threads(
int num_threads) { int num_threads) {
intra_op_parallelism_threads_ = num_threads; intra_op_parallelism_threads_ = num_threads;
@ -85,11 +77,6 @@ struct Backend::EigenThreadPoolWrapper {
/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend( /* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
const BackendOptions& options) { const BackendOptions& options) {
int64 replica_count = options.number_of_replicas();
if (replica_count == -1) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
replica_count = flags->xla_replicas;
}
perftools::gputools::Platform* platform = options.platform(); perftools::gputools::Platform* platform = options.platform();
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto stream_executors, TF_ASSIGN_OR_RETURN(auto stream_executors,
@ -98,9 +85,9 @@ struct Backend::EigenThreadPoolWrapper {
TransferManager::GetForPlatform(platform)); TransferManager::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(auto computation_placer, TF_ASSIGN_OR_RETURN(auto computation_placer,
ComputationPlacer::GetForPlatform(platform)); ComputationPlacer::GetForPlatform(platform));
std::unique_ptr<Backend> backend(new Backend( std::unique_ptr<Backend> backend(
replica_count, platform, compiler, stream_executors, transfer_manager, new Backend(platform, compiler, stream_executors, transfer_manager,
computation_placer, options.intra_op_parallelism_threads())); computation_placer, options.intra_op_parallelism_threads()));
return std::move(backend); return std::move(backend);
} }
@ -134,36 +121,25 @@ StatusOr<Backend::StreamPtr> Backend::BorrowStream(
} }
Backend::Backend( Backend::Backend(
int64 replica_count, perftools::gputools::Platform* platform, perftools::gputools::Platform* platform, Compiler* compiler,
Compiler* compiler,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors, tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
TransferManager* transfer_manager, ComputationPlacer* computation_placer, TransferManager* transfer_manager, ComputationPlacer* computation_placer,
int intra_op_parallelism_threads) int intra_op_parallelism_threads)
: platform_(platform), : platform_(platform),
compiler_(compiler), compiler_(compiler),
transfer_manager_(transfer_manager), transfer_manager_(transfer_manager),
computation_placer_(computation_placer), computation_placer_(computation_placer) {
replica_count_(replica_count) {
// The given set of stream executors set may include invalid executors. // The given set of stream executors set may include invalid executors.
for (se::StreamExecutor* exec : stream_executors) { for (se::StreamExecutor* exec : stream_executors) {
if (exec != nullptr) { if (exec != nullptr) {
stream_executors_.push_back(exec); stream_executors_.push_back(exec);
} }
} }
CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
// Create a memory allocator for the valid stream executors. // Create a memory allocator for the valid stream executors.
memory_allocator_ = memory_allocator_ =
MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors); MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
// First check that there are some non-null stream executors to avoid issuing
// an error mentioning replicas in the common case of requesting just 1
// replica, which means no replication.
CHECK(!stream_executors_.empty()) CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.'; << "Service found no devices for backend " << platform_->Name() << '.';
CHECK_GE(stream_executors_.size(), replica_count)
<< "Requested more replicas than there are devices for backend "
<< platform_->Name() << '.';
if (platform->id() == se::host::kHostPlatformId) { if (platform->id() == se::host::kHostPlatformId) {
inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(

View File

@ -41,20 +41,12 @@ struct ThreadPoolDevice;
namespace xla { namespace xla {
// Options to configure the backend when it is created. // Options to configure the backend when it is created.
//
// TODO(b/62588571): Remove the notion of replicas from the backend.
class BackendOptions { class BackendOptions {
public: public:
// Set the platform backing the backend, or nullptr for the default platform. // Set the platform backing the backend, or nullptr for the default platform.
BackendOptions& set_platform(perftools::gputools::Platform* platform); BackendOptions& set_platform(perftools::gputools::Platform* platform);
perftools::gputools::Platform* platform() const; perftools::gputools::Platform* platform() const;
// Set the number of replicas to use when compiling replicated
// programs. The default is -1 meaning that the value is read from
// the xla_replicas flag.
BackendOptions& set_number_of_replicas(int number_of_replicas);
int number_of_replicas() const;
// Sets the thread pool size for parallel execution of an individual operator. // Sets the thread pool size for parallel execution of an individual operator.
// The default value of -1 will result in initializing the thread pool with // The default value of -1 will result in initializing the thread pool with
// the number of threads equal to the number of cores in the system. // the number of threads equal to the number of cores in the system.
@ -63,7 +55,6 @@ class BackendOptions {
private: private:
perftools::gputools::Platform* platform_ = nullptr; perftools::gputools::Platform* platform_ = nullptr;
int number_of_replicas_ = -1;
int intra_op_parallelism_threads_ = -1; int intra_op_parallelism_threads_ = -1;
}; };
@ -77,8 +68,7 @@ class Backend {
public: public:
using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr; using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
// Creates a new backend for the given platform with the given number of // Creates a new backend.
// replicas.
static StatusOr<std::unique_ptr<Backend>> CreateBackend( static StatusOr<std::unique_ptr<Backend>> CreateBackend(
const BackendOptions& options); const BackendOptions& options);
@ -101,9 +91,6 @@ class Backend {
// all of these devices may be usable by XLA. // all of these devices may be usable by XLA.
int device_count() const { return stream_executors_.size(); } int device_count() const { return stream_executors_.size(); }
// Returns the number of replicas when replication is enabled.
int replica_count() const { return replica_count_; }
// Returns the device ordinal number of the default device. // Returns the device ordinal number of the default device.
int default_device_ordinal() const; int default_device_ordinal() const;
@ -170,8 +157,7 @@ class Backend {
private: private:
struct EigenThreadPoolWrapper; struct EigenThreadPoolWrapper;
Backend(int64 replica_count, perftools::gputools::Platform* platform, Backend(perftools::gputools::Platform* platform, Compiler* compiler,
Compiler* compiler,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors, stream_executors,
TransferManager* transfer_manager, TransferManager* transfer_manager,
@ -184,7 +170,6 @@ class Backend {
Compiler* compiler_; Compiler* compiler_;
TransferManager* transfer_manager_; TransferManager* transfer_manager_;
ComputationPlacer* computation_placer_; ComputationPlacer* computation_placer_;
int64 replica_count_ = -1;
// Vector of stream executors. stream_executors_[0] is the default executor. // Vector of stream executors. stream_executors_[0] is the default executor.
std::vector<perftools::gputools::StreamExecutor*> stream_executors_; std::vector<perftools::gputools::StreamExecutor*> stream_executors_;

View File

@ -52,14 +52,16 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend, TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend()); CreateComputeConstantBackend());
std::unique_ptr<CompileOnlyService> service( std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
new CompileOnlyService(compiler, std::move(compute_constant_backend))); options, compiler, std::move(compute_constant_backend)));
return std::move(service); return std::move(service);
} }
CompileOnlyService::CompileOnlyService( CompileOnlyService::CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend) const ServiceOptions& options, Compiler* compiler,
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)), std::unique_ptr<Backend> compute_constant_backend)
: Service(options, /*backend=*/nullptr,
std::move(compute_constant_backend)),
compiler_(compiler) { compiler_(compiler) {
runs_in_client_process_ = true; runs_in_client_process_ = true;
} }

View File

@ -103,7 +103,8 @@ class CompileOnlyService : public Service {
private: private:
explicit CompileOnlyService( explicit CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend); const ServiceOptions& options, Compiler* compiler,
std::unique_ptr<Backend> compute_constant_backend);
CompileOnlyService(const CompileOnlyService&) = delete; CompileOnlyService(const CompileOnlyService&) = delete;
void operator=(const CompileOnlyService&) = delete; void operator=(const CompileOnlyService&) = delete;

View File

@ -63,7 +63,6 @@ namespace xla {
BackendOptions backend_options; BackendOptions backend_options;
backend_options.set_platform(platform) backend_options.set_platform(platform)
.set_number_of_replicas(options.number_of_replicas())
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend, TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
Backend::CreateBackend(backend_options)); Backend::CreateBackend(backend_options));
@ -71,13 +70,15 @@ namespace xla {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend, TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend()); CreateComputeConstantBackend());
std::unique_ptr<LocalService> service(new LocalService( std::unique_ptr<LocalService> service(new LocalService(
std::move(backend), std::move(compute_constant_backend))); options, std::move(backend), std::move(compute_constant_backend)));
return std::move(service); return std::move(service);
} }
LocalService::LocalService(std::unique_ptr<Backend> execute_backend, LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend) std::unique_ptr<Backend> compute_constant_backend)
: Service(std::move(execute_backend), std::move(compute_constant_backend)) { : Service(options, std::move(execute_backend),
std::move(compute_constant_backend)) {
runs_in_client_process_ = true; runs_in_client_process_ = true;
} }
@ -153,7 +154,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Construct computation layout from the argument layouts. // Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape); auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
module_config->set_has_hybrid_result(has_hybrid_result); module_config->set_has_hybrid_result(has_hybrid_result);
module_config->set_replica_count(execute_backend_->replica_count()); module_config->set_replica_count(options_.number_of_replicas());
module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) { if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
module_config->set_intra_op_parallelism_threads( module_config->set_intra_op_parallelism_threads(

View File

@ -60,7 +60,8 @@ class LocalService : public Service {
const Shape* result_layout, int device_ordinal, bool has_hybrid_result); const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
private: private:
explicit LocalService(std::unique_ptr<Backend> backend, explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend); std::unique_ptr<Backend> compute_constant_backend);
LocalService(const LocalService&) = delete; LocalService(const LocalService&) = delete;
void operator=(const LocalService&) = delete; void operator=(const LocalService&) = delete;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const {
} }
BackendOptions backend_options; BackendOptions backend_options;
backend_options.set_platform(platform); backend_options.set_platform(platform);
backend_options.set_number_of_replicas(options.number_of_replicas());
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend, TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend()); CreateComputeConstantBackend());
std::unique_ptr<Service> service(new Service( std::unique_ptr<Service> service(
std::move(execute_backend), std::move(compute_constant_backend))); new Service(options, std::move(execute_backend),
std::move(compute_constant_backend)));
return std::move(service); return std::move(service);
} }
@ -158,7 +160,6 @@ Service::CreateComputeConstantBackend() {
if (platform->id() == se::host::kHostPlatformId) { if (platform->id() == se::host::kHostPlatformId) {
BackendOptions backend_options; BackendOptions backend_options;
backend_options.set_platform(platform); backend_options.set_platform(platform);
backend_options.set_number_of_replicas(1);
return Backend::CreateBackend(backend_options); return Backend::CreateBackend(backend_options);
} }
} }
@ -171,11 +172,24 @@ Service::CreateComputeConstantBackend() {
}; };
} }
Service::Service(std::unique_ptr<Backend> execute_backend, Service::Service(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend) std::unique_ptr<Backend> compute_constant_backend)
: execute_backend_(std::move(execute_backend)), : options_(options),
execute_backend_(std::move(execute_backend)),
compute_constant_backend_(std::move(compute_constant_backend)) { compute_constant_backend_(std::move(compute_constant_backend)) {
// TODO(b/32648682): this flag / options update dance will go away once we
// pass the replica count explicitly to the service.
if (options_.number_of_replicas() < 0) {
legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
options_.set_number_of_replicas(flags->xla_replicas);
}
if (execute_backend_) { if (execute_backend_) {
if (execute_backend_->device_count() > 0) {
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
<< "Requested more replicas than there are devices.";
}
LOG(INFO) << Printf( LOG(INFO) << Printf(
"XLA service %p executing computations on platform %s. Devices:", this, "XLA service %p executing computations on platform %s. Devices:", this,
execute_backend_->platform()->Name().c_str()); execute_backend_->platform()->Name().c_str());
@ -325,7 +339,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
module_config->enable_hlo_profiling(true); module_config->enable_hlo_profiling(true);
} }
module_config->set_replica_count(backend->replica_count()); module_config->set_replica_count(options_.number_of_replicas());
module_config->set_seed(execution_options.seed()); module_config->set_seed(execution_options.seed());
module_config->set_debug_options(execution_options.debug_options()); module_config->set_debug_options(execution_options.debug_options());
@ -506,7 +520,7 @@ Service::ExecuteParallelAndRegisterResult(
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices( backend->computation_placer()->AssignDevices(
backend->replica_count(), executables.size())); options_.number_of_replicas(), executables.size()));
for (int64 i = 0; i < executables.size(); i++) { for (int64 i = 0; i < executables.size(); i++) {
// Stream executors for the replicas of the current computation. // Stream executors for the replicas of the current computation.
@ -572,7 +586,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
backend->computation_placer()->AssignDevices( backend->computation_placer()->AssignDevices(
backend->replica_count(), /*computation_count=*/1)); options_.number_of_replicas(),
/*computation_count=*/1));
// Set up run options. // Set up run options.
std::vector<ServiceExecutableRunOptions> run_options; std::vector<ServiceExecutableRunOptions> run_options;
@ -589,14 +604,14 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
} }
perftools::gputools::DeviceMemoryBase result; perftools::gputools::DeviceMemoryBase result;
if (backend->replica_count() == 1) { if (options_.number_of_replicas() == 1) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>( result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
&run_options[0], profile, arguments)); &run_options[0], profile, arguments));
} else { } else {
std::vector< std::vector<
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>> tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
repeated_arguments(backend->replica_count(), arguments); repeated_arguments(options_.number_of_replicas(), arguments);
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
run_options, repeated_arguments)); run_options, repeated_arguments));
@ -626,7 +641,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
std::vector<string> computation_names; std::vector<string> computation_names;
std::vector<DeviceHandle> device_handles; std::vector<DeviceHandle> device_handles;
if (arg->requests_size() * execute_backend_->replica_count() > if (arg->requests_size() * options_.number_of_replicas() >
execute_backend_->device_count()) { execute_backend_->device_count()) {
return FailedPrecondition( return FailedPrecondition(
"there are not enough stream executors to execute %d computations", "there are not enough stream executors to execute %d computations",
@ -723,7 +738,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) { GetDeviceHandlesResponse* result) {
const int64 available_device_count = execute_backend_->device_count(); const int64 available_device_count = execute_backend_->device_count();
const int64 replica_count = execute_backend_->replica_count(); const int64 replica_count = options_.number_of_replicas();
if (replica_count <= 0) { if (replica_count <= 0) {
return FailedPrecondition("Replica count must be a positive integer"); return FailedPrecondition("Replica count must be a positive integer");
} }
@ -948,7 +963,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
Literal literal = Literal(arg->literal()); Literal literal = Literal(arg->literal());
const Shape& shape = literal.shape(); const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) { if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) {
// TODO(b/32990684): Tuple transfers to host end up allocating further // TODO(b/32990684): Tuple transfers to host end up allocating further
// buffers - implement that correctly. // buffers - implement that correctly.
return Unimplemented( return Unimplemented(
@ -988,7 +1003,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) { TransferToInfeedResponse* result) {
const int64 replica_count = execute_backend_->replica_count(); const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition( return FailedPrecondition(
"%s", "%s",
@ -1017,7 +1032,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::TransferFromOutfeed( tensorflow::Status Service::TransferFromOutfeed(
const TransferFromOutfeedRequest* arg, const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) { TransferFromOutfeedResponse* result) {
const int64 replica_count = execute_backend_->replica_count(); const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition( return FailedPrecondition(
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
@ -1428,13 +1443,13 @@ DeviceHandle Service::SingleComputationDeviceHandle() const {
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas( StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
const Backend& backend, const DeviceHandle& device_handle) const { const Backend& backend, const DeviceHandle& device_handle) const {
std::vector<perftools::gputools::StreamExecutor*> replicas; std::vector<perftools::gputools::StreamExecutor*> replicas;
for (int replica = 0; replica < backend.replica_count(); ++replica) { for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
// From the computation placer, find out the device ids of the replicas for // From the computation placer, find out the device ids of the replicas for
// the given device handle. // the given device handle.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
int device_ordinal, int device_ordinal,
backend.computation_placer()->DeviceId(replica, device_handle.handle(), backend.computation_placer()->DeviceId(replica, device_handle.handle(),
backend.replica_count(), options_.number_of_replicas(),
device_handle.device_count())); device_handle.device_count()));
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
replicas.push_back(executor); replicas.push_back(executor);

View File

@ -248,7 +248,7 @@ class Service : public ServiceInterface {
// The constructor is private. Use the NewService factory to create new // The constructor is private. Use the NewService factory to create new
// service objects. // service objects.
Service(std::unique_ptr<Backend> backend, Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend); std::unique_ptr<Backend> compute_constant_backend);
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend(); static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
@ -355,6 +355,8 @@ class Service : public ServiceInterface {
// single computation that is not model-parallelized. // single computation that is not model-parallelized.
DeviceHandle SingleComputationDeviceHandle() const; DeviceHandle SingleComputationDeviceHandle() const;
ServiceOptions options_;
// Tracks computations built via the API. // Tracks computations built via the API.
ComputationTracker computation_tracker_; ComputationTracker computation_tracker_;