From ab17c2f07cfeb1d1956108afbcd543709d489f55 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sat, 14 Sep 2019 18:28:47 -0700 Subject: [PATCH] Use DialectConversion framework in legalize_tf pass instead of GreedyPatternRewriter In legalize_tf pass, we have patterns for TF to HLO lowerings as well as TF to TF lowerings. In case of multiple applicable patterns, lowering to TF prematurely may lose some optimization opportunities. Switching to DialectConversion makes sure that direct lowering from TF to HLO are preferred. Fixed tests to directly test the pass output and not the canonicalized output that the greedy rewriter produces. PiperOrigin-RevId: 269125599 --- .../tensorflow/utils/compile_mlir_util.cc | 4 ++-- tensorflow/compiler/mlir/xla/BUILD | 1 + .../compiler/mlir/xla/tests/legalize-tf.mlir | 12 +++++----- .../mlir/xla/transforms/legalize_tf.cc | 22 +++++++++---------- .../xla/transforms/legalize_tf_patterns.td | 8 +++++-- .../compiler/mlir/xla/transforms/passes.h | 3 ++- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 648aeef36da..a1fa728b16e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -150,8 +150,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // error reporting system. Report a generic error if pass manager failed // without emitting a diagnostic. mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext()); - mlir::xla_hlo::legalizeTF(module_op); - if (!error_handler.ok()) { + mlir::LogicalResult result = mlir::xla_hlo::legalizeTF(module_op); + if (failed(result)) { return error_handler.Combine( errors::Internal("MLIR TF to XLA legalization failed")); } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index bcede30ea73..b5fb23ea9c2 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -94,6 +94,7 @@ cc_library( "@local_config_mlir//:Pass", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2ce1fdac717..75163bfbb8c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -321,7 +321,7 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0> : tensor<1xi32>} - // CHECK-NEXT: xla_hlo.max %arg0, %[[ZERO]] : tensor<1xi32> + // CHECK-NEXT: xla_hlo.max %[[ZERO]], %arg0 : tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -342,10 +342,9 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @simple_softmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor} - // CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor} // Verify reduce op for max computation and its body. + // CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor} // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]]) // CHECK: xla_hlo.max // CHECK: "xla_hlo.return" @@ -353,14 +352,17 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) + // CHECK: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // Verify reduce op for summation and its body. - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]]) + // CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor} + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) // CHECK: xla_hlo.add // CHECK: "xla_hlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 8606519fcd2..c00f314fad1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" @@ -369,31 +370,30 @@ class ConvertSoftmaxOp : public OpRewritePattern { } // end namespace xla } // end namespace mlir -void mlir::xla_hlo::legalizeTF(Operation *op) { +LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) { + MLIRContext *context = op->getContext(); + // Add lowering patterns to the list. OwningRewritePatternList patterns; - xla::populateWithGenerated(op->getContext(), &patterns); + xla::populateWithGenerated(context, &patterns); // Add patterns that lower some of the high level TensorFlow ops to lower // level TensorFlow ops. So, we don't have to target all the TensorFlow ops // here for lowering to HLO. - // - // TODO(b/140964075): Switch to DialectConversion to avoid premature lowering - // to lower level TensorFlow ops if we actually want to target the higher - // level TensorFlow op directly. - mlir::TF::PopulateLoweringTFPatterns(op->getContext(), &patterns); + mlir::TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert(op->getContext()); patterns.insert(op->getContext()); - // Recursively applies rewrite patterns to nested operations. - applyPatternsGreedily(op, patterns); + ConversionTarget target(*context); + target.addLegalDialect(); + + return applyPartialConversion(op, target, patterns); } /// Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - auto func = getFunction(); - mlir::xla_hlo::legalizeTF(func); + if (failed(mlir::xla_hlo::legalizeTF(getFunction()))) signalPassFailure(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 5ae81711f71..67746fd8352 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -119,11 +119,15 @@ def GetHLOAxisFromTFAxis : NativeCodeCall< def HasRankedFirstOperand : ConstraintgetType().isa()">>; -// Here, we convert from TensorFlow axis format to HLO axis format which +// This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this // conversion, use the first input to get inputs rank. Other inputs need not be // ranked. -def : Pat<(TF_ConcatV2Op $inputs, (HLO_ConstOp OneElementAttr:$axis), $unused), +// Defining op for `axis` is TensorFlow constant op in the pattern as during +// the conversion, original Concat op operands still refers to the old ops even +// if HLO constant op is introduced as an replacement for the TensorFlow +// Constant op. +def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused), (HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 7545a558926..12370b20557 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir namespace mlir { @@ -36,7 +37,7 @@ std::unique_ptr> createLegalizeTFPass(); /// Converts the provided Operation as well as all nested operations into XLA /// dialect using the conversion patterns registered by the XLA dialect. -void legalizeTF(Operation* op); +LogicalResult legalizeTF(Operation* op); /// Lowers XLA control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass();