Explicitly register the tensorflow dialect in cubin_creator.
The other used dialects are registered automatically by depending on AllPassesAndDialects. PiperOrigin-RevId: 314511426 Change-Id: I5beced8e299fde22fd62f6f2b15404e60115f1ee
This commit is contained in:
parent
2cbb43e1fc
commit
d0c150d2e6
@ -21,6 +21,7 @@ cc_library(
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -42,6 +43,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/NVVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||
@ -216,14 +218,22 @@ Status PropagateStaticShapeKnowledgeToKernel(
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
|
||||
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::ArrayRef<uint32_t> unroll_factors) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
context.allowUnregisteredDialects(); // TODO(b/152572127)
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||
|
Loading…
Reference in New Issue
Block a user