Explicitly register the tensorflow dialect in cubin_creator.

The other used dialects are registered automatically by depending on
AllPassesAndDialects.

PiperOrigin-RevId: 314511426
Change-Id: I5beced8e299fde22fd62f6f2b15404e60115f1ee
This commit is contained in:
Adrian Kuegel 2020-06-03 05:12:41 -07:00 committed by TensorFlower Gardener
parent 2cbb43e1fc
commit d0c150d2e6
2 changed files with 12 additions and 1 deletions

View File

@ -21,6 +21,7 @@ cc_library(
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TargetNVVMIR",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -42,6 +43,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/NVVMIR.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
@ -216,14 +218,22 @@ Status PropagateStaticShapeKnowledgeToKernel(
}
return Status::OK();
}
void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
} // namespace
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
RegisterDialects();
mlir::MLIRContext context;
context.allowUnregisteredDialects(); // TODO(b/152572127)
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));