Use DialectConversion framework in legalize_tf pass instead of GreedyPatternRewriter
In legalize_tf pass, we have patterns for TF to HLO lowerings as well as TF to TF lowerings. In case of multiple applicable patterns, lowering to TF prematurely may lose some optimization opportunities. Switching to DialectConversion makes sure that direct lowering from TF to HLO are preferred. Fixed tests to directly test the pass output and not the canonicalized output that the greedy rewriter produces. PiperOrigin-RevId: 269125599
This commit is contained in:
parent
2179d24c5e
commit
ab17c2f07c
@ -150,8 +150,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
// error reporting system. Report a generic error if pass manager failed
|
||||
// without emitting a diagnostic.
|
||||
mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
|
||||
mlir::xla_hlo::legalizeTF(module_op);
|
||||
if (!error_handler.ok()) {
|
||||
mlir::LogicalResult result = mlir::xla_hlo::legalizeTF(module_op);
|
||||
if (failed(result)) {
|
||||
return error_handler.Combine(
|
||||
errors::Internal("MLIR TF to XLA legalization failed"));
|
||||
}
|
||||
|
||||
@ -94,6 +94,7 @@ cc_library(
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -321,7 +321,7 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK-LABEL: func @relu
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0> : tensor<1xi32>}
|
||||
// CHECK-NEXT: xla_hlo.max %arg0, %[[ZERO]] : tensor<1xi32>
|
||||
// CHECK-NEXT: xla_hlo.max %[[ZERO]], %arg0 : tensor<1xi32>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %0: tensor<1xi32>
|
||||
}
|
||||
@ -342,10 +342,9 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-LABEL: func @simple_softmax
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
|
||||
func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor<f32>}
|
||||
// CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
|
||||
// Verify reduce op for max computation and its body.
|
||||
// CHECK: %[[NEG_INF:.*]] = "xla_hlo.constant"() {value = dense<0xFF800000> : tensor<f32>}
|
||||
// CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]])
|
||||
// CHECK: xla_hlo.max
|
||||
// CHECK: "xla_hlo.return"
|
||||
@ -353,14 +352,17 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
|
||||
// CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]])
|
||||
// CHECK: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
|
||||
// Verify reduce op for summation and its body.
|
||||
// CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]])
|
||||
// CHECK: %[[ZERO:.*]] = "xla_hlo.constant"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
// CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]])
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: "xla_hlo.return"
|
||||
// CHECK: {dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// return %[[RESULT]]
|
||||
|
||||
%0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
@ -369,31 +370,30 @@ class ConvertSoftmaxOp : public OpRewritePattern<TF::SoftmaxOp> {
|
||||
} // end namespace xla
|
||||
} // end namespace mlir
|
||||
|
||||
void mlir::xla_hlo::legalizeTF(Operation *op) {
|
||||
LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) {
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
xla::populateWithGenerated(op->getContext(), &patterns);
|
||||
xla::populateWithGenerated(context, &patterns);
|
||||
|
||||
// Add patterns that lower some of the high level TensorFlow ops to lower
|
||||
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
|
||||
// here for lowering to HLO.
|
||||
//
|
||||
// TODO(b/140964075): Switch to DialectConversion to avoid premature lowering
|
||||
// to lower level TensorFlow ops if we actually want to target the higher
|
||||
// level TensorFlow op directly.
|
||||
mlir::TF::PopulateLoweringTFPatterns(op->getContext(), &patterns);
|
||||
mlir::TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
|
||||
patterns.insert<mlir::xla::ConvertMaxPoolOp>(op->getContext());
|
||||
patterns.insert<mlir::xla::ConvertSoftmaxOp>(op->getContext());
|
||||
|
||||
// Recursively applies rewrite patterns to nested operations.
|
||||
applyPatternsGreedily(op, patterns);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
|
||||
return applyPartialConversion(op, target, patterns);
|
||||
}
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
void LegalizeTF::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
mlir::xla_hlo::legalizeTF(func);
|
||||
if (failed(mlir::xla_hlo::legalizeTF(getFunction()))) signalPassFailure();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
|
||||
@ -119,11 +119,15 @@ def GetHLOAxisFromTFAxis : NativeCodeCall<
|
||||
def HasRankedFirstOperand
|
||||
: Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>;
|
||||
|
||||
// Here, we convert from TensorFlow axis format to HLO axis format which
|
||||
// This pattern converts TensorFlow axis format to HLO axis format which
|
||||
// doesn't wrap around like TensorFlow and is always positive. For this
|
||||
// conversion, use the first input to get inputs rank. Other inputs need not be
|
||||
// ranked.
|
||||
def : Pat<(TF_ConcatV2Op $inputs, (HLO_ConstOp OneElementAttr:$axis), $unused),
|
||||
// Defining op for `axis` is TensorFlow constant op in the pattern as during
|
||||
// the conversion, original Concat op operands still refers to the old ops even
|
||||
// if HLO constant op is introduced as an replacement for the TensorFlow
|
||||
// Constant op.
|
||||
def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused),
|
||||
(HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)),
|
||||
[(HasRankedFirstOperand $inputs)]>;
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -36,7 +37,7 @@ std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass();
|
||||
|
||||
/// Converts the provided Operation as well as all nested operations into XLA
|
||||
/// dialect using the conversion patterns registered by the XLA dialect.
|
||||
void legalizeTF(Operation* op);
|
||||
LogicalResult legalizeTF(Operation* op);
|
||||
|
||||
/// Lowers XLA control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeControlFlowPass();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user