Configure LLVMDialect and extract pointer size.
PiperOrigin-RevId: 258361644
This commit is contained in:
parent
b542d67952
commit
7e8efb0fb9
@ -39,8 +39,13 @@ cc_library(
|
||||
deps = [
|
||||
":failover_compiler",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_constants",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/core:lib",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:LLVMDialect",
|
||||
],
|
||||
alwayslink = True, # Contains compiler registration
|
||||
)
|
||||
|
@ -15,13 +15,33 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
|
||||
|
||||
#include "mlir/LLVMIR/LLVMDialect.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace mlir {
|
||||
|
||||
using ::mlir::MLIRContext;
|
||||
using ::mlir::LLVM::LLVMDialect;
|
||||
|
||||
namespace {
|
||||
int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) {
|
||||
LLVMDialect* dialect = context->getRegisteredDialect<LLVMDialect>();
|
||||
llvm::Module& module = dialect->getLLVMModule();
|
||||
module.setTargetTriple(gpu::nvptx::kTargetTriple);
|
||||
module.setDataLayout(gpu::nvptx::kDataLayout);
|
||||
return module.getDataLayout().getPointerSize();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MlirCompiler::MlirCompiler()
|
||||
: pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {}
|
||||
|
||||
se::Platform::Id MlirCompiler::PlatformId() const {
|
||||
return stream_executor::cuda::kCudaPlatformId;
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
|
||||
namespace xla {
|
||||
@ -26,7 +27,7 @@ namespace mlir {
|
||||
// generation of a think suitable for XLAs runtime.
|
||||
class MlirCompiler : public Compiler {
|
||||
public:
|
||||
MlirCompiler() {}
|
||||
MlirCompiler();
|
||||
|
||||
se::Platform::Id PlatformId() const override;
|
||||
|
||||
@ -58,12 +59,15 @@ class MlirCompiler : public Compiler {
|
||||
const AotCompilationOptions& options) override;
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||
// TODO(herhut): Get this from the LLVMDialect in MLIR.
|
||||
int64 pointer_size = 8;
|
||||
int64 pointer_size = pointer_size_;
|
||||
return [pointer_size](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
::mlir::MLIRContext context_;
|
||||
int64 pointer_size_;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
Loading…
x
Reference in New Issue
Block a user