diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 648aeef36da..a1fa728b16e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -150,8 +150,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // error reporting system. Report a generic error if pass manager failed // without emitting a diagnostic. mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext()); - mlir::xla_hlo::legalizeTF(module_op); - if (!error_handler.ok()) { + mlir::LogicalResult result = mlir::xla_hlo::legalizeTF(module_op); + if (failed(result)) { return error_handler.Combine( errors::Internal("MLIR TF to XLA legalization failed")); } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index bcede30ea73..b5fb23ea9c2 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -94,6 +94,7 @@ cc_library( "@local_config_mlir//:Pass", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2ce1fdac717..75163bfbb8c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -321,7 +321,7 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0> : tensor<1xi32>} - // CHECK-NEXT: xla_hlo.max %arg0, %[[ZERO]] : tensor<1xi32> + // CHECK-NEXT: xla_hlo.max %[[ZERO]], %arg0 : tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -342,10 +342,9 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @simple_softmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor} - // CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor} // Verify reduce op for max computation and its body. + // CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor} // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]]) // CHECK: xla_hlo.max // CHECK: "xla_hlo.return" @@ -353,14 +352,17 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]]) + // CHECK: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // Verify reduce op for summation and its body. - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]]) + // CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor} + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) // CHECK: xla_hlo.add // CHECK: "xla_hlo.return" // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 8606519fcd2..c00f314fad1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" @@ -369,31 +370,30 @@ class ConvertSoftmaxOp : public OpRewritePattern { } // end namespace xla } // end namespace mlir -void mlir::xla_hlo::legalizeTF(Operation *op) { +LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) { + MLIRContext *context = op->getContext(); + // Add lowering patterns to the list. OwningRewritePatternList patterns; - xla::populateWithGenerated(op->getContext(), &patterns); + xla::populateWithGenerated(context, &patterns); // Add patterns that lower some of the high level TensorFlow ops to lower // level TensorFlow ops. So, we don't have to target all the TensorFlow ops // here for lowering to HLO. - // - // TODO(b/140964075): Switch to DialectConversion to avoid premature lowering - // to lower level TensorFlow ops if we actually want to target the higher - // level TensorFlow op directly. - mlir::TF::PopulateLoweringTFPatterns(op->getContext(), &patterns); + mlir::TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert(op->getContext()); patterns.insert(op->getContext()); - // Recursively applies rewrite patterns to nested operations. - applyPatternsGreedily(op, patterns); + ConversionTarget target(*context); + target.addLegalDialect(); + + return applyPartialConversion(op, target, patterns); } /// Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - auto func = getFunction(); - mlir::xla_hlo::legalizeTF(func); + if (failed(mlir::xla_hlo::legalizeTF(getFunction()))) signalPassFailure(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 5ae81711f71..67746fd8352 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -119,11 +119,15 @@ def GetHLOAxisFromTFAxis : NativeCodeCall< def HasRankedFirstOperand : ConstraintgetType().isa()">>; -// Here, we convert from TensorFlow axis format to HLO axis format which +// This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this // conversion, use the first input to get inputs rank. Other inputs need not be // ranked. -def : Pat<(TF_ConcatV2Op $inputs, (HLO_ConstOp OneElementAttr:$axis), $unused), +// Defining op for `axis` is TensorFlow constant op in the pattern as during +// the conversion, original Concat op operands still refers to the old ops even +// if HLO constant op is introduced as an replacement for the TensorFlow +// Constant op. +def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused), (HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 7545a558926..12370b20557 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir namespace mlir { @@ -36,7 +37,7 @@ std::unique_ptr> createLegalizeTFPass(); /// Converts the provided Operation as well as all nested operations into XLA /// dialect using the conversion patterns registered by the XLA dialect. -void legalizeTF(Operation* op); +LogicalResult legalizeTF(Operation* op); /// Lowers XLA control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass();