Configure LLVMDialect and extract pointer size.
PiperOrigin-RevId: 258361644
This commit is contained in:
parent
b542d67952
commit
7e8efb0fb9
@ -39,8 +39,13 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":failover_compiler",
|
":failover_compiler",
|
||||||
"//tensorflow/compiler/xla/service:compiler",
|
"//tensorflow/compiler/xla/service:compiler",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:gpu_constants",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||||
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
|
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@local_config_mlir//:IR",
|
||||||
|
"@local_config_mlir//:LLVMDialect",
|
||||||
],
|
],
|
||||||
alwayslink = True, # Contains compiler registration
|
alwayslink = True, # Contains compiler registration
|
||||||
)
|
)
|
||||||
|
@ -15,13 +15,33 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
|
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
|
||||||
|
|
||||||
|
#include "mlir/LLVMIR/LLVMDialect.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||||
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
using ::mlir::MLIRContext;
|
||||||
|
using ::mlir::LLVM::LLVMDialect;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) {
|
||||||
|
LLVMDialect* dialect = context->getRegisteredDialect<LLVMDialect>();
|
||||||
|
llvm::Module& module = dialect->getLLVMModule();
|
||||||
|
module.setTargetTriple(gpu::nvptx::kTargetTriple);
|
||||||
|
module.setDataLayout(gpu::nvptx::kDataLayout);
|
||||||
|
return module.getDataLayout().getPointerSize();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MlirCompiler::MlirCompiler()
|
||||||
|
: pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {}
|
||||||
|
|
||||||
se::Platform::Id MlirCompiler::PlatformId() const {
|
se::Platform::Id MlirCompiler::PlatformId() const {
|
||||||
return stream_executor::cuda::kCudaPlatformId;
|
return stream_executor::cuda::kCudaPlatformId;
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -26,7 +27,7 @@ namespace mlir {
|
|||||||
// generation of a think suitable for XLAs runtime.
|
// generation of a think suitable for XLAs runtime.
|
||||||
class MlirCompiler : public Compiler {
|
class MlirCompiler : public Compiler {
|
||||||
public:
|
public:
|
||||||
MlirCompiler() {}
|
MlirCompiler();
|
||||||
|
|
||||||
se::Platform::Id PlatformId() const override;
|
se::Platform::Id PlatformId() const override;
|
||||||
|
|
||||||
@ -58,12 +59,15 @@ class MlirCompiler : public Compiler {
|
|||||||
const AotCompilationOptions& options) override;
|
const AotCompilationOptions& options) override;
|
||||||
|
|
||||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||||
// TODO(herhut): Get this from the LLVMDialect in MLIR.
|
int64 pointer_size = pointer_size_;
|
||||||
int64 pointer_size = 8;
|
|
||||||
return [pointer_size](const Shape& shape) {
|
return [pointer_size](const Shape& shape) {
|
||||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
::mlir::MLIRContext context_;
|
||||||
|
int64 pointer_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
Loading…
x
Reference in New Issue
Block a user