- Fix PjRt GPU Client to allow num_partitions > 1
- Fix XLA GPU backend to allow num_partitions > 1 PiperOrigin-RevId: 354998179 Change-Id: Ief04993252d80e3edf04a91f4523edd25d8b3102
This commit is contained in:
parent
c455215395
commit
761dc221ce
tensorflow/compiler/xla
@ -46,9 +46,7 @@ class GpuClient : public xla::PjRtStreamExecutorClient {
|
||||
|
||||
xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
// XLA:GPU does not support multiple partitions yet.
|
||||
TF_RET_CHECK(num_partitions == 1) << num_partitions;
|
||||
if (num_replicas <= addressable_devices().size()) {
|
||||
if (num_partitions == 1 && num_replicas <= addressable_devices().size()) {
|
||||
xla::DeviceAssignment assignment(num_replicas, 1);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
assignment(i, 0) = addressable_devices().at(i)->id();
|
||||
|
@ -104,6 +104,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
@ -186,6 +187,19 @@ CpuCompiler::CpuCompiler() {
|
||||
(void)llvm_initialized;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
const CompileOptions& options) {
|
||||
for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
|
||||
if (se_vector.size() != 1) {
|
||||
return Unimplemented(
|
||||
"Model partitioning not implemented for the CPU compiler");
|
||||
}
|
||||
}
|
||||
return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
|
||||
}
|
||||
|
||||
/* static */ void CpuCompiler::InitializeLLVMTarget() {
|
||||
// Initialize LLVM's MC layer for the native target.
|
||||
llvm::InitializeNativeTarget();
|
||||
|
@ -125,12 +125,10 @@ class CpuCompiler : public LLVMCompiler {
|
||||
CpuCompiler();
|
||||
~CpuCompiler() override {}
|
||||
|
||||
// Bring in
|
||||
// StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
// std::vector<std::unique_ptr<HloModule>> modules,
|
||||
// std::vector<std::vector<se::StreamExecutor*>>
|
||||
// stream_execs)
|
||||
using LLVMCompiler::Compile;
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::unique_ptr<HloModuleGroup> module_group,
|
||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||
const CompileOptions& options) override;
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
|
||||
|
||||
#include "tensorflow/core/platform/denormal.h"
|
||||
|
||||
#ifdef __FAST_MATH__
|
||||
@ -41,11 +42,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules =
|
||||
module_group->ConsumeModules();
|
||||
for (size_t i = 0; i < modules.size(); i++) {
|
||||
if (stream_execs[i].size() != 1) {
|
||||
return Unimplemented(
|
||||
"Model partitioning not implemented for the CPU/GPU compilers!");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(modules[i],
|
||||
RunHloPasses(std::move(modules[i]), stream_execs[i][0],
|
||||
options.device_allocator));
|
||||
|
Loading…
Reference in New Issue
Block a user