Run SCCP pass post TensorFlow shape inference pass
TensorFlow Shape inference pass no longer materializes constants found through shape propagation for performance reasons but some legalizations to HLO patterns rely on such constants. See the test for such an example in which without the SCCP pass the main function won't have the constant from the function. PiperOrigin-RevId: 359012814 Change-Id: I9bffaf0fb823dd54e9f9c44102fde50f99d7761d
This commit is contained in:
parent
bc9079ea41
commit
ee6fa8155e
tensorflow/compiler/mlir/tensorflow
@ -0,0 +1,26 @@
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"pbtxt",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-opt -tf-to-hlo-pipeline %s | FileCheck %s
|
||||
|
||||
// Verifies that constants generated post shape inference are propagated.
|
||||
// get_shape result in this test.
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<?xi64> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
|
||||
%1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.constant dense<[10, 19]>
|
||||
%2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @get_shape} : (tensor<?x?xf32>) -> (tensor<?xi64>)
|
||||
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %2 : tensor<?xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @get_shape
|
||||
func @get_shape(%arg0 : tensor<*xi64>) -> tensor<?xi64> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<*xi64>) -> tensor<?xi64>
|
||||
return %0 : tensor<?xi64>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -284,12 +284,22 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
// The SCCP pass performs constant propagation across the IR, which, for
|
||||
// example, propagates constant arguments into callee functions.
|
||||
// TOOD(hinsu): Investigate if we really need SCCP pass before shape inference
|
||||
// and can do with just one pass after the shape inference.
|
||||
pm.addPass(mlir::createSCCPPass());
|
||||
// Guarantee all functions have one use, which enables shape inference.
|
||||
pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
|
||||
// Run shape inference pass before tensorlist decomposition to get buffer
|
||||
// shape of uninitialized TensorLists.
|
||||
pm.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Run SCCP pass again as the availability of shapes may open up new
|
||||
// opportunities for constant propagation. Note that the shape inference pass
|
||||
// doesn't materialize new constants even if those are computed internally for
|
||||
// the purpose of shape inference. These constants might be required by the
|
||||
// legalization passes.
|
||||
pm.addPass(mlir::createSCCPPass());
|
||||
|
||||
pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
|
||||
|
Loading…
Reference in New Issue
Block a user