Use DialectConversion framework in legalize_tf pass instead of GreedyPatternRewriter

In legalize_tf pass, we have patterns for TF to HLO lowerings as well as TF to TF lowerings. In case of multiple applicable patterns, lowering to TF prematurely may lose some optimization opportunities. Switching to DialectConversion makes sure that direct lowering from TF to HLO are preferred.

Fixed tests to directly test the pass output and not the canonicalized output that the greedy rewriter produces.

PiperOrigin-RevId: 269125599
This commit is contained in:
Smit Hinsu 2019-09-14 18:28:47 -07:00 committed by TensorFlower Gardener
parent 2179d24c5e
commit ab17c2f07c
6 changed files with 29 additions and 21 deletions

View File

@ -150,8 +150,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// error reporting system. Report a generic error if pass manager failed
// without emitting a diagnostic.
mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
mlir::xla_hlo::legalizeTF(module_op);
if (!error_handler.ok()) {
mlir::LogicalResult result = mlir::xla_hlo::legalizeTF(module_op);
if (failed(result)) {
return error_handler.Combine(
errors::Internal("MLIR TF to XLA legalization failed"));
}

View File

@ -94,6 +94,7 @@ cc_library(
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
],
alwayslink = 1,
)

View File

@ -321,7 +321,7 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK-LABEL: func @relu
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-NEXT: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0> : tensor<1xi32>}
// CHECK-NEXT: xla_hlo.max %arg0, %[[ZERO]] : tensor<1xi32>
// CHECK-NEXT: xla_hlo.max %[[ZERO]], %arg0 : tensor<1xi32>
%0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
return %0: tensor<1xi32>
}
@ -342,10 +342,9 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-LABEL: func @simple_softmax
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor<f32>}
// CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor<f32>}
// Verify reduce op for max computation and its body.
// CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor<f32>}
// CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]])
// CHECK: xla_hlo.max
// CHECK: "xla_hlo.return"
@ -353,14 +352,17 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]])
// CHECK: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// Verify reduce op for summation and its body.
// CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]])
// CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]])
// CHECK: xla_hlo.add
// CHECK: "xla_hlo.return"
// CHECK: {dimensions = dense<1> : tensor<1xi64>}
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// return %[[RESULT]]
%0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
@ -369,31 +370,30 @@ class ConvertSoftmaxOp : public OpRewritePattern<TF::SoftmaxOp> {
} // end namespace xla
} // end namespace mlir
void mlir::xla_hlo::legalizeTF(Operation *op) {
LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) {
MLIRContext *context = op->getContext();
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
xla::populateWithGenerated(op->getContext(), &patterns);
xla::populateWithGenerated(context, &patterns);
// Add patterns that lower some of the high level TensorFlow ops to lower
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
// here for lowering to HLO.
//
// TODO(b/140964075): Switch to DialectConversion to avoid premature lowering
// to lower level TensorFlow ops if we actually want to target the higher
// level TensorFlow op directly.
mlir::TF::PopulateLoweringTFPatterns(op->getContext(), &patterns);
mlir::TF::PopulateLoweringTFPatterns(context, &patterns);
patterns.insert<mlir::xla::ConvertMaxPoolOp>(op->getContext());
patterns.insert<mlir::xla::ConvertSoftmaxOp>(op->getContext());
// Recursively applies rewrite patterns to nested operations.
applyPatternsGreedily(op, patterns);
ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>();
return applyPartialConversion(op, target, patterns);
}
/// Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
auto func = getFunction();
mlir::xla_hlo::legalizeTF(func);
if (failed(mlir::xla_hlo::legalizeTF(getFunction()))) signalPassFailure();
}
static PassRegistration<LegalizeTF> pass(

View File

@ -119,11 +119,15 @@ def GetHLOAxisFromTFAxis : NativeCodeCall<
def HasRankedFirstOperand
: Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>;
// Here, we convert from TensorFlow axis format to HLO axis format which
// This pattern converts TensorFlow axis format to HLO axis format which
// doesn't wrap around like TensorFlow and is always positive. For this
// conversion, use the first input to get inputs rank. Other inputs need not be
// ranked.
def : Pat<(TF_ConcatV2Op $inputs, (HLO_ConstOp OneElementAttr:$axis), $unused),
// Defining op for `axis` is TensorFlow constant op in the pattern as during
// the conversion, original Concat op operands still refers to the old ops even
// if HLO constant op is introduced as an replacement for the TensorFlow
// Constant op.
def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused),
(HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)),
[(HasRankedFirstOperand $inputs)]>;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
namespace mlir {
@ -36,7 +37,7 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass();
/// Converts the provided Operation as well as all nested operations into XLA
/// dialect using the conversion patterns registered by the XLA dialect.
void legalizeTF(Operation* op);
LogicalResult legalizeTF(Operation* op);
/// Lowers XLA control flow ops to the Standard dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeControlFlowPass();