Add ComputationPlacer to assign device ids for replicated model-parallel computations.
PiperOrigin-RevId: 159056198
This commit is contained in:
parent
0fa19543aa
commit
7d3497a639
@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const {
|
||||
return execution_profile_;
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_device_assignment(
|
||||
DeviceAssignment* device_assignment) {
|
||||
device_assignment_ = device_assignment;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DeviceAssignment* ExecutableRunOptions::device_assignment() const {
|
||||
return device_assignment_;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -40,6 +40,7 @@ struct ThreadPoolDevice;
|
||||
namespace xla {
|
||||
|
||||
class DeviceMemoryAllocator;
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
|
||||
// Class containing options for running a LocalExecutable.
|
||||
@ -79,9 +80,14 @@ class ExecutableRunOptions {
|
||||
ExecutionProfile* execution_profile() const;
|
||||
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
|
||||
|
||||
ExecutableRunOptions& set_device_assignment(
|
||||
DeviceAssignment* device_assignment);
|
||||
DeviceAssignment* device_assignment() const;
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
DeviceAssignment* device_assignment_ = nullptr;
|
||||
perftools::gputools::Stream* stream_ = nullptr;
|
||||
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
|
||||
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
||||
|
@ -332,6 +332,7 @@ cc_library(
|
||||
hdrs = ["backend.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":computation_placer",
|
||||
":device_memory_allocator",
|
||||
":platform_util",
|
||||
":pool",
|
||||
@ -950,6 +951,26 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "computation_placer",
|
||||
srcs = ["computation_placer.cc"],
|
||||
hdrs = ["computation_placer.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
alwayslink = True, # Contains per-platform computation placer registration
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "generic_transfer_manager",
|
||||
srcs = ["generic_transfer_manager.cc"],
|
||||
|
@ -96,9 +96,11 @@ struct Backend::EigenThreadPoolWrapper {
|
||||
PlatformUtil::GetStreamExecutors(platform));
|
||||
TF_ASSIGN_OR_RETURN(auto transfer_manager,
|
||||
TransferManager::GetForPlatform(platform));
|
||||
std::unique_ptr<Backend> backend(
|
||||
new Backend(replica_count, platform, compiler, stream_executors,
|
||||
transfer_manager, options.intra_op_parallelism_threads()));
|
||||
TF_ASSIGN_OR_RETURN(auto computation_placer,
|
||||
ComputationPlacer::GetForPlatform(platform));
|
||||
std::unique_ptr<Backend> backend(new Backend(
|
||||
replica_count, platform, compiler, stream_executors, transfer_manager,
|
||||
computation_placer, options.intra_op_parallelism_threads()));
|
||||
return std::move(backend);
|
||||
}
|
||||
|
||||
@ -135,10 +137,12 @@ Backend::Backend(
|
||||
int64 replica_count, perftools::gputools::Platform* platform,
|
||||
Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||
TransferManager* transfer_manager, int intra_op_parallelism_threads)
|
||||
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads)
|
||||
: platform_(platform),
|
||||
compiler_(compiler),
|
||||
transfer_manager_(transfer_manager),
|
||||
computation_placer_(computation_placer),
|
||||
replica_count_(replica_count) {
|
||||
// The given set of stream executors set may include invalid executors.
|
||||
for (se::StreamExecutor* exec : stream_executors) {
|
||||
@ -179,36 +183,6 @@ int Backend::default_device_ordinal() const {
|
||||
return default_stream_executor()->device_ordinal();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
|
||||
int device_ordinal) const {
|
||||
if (stream_executors_[device_ordinal] == nullptr) {
|
||||
return InvalidArgument("device %s not supported by XLA service",
|
||||
device_name(device_ordinal).c_str());
|
||||
}
|
||||
|
||||
// Find replica_count_ stream executors starting from the given device
|
||||
// ordinal.
|
||||
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
||||
for (se::StreamExecutor* exec : stream_executors_) {
|
||||
CHECK(exec != nullptr);
|
||||
if (exec->device_ordinal() >= device_ordinal) {
|
||||
replicas.push_back(exec);
|
||||
if (replicas.size() >= replica_count_) {
|
||||
return replicas;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return InvalidArgument(
|
||||
"Not enough devices for replicas for the device ordinal %d",
|
||||
device_ordinal);
|
||||
}
|
||||
|
||||
std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
|
||||
CHECK_GE(stream_executors_.size(), replica_count_);
|
||||
return Replicas(default_device_ordinal()).ValueOrDie();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
|
||||
return inter_op_thread_pool_.get();
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/pool.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
@ -40,6 +41,8 @@ struct ThreadPoolDevice;
|
||||
namespace xla {
|
||||
|
||||
// Options to configure the backend when it is created.
|
||||
//
|
||||
// TODO(b/62588571): Remove the notion of replicas from the backend.
|
||||
class BackendOptions {
|
||||
public:
|
||||
// Set the platform backing the backend, or nullptr for the default platform.
|
||||
@ -92,11 +95,15 @@ class Backend {
|
||||
return memory_allocator_.get();
|
||||
}
|
||||
TransferManager* transfer_manager() const { return transfer_manager_; }
|
||||
ComputationPlacer* computation_placer() const { return computation_placer_; }
|
||||
|
||||
// Returns the number of devices of the platform type which are visible. Not
|
||||
// all of these devices may be usable by XLA.
|
||||
int device_count() const { return stream_executors_.size(); }
|
||||
|
||||
// Returns the number of replicas when replication is enabled.
|
||||
int replica_count() const { return replica_count_; }
|
||||
|
||||
// Returns the device ordinal number of the default device.
|
||||
int default_device_ordinal() const;
|
||||
|
||||
@ -107,24 +114,13 @@ class Backend {
|
||||
return stream_executors_;
|
||||
}
|
||||
|
||||
// Returns the replicas for the default stream executor.
|
||||
//
|
||||
// When the number of replicas is R, the first R stream executors are assigned
|
||||
// to the replicas of the default stream executor.
|
||||
std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
|
||||
|
||||
// Returns the replicas for the given device_ordinal. The given device ordinal
|
||||
// is considered to be the first device ordinal among the replicas. Returns an
|
||||
// error status if the stream executor for the given given device ordinal does
|
||||
// not exist or if there are not enough stream executors for the replicas.
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
|
||||
int device_ordinal) const;
|
||||
|
||||
// Return the stream executor for the given device ordinal.
|
||||
// Returns the stream executor for the given device ordinal.
|
||||
StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
|
||||
int device_ordinal) const;
|
||||
|
||||
// Return the stream executor for the default device ordinal.
|
||||
// Returns the stream executor for the default device ordinal. This stream
|
||||
// executor can only be used when the number of computations is 1 (replication
|
||||
// can be > 1).
|
||||
perftools::gputools::StreamExecutor* default_stream_executor() const {
|
||||
CHECK(!stream_executors_.empty());
|
||||
return stream_executors_[0];
|
||||
@ -178,13 +174,16 @@ class Backend {
|
||||
Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
stream_executors,
|
||||
TransferManager* transfer_manager, int intra_op_parallelism_threads);
|
||||
TransferManager* transfer_manager,
|
||||
ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads);
|
||||
Backend(const Backend&) = delete;
|
||||
Backend& operator=(const Backend&) = delete;
|
||||
|
||||
perftools::gputools::Platform* platform_;
|
||||
Compiler* compiler_;
|
||||
TransferManager* transfer_manager_;
|
||||
ComputationPlacer* computation_placer_;
|
||||
int64 replica_count_ = -1;
|
||||
|
||||
// Vector of stream executors. stream_executors_[0] is the default executor.
|
||||
|
151
tensorflow/compiler/xla/service/computation_placer.cc
Normal file
151
tensorflow/compiler/xla/service/computation_placer.cc
Normal file
@ -0,0 +1,151 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
|
||||
proto->set_replica_count(replica_count());
|
||||
proto->set_computation_count(computation_count());
|
||||
for (int computation = 0; computation < computation_count(); ++computation) {
|
||||
DeviceAssignmentProto::ComputationDevice* computation_device =
|
||||
proto->add_computation_devices();
|
||||
for (int replica = 0; replica < replica_count(); ++replica) {
|
||||
computation_device->add_replica_device_ids((*this)(replica, computation));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
|
||||
const DeviceAssignmentProto& proto) {
|
||||
TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
|
||||
DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
|
||||
for (int computation = 0; computation < proto.computation_count();
|
||||
++computation) {
|
||||
const auto& computation_device = proto.computation_devices(computation);
|
||||
TF_RET_CHECK(computation_device.replica_device_ids_size() ==
|
||||
proto.replica_count());
|
||||
for (int replica = 0; replica < proto.replica_count(); ++replica) {
|
||||
assignment(replica, computation) =
|
||||
computation_device.replica_device_ids(replica);
|
||||
}
|
||||
}
|
||||
return std::move(assignment);
|
||||
}
|
||||
|
||||
StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
|
||||
int replica_count,
|
||||
int computation_count) {
|
||||
TF_RET_CHECK(replica < replica_count);
|
||||
TF_RET_CHECK(computation < computation_count);
|
||||
|
||||
return computation * replica_count + replica;
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
|
||||
int replica_count, int computation_count) {
|
||||
DeviceAssignment assignment(replica_count, computation_count);
|
||||
for (int replica = 0; replica < replica_count; ++replica) {
|
||||
for (int computation = 0; computation < computation_count; ++computation) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
int device_id,
|
||||
DeviceId(replica, computation, replica_count, computation_count));
|
||||
assignment(replica, computation) = device_id;
|
||||
}
|
||||
}
|
||||
return std::move(assignment);
|
||||
}
|
||||
|
||||
/* static */ void ComputationPlacer::RegisterComputationPlacer(
|
||||
se::Platform::Id platform_id,
|
||||
ComputationPlacerCreationFunction creation_function) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*ComputationPlacer::platform_computation_placer_mutex());
|
||||
auto* computation_placers = GetPlatformComputationPlacers();
|
||||
CHECK(computation_placers->find(platform_id) == computation_placers->end());
|
||||
(*computation_placers)[platform_id].creation_function = creation_function;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
|
||||
const se::Platform* platform) {
|
||||
tensorflow::mutex_lock lock(
|
||||
*ComputationPlacer::platform_computation_placer_mutex());
|
||||
auto* computation_placers = GetPlatformComputationPlacers();
|
||||
|
||||
auto it = computation_placers->find(platform->id());
|
||||
if (it == computation_placers->end()) {
|
||||
return NotFound(
|
||||
"could not find registered computation placer for platform %s -- check "
|
||||
"target linkage",
|
||||
platform->Name().c_str());
|
||||
}
|
||||
|
||||
if (it->second.placer == nullptr) {
|
||||
// Lazily create the computation placer the first time it is needed.
|
||||
it->second.placer = (*it->second.creation_function)();
|
||||
}
|
||||
|
||||
return it->second.placer.get();
|
||||
}
|
||||
|
||||
/* static */ tensorflow::mutex*
|
||||
ComputationPlacer::platform_computation_placer_mutex() {
|
||||
static tensorflow::mutex* m = new tensorflow::mutex;
|
||||
return m;
|
||||
}
|
||||
|
||||
/* static */ std::map<perftools::gputools::Platform::Id,
|
||||
ComputationPlacer::State>*
|
||||
ComputationPlacer::GetPlatformComputationPlacers() {
|
||||
static auto* r =
|
||||
new std::map<perftools::gputools::Platform::Id, ComputationPlacer::State>;
|
||||
return r;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
|
||||
return xla::MakeUnique<xla::ComputationPlacer>();
|
||||
}
|
||||
|
||||
static bool InitModule() {
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId,
|
||||
&CreateComputationPlacer);
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId,
|
||||
&CreateComputationPlacer);
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
113
tensorflow/compiler/xla/service/computation_placer.h
Normal file
113
tensorflow/compiler/xla/service/computation_placer.h
Normal file
@ -0,0 +1,113 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Class that represents the device assignment for a set of XLA replicated
|
||||
// computations. For R replicas and C computations, R * C devices are required
|
||||
// execute the computation in parallel. The assigned device ids can be accessed
|
||||
// by assignment(replica, computation).
|
||||
class DeviceAssignment : public Array2D<int> {
|
||||
public:
|
||||
DeviceAssignment() {}
|
||||
DeviceAssignment(int replica_count, int computation_count)
|
||||
: Array2D<int>(replica_count, computation_count, -1) {
|
||||
CHECK_GT(replica_count, 0);
|
||||
CHECK_GT(computation_count, 0);
|
||||
}
|
||||
|
||||
int replica_count() const { return height(); }
|
||||
int computation_count() const { return width(); }
|
||||
|
||||
// Protocol buffer serialization and deserialization.
|
||||
Status Serialize(DeviceAssignmentProto* proto) const;
|
||||
static StatusOr<DeviceAssignment> Deserialize(
|
||||
const DeviceAssignmentProto& proto);
|
||||
};
|
||||
|
||||
// A generic implementation of the XLA computation placer, which assigns device
|
||||
// ids to a set of replicated computations.
|
||||
class ComputationPlacer {
|
||||
public:
|
||||
ComputationPlacer() {}
|
||||
virtual ~ComputationPlacer() {}
|
||||
|
||||
// Returns the device id assigned to the given replica and computation
|
||||
// instance for [replica_count x computation_count] setup. The returned device
|
||||
// id must match the assignement from PlaceReplicatedComputation().
|
||||
virtual StatusOr<int> DeviceId(int replica, int computation,
|
||||
int replica_count, int computation_count);
|
||||
|
||||
// Returns the device ids assigned to a set of replicated computations, given
|
||||
// the number of replicas and the number of computations.
|
||||
virtual StatusOr<DeviceAssignment> AssignDevices(int replica_count,
|
||||
int computation_count);
|
||||
|
||||
using ComputationPlacerCreationFunction =
|
||||
std::unique_ptr<ComputationPlacer> (*)();
|
||||
|
||||
// Registers a computation placer creation function for a particular platform.
|
||||
static void RegisterComputationPlacer(
|
||||
perftools::gputools::Platform::Id platform_id,
|
||||
ComputationPlacerCreationFunction creation_function);
|
||||
|
||||
// Returns the computation placer singleton pointer if it is available for the
|
||||
// given platform, or an error status if it is not.
|
||||
static StatusOr<ComputationPlacer*> GetForPlatform(
|
||||
const perftools::gputools::Platform* platform);
|
||||
|
||||
private:
|
||||
// Routine that returns the mutex that guards the platform-to-computation
|
||||
// placer map. Done as a routine to ensure correct initialization ordering,
|
||||
// since RegisterComputationPlacer can be called during program initialization
|
||||
// time.
|
||||
static tensorflow::mutex* platform_computation_placer_mutex();
|
||||
|
||||
// State kept for each kind of ComputationPlacer. Registration functions set
|
||||
// up creation_function, and then we use that to lazily create "placer" the
|
||||
// first time GetForPlatform is invoked for a particular id.
|
||||
struct State {
|
||||
std::unique_ptr<ComputationPlacer> placer;
|
||||
ComputationPlacerCreationFunction creation_function = nullptr;
|
||||
};
|
||||
|
||||
// Map from platform kind to computation placer singleton.
|
||||
static std::map<perftools::gputools::Platform::Id, State>*
|
||||
GetPlatformComputationPlacers();
|
||||
|
||||
perftools::gputools::Platform::Id platform_id_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_
|
@ -152,7 +152,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
// Construct computation layout from the argument layouts.
|
||||
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
|
||||
module_config->set_has_hybrid_result(has_hybrid_result);
|
||||
module_config->set_replica_count(execute_backend_->Replicas().size());
|
||||
module_config->set_replica_count(execute_backend_->replica_count());
|
||||
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
||||
if (flags->xla_hlo_profile) {
|
||||
module_config->enable_hlo_profiling(true);
|
||||
|
@ -325,7 +325,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
|
||||
module_config->set_replica_count(backend->Replicas().size());
|
||||
module_config->set_replica_count(backend->replica_count());
|
||||
module_config->set_seed(execution_options.seed());
|
||||
module_config->set_debug_options(execution_options.debug_options());
|
||||
|
||||
@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
tensorflow::gtl::ArraySlice<
|
||||
std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
arguments,
|
||||
Backend* backend,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors,
|
||||
Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
|
||||
tensorflow::gtl::ArraySlice<string> result_tags) {
|
||||
// TODO(b/33943292): Support for replication when using multiple computations.
|
||||
TF_RET_CHECK(backend->Replicas().size() == 1);
|
||||
|
||||
// Set up streams.
|
||||
// Streams where the computation are launched, so we can wait on the streams
|
||||
// to complete.
|
||||
std::vector<Pool<se::Stream>::SmartPtr> streams;
|
||||
|
||||
for (se::StreamExecutor* executor : executors) {
|
||||
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
|
||||
backend->BorrowStream(executor));
|
||||
streams.push_back(std::move(stream));
|
||||
}
|
||||
|
||||
// Set up run options.
|
||||
std::vector<ServiceExecutableRunOptions> run_options;
|
||||
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
|
||||
ExecutableRunOptions options;
|
||||
options.set_stream(stream.get());
|
||||
options.set_allocator(backend->memory_allocator());
|
||||
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
|
||||
options.set_intra_op_thread_pool(
|
||||
backend->eigen_intra_op_thread_pool_device());
|
||||
run_options.emplace_back(options, backend->StreamBorrower());
|
||||
}
|
||||
|
||||
// Asynchronously launch all executables.
|
||||
// Global data handles for the computation results, one for each computation.
|
||||
std::vector<GlobalDataHandle> result_handles;
|
||||
for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0;
|
||||
i < executables.size(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
perftools::gputools::DeviceMemoryBase result,
|
||||
executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i]));
|
||||
result_handles.push_back(allocation_tracker_.Register(
|
||||
backend, executors[i]->device_ordinal(), result,
|
||||
executables[i]->result_shape(), result_tags[i]));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
backend->computation_placer()->AssignDevices(
|
||||
backend->replica_count(), executables.size()));
|
||||
|
||||
for (int64 i = 0; i < executables.size(); i++) {
|
||||
// Stream executors for the replicas of the current computation.
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
|
||||
for (int64 replica = 0; replica < replicas.size(); ++replica) {
|
||||
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
|
||||
backend->BorrowStream(replicas[replica]));
|
||||
streams.push_back(std::move(stream));
|
||||
|
||||
// Set up run options.
|
||||
ExecutableRunOptions options;
|
||||
options.set_stream(streams.back().get());
|
||||
options.set_allocator(backend->memory_allocator());
|
||||
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
|
||||
options.set_intra_op_thread_pool(
|
||||
backend->eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment);
|
||||
ServiceExecutableRunOptions run_options(options,
|
||||
backend->StreamBorrower());
|
||||
|
||||
// Asynchronously launch the computation.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
perftools::gputools::DeviceMemoryBase result,
|
||||
executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i]));
|
||||
|
||||
// All replicas share the same device address for the result allocation,
|
||||
// so only one of the replicas need to register the result handle.
|
||||
if (replica == 0) {
|
||||
result_handles.push_back(allocation_tracker_.Register(
|
||||
backend, replicas[0]->device_ordinal(), result,
|
||||
executables[i]->result_shape(), result_tags[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all executions to complete.
|
||||
for (int64 i = 0; i < result_handles.size(); ++i) {
|
||||
for (int64 i = 0; i < streams.size(); ++i) {
|
||||
if (!streams[i]->BlockHostUntilDone()) {
|
||||
return InternalError("failed to complete execution for stream %lld", i);
|
||||
}
|
||||
@ -550,17 +558,22 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
||||
arguments,
|
||||
Backend* backend, perftools::gputools::StreamExecutor* executor,
|
||||
const string& result_tag, ExecutionProfile* profile) {
|
||||
TF_RET_CHECK(!backend->Replicas().empty());
|
||||
|
||||
// Set up streams.
|
||||
std::vector<Pool<se::Stream>::SmartPtr> streams;
|
||||
|
||||
for (se::StreamExecutor* executor : backend->Replicas()) {
|
||||
TF_ASSIGN_OR_RETURN(auto replicas,
|
||||
Replicas(*backend, SingleComputationDeviceHandle()));
|
||||
TF_RET_CHECK(!replicas.empty());
|
||||
for (se::StreamExecutor* executor : replicas) {
|
||||
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
|
||||
backend->BorrowStream(executor));
|
||||
streams.push_back(std::move(stream));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
backend->computation_placer()->AssignDevices(
|
||||
backend->replica_count(), /*computation_count=*/1));
|
||||
|
||||
// Set up run options.
|
||||
std::vector<ServiceExecutableRunOptions> run_options;
|
||||
for (const Pool<se::Stream>::SmartPtr& stream : streams) {
|
||||
@ -570,19 +583,20 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
|
||||
options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
|
||||
options.set_intra_op_thread_pool(
|
||||
backend->eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment);
|
||||
run_options.emplace_back(options, backend->StreamBorrower(),
|
||||
backend->inter_op_thread_pool());
|
||||
}
|
||||
|
||||
perftools::gputools::DeviceMemoryBase result;
|
||||
if (backend->Replicas().size() == 1) {
|
||||
if (backend->replica_count() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result, executable->ExecuteOnStreamWrapper<se::DeviceMemoryBase>(
|
||||
&run_options[0], profile, arguments));
|
||||
} else {
|
||||
std::vector<
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
|
||||
repeated_arguments(backend->Replicas().size(), arguments);
|
||||
repeated_arguments(backend->replica_count(), arguments);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
|
||||
run_options, repeated_arguments));
|
||||
@ -610,25 +624,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
std::vector<VersionedComputationHandle> versioned_handles;
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||
std::vector<string> computation_names;
|
||||
std::vector<DeviceHandle> device_handles;
|
||||
|
||||
if (arg->requests_size() > execute_backend_->stream_executors().size()) {
|
||||
if (arg->requests_size() * execute_backend_->replica_count() >
|
||||
execute_backend_->device_count()) {
|
||||
return FailedPrecondition(
|
||||
"there are not enough stream executors to execute %d computations",
|
||||
arg->requests_size());
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < arg->requests_size(); ++i) {
|
||||
// Get the stream executor on which the computation will run. Select the
|
||||
// specific device if requested, otherwise select the i'th device from the
|
||||
// list of available stream executors.
|
||||
se::StreamExecutor* executor;
|
||||
if (arg->requests(i).has_device_handle()) {
|
||||
executor =
|
||||
execute_backend_
|
||||
->stream_executors()[arg->requests(i).device_handle().handle()];
|
||||
} else {
|
||||
executor = execute_backend_->stream_executors()[i];
|
||||
// Get the stream executor for the i'th computation. This stream executor
|
||||
// is one of the executors to run the replicated computation.
|
||||
if (!arg->requests(i).has_device_handle()) {
|
||||
return FailedPrecondition(
|
||||
"device handles must be given to execute parallel computations");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replicas,
|
||||
Replicas(*execute_backend_, arg->requests(i).device_handle()));
|
||||
se::StreamExecutor* executor = replicas[0];
|
||||
CHECK(executor != nullptr);
|
||||
|
||||
// Resolve the UserComputation object associated with the requested
|
||||
@ -673,6 +688,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
module_configs.push_back(std::move(module_config));
|
||||
computation_names.push_back(user_computation->name());
|
||||
executors.push_back(executor);
|
||||
device_handles.push_back(arg->requests(i).device_handle());
|
||||
}
|
||||
|
||||
// Build the user computations into HloModules and compile to generate the
|
||||
@ -692,7 +708,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<GlobalDataHandle> outputs,
|
||||
ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
|
||||
execute_backend_.get(), executors,
|
||||
execute_backend_.get(), device_handles,
|
||||
computation_names));
|
||||
for (const GlobalDataHandle& output : outputs) {
|
||||
ExecuteResponse response;
|
||||
@ -706,10 +722,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
|
||||
tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
GetDeviceHandlesResponse* result) {
|
||||
const int64 available_device_count =
|
||||
execute_backend_->stream_executors().size();
|
||||
const int64 replicas = execute_backend_->Replicas().size();
|
||||
if (available_device_count < arg->device_count() * replicas) {
|
||||
const int64 available_device_count = execute_backend_->device_count();
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
if (replica_count <= 0) {
|
||||
return FailedPrecondition("Replica count must be a positive integer");
|
||||
}
|
||||
if (available_device_count < arg->device_count() * replica_count) {
|
||||
return ResourceExhausted(
|
||||
"Requested device count (%lld) exceeds the number of available devices "
|
||||
"on the target (%lld)",
|
||||
@ -718,8 +736,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
|
||||
for (int64 i = 0; i < arg->device_count(); ++i) {
|
||||
DeviceHandle device_handle;
|
||||
device_handle.set_handle(
|
||||
execute_backend_->stream_executors()[i * replicas]->device_ordinal());
|
||||
device_handle.set_handle(i);
|
||||
device_handle.set_device_count(arg->device_count());
|
||||
*result->add_device_handles() = device_handle;
|
||||
}
|
||||
|
||||
@ -841,11 +859,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
|
||||
execute_backend_->default_stream_executor(),
|
||||
&profile));
|
||||
|
||||
TF_RET_CHECK(!execute_backend_->Replicas().empty());
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
|
||||
SingleComputationDeviceHandle()));
|
||||
TF_RET_CHECK(!replicas.empty());
|
||||
|
||||
// Set up streams.
|
||||
std::vector<Pool<se::Stream>::SmartPtr> streams;
|
||||
|
||||
for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
|
||||
for (se::StreamExecutor* executor : replicas) {
|
||||
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
|
||||
execute_backend_->BorrowStream(executor));
|
||||
streams.push_back(std::move(stream));
|
||||
@ -927,19 +948,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
Literal literal = Literal(arg->literal());
|
||||
const Shape& shape = literal.shape();
|
||||
|
||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
|
||||
if (ShapeUtil::IsTuple(shape) && execute_backend_->replica_count() > 1) {
|
||||
// TODO(b/32990684): Tuple transfers to host end up allocating further
|
||||
// buffers - implement that correctly.
|
||||
return Unimplemented(
|
||||
"Tuple transfers to the device not supported with replication.");
|
||||
}
|
||||
|
||||
se::StreamExecutor* stream_executor;
|
||||
std::vector<se::StreamExecutor*> replicas;
|
||||
if (arg->has_device_handle()) {
|
||||
TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
|
||||
arg->device_handle().handle()));
|
||||
TF_ASSIGN_OR_RETURN(replicas,
|
||||
Replicas(*execute_backend_, arg->device_handle()));
|
||||
} else {
|
||||
stream_executor = execute_backend_->default_stream_executor();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
|
||||
}
|
||||
|
||||
// Allocate memory on the device, using the stream executor. The size of the
|
||||
@ -950,14 +972,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
|
||||
execute_backend_->memory_allocator()->Allocate(
|
||||
stream_executor->device_ordinal(), allocation_size));
|
||||
replicas[0]->device_ordinal(), allocation_size));
|
||||
|
||||
*result->mutable_data() = allocation_tracker_.Register(
|
||||
execute_backend_.get(), stream_executor->device_ordinal(), allocation,
|
||||
shape, StrCat("TransferToServer literal of size ", allocation_size));
|
||||
execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape,
|
||||
StrCat("TransferToServer literal of size ", allocation_size));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
|
||||
stream_executor->device_ordinal()));
|
||||
for (se::StreamExecutor* executor : replicas) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralToDevice(
|
||||
@ -968,7 +988,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
|
||||
|
||||
tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
TransferToInfeedResponse* result) {
|
||||
const int64 replica_count = execute_backend_->Replicas().size();
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||
return FailedPrecondition(
|
||||
"%s",
|
||||
@ -980,11 +1000,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
|
||||
se::StreamExecutor* executor;
|
||||
if (arg->has_device_handle()) {
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
|
||||
arg->device_handle().handle()));
|
||||
TF_ASSIGN_OR_RETURN(auto replicas,
|
||||
Replicas(*execute_backend_, arg->device_handle()));
|
||||
executor = replicas[arg->replica_id()];
|
||||
} else {
|
||||
executor = execute_backend_->Replicas()[arg->replica_id()];
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replicas,
|
||||
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
|
||||
executor = replicas[arg->replica_id()];
|
||||
}
|
||||
|
||||
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
|
||||
@ -994,7 +1017,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
tensorflow::Status Service::TransferFromOutfeed(
|
||||
const TransferFromOutfeedRequest* arg,
|
||||
TransferFromOutfeedResponse* result) {
|
||||
const int64 replica_count = execute_backend_->Replicas().size();
|
||||
const int64 replica_count = execute_backend_->replica_count();
|
||||
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
|
||||
return FailedPrecondition(
|
||||
"The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
|
||||
@ -1004,11 +1027,14 @@ tensorflow::Status Service::TransferFromOutfeed(
|
||||
|
||||
se::StreamExecutor* executor;
|
||||
if (arg->has_device_handle()) {
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
|
||||
arg->device_handle().handle()));
|
||||
TF_ASSIGN_OR_RETURN(auto replicas,
|
||||
Replicas(*execute_backend_, arg->device_handle()));
|
||||
executor = replicas[arg->replica_id()];
|
||||
} else {
|
||||
executor = execute_backend_->Replicas()[arg->replica_id()];
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto replicas,
|
||||
Replicas(*execute_backend_, SingleComputationDeviceHandle()));
|
||||
executor = replicas[arg->replica_id()];
|
||||
}
|
||||
|
||||
Literal literal;
|
||||
@ -1387,4 +1413,28 @@ tensorflow::Status Service::LoadComputationSnapshot(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
DeviceHandle Service::SingleComputationDeviceHandle() const {
|
||||
DeviceHandle device_handle;
|
||||
device_handle.set_handle(0);
|
||||
device_handle.set_device_count(1);
|
||||
return device_handle;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
|
||||
const Backend& backend, const DeviceHandle& device_handle) const {
|
||||
std::vector<perftools::gputools::StreamExecutor*> replicas;
|
||||
for (int replica = 0; replica < backend.replica_count(); ++replica) {
|
||||
// From the computation placer, find out the device ids of the replicas for
|
||||
// the given device handle.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
int device_ordinal,
|
||||
backend.computation_placer()->DeviceId(replica, device_handle.handle(),
|
||||
backend.replica_count(),
|
||||
device_handle.device_count()));
|
||||
TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
|
||||
replicas.push_back(executor);
|
||||
}
|
||||
return replicas;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -319,8 +319,7 @@ class Service : public ServiceInterface {
|
||||
std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
arguments,
|
||||
Backend* backend,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executors,
|
||||
tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
|
||||
tensorflow::gtl::ArraySlice<string> result_tags);
|
||||
|
||||
// Returns an HLO dumper for use in the compiler (it refers to flags
|
||||
@ -346,6 +345,16 @@ class Service : public ServiceInterface {
|
||||
tensorflow::Status ValidateResultShapeWithLayout(
|
||||
const Shape& shape_with_layout, const Shape& result_shape) const;
|
||||
|
||||
// Returns the stream executors assigned to the replicas represented by the
|
||||
// given device handle. Each device_handle is a virtual replicated device that
|
||||
// represents a set of physical devices for the replicas.
|
||||
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
|
||||
const Backend& backend, const DeviceHandle& device_handle) const;
|
||||
|
||||
// Returns the device handle that represents the replicated device for a
|
||||
// single computation that is not model-parallelized.
|
||||
DeviceHandle SingleComputationDeviceHandle() const;
|
||||
|
||||
// Tracks computations built via the API.
|
||||
ComputationTracker computation_tracker_;
|
||||
|
||||
|
@ -99,6 +99,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
@ -196,6 +197,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
@ -746,6 +748,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
|
@ -255,11 +255,15 @@ message ComputationDataHandle {
|
||||
int64 handle = 1;
|
||||
}
|
||||
|
||||
// Handle given to a user that represents a device to execute a computation.
|
||||
// When replication is enabled, the device handle represents the device for the
|
||||
// replica id 0.
|
||||
// Handle given to a user that represents a replicated virtual device. Each
|
||||
// replicated device represents N physical devices for execution where N is the
|
||||
// number of replicas.
|
||||
message DeviceHandle {
|
||||
int64 handle = 1;
|
||||
|
||||
// The number of model-parallel virtual devices that communicate via XLA
|
||||
// Send/Recv instructions.
|
||||
int64 device_count = 2;
|
||||
}
|
||||
|
||||
// Handle given to a user to represent a channel between two computations
|
||||
@ -269,6 +273,21 @@ message ChannelHandle {
|
||||
int64 handle = 1;
|
||||
}
|
||||
|
||||
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
|
||||
// represents the device ids assigned to a set of replicated computations.
|
||||
// See xla::DeviceAssignment class comment for more details.
|
||||
message DeviceAssignmentProto {
|
||||
int32 replica_count = 1;
|
||||
int32 computation_count = 2;
|
||||
|
||||
// Each logical computation runs on replica_count physical devices.
|
||||
// ComputationDevice represents the device ids assinged to the replicas.
|
||||
message ComputationDevice {
|
||||
repeated int32 replica_device_ids = 1;
|
||||
}
|
||||
repeated ComputationDevice computation_devices = 3;
|
||||
}
|
||||
|
||||
// Literals are used when the server and client need to exchange materialized
|
||||
// data / results. Literals are also used to describe constants used in
|
||||
// computations.
|
||||
|
Loading…
Reference in New Issue
Block a user