[MLIR][KERNEL_GEN] Drop test from TF Kernel conversion passes.
These passes will be used in prod, so they should not be called test passes. PiperOrigin-RevId: 330929659 Change-Id: Ifd2d48862563d5d0e1ed829fc6b929c5a97f74fd
This commit is contained in:
parent
7519f88d21
commit
7021444afe
@ -139,7 +139,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
|
||||
// Embed TF Framework ops.
|
||||
if (!gpu_binary_only) {
|
||||
pm.addPass(mlir::kernel_gen::tf_framework::createEmbedTFFrameworkPass());
|
||||
pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass());
|
||||
}
|
||||
|
||||
// Some basic cleanup.
|
||||
@ -190,12 +190,10 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
gpu_binary_attr_name, compute_capability));
|
||||
|
||||
if (!gpu_binary_only) {
|
||||
pm.addPass(mlir::kernel_gen::tf_framework::
|
||||
createTestTFFrameworkLegalizeToLLVMPass());
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
}
|
||||
|
||||
return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.")
|
||||
: Status::OK();
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ cc_library(
|
||||
"bufferize_pass.cc",
|
||||
"embed_tf_framework_pass.cc",
|
||||
"shape_to_descriptors_pass.cc",
|
||||
"tf_framework_legalize_to_llvm_pass.cc",
|
||||
"tf_kernel_to_llvm_pass.cc",
|
||||
],
|
||||
hdrs = ["passes.h"],
|
||||
deps = [
|
||||
|
||||
@ -72,7 +72,7 @@ class EmbedTFFrameworkPass
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp> > createEmbedTFFrameworkPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp> > CreateEmbedTFFrameworkPass() {
|
||||
return std::make_unique<EmbedTFFrameworkPass>();
|
||||
}
|
||||
|
||||
|
||||
@ -25,20 +25,19 @@ namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace tf_framework {
|
||||
|
||||
// Test pass for applying TF Framework -> LLVM patterns.
|
||||
std::unique_ptr<OperationPass<ModuleOp> >
|
||||
createTestTFFrameworkLegalizeToLLVMPass();
|
||||
|
||||
// Pass to replace some of the Standard ops with TF Framework ops.
|
||||
// * adds tf_framework::OpKernelContextType argument to the function
|
||||
// * std.alloc becomes tf_framework.alloc_raw
|
||||
// * std.dealloc becomes tf_framework.dealloc_raw
|
||||
std::unique_ptr<OperationPass<ModuleOp> > createEmbedTFFrameworkPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp> > CreateEmbedTFFrameworkPass();
|
||||
|
||||
} // namespace tf_framework
|
||||
|
||||
namespace transforms {
|
||||
|
||||
// Pass for applying LLVM legalization patterns.
|
||||
std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass();
|
||||
|
||||
// Pass to tranform shape computations in shape dialect to standard and scf
|
||||
// using memref descriptors.
|
||||
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
|
||||
|
||||
@ -18,23 +18,22 @@ limitations under the License.
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TestTFFrameworkLegalizeToLLVMPass
|
||||
: Pass<"test-tf-framework-legalize-to-llvm", "ModuleOp"> {
|
||||
let summary = "Test pass for applying TF Framework -> LLVM patterns.";
|
||||
let constructor = "tf_framework::createTestTFFrameworkLegalizeToLLVMPass()";
|
||||
def TFKernelToLLVMPass : Pass<"tf-kernel-to-llvm", "ModuleOp"> {
|
||||
let summary = "Pass for applying LLVM legalization patterns.";
|
||||
let constructor = "transforms::CreateTFKernelToLLVMPass()";
|
||||
}
|
||||
|
||||
def EmbedTFFrameworkPass : Pass<"embed-tf-framework", "ModuleOp"> {
|
||||
let summary = "Pass to embed TF Framework for allocation and error reporting";
|
||||
let constructor = "tf_framework::createEmbedTFFrameworkPass()";
|
||||
let constructor = "tf_framework::CreateEmbedTFFrameworkPass()";
|
||||
}
|
||||
|
||||
def ShapeToDescriptorsPass : Pass<"test-shape-to-descriptors", "ModuleOp"> {
|
||||
def ShapeToDescriptorsPass : Pass<"shape-to-descriptors", "ModuleOp"> {
|
||||
let summary = "Pass to transform shape computations to descriptors";
|
||||
let constructor = "transforms::CreateShapeToDescriptorsPass()";
|
||||
}
|
||||
|
||||
def BufferizePass : Pass<"test-bufferize", "ModuleOp"> {
|
||||
def BufferizePass : Pass<"bufferize", "ModuleOp"> {
|
||||
let summary = "Pass to transform operations on values to buffer based ones";
|
||||
let constructor = "transforms::CreateBufferizePass()";
|
||||
}
|
||||
|
||||
@ -27,14 +27,13 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
namespace kernel_gen {
|
||||
namespace tf_framework {
|
||||
namespace transforms {
|
||||
namespace {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
||||
|
||||
class TestTFFrameworkToLLVMPass
|
||||
: public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
|
||||
class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
@ -52,7 +51,8 @@ class TestTFFrameworkToLLVMPass
|
||||
// Populate patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
populateStdToLLVMConversionPatterns(type_converter, patterns);
|
||||
PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, &patterns);
|
||||
tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
|
||||
&patterns);
|
||||
populateGpuToLLVMConversionPatterns(type_converter, patterns, "gpu.binary");
|
||||
lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns);
|
||||
|
||||
@ -71,11 +71,10 @@ class TestTFFrameworkToLLVMPass
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp> >
|
||||
createTestTFFrameworkLegalizeToLLVMPass() {
|
||||
return std::make_unique<TestTFFrameworkToLLVMPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp> > CreateTFKernelToLLVMPass() {
|
||||
return std::make_unique<TFKernelToLLVMPass>();
|
||||
}
|
||||
|
||||
} // namespace tf_framework
|
||||
} // namespace transforms
|
||||
} // namespace kernel_gen
|
||||
} // namespace mlir
|
||||
Loading…
x
Reference in New Issue
Block a user