Make use of same_shape and tensorflow abi knowledge propagation passes.
PiperOrigin-RevId: 341586499 Change-Id: Ifa6177fba122d53375b47ec69fa8a401d51582ac
This commit is contained in:
parent
e897624fac
commit
2c05a4a796
@ -198,32 +198,57 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
llvm::ArrayRef<std::string> architectures,
|
||||
bool generate_fatbin) {
|
||||
Status AmendKernelLLVMIRWithStaticKnowledge(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreatePropagateShapeKnowledgeToKernels());
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::kernel_gen::transforms::CreatePropagateTfAbiKnowledgeToKernels());
|
||||
|
||||
return failed(pm.run(module))
|
||||
? InternalError("Amending LLVMIR with static knowledge failed.")
|
||||
: Status::OK();
|
||||
}
|
||||
|
||||
Status GenerateDeviceCode(mlir::ModuleOp module, bool gpu_binary_only,
|
||||
llvm::ArrayRef<uint32_t> same_shape,
|
||||
llvm::StringRef gpu_binary_attr_name,
|
||||
llvm::ArrayRef<std::string> architectures,
|
||||
bool generate_fatbin) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
auto& kernel_pm = pm.nest<mlir::gpu::GPUModuleOp>();
|
||||
// TODO(herhut): Remove this.
|
||||
if (gpu_binary_only) {
|
||||
// Grab the original signature from the single function.
|
||||
kernel_pm.addNestedPass<mlir::LLVM::LLVMFuncOp>(
|
||||
mlir::kernel_gen::transforms::CreatePropagateTensorFlowABIKnowledgePass(
|
||||
same_shape));
|
||||
}
|
||||
// Remove debug information to ensure we do not create debug PTX.
|
||||
kernel_pm.addPass(mlir::createStripDebugInfoPass());
|
||||
kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
|
||||
gpu_binary_attr_name, architectures, generate_fatbin));
|
||||
|
||||
if (!gpu_binary_only) {
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
}
|
||||
return failed(pm.run(module)) ? InternalError("Lowering to LLVM IR failed.")
|
||||
: Status::OK();
|
||||
return failed(pm.run(module))
|
||||
? InternalError("Generating device code failed.")
|
||||
: Status::OK();
|
||||
}
|
||||
|
||||
Status LowerHostSideToFinalForm(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
applyTensorflowAndCLOptions(pm);
|
||||
|
||||
pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
|
||||
return failed(pm.run(module))
|
||||
? InternalError("Final lowering of host side failed.")
|
||||
: Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -249,9 +274,13 @@ StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
|
||||
#elif GOOGLE_CUDA
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
#endif
|
||||
TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
|
||||
kGpuBinaryAttrName, architectures,
|
||||
generate_fatbin));
|
||||
TF_RETURN_IF_ERROR(AmendKernelLLVMIRWithStaticKnowledge(module.get()));
|
||||
TF_RETURN_IF_ERROR(GenerateDeviceCode(module.get(), gpu_binary_only,
|
||||
same_shape, kGpuBinaryAttrName,
|
||||
architectures, generate_fatbin));
|
||||
if (!gpu_binary_only) {
|
||||
TF_RETURN_IF_ERROR(LowerHostSideToFinalForm(module.get()));
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75
|
||||
|
||||
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> attributes {tf_entry} {
|
||||
%0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user