Make XLA GPU and XLA MLIR GPU emitters share the same HLO optimizations passes.
For now, the MLIR backend will expect the same HLO as input as the GPU backend does. Hence, we need to run the same required passes. Also use the same HLO level optimizations, so that we get comparable HLO. PiperOrigin-RevId: 258349785
This commit is contained in:
parent
66b97d20ba
commit
2ba24e8d89
@ -165,9 +165,69 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
|
||||
return ".";
|
||||
}
|
||||
|
||||
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
const HloInstruction* operand,
|
||||
const ShapeIndex& user_index) {
|
||||
// Share the bias buffer with the parent instruction.
|
||||
if (IsCublasGemm(*user)) {
|
||||
if (user->operand_count() == 3 && user->operand(2) == operand) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// The operand of cholesky can be shared with the first output.
|
||||
if (user->opcode() == HloOpcode::kCustomCall &&
|
||||
user->custom_call_target() == kCusolverCholeskyCallTarget) {
|
||||
return user_index.size() == 1 && user_index[0] == 0;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
|
||||
//
|
||||
// Using such a driver only a problem if we fail to use ptxas to compile our ptx
|
||||
// and have to use the driver instead, so you should only call this function if
|
||||
// we're going to use the driver JIT.
|
||||
//
|
||||
// Only prints a warning the first time it's called.
|
||||
void WarnIfBadDriverJITVersion() {
|
||||
static std::once_flag run_once;
|
||||
std::call_once(run_once, [] {
|
||||
auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
|
||||
if (!version_or_status.ok()) {
|
||||
LOG(WARNING) << "Couldn't read CUDA driver version.";
|
||||
return;
|
||||
}
|
||||
se::cuda::DriverVersion version = version_or_status.ValueOrDie();
|
||||
|
||||
// The following versions of the driver JIT miscompile some address
|
||||
// calculations with large offsets (e.g. "load ptr + large_constant"),
|
||||
// b/70245379:
|
||||
//
|
||||
// - 384.x before 384.108
|
||||
// - 387.x before 387.40
|
||||
// - 390.x before 390.10.
|
||||
//
|
||||
// In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the
|
||||
// fix for the "large multioutput fusions" miscompile, b/111107644.
|
||||
if (version < std::make_tuple(396, 20, 0)) {
|
||||
LOG(WARNING)
|
||||
<< "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
|
||||
<< se::cuda::DriverVersionToString(version)
|
||||
<< ", which is older than 396.20.0. These versions are known to "
|
||||
"miscompile XLA code, leading to incorrect results or "
|
||||
"invalid-address errors.\nXLA only uses the driver JIT if it "
|
||||
"cannot find ptxas; you don't need to update your driver if "
|
||||
"you can point XLA to ptxas 9.2.88 or newer.";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Runs optimization passes on the given HLO module.
|
||||
Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
Status impl::OptimizeHloModule(HloModule* hlo_module,
|
||||
se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
{
|
||||
HloPassPipeline pipeline("optimization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
@ -402,26 +462,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
const HloInstruction* operand,
|
||||
const ShapeIndex& user_index) {
|
||||
// Share the bias buffer with the parent instruction.
|
||||
if (IsCublasGemm(*user)) {
|
||||
if (user->operand_count() == 3 && user->operand(2) == operand) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// The operand of cholesky can be shared with the first output.
|
||||
if (user->opcode() == HloOpcode::kCustomCall &&
|
||||
user->custom_call_target() == kCusolverCholeskyCallTarget) {
|
||||
return user_index.size() == 1 && user_index[0] == 0;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Modifies the given HLO module so that it will be accepted by IrEmitter.
|
||||
// Unlike optimization passes, the passes are necessary for correctness.
|
||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
Status impl::PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
// In some cases, we have to place the result of an instruction in a temporary
|
||||
// buffer. For instance, the buffer that holds an external parameter is
|
||||
// assumed immutable at this point, and should not be reused for output
|
||||
@ -452,48 +495,6 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
|
||||
//
|
||||
// Using such a driver only a problem if we fail to use ptxas to compile our ptx
|
||||
// and have to use the driver instead, so you should only call this function if
|
||||
// we're going to use the driver JIT.
|
||||
//
|
||||
// Only prints a warning the first time it's called.
|
||||
void WarnIfBadDriverJITVersion() {
|
||||
static std::once_flag run_once;
|
||||
std::call_once(run_once, [] {
|
||||
auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
|
||||
if (!version_or_status.ok()) {
|
||||
LOG(WARNING) << "Couldn't read CUDA driver version.";
|
||||
return;
|
||||
}
|
||||
se::cuda::DriverVersion version = version_or_status.ValueOrDie();
|
||||
|
||||
// The following versions of the driver JIT miscompile some address
|
||||
// calculations with large offsets (e.g. "load ptr + large_constant"),
|
||||
// b/70245379:
|
||||
//
|
||||
// - 384.x before 384.108
|
||||
// - 387.x before 387.40
|
||||
// - 390.x before 390.10.
|
||||
//
|
||||
// In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the
|
||||
// fix for the "large multioutput fusions" miscompile, b/111107644.
|
||||
if (version < std::make_tuple(396, 20, 0)) {
|
||||
LOG(WARNING)
|
||||
<< "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
|
||||
<< se::cuda::DriverVersionToString(version)
|
||||
<< ", which is older than 396.20.0. These versions are known to "
|
||||
"miscompile XLA code, leading to incorrect results or "
|
||||
"invalid-address errors.\nXLA only uses the driver JIT if it "
|
||||
"cannot find ptxas; you don't need to update your driver if "
|
||||
"you can point XLA to ptxas 9.2.88 or newer.";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NVPTXCompiler::NVPTXCompiler()
|
||||
: pointer_size_(llvm::DataLayout(nvptx::kDataLayout)
|
||||
.getPointerSize(0 /* default address space */)) {}
|
||||
@ -507,9 +508,9 @@ StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses(
|
||||
[&] { return absl::StrCat("HLO Transforms:", module->name()); },
|
||||
tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
TF_RETURN_IF_ERROR(
|
||||
OptimizeHloModule(module.get(), stream_exec, device_allocator));
|
||||
impl::OptimizeHloModule(module.get(), stream_exec, device_allocator));
|
||||
|
||||
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
|
||||
TF_RETURN_IF_ERROR(impl::PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
@ -38,6 +38,17 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// Temporarily expose the optimization pipeline for the GPU backend for reuse
|
||||
// in the MLIR GPU backend.
|
||||
// TODO(b/137624192): Remove once MLIR backend uses tailored optimizations.
|
||||
namespace impl {
|
||||
|
||||
Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator);
|
||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// The GPU compiler generates efficient GPU executables.
|
||||
class NVPTXCompiler : public LLVMCompiler {
|
||||
public:
|
||||
|
@ -40,6 +40,7 @@ cc_library(
|
||||
":failover_compiler",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
alwayslink = True, # Contains compiler registration
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace mlir {
|
||||
@ -28,7 +29,15 @@ se::Platform::Id MlirCompiler::PlatformId() const {
|
||||
StatusOr<std::unique_ptr<HloModule>> MlirCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
return Unimplemented("Not yet implemented in MLIR compiler");
|
||||
// Until we find a reason to do something different, run the same passes
|
||||
// that the normal GPU backend runs.
|
||||
TF_RETURN_IF_ERROR(xla::gpu::impl::OptimizeHloModule(
|
||||
module.get(), stream_exec, device_allocator));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::gpu::impl::PrepareHloModuleForIrEmitting(module.get()));
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> MlirCompiler::RunBackend(
|
||||
|
Loading…
x
Reference in New Issue
Block a user