Rollforward with build fix to MLIR TPU compiler.
Also renamed some methods to avoid "hides overloaded virtual function" compilation error which only appears in the "Builder" analysis in Critique. See b/11765370. *** Original change description *** Automated rollback of commit 51f0eb5849be0f9ce20e5eb8370158088711f19d PiperOrigin-RevId: 216914046
This commit is contained in:
parent
0608a3f0db
commit
fe52c06bd6
@ -860,6 +860,7 @@ cc_library(
|
|||||||
":executable",
|
":executable",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_module_config",
|
":hlo_module_config",
|
||||||
|
":hlo_module_group",
|
||||||
":logical_buffer",
|
":logical_buffer",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
|
@ -103,8 +103,10 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
hlo_modules.push_back(std::move(hlo_module));
|
hlo_modules.push_back(std::move(hlo_module));
|
||||||
}
|
}
|
||||||
|
|
||||||
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options,
|
return compiler_->CompileAheadOfTime(
|
||||||
metadata);
|
absl::make_unique<HloModuleGroup>(hlo_modules[0]->name(),
|
||||||
|
absl::MakeSpan(hlo_modules)),
|
||||||
|
options, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo,
|
|||||||
// Define a default version where metadata is not used.
|
// Define a default version where metadata is not used.
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
Compiler::CompileAheadOfTime(
|
Compiler::CompileAheadOfTime(
|
||||||
std::vector<std::unique_ptr<HloModule>> modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& options,
|
const AotCompilationOptions& options,
|
||||||
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
||||||
if (metadata != nullptr) {
|
if (metadata != nullptr) {
|
||||||
@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime(
|
|||||||
"Populating AotCompilationMetadata is not implemented on this "
|
"Populating AotCompilationMetadata is not implemented on this "
|
||||||
"compiler.");
|
"compiler.");
|
||||||
}
|
}
|
||||||
return CompileAheadOfTime(std::move(modules), options);
|
return CompileAheadOfTime(std::move(module_group), options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*
|
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
@ -135,6 +136,12 @@ class Compiler {
|
|||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
DeviceMemoryAllocator* device_allocator) = 0;
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
|
// Optimizes a HLO module group, a set of module which runs concurrently on
|
||||||
|
// multiple devices potentially communicating data between the modules.
|
||||||
|
virtual Status RunHloPassesOnModuleGroup(
|
||||||
|
HloModuleGroup* module_group, se::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
// Compiles the HLO module for execution on a device given by the executor,
|
// Compiles the HLO module for execution on a device given by the executor,
|
||||||
// and returns an executable object or an error status. No HLO passes are
|
// and returns an executable object or an error status. No HLO passes are
|
||||||
// applied to module. Generally a module should be passed through RunHloPasses
|
// applied to module. Generally a module should be passed through RunHloPasses
|
||||||
@ -145,12 +152,18 @@ class Compiler {
|
|||||||
// (not just type of device) indicated by the executor.
|
// (not just type of device) indicated by the executor.
|
||||||
//
|
//
|
||||||
// device_allocator is optional; see RunHloPasses.
|
// device_allocator is optional; see RunHloPasses.
|
||||||
//
|
|
||||||
// Use the overload below to compile computations that run in parallel.
|
|
||||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
|
||||||
DeviceMemoryAllocator* device_allocator) = 0;
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
|
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||||
|
// communicating data between the modules.
|
||||||
|
virtual StatusOr<std::vector<std::unique_ptr<Executable>>>
|
||||||
|
RunBackendOnModuleGroup(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||||
// communicating data between the modules, and returns a corresponding
|
// communicating data between the modules, and returns a corresponding
|
||||||
// sequence of executable objects.
|
// sequence of executable objects.
|
||||||
@ -160,7 +173,7 @@ class Compiler {
|
|||||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||||
// modules to RunHloPasses and RunBackends.
|
// modules to RunHloPasses and RunBackends.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
DeviceMemoryAllocator* device_allocator) = 0;
|
DeviceMemoryAllocator* device_allocator) = 0;
|
||||||
|
|
||||||
@ -184,16 +197,16 @@ class Compiler {
|
|||||||
ComputeDefaultBackendConfig(const HloInstruction& hlo,
|
ComputeDefaultBackendConfig(const HloInstruction& hlo,
|
||||||
se::StreamExecutor* executor) const;
|
se::StreamExecutor* executor) const;
|
||||||
|
|
||||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
// Compiles the HLO module group for ahead-of-time execution. This is
|
||||||
// use in static compilation.
|
// intended for use in static compilation.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& options) = 0;
|
const AotCompilationOptions& options) = 0;
|
||||||
|
|
||||||
// Similar to CompileAheadOfTime above but AotCompilationMetadata
|
// Similar to CompileAheadOfTime above but AotCompilationMetadata
|
||||||
// has an argument that can be populated during compilation.
|
// has an argument that can be populated during compilation.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& options,
|
const AotCompilationOptions& options,
|
||||||
std::unique_ptr<AotCompilationMetadata>* metadata);
|
std::unique_ptr<AotCompilationMetadata>* metadata);
|
||||||
|
|
||||||
|
@ -676,9 +676,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& aot_options) {
|
const AotCompilationOptions& aot_options) {
|
||||||
TF_RET_CHECK(!modules.empty());
|
TF_RET_CHECK(!module_group->empty());
|
||||||
|
std::vector<std::unique_ptr<HloModule>> modules =
|
||||||
|
module_group->ConsumeModules();
|
||||||
|
|
||||||
std::call_once(llvm_command_line_options_initialized,
|
std::call_once(llvm_command_line_options_initialized,
|
||||||
&llvm_ir::InitializeLLVMCommandLineOptions,
|
&llvm_ir::InitializeLLVMCommandLineOptions,
|
||||||
modules[0]->config());
|
modules[0]->config());
|
||||||
|
@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler {
|
|||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& options) override;
|
const AotCompilationOptions& options) override;
|
||||||
|
|
||||||
se::Platform::Id PlatformId() const override;
|
se::Platform::Id PlatformId() const override;
|
||||||
|
@ -825,9 +825,8 @@ std::vector<uint8> NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
NVPTXCompiler::CompileAheadOfTime(
|
NVPTXCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::unique_ptr<HloModule>> module,
|
const AotCompilationOptions& options) {
|
||||||
const AotCompilationOptions& options) {
|
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"not yet implemented: NVPTXCompiler::CompileAheadOfTime");
|
"not yet implemented: NVPTXCompiler::CompileAheadOfTime");
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler {
|
|||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> module,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
AotCompilationOptions const& options) override;
|
AotCompilationOptions const& options) override;
|
||||||
|
|
||||||
se::Platform::Id PlatformId() const override;
|
se::Platform::Id PlatformId() const override;
|
||||||
|
@ -17,9 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
HloModuleGroup::HloModuleGroup(absl::string_view name,
|
HloModuleGroup::HloModuleGroup(std::unique_ptr<HloModule> module)
|
||||||
std::unique_ptr<HloModule> module)
|
: name_(module->name()) {
|
||||||
: name_(name) {
|
|
||||||
push_back(std::move(module));
|
push_back(std::move(module));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ class HloModuleGroup {
|
|||||||
explicit HloModuleGroup(absl::string_view name) : name_(name) {}
|
explicit HloModuleGroup(absl::string_view name) : name_(name) {}
|
||||||
|
|
||||||
// Construct a module group containing a single module.
|
// Construct a module group containing a single module.
|
||||||
HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
|
explicit HloModuleGroup(std::unique_ptr<HloModule> module);
|
||||||
|
|
||||||
// Construct a module group containing any number of modules.
|
// Construct a module group containing any number of modules.
|
||||||
HloModuleGroup(absl::string_view name,
|
HloModuleGroup(absl::string_view name,
|
||||||
@ -50,11 +50,16 @@ class HloModuleGroup {
|
|||||||
// Add a module to the back of vector of modules in the group.
|
// Add a module to the back of vector of modules in the group.
|
||||||
void push_back(std::unique_ptr<HloModule> module);
|
void push_back(std::unique_ptr<HloModule> module);
|
||||||
|
|
||||||
|
// Replaces the existing module at the given index with the given module. The
|
||||||
|
// existing module is discarded.
|
||||||
|
void ReplaceModule(int index, std::unique_ptr<HloModule> module);
|
||||||
|
|
||||||
// Moves all modules from the group into the returned vector. After this
|
// Moves all modules from the group into the returned vector. After this
|
||||||
// method runs, the module group will be empty.
|
// method runs, the module group will be empty.
|
||||||
std::vector<std::unique_ptr<HloModule>> ConsumeModules();
|
std::vector<std::unique_ptr<HloModule>> ConsumeModules();
|
||||||
|
|
||||||
string name() const { return name_; }
|
string name() const { return name_; }
|
||||||
|
|
||||||
string ToString() const;
|
string ToString() const;
|
||||||
|
|
||||||
// Serialize the module group to/from a proto.
|
// Serialize the module group to/from a proto.
|
||||||
@ -63,6 +68,12 @@ class HloModuleGroup {
|
|||||||
const HloModuleGroupProto& proto,
|
const HloModuleGroupProto& proto,
|
||||||
absl::Span<const HloModuleConfig> module_configs);
|
absl::Span<const HloModuleConfig> module_configs);
|
||||||
|
|
||||||
|
// Returns the number of modules in the module group.
|
||||||
|
int size() const { return modules_.size(); }
|
||||||
|
|
||||||
|
// Returns true if there are no modules in the module group.
|
||||||
|
bool empty() const { return modules_.empty(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string name_;
|
string name_;
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
|
|||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
ParseHloString(text));
|
ParseHloString(text));
|
||||||
HloModuleGroup group(TestName(), std::move(module));
|
HloModuleGroup group(std::move(module));
|
||||||
|
|
||||||
EXPECT_EQ(group.modules().size(), 1);
|
EXPECT_EQ(group.modules().size(), 1);
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -57,6 +57,12 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
|
|||||||
return std::move(hlo_module);
|
return std::move(hlo_module);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status InterpreterCompiler::RunHloPassesOnModuleGroup(
|
||||||
|
HloModuleGroup* module_group, se::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return Unimplemented("Module group compilation not supported on Interpreter");
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||||
@ -76,17 +82,26 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
|||||||
return std::move(executable);
|
return std::move(executable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Executable>>>
|
||||||
|
InterpreterCompiler::RunBackendOnModuleGroup(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return Unimplemented(
|
||||||
|
"Module group compilation is not supported on Interpreter.");
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> /*hlo_modules*/,
|
std::unique_ptr<HloModuleGroup> /*module_group*/,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/,
|
std::vector<std::vector<se::StreamExecutor*>> /*stream_execs*/,
|
||||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||||
return tensorflow::errors::Unimplemented(
|
return Unimplemented(
|
||||||
"Compilation of multiple HLO modules is not supported on Interpreter.");
|
"Module group compilation is not supported on Interpreter.");
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
InterpreterCompiler::CompileAheadOfTime(
|
InterpreterCompiler::CompileAheadOfTime(
|
||||||
std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& aot_options) {
|
const AotCompilationOptions& aot_options) {
|
||||||
return tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"AOT compilation not supported on Interpreter");
|
"AOT compilation not supported on Interpreter");
|
||||||
|
@ -46,18 +46,25 @@ class InterpreterCompiler : public Compiler {
|
|||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
Status RunHloPassesOnModuleGroup(
|
||||||
|
HloModuleGroup* module_group, se::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||||
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
|
||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Executable>>> RunBackendOnModuleGroup(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> hlo_modules,
|
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
|
||||||
const AotCompilationOptions& aot_options) override;
|
const AotCompilationOptions& aot_options) override;
|
||||||
|
|
||||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
|
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
|
||||||
|
@ -21,8 +21,24 @@ limitations under the License.
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
Status LLVMCompiler::RunHloPassesOnModuleGroup(
|
||||||
|
HloModuleGroup* module_group, se::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return Unimplemented(
|
||||||
|
"Model partitioning not implemented for the CPU/GPU compilers!");
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Executable>>>
|
||||||
|
LLVMCompiler::RunBackendOnModuleGroup(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) {
|
||||||
|
return Unimplemented(
|
||||||
|
"Model partitioning not implemented for the CPU/GPU compilers!");
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
DeviceMemoryAllocator* device_allocator) {
|
DeviceMemoryAllocator* device_allocator) {
|
||||||
// Tensorflow tries to enable the following behaviors in all its threads:
|
// Tensorflow tries to enable the following behaviors in all its threads:
|
||||||
@ -38,6 +54,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
|
|||||||
tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals;
|
tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals;
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Executable>> result;
|
std::vector<std::unique_ptr<Executable>> result;
|
||||||
|
std::vector<std::unique_ptr<HloModule>> modules =
|
||||||
|
module_group->ConsumeModules();
|
||||||
for (size_t i = 0; i < modules.size(); i++) {
|
for (size_t i = 0; i < modules.size(); i++) {
|
||||||
if (stream_execs[i].size() != 1) {
|
if (stream_execs[i].size() != 1) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
|
@ -69,8 +69,17 @@ class LLVMCompiler : public Compiler {
|
|||||||
using Compiler::RunBackend;
|
using Compiler::RunBackend;
|
||||||
using Compiler::RunHloPasses;
|
using Compiler::RunHloPasses;
|
||||||
|
|
||||||
|
Status RunHloPassesOnModuleGroup(
|
||||||
|
HloModuleGroup* module_group, se::StreamExecutor* executor,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<Executable>>> RunBackendOnModuleGroup(
|
||||||
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
|
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
|
||||||
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||||
std::vector<std::unique_ptr<HloModule>> modules,
|
std::unique_ptr<HloModuleGroup> module_group,
|
||||||
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
|
||||||
DeviceMemoryAllocator* device_allocator) override;
|
DeviceMemoryAllocator* device_allocator) override;
|
||||||
|
|
||||||
|
@ -341,18 +341,19 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK_EQ(module_protos.size(), module_configs.size());
|
CHECK_EQ(module_protos.size(), module_configs.size());
|
||||||
std::vector<std::unique_ptr<HloModule>> modules;
|
auto module_group =
|
||||||
|
absl::make_unique<HloModuleGroup>(module_protos[0]->name());
|
||||||
for (int64 i = 0; i < module_protos.size(); ++i) {
|
for (int64 i = 0; i < module_protos.size(); ++i) {
|
||||||
const HloModuleProto* proto = module_protos[i];
|
const HloModuleProto* proto = module_protos[i];
|
||||||
const HloModuleConfig& config = *module_configs[i];
|
const HloModuleConfig& config = *module_configs[i];
|
||||||
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
|
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
|
||||||
modules.push_back(std::move(module));
|
module_group->push_back(std::move(module));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<Executable>> executables,
|
std::vector<std::unique_ptr<Executable>> executables,
|
||||||
backend->compiler()->Compile(std::move(modules), std::move(executors),
|
backend->compiler()->Compile(std::move(module_group),
|
||||||
device_allocator));
|
std::move(executors), device_allocator));
|
||||||
|
|
||||||
for (size_t i = 0; i < module_protos.size(); ++i) {
|
for (size_t i = 0; i < module_protos.size(); ++i) {
|
||||||
if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
|
if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
|
||||||
|
@ -32,11 +32,10 @@ StatusOr<std::unique_ptr<AotCompilationResult>>
|
|||||||
CodegenTestBase::CompileToAotCompilationResult(
|
CodegenTestBase::CompileToAotCompilationResult(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
const AotCompilationOptions& options) {
|
const AotCompilationOptions& options) {
|
||||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
auto module_group = absl::make_unique<HloModuleGroup>(std::move(hlo_module));
|
||||||
hlo_modules.push_back(std::move(hlo_module));
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<std::unique_ptr<AotCompilationResult>> results,
|
std::vector<std::unique_ptr<AotCompilationResult>> results,
|
||||||
backend().compiler()->CompileAheadOfTime(std::move(hlo_modules),
|
backend().compiler()->CompileAheadOfTime(std::move(module_group),
|
||||||
options));
|
options));
|
||||||
return std::move(results.front());
|
return std::move(results.front());
|
||||||
}
|
}
|
||||||
|
@ -93,15 +93,16 @@ class LLVMCompilerTest : public ::testing::Test {
|
|||||||
std::unique_ptr<HloModule> hlo_module = CreateNewModule();
|
std::unique_ptr<HloModule> hlo_module = CreateNewModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
std::vector<std::unique_ptr<HloModule>> modules;
|
auto module_group = absl::make_unique<HloModuleGroup>("test_module_group");
|
||||||
modules.push_back(hlo_module->Clone());
|
module_group->push_back(hlo_module->Clone());
|
||||||
modules.push_back(std::move(hlo_module));
|
module_group->push_back(std::move(hlo_module));
|
||||||
|
|
||||||
std::vector<std::vector<se::StreamExecutor *>> executors;
|
std::vector<std::vector<se::StreamExecutor *>> executors;
|
||||||
executors.push_back({backend_->default_stream_executor()});
|
executors.push_back({backend_->default_stream_executor()});
|
||||||
executors.push_back({backend_->default_stream_executor()});
|
executors.push_back({backend_->default_stream_executor()});
|
||||||
|
|
||||||
EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors),
|
EXPECT_IS_OK(compiler->Compile(std::move(module_group),
|
||||||
|
std::move(executors),
|
||||||
/*device_allocator=*/nullptr));
|
/*device_allocator=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) {
|
|||||||
TestCompilerHooks(&compiler);
|
TestCompilerHooks(&compiler);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CpuCompilerTest, MultiModuleCompilation) {
|
TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) {
|
||||||
cpu::CpuCompiler compiler;
|
cpu::CpuCompiler compiler;
|
||||||
TestMultiModuleCompilation(&compiler);
|
TestMultiModuleCompilation(&compiler);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuCompilerTest, MultModuleCompilation) {
|
TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) {
|
||||||
gpu::NVPTXCompiler compiler;
|
gpu::NVPTXCompiler compiler;
|
||||||
TestMultiModuleCompilation(&compiler);
|
TestMultiModuleCompilation(&compiler);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user