diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index fb880d56417..72ca402427e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -39,8 +39,13 @@ cc_library( deps = [ ":failover_compiler", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_constants", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl", + "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/core:lib", + "@local_config_mlir//:IR", + "@local_config_mlir//:LLVMDialect", ], alwayslink = True, # Contains compiler registration ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index 8360bfc3da2..5421a3ae093 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -15,13 +15,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "mlir/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace mlir { +using ::mlir::MLIRContext; +using ::mlir::LLVM::LLVMDialect; + +namespace { +int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { + LLVMDialect* dialect = context->getRegisteredDialect(); + llvm::Module& module = dialect->getLLVMModule(); + module.setTargetTriple(gpu::nvptx::kTargetTriple); + module.setDataLayout(gpu::nvptx::kDataLayout); + return module.getDataLayout().getPointerSize(); +} +} // namespace + +MlirCompiler::MlirCompiler() + : pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {} + se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index ee0c372bb00..f02164c4d24 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "tensorflow/compiler/xla/service/compiler.h" namespace xla { @@ -26,7 +27,7 @@ namespace mlir { // generation of a think suitable for XLAs runtime. class MlirCompiler : public Compiler { public: - MlirCompiler() {} + MlirCompiler(); se::Platform::Id PlatformId() const override; @@ -58,12 +59,15 @@ class MlirCompiler : public Compiler { const AotCompilationOptions& options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - // TODO(herhut): Get this from the LLVMDialect in MLIR. - int64 pointer_size = 8; + int64 pointer_size = pointer_size_; return [pointer_size](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, pointer_size); }; } + + private: + ::mlir::MLIRContext context_; + int64 pointer_size_; }; } // namespace mlir