From fe52c06bd6e0a380487fc8e2dec3831377403384 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Fri, 12 Oct 2018 13:39:20 -0700 Subject: [PATCH] Rollforward with build fix to MLIR TPU compiler. Also renamed some methods to avoid "hides overloaded virtual function" compilation error which only appears in the "Builder" analysis in Critique. See b/11765370. *** Original change description *** Automated rollback of commit 51f0eb5849be0f9ce20e5eb8370158088711f19d PiperOrigin-RevId: 216914046 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/compile_only_service.cc | 6 +++-- tensorflow/compiler/xla/service/compiler.cc | 4 +-- tensorflow/compiler/xla/service/compiler.h | 27 ++++++++++++++----- .../compiler/xla/service/cpu/cpu_compiler.cc | 7 +++-- .../compiler/xla/service/cpu/cpu_compiler.h | 2 +- .../xla/service/gpu/nvptx_compiler.cc | 5 ++-- .../compiler/xla/service/gpu/nvptx_compiler.h | 2 +- .../compiler/xla/service/hlo_module_group.cc | 5 ++-- .../compiler/xla/service/hlo_module_group.h | 13 ++++++++- .../xla/service/hlo_module_group_test.cc | 2 +- .../xla/service/interpreter/compiler.cc | 23 +++++++++++++--- .../xla/service/interpreter/compiler.h | 11 ++++++-- .../compiler/xla/service/llvm_compiler.cc | 20 +++++++++++++- .../compiler/xla/service/llvm_compiler.h | 11 +++++++- tensorflow/compiler/xla/service/service.cc | 9 ++++--- .../compiler/xla/tests/codegen_test_base.cc | 5 ++-- .../compiler/xla/tests/llvm_compiler_test.cc | 13 ++++----- 18 files changed, 122 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6c3b9764b70..7d03eba800f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -860,6 +860,7 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":hlo_module_group", ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index bd5045b9b91..c9b0e4c08c3 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -103,8 +103,10 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, - metadata); + return compiler_->CompileAheadOfTime( + absl::make_unique(hlo_modules[0]->name(), + absl::MakeSpan(hlo_modules)), + options, metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 687ecafe0c3..80c630c6201 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( - std::vector> modules, + std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata) { if (metadata != nullptr) { @@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime( "Populating AotCompilationMetadata is not implemented on this " "compiler."); } - return CompileAheadOfTime(std::move(modules), options); + return CompileAheadOfTime(std::move(module_group), options); } /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 1fdda31c34a..9ab179303b3 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -135,6 +136,12 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Optimizes a HLO module group, a set of module which runs concurrently on + // multiple devices potentially communicating data between the modules. + virtual Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses @@ -145,12 +152,18 @@ class Compiler { // (not just type of device) indicated by the executor. // // device_allocator is optional; see RunHloPasses. - // - // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially + // communicating data between the modules. + virtual StatusOr>> + RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. @@ -160,7 +173,7 @@ class Compiler { // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; @@ -184,16 +197,16 @@ class Compiler { ComputeDefaultBackendConfig(const HloInstruction& hlo, se::StreamExecutor* executor) const; - // Compiles the HLO module for ahead-of-time execution. This is intended for - // use in static compilation. + // Compiles the HLO module group for ahead-of-time execution. This is + // intended for use in static compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 68c715a086a..da01c0caf2a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -676,9 +676,12 @@ StatusOr> CpuCompiler::RunBackend( } StatusOr>> -CpuCompiler::CompileAheadOfTime(std::vector> modules, +CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(!modules.empty()); + TF_RET_CHECK(!module_group->empty()); + std::vector> modules = + module_group->ConsumeModules(); + std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, modules[0]->config()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index f2af923782d..c67307548dd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override; se::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 829d1499bc8..791d414c915 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -825,9 +825,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -NVPTXCompiler::CompileAheadOfTime( - std::vector> module, - const AotCompilationOptions& options) { +NVPTXCompiler::CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) { return Unimplemented( "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index c4a0b727cd3..f79ae2990ae 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -59,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> module, + CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; se::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc index f9b56ef4643..8999ac9f324 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -17,9 +17,8 @@ limitations under the License. namespace xla { -HloModuleGroup::HloModuleGroup(absl::string_view name, - std::unique_ptr module) - : name_(name) { +HloModuleGroup::HloModuleGroup(std::unique_ptr module) + : name_(module->name()) { push_back(std::move(module)); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h index 7338be8b9c5..7c39cf17815 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -35,7 +35,7 @@ class HloModuleGroup { explicit HloModuleGroup(absl::string_view name) : name_(name) {} // Construct a module group containing a single module. - HloModuleGroup(absl::string_view name, std::unique_ptr module); + explicit HloModuleGroup(std::unique_ptr module); // Construct a module group containing any number of modules. HloModuleGroup(absl::string_view name, @@ -50,11 +50,16 @@ class HloModuleGroup { // Add a module to the back of vector of modules in the group. void push_back(std::unique_ptr module); + // Replaces the existing module at the given index with the given module. The + // existing module is discarded. + void ReplaceModule(int index, std::unique_ptr module); + // Moves all modules from the group into the returned vector. After this // method runs, the module group will be empty. std::vector> ConsumeModules(); string name() const { return name_; } + string ToString() const; // Serialize the module group to/from a proto. @@ -63,6 +68,12 @@ class HloModuleGroup { const HloModuleGroupProto& proto, absl::Span module_configs); + // Returns the number of modules in the module group. + int size() const { return modules_.size(); } + + // Returns true if there are no modules in the module group. + bool empty() const { return modules_.empty(); } + private: string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index b7b12cb72b8..5a9a86af564 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -46,7 +46,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(text)); - HloModuleGroup group(TestName(), std::move(module)); + HloModuleGroup group(std::move(module)); EXPECT_EQ(group.modules().size(), 1); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 7c79eb7d791..26643667c86 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -57,6 +57,12 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } +Status InterpreterCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Module group compilation not supported on Interpreter"); +} + StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { @@ -76,17 +82,26 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } +StatusOr>> +InterpreterCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Module group compilation is not supported on Interpreter."); +} + StatusOr>> InterpreterCompiler::Compile( - std::vector> /*hlo_modules*/, + std::unique_ptr /*module_group*/, std::vector> /*stream_execs*/, DeviceMemoryAllocator* /*device_allocator*/) { - return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Interpreter."); + return Unimplemented( + "Module group compilation is not supported on Interpreter."); } StatusOr>> InterpreterCompiler::CompileAheadOfTime( - std::vector> hlo_modules, + std::unique_ptr module_group, const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Interpreter"); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index e90ae3e8185..d8cb32c0beb 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -46,18 +46,25 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( - std::vector> hlo_modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> hlo_modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index b17c9d50450..d287aa4ec7b 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -21,8 +21,24 @@ limitations under the License. #endif namespace xla { +Status LLVMCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + +StatusOr>> +LLVMCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + StatusOr>> LLVMCompiler::Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: @@ -38,6 +54,8 @@ StatusOr>> LLVMCompiler::Compile( tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; std::vector> result; + std::vector> modules = + module_group->ConsumeModules(); for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { return Unimplemented( diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index f1c623508c5..86abd5da018 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -69,8 +69,17 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, se::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; + StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d290c0eb5df..cb6a9e6707d 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -341,18 +341,19 @@ StatusOr>> Service::BuildExecutables( } CHECK_EQ(module_protos.size(), module_configs.size()); - std::vector> modules; + auto module_group = + absl::make_unique(module_protos[0]->name()); for (int64 i = 0; i < module_protos.size(); ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); - modules.push_back(std::move(module)); + module_group->push_back(std::move(module)); } TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); + backend->compiler()->Compile(std::move(module_group), + std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 022641394f1..fbebe040873 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -32,11 +32,10 @@ StatusOr> CodegenTestBase::CompileToAotCompilationResult( std::unique_ptr hlo_module, const AotCompilationOptions& options) { - std::vector> hlo_modules; - hlo_modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> results, - backend().compiler()->CompileAheadOfTime(std::move(hlo_modules), + backend().compiler()->CompileAheadOfTime(std::move(module_group), options)); return std::move(results.front()); } diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 8d658695576..c622b295094 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -93,15 +93,16 @@ class LLVMCompilerTest : public ::testing::Test { std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); - std::vector> modules; - modules.push_back(hlo_module->Clone()); - modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique("test_module_group"); + module_group->push_back(hlo_module->Clone()); + module_group->push_back(std::move(hlo_module)); std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + EXPECT_IS_OK(compiler->Compile(std::move(module_group), + std::move(executors), /*device_allocator=*/nullptr)); } @@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) { TestCompilerHooks(&compiler); } -TEST_F(CpuCompilerTest, MultiModuleCompilation) { +TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { cpu::CpuCompiler compiler; TestMultiModuleCompilation(&compiler); } -TEST_F(GpuCompilerTest, MultModuleCompilation) { +TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) { gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); }