Rollforward with build fix to MLIR TPU compiler.

Also renamed some methods to avoid "hides overloaded virtual function" compilation error which only appears in the "Builder" analysis in Critique. See b/11765370.

*** Original change description ***

Automated rollback of commit 51f0eb5849be0f9ce20e5eb8370158088711f19d

PiperOrigin-RevId: 216914046
This commit is contained in:
Mark Heffernan 2018-10-12 13:39:20 -07:00 committed by TensorFlower Gardener
parent 0608a3f0db
commit fe52c06bd6
18 changed files with 122 additions and 44 deletions

View File

@ -860,6 +860,7 @@ cc_library(
":executable", ":executable",
":hlo", ":hlo",
":hlo_module_config", ":hlo_module_config",
":hlo_module_group",
":logical_buffer", ":logical_buffer",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",

View File

@ -103,8 +103,10 @@ CompileOnlyService::CompileAheadOfTime(
hlo_modules.push_back(std::move(hlo_module)); hlo_modules.push_back(std::move(hlo_module));
} }
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, return compiler_->CompileAheadOfTime(
metadata); absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(),
absl::MakeSpan(hlo_modules)),
options, metadata);
} }
} // namespace xla } // namespace xla

View File

@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo,
// Define a default version where metadata is not used. // Define a default version where metadata is not used.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
Compiler::CompileAheadOfTime( Compiler::CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options, const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) { std::unique_ptr<AotCompilationMetadata>* metadata) {
if (metadata != nullptr) { if (metadata != nullptr) {
@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime(
"Populating AotCompilationMetadata is not implemented on this " "Populating AotCompilationMetadata is not implemented on this "
"compiler."); "compiler.");
} }
return CompileAheadOfTime(std::move(modules), options); return CompileAheadOfTime(std::move(module_group), options);
} }
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>* /* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
@ -135,6 +136,12 @@ class Compiler {
std::unique_ptr<HloModule> module, se::StreamExecutor* executor, std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0; DeviceMemoryAllocator* device_allocator) = 0;
// Optimizes a HLO module group, a set of module which runs concurrently on
// multiple devices potentially communicating data between the modules.
virtual Status RunHloPassesOnModuleGroup(
HloModuleGroup* module_group, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0;
// Compiles the HLO module for execution on a device given by the executor, // Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are // and returns an executable object or an error status. No HLO passes are
// applied to module. Generally a module should be passed through RunHloPasses // applied to module. Generally a module should be passed through RunHloPasses
@ -145,12 +152,18 @@ class Compiler {
// (not just type of device) indicated by the executor. // (not just type of device) indicated by the executor.
// //
// device_allocator is optional; see RunHloPasses. // device_allocator is optional; see RunHloPasses.
//
// Use the overload below to compile computations that run in parallel.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend( virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor, std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0; DeviceMemoryAllocator* device_allocator) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>>
RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially // Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding // communicating data between the modules, and returns a corresponding
// sequence of executable objects. // sequence of executable objects.
@ -160,7 +173,7 @@ class Compiler {
// TODO(b/68666782): Remove this method after adding support for multiple // TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends. // modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec, std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) = 0; DeviceMemoryAllocator* device_allocator) = 0;
@ -184,16 +197,16 @@ class Compiler {
ComputeDefaultBackendConfig(const HloInstruction& hlo, ComputeDefaultBackendConfig(const HloInstruction& hlo,
se::StreamExecutor* executor) const; se::StreamExecutor* executor) const;
// Compiles the HLO module for ahead-of-time execution. This is intended for // Compiles the HLO module group for ahead-of-time execution. This is
// use in static compilation. // intended for use in static compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) = 0; const AotCompilationOptions& options) = 0;
// Similar to CompileAheadOfTime above but AotCompilationMetadata // Similar to CompileAheadOfTime above but AotCompilationMetadata
// has an argument that can be populated during compilation. // has an argument that can be populated during compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options, const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata); std::unique_ptr<AotCompilationMetadata>* metadata);

View File

@ -676,9 +676,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) { const AotCompilationOptions& aot_options) {
TF_RET_CHECK(!modules.empty()); TF_RET_CHECK(!module_group->empty());
std::vector<std::unique_ptr<HloModule>> modules =
module_group->ConsumeModules();
std::call_once(llvm_command_line_options_initialized, std::call_once(llvm_command_line_options_initialized,
&llvm_ir::InitializeLLVMCommandLineOptions, &llvm_ir::InitializeLLVMCommandLineOptions,
modules[0]->config()); modules[0]->config());

View File

@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler {
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) override; const AotCompilationOptions& options) override;
se::Platform::Id PlatformId() const override; se::Platform::Id PlatformId() const override;

View File

@ -825,9 +825,8 @@ std::vector<uint8> NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx,
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
NVPTXCompiler::CompileAheadOfTime( NVPTXCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::unique_ptr<HloModule>> module, const AotCompilationOptions& options) {
const AotCompilationOptions& options) {
return Unimplemented( return Unimplemented(
"not yet implemented: NVPTXCompiler::CompileAheadOfTime"); "not yet implemented: NVPTXCompiler::CompileAheadOfTime");
} }

View File

@ -59,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler {
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
AotCompilationOptions const& options) override; AotCompilationOptions const& options) override;
se::Platform::Id PlatformId() const override; se::Platform::Id PlatformId() const override;

View File

@ -17,9 +17,8 @@ limitations under the License.
namespace xla { namespace xla {
HloModuleGroup::HloModuleGroup(absl::string_view name, HloModuleGroup::HloModuleGroup(std::unique_ptr<HloModule> module)
std::unique_ptr<HloModule> module) : name_(module->name()) {
: name_(name) {
push_back(std::move(module)); push_back(std::move(module));
} }

View File

@ -35,7 +35,7 @@ class HloModuleGroup {
explicit HloModuleGroup(absl::string_view name) : name_(name) {} explicit HloModuleGroup(absl::string_view name) : name_(name) {}
// Construct a module group containing a single module. // Construct a module group containing a single module.
HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module); explicit HloModuleGroup(std::unique_ptr<HloModule> module);
// Construct a module group containing any number of modules. // Construct a module group containing any number of modules.
HloModuleGroup(absl::string_view name, HloModuleGroup(absl::string_view name,
@ -50,11 +50,16 @@ class HloModuleGroup {
// Add a module to the back of vector of modules in the group. // Add a module to the back of vector of modules in the group.
void push_back(std::unique_ptr<HloModule> module); void push_back(std::unique_ptr<HloModule> module);
// Replaces the existing module at the given index with the given module. The
// existing module is discarded.
void ReplaceModule(int index, std::unique_ptr<HloModule> module);
// Moves all modules from the group into the returned vector. After this // Moves all modules from the group into the returned vector. After this
// method runs, the module group will be empty. // method runs, the module group will be empty.
std::vector<std::unique_ptr<HloModule>> ConsumeModules(); std::vector<std::unique_ptr<HloModule>> ConsumeModules();
string name() const { return name_; } string name() const { return name_; }
string ToString() const; string ToString() const;
// Serialize the module group to/from a proto. // Serialize the module group to/from a proto.
@ -63,6 +68,12 @@ class HloModuleGroup {
const HloModuleGroupProto& proto, const HloModuleGroupProto& proto,
absl::Span<const HloModuleConfig> module_configs); absl::Span<const HloModuleConfig> module_configs);
// Returns the number of modules in the module group.
int size() const { return modules_.size(); }
// Returns true if there are no modules in the module group.
bool empty() const { return modules_.empty(); }
private: private:
string name_; string name_;

View File

@ -46,7 +46,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
)"; )";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(text)); ParseHloString(text));
HloModuleGroup group(TestName(), std::move(module)); HloModuleGroup group(std::move(module));
EXPECT_EQ(group.modules().size(), 1); EXPECT_EQ(group.modules().size(), 1);
EXPECT_THAT( EXPECT_THAT(

View File

@ -57,6 +57,12 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
return std::move(hlo_module); return std::move(hlo_module);
} }
Status InterpreterCompiler::RunHloPassesOnModuleGroup(
HloModuleGroup* module_group, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) {
return Unimplemented("Module group compilation not supported on Interpreter");
}
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend( StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* /*device_allocator*/) { DeviceMemoryAllocator* /*device_allocator*/) {
@ -76,17 +82,26 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
return std::move(executable); return std::move(executable);
} }
StatusOr<std::vector<std::unique_ptr<Executable>>>
InterpreterCompiler::RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) {
return Unimplemented(
"Module group compilation is not supported on Interpreter.");
}
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
std::vector<std::unique_ptr<HloModule>> /*hlo_modules*/, std::unique_ptr<HloModuleGroup> /*module_group*/,
std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/, std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/,
DeviceMemoryAllocator* /*device_allocator*/) { DeviceMemoryAllocator* /*device_allocator*/) {
return tensorflow::errors::Unimplemented( return Unimplemented(
"Compilation of multiple HLO modules is not supported on Interpreter."); "Module group compilation is not supported on Interpreter.");
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
InterpreterCompiler::CompileAheadOfTime( InterpreterCompiler::CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> hlo_modules, std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) { const AotCompilationOptions& aot_options) {
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
"AOT compilation not supported on Interpreter"); "AOT compilation not supported on Interpreter");

View File

@ -46,18 +46,25 @@ class InterpreterCompiler : public Compiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses( StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;
Status RunHloPassesOnModuleGroup(
HloModuleGroup* module_group, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::unique_ptr<Executable>> RunBackend( StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> hlo_modules, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec, std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> hlo_modules, CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) override; const AotCompilationOptions& aot_options) override;
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;

View File

@ -21,8 +21,24 @@ limitations under the License.
#endif #endif
namespace xla { namespace xla {
Status LLVMCompiler::RunHloPassesOnModuleGroup(
HloModuleGroup* module_group, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) {
return Unimplemented(
"Model partitioning not implemented for the CPU/GPU compilers!");
}
StatusOr<std::vector<std::unique_ptr<Executable>>>
LLVMCompiler::RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) {
return Unimplemented(
"Model partitioning not implemented for the CPU/GPU compilers!");
}
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
DeviceMemoryAllocator* device_allocator) { DeviceMemoryAllocator* device_allocator) {
// Tensorflow tries to enable the following behaviors in all its threads: // Tensorflow tries to enable the following behaviors in all its threads:
@ -38,6 +54,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals;
std::vector<std::unique_ptr<Executable>> result; std::vector<std::unique_ptr<Executable>> result;
std::vector<std::unique_ptr<HloModule>> modules =
module_group->ConsumeModules();
for (size_t i = 0; i < modules.size(); i++) { for (size_t i = 0; i < modules.size(); i++) {
if (stream_execs[i].size() != 1) { if (stream_execs[i].size() != 1) {
return Unimplemented( return Unimplemented(

View File

@ -69,8 +69,17 @@ class LLVMCompiler : public Compiler {
using Compiler::RunBackend; using Compiler::RunBackend;
using Compiler::RunHloPasses; using Compiler::RunHloPasses;
Status RunHloPassesOnModuleGroup(
HloModuleGroup* module_group, se::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> RunBackendOnModuleGroup(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs, std::vector<std::vector<se::StreamExecutor*>> stream_execs,
DeviceMemoryAllocator* device_allocator) override; DeviceMemoryAllocator* device_allocator) override;

View File

@ -341,18 +341,19 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
} }
CHECK_EQ(module_protos.size(), module_configs.size()); CHECK_EQ(module_protos.size(), module_configs.size());
std::vector<std::unique_ptr<HloModule>> modules; auto module_group =
absl::make_unique<HloModuleGroup>(module_protos[0]->name());
for (int64 i = 0; i < module_protos.size(); ++i) { for (int64 i = 0; i < module_protos.size(); ++i) {
const HloModuleProto* proto = module_protos[i]; const HloModuleProto* proto = module_protos[i];
const HloModuleConfig& config = *module_configs[i]; const HloModuleConfig& config = *module_configs[i];
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
modules.push_back(std::move(module)); module_group->push_back(std::move(module));
} }
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<Executable>> executables, std::vector<std::unique_ptr<Executable>> executables,
backend->compiler()->Compile(std::move(modules), std::move(executors), backend->compiler()->Compile(std::move(module_group),
device_allocator)); std::move(executors), device_allocator));
for (size_t i = 0; i < module_protos.size(); ++i) { for (size_t i = 0; i < module_protos.size(); ++i) {
if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {

View File

@ -32,11 +32,10 @@ StatusOr<std::unique_ptr<AotCompilationResult>>
CodegenTestBase::CompileToAotCompilationResult( CodegenTestBase::CompileToAotCompilationResult(
std::unique_ptr<HloModule> hlo_module, std::unique_ptr<HloModule> hlo_module,
const AotCompilationOptions& options) { const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules; auto module_group = absl::make_unique<HloModuleGroup>(std::move(hlo_module));
hlo_modules.push_back(std::move(hlo_module));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<AotCompilationResult>> results, std::vector<std::unique_ptr<AotCompilationResult>> results,
backend().compiler()->CompileAheadOfTime(std::move(hlo_modules), backend().compiler()->CompileAheadOfTime(std::move(module_group),
options)); options));
return std::move(results.front()); return std::move(results.front());
} }

View File

@ -93,15 +93,16 @@ class LLVMCompilerTest : public ::testing::Test {
std::unique_ptr<HloModule> hlo_module = CreateNewModule(); std::unique_ptr<HloModule> hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build()); hlo_module->AddEntryComputation(builder.Build());
std::vector<std::unique_ptr<HloModule>> modules; auto module_group = absl::make_unique<HloModuleGroup>("test_module_group");
modules.push_back(hlo_module->Clone()); module_group->push_back(hlo_module->Clone());
modules.push_back(std::move(hlo_module)); module_group->push_back(std::move(hlo_module));
std::vector<std::vector<se::StreamExecutor *>> executors; std::vector<std::vector<se::StreamExecutor *>> executors;
executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()});
executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()});
EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), EXPECT_IS_OK(compiler->Compile(std::move(module_group),
std::move(executors),
/*device_allocator=*/nullptr)); /*device_allocator=*/nullptr));
} }
@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) {
TestCompilerHooks(&compiler); TestCompilerHooks(&compiler);
} }
TEST_F(CpuCompilerTest, MultiModuleCompilation) { TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) {
cpu::CpuCompiler compiler; cpu::CpuCompiler compiler;
TestMultiModuleCompilation(&compiler); TestMultiModuleCompilation(&compiler);
} }
TEST_F(GpuCompilerTest, MultModuleCompilation) { TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) {
gpu::NVPTXCompiler compiler; gpu::NVPTXCompiler compiler;
TestMultiModuleCompilation(&compiler); TestMultiModuleCompilation(&compiler);
} }