Add TF2XLA fallback patterns to LegalizeTF pass

Now LegalizeTF pass can optionally apply TF2XLA fallback patterns.
ConvertMLIRToXlaComputation now uses this instead of a separate TF2XLA fallback
pass which has following advantages:
- ops which need TF -> TF lowering before a fallback pattern can be applied can
  now be legalized, previously they couldn't
- saves intermediate canonicalization and shape inference passes
- more flexible control over order in which patterns should be applied

PiperOrigin-RevId: 325536262
Change-Id: I6f5d42fe889e5d9404ac558b23ca2e4d6277226f
This commit is contained in:
Michael Gester 2020-08-07 17:02:03 -07:00 committed by TensorFlower Gardener
parent 6b65afa420
commit 00dbf072db
6 changed files with 119 additions and 29 deletions

View File

@ -312,29 +312,25 @@ Status ConvertMLIRToXlaComputation(
// inside PromoteResourcesToArgs.
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
for (auto& target_pass : custom_legalization_passes) {
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
}
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
// legalization pass.
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// Run shape inference pass to propagate shapes through tensor_cast operations
// from static to dynamic shapes. This could be generated if the shape
// inference was originally missing in a TF op but the corresponding HLO op
// had static shape after lowering.
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
// Run LegalizeTFPass again because the previous legalization passes can
// expose more graph pruning and canonicalization opportunities that are
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
// invocation.
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
// In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow.
tf2xla.addNestedPass<mlir::FuncOp>(

View File

@ -56,6 +56,7 @@ cc_library(
],
deps = [
":type_to_shape",
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
"//tensorflow/compiler/mlir/hlo:convert_op_folder",

View File

@ -0,0 +1,50 @@
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s
// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s
// We run this test four times:
// 1) Legalize without using TF2XLA fallback (ops cannot be legalized).
// 2) Use fallback with a device that supports all ops (ops can be legalized).
// 3) Use fallback with unspecified device (ops cannot be legalized).
// 4) Use fallback with specified but unsupported device (ops cannot be legalized).
//
// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases
// produce remarks that don't occur for 1) and 2) and there is no way to check
// the remarks only for 3) and 4) (except using two files).
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: non_max_suppression_v4
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
%max_size = mhlo.constant dense<2> : tensor<i32>
// NO_FALLBACK: tf.NonMaxSuppressionV4
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4
// UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
// UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
}
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// NO_FALLBACK: tf.MirrorPad
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad
// UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad
// UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: atan2
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
// NO_FALLBACK: tf.Atan2
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
return %0: tensor<4x4x4xf32>
}
}

View File

@ -71,9 +71,14 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF &) {}
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) {
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
allow_partial_conversion_ = allow_partial_conversion;
legalize_chlo_ = legalize_chlo;
use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
if (tf2xla_fallback_device_type.hasValue()) {
device_type_ = tf2xla_fallback_device_type.getValue().str();
}
}
/// Performs the lowering to XLA dialect.
@ -89,6 +94,17 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
llvm::cl::desc(
"Also legalizes intermediate chlo ops to hlo (default true)"),
llvm::cl::init(true)};
Option<bool> use_tf2xla_fallback_{
*this, "use-tf2xla-fallback",
llvm::cl::desc(
"Also use TF2XLA fallback for legalization (default false)"),
llvm::cl::init(false)};
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc(
"The device type used by TF2XLA fallback. Must be specified if "
"use-tf2xla-fallback is true, otherwise not used."),
llvm::cl::init("INVALID_DEVICE_TYPE")};
};
/// Returns if the given TF data format string is the default format.
@ -5746,9 +5762,14 @@ void EmitLegalizationErrors(Operation *op,
// Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
if (failed(
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_)))
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
if (use_tf2xla_fallback_) {
tf2xla_fallback_device_type = device_type_;
}
if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
legalize_chlo_, tf2xla_fallback_device_type))) {
signalPassFailure();
}
}
static PassRegistration<LegalizeTF> pass(
@ -5758,14 +5779,29 @@ static PassRegistration<LegalizeTF> pass(
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
bool legalize_chlo) {
LogicalResult legalizeTF(
Operation *op, bool allow_partial_conversion, bool legalize_chlo,
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
MLIRContext *context = op->getContext();
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
// Note that the `OperationConverter` orders patterns lexicographically by:
// 1) Ascending legalization depth (i.e., minimum number of patterns necessary
// to arrive at conversion target).
// 2) Descending pattern benefit.
// 3) Order of patterns in `OwningRewritePatternList`.
// Add TF->HLO legalization patterns.
PopulateLegalizeTfPatterns(context, &patterns);
// Add TF->HLO legalization patterns via TF2XLA fallback.
if (tf2xla_fallback_device_type.hasValue()) {
PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
patterns);
}
// Add TF->TF lowering patterns.
TF::PopulateLoweringTFPatterns(context, &patterns);
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
// CHLO first.
if (legalize_chlo) {
@ -5805,11 +5841,6 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
void PopulateLegalizeTfPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns);
// Add patterns that lower some of the high level TensorFlow ops to lower
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
// here for lowering to HLO.
TF::PopulateLoweringTFPatterns(context, patterns);
patterns->insert<
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
@ -5838,8 +5869,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion, bool legalize_chlo) {
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
bool allow_partial_conversion, bool legalize_chlo,
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
tf2xla_fallback_device_type);
}
} // end namespace mhlo

View File

@ -528,8 +528,7 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
// global device type for all TensorFlow ops.
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
"Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")};
llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
};
static PassRegistration<LegalizeTF> pass(

View File

@ -36,8 +36,13 @@ namespace mhlo {
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
/// false, emits an error if there is any operation that can't be legalized.
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
/// patterns from TF2XLA fallback for provided device type (see
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
/// used.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
bool allow_partial_conversion = false, bool legalize_chlo = true);
bool allow_partial_conversion = false, bool legalize_chlo = true,
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
/// specified device type.
@ -63,8 +68,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
/// dialect using the conversion patterns registered by the HLO dialect. When
/// allow_partial_conversion is false, emits an error if there is any operation
/// that can't be legalized.
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false,
bool legalize_chlo = true);
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
/// patterns from TF2XLA fallback for provided device type (see
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
/// used.
LogicalResult legalizeTF(
Operation* op, bool allow_partial_conversion = false,
bool legalize_chlo = true,
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication
// ops.