Add TF2XLA fallback patterns to LegalizeTF pass

Now LegalizeTF pass can optionally apply TF2XLA fallback patterns.
ConvertMLIRToXlaComputation now uses this instead of a separate TF2XLA fallback
pass which has following advantages:
- ops which need TF -> TF lowering before a fallback pattern can be applied can
  now be legalized, previously they couldn't
- saves intermediate canonicalization and shape inference passes
- more flexible control over order in which patterns should be applied

PiperOrigin-RevId: 325536262
Change-Id: I6f5d42fe889e5d9404ac558b23ca2e4d6277226f
This commit is contained in:
Michael Gester 2020-08-07 17:02:03 -07:00 committed by TensorFlower Gardener
parent 6b65afa420
commit 00dbf072db
6 changed files with 119 additions and 29 deletions

View File

@ -312,29 +312,25 @@ Status ConvertMLIRToXlaComputation(
// inside PromoteResourcesToArgs. // inside PromoteResourcesToArgs.
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true)); tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
for (auto& target_pass : custom_legalization_passes) { for (auto& target_pass : custom_legalization_passes) {
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass)); tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
} }
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
// legalization pass.
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// Run shape inference pass to propagate shapes through tensor_cast operations // Run shape inference pass to propagate shapes through tensor_cast operations
// from static to dynamic shapes. This could be generated if the shape // from static to dynamic shapes. This could be generated if the shape
// inference was originally missing in a TF op but the corresponding HLO op // inference was originally missing in a TF op but the corresponding HLO op
// had static shape after lowering. // had static shape after lowering.
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
// Run LegalizeTFPass again because the previous legalization passes can // Run LegalizeTFPass again because the previous legalization passes can
// expose more graph pruning and canonicalization opportunities that are // expose more graph pruning and canonicalization opportunities that are
// necessary for the second LegalizeTFPass(allow_partial_conversion=false) // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
// invocation. // invocation.
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false)); tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
// In order to export to XLA, we must sink constants to control flow regions, // In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow. // since XLA uses functional control flow.
tf2xla.addNestedPass<mlir::FuncOp>( tf2xla.addNestedPass<mlir::FuncOp>(

View File

@ -56,6 +56,7 @@ cc_library(
], ],
deps = [ deps = [
":type_to_shape", ":type_to_shape",
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
"//tensorflow/compiler/mlir/hlo:convert_op_folder", "//tensorflow/compiler/mlir/hlo:convert_op_folder",

View File

@ -0,0 +1,50 @@
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s
// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s
// We run this test four times:
// 1) Legalize without using TF2XLA fallback (ops cannot be legalized).
// 2) Use fallback with a device that supports all ops (ops can be legalized).
// 3) Use fallback with unspecified device (ops cannot be legalized).
// 4) Use fallback with specified but unsupported device (ops cannot be legalized).
//
// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases
// produce remarks that don't occur for 1) and 2) and there is no way to check
// the remarks only for 3) and 4) (except using two files).
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: non_max_suppression_v4
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
%max_size = mhlo.constant dense<2> : tensor<i32>
// NO_FALLBACK: tf.NonMaxSuppressionV4
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4
// UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
// UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
}
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// NO_FALLBACK: tf.MirrorPad
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad
// UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad
// UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: atan2
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
// NO_FALLBACK: tf.Atan2
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
return %0: tensor<4x4x4xf32>
}
}

View File

@ -71,9 +71,14 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public: public:
LegalizeTF() = default; LegalizeTF() = default;
LegalizeTF(const LegalizeTF &) {} LegalizeTF(const LegalizeTF &) {}
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
allow_partial_conversion_ = allow_partial_conversion; allow_partial_conversion_ = allow_partial_conversion;
legalize_chlo_ = legalize_chlo; legalize_chlo_ = legalize_chlo;
use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
if (tf2xla_fallback_device_type.hasValue()) {
device_type_ = tf2xla_fallback_device_type.getValue().str();
}
} }
/// Performs the lowering to XLA dialect. /// Performs the lowering to XLA dialect.
@ -89,6 +94,17 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
llvm::cl::desc( llvm::cl::desc(
"Also legalizes intermediate chlo ops to hlo (default true)"), "Also legalizes intermediate chlo ops to hlo (default true)"),
llvm::cl::init(true)}; llvm::cl::init(true)};
Option<bool> use_tf2xla_fallback_{
*this, "use-tf2xla-fallback",
llvm::cl::desc(
"Also use TF2XLA fallback for legalization (default false)"),
llvm::cl::init(false)};
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc(
"The device type used by TF2XLA fallback. Must be specified if "
"use-tf2xla-fallback is true, otherwise not used."),
llvm::cl::init("INVALID_DEVICE_TYPE")};
}; };
/// Returns if the given TF data format string is the default format. /// Returns if the given TF data format string is the default format.
@ -5746,9 +5762,14 @@ void EmitLegalizationErrors(Operation *op,
// Performs the lowering to XLA dialect. // Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() { void LegalizeTF::runOnFunction() {
if (failed( llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) if (use_tf2xla_fallback_) {
tf2xla_fallback_device_type = device_type_;
}
if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
legalize_chlo_, tf2xla_fallback_device_type))) {
signalPassFailure(); signalPassFailure();
}
} }
static PassRegistration<LegalizeTF> pass( static PassRegistration<LegalizeTF> pass(
@ -5758,14 +5779,29 @@ static PassRegistration<LegalizeTF> pass(
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, LogicalResult legalizeTF(
bool legalize_chlo) { Operation *op, bool allow_partial_conversion, bool legalize_chlo,
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
// Add lowering patterns to the list.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
// Note that the `OperationConverter` orders patterns lexicographically by:
// 1) Ascending legalization depth (i.e., minimum number of patterns necessary
// to arrive at conversion target).
// 2) Descending pattern benefit.
// 3) Order of patterns in `OwningRewritePatternList`.
// Add TF->HLO legalization patterns.
PopulateLegalizeTfPatterns(context, &patterns); PopulateLegalizeTfPatterns(context, &patterns);
// Add TF->HLO legalization patterns via TF2XLA fallback.
if (tf2xla_fallback_device_type.hasValue()) {
PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
patterns);
}
// Add TF->TF lowering patterns.
TF::PopulateLoweringTFPatterns(context, &patterns);
// Populate with CHLO->HLO lowerings to account for TF ops legalized to // Populate with CHLO->HLO lowerings to account for TF ops legalized to
// CHLO first. // CHLO first.
if (legalize_chlo) { if (legalize_chlo) {
@ -5805,11 +5841,6 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
void PopulateLegalizeTfPatterns(MLIRContext *context, void PopulateLegalizeTfPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns); populateWithGenerated(context, patterns);
// Add patterns that lower some of the high level TensorFlow ops to lower
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
// here for lowering to HLO.
TF::PopulateLoweringTFPatterns(context, patterns);
patterns->insert< patterns->insert<
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
@ -5838,8 +5869,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass( std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion, bool legalize_chlo) { bool allow_partial_conversion, bool legalize_chlo,
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo); llvm::Optional<StringRef> tf2xla_fallback_device_type) {
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
tf2xla_fallback_device_type);
} }
} // end namespace mhlo } // end namespace mhlo

View File

@ -528,8 +528,7 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
// global device type for all TensorFlow ops. // global device type for all TensorFlow ops.
Option<std::string> device_type_{ Option<std::string> device_type_{
*this, "device-type", *this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops. " llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
"Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")};
}; };
static PassRegistration<LegalizeTF> pass( static PassRegistration<LegalizeTF> pass(

View File

@ -36,8 +36,13 @@ namespace mhlo {
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
/// false, emits an error if there is any operation that can't be legalized. /// false, emits an error if there is any operation that can't be legalized.
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
/// patterns from TF2XLA fallback for provided device type (see
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
/// used.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass( std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion = false, bool legalize_chlo = true); bool allow_partial_conversion = false, bool legalize_chlo = true,
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
/// specified device type. /// specified device type.
@ -63,8 +68,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
/// dialect using the conversion patterns registered by the HLO dialect. When /// dialect using the conversion patterns registered by the HLO dialect. When
/// allow_partial_conversion is false, emits an error if there is any operation /// allow_partial_conversion is false, emits an error if there is any operation
/// that can't be legalized. /// that can't be legalized.
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
bool legalize_chlo = true); /// patterns from TF2XLA fallback for provided device type (see
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
/// used.
LogicalResult legalizeTF(
Operation* op, bool allow_partial_conversion = false,
bool legalize_chlo = true,
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication // Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication
// ops. // ops.