Add TF2XLA fallback patterns to LegalizeTF pass
Now LegalizeTF pass can optionally apply TF2XLA fallback patterns. ConvertMLIRToXlaComputation now uses this instead of a separate TF2XLA fallback pass which has following advantages: - ops which need TF -> TF lowering before a fallback pattern can be applied can now be legalized, previously they couldn't - saves intermediate canonicalization and shape inference passes - more flexible control over order in which patterns should be applied PiperOrigin-RevId: 325536262 Change-Id: I6f5d42fe889e5d9404ac558b23ca2e4d6277226f
This commit is contained in:
parent
6b65afa420
commit
00dbf072db
@ -312,29 +312,25 @@ Status ConvertMLIRToXlaComputation(
|
||||
// inside PromoteResourcesToArgs.
|
||||
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
||||
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
|
||||
/*tf2xla_fallback_device_type=*/device_type));
|
||||
for (auto& target_pass : custom_legalization_passes) {
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
|
||||
}
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
|
||||
// legalization pass.
|
||||
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
||||
// Run shape inference pass to propagate shapes through tensor_cast operations
|
||||
// from static to dynamic shapes. This could be generated if the shape
|
||||
// inference was originally missing in a TF op but the corresponding HLO op
|
||||
// had static shape after lowering.
|
||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Run LegalizeTFPass again because the previous legalization passes can
|
||||
// expose more graph pruning and canonicalization opportunities that are
|
||||
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
||||
// invocation.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
|
||||
/*tf2xla_fallback_device_type=*/device_type));
|
||||
// In order to export to XLA, we must sink constants to control flow regions,
|
||||
// since XLA uses functional control flow.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
|
@ -56,6 +56,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":type_to_shape",
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
|
||||
|
@ -0,0 +1,50 @@
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s
|
||||
// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s
|
||||
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s
|
||||
|
||||
// We run this test four times:
|
||||
// 1) Legalize without using TF2XLA fallback (ops cannot be legalized).
|
||||
// 2) Use fallback with a device that supports all ops (ops can be legalized).
|
||||
// 3) Use fallback with unspecified device (ops cannot be legalized).
|
||||
// 4) Use fallback with specified but unsupported device (ops cannot be legalized).
|
||||
//
|
||||
// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases
|
||||
// produce remarks that don't occur for 1) and 2) and there is no way to check
|
||||
// the remarks only for 3) and 4) (except using two files).
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4
|
||||
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
|
||||
%max_size = mhlo.constant dense<2> : tensor<i32>
|
||||
// NO_FALLBACK: tf.NonMaxSuppressionV4
|
||||
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4
|
||||
// UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
|
||||
// UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mirror_pad
|
||||
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||
// NO_FALLBACK: tf.MirrorPad
|
||||
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad
|
||||
// UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad
|
||||
// UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad
|
||||
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: atan2
|
||||
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
|
||||
// NO_FALLBACK: tf.Atan2
|
||||
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
|
||||
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
|
||||
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
|
||||
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
|
||||
return %0: tensor<4x4x4xf32>
|
||||
}
|
||||
|
||||
}
|
@ -71,9 +71,14 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF &) {}
|
||||
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) {
|
||||
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||
allow_partial_conversion_ = allow_partial_conversion;
|
||||
legalize_chlo_ = legalize_chlo;
|
||||
use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
|
||||
if (tf2xla_fallback_device_type.hasValue()) {
|
||||
device_type_ = tf2xla_fallback_device_type.getValue().str();
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
@ -89,6 +94,17 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
llvm::cl::desc(
|
||||
"Also legalizes intermediate chlo ops to hlo (default true)"),
|
||||
llvm::cl::init(true)};
|
||||
Option<bool> use_tf2xla_fallback_{
|
||||
*this, "use-tf2xla-fallback",
|
||||
llvm::cl::desc(
|
||||
"Also use TF2XLA fallback for legalization (default false)"),
|
||||
llvm::cl::init(false)};
|
||||
Option<std::string> device_type_{
|
||||
*this, "device-type",
|
||||
llvm::cl::desc(
|
||||
"The device type used by TF2XLA fallback. Must be specified if "
|
||||
"use-tf2xla-fallback is true, otherwise not used."),
|
||||
llvm::cl::init("INVALID_DEVICE_TYPE")};
|
||||
};
|
||||
|
||||
/// Returns if the given TF data format string is the default format.
|
||||
@ -5746,9 +5762,14 @@ void EmitLegalizationErrors(Operation *op,
|
||||
|
||||
// Performs the lowering to XLA dialect.
|
||||
void LegalizeTF::runOnFunction() {
|
||||
if (failed(
|
||||
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_)))
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
|
||||
if (use_tf2xla_fallback_) {
|
||||
tf2xla_fallback_device_type = device_type_;
|
||||
}
|
||||
if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
|
||||
legalize_chlo_, tf2xla_fallback_device_type))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
@ -5758,14 +5779,29 @@ static PassRegistration<LegalizeTF> pass(
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
bool legalize_chlo) {
|
||||
LogicalResult legalizeTF(
|
||||
Operation *op, bool allow_partial_conversion, bool legalize_chlo,
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
// Add lowering patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
// Note that the `OperationConverter` orders patterns lexicographically by:
|
||||
// 1) Ascending legalization depth (i.e., minimum number of patterns necessary
|
||||
// to arrive at conversion target).
|
||||
// 2) Descending pattern benefit.
|
||||
// 3) Order of patterns in `OwningRewritePatternList`.
|
||||
|
||||
// Add TF->HLO legalization patterns.
|
||||
PopulateLegalizeTfPatterns(context, &patterns);
|
||||
|
||||
// Add TF->HLO legalization patterns via TF2XLA fallback.
|
||||
if (tf2xla_fallback_device_type.hasValue()) {
|
||||
PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
|
||||
patterns);
|
||||
}
|
||||
|
||||
// Add TF->TF lowering patterns.
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
|
||||
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
||||
// CHLO first.
|
||||
if (legalize_chlo) {
|
||||
@ -5805,11 +5841,6 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
|
||||
// Add patterns that lower some of the high level TensorFlow ops to lower
|
||||
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
|
||||
// here for lowering to HLO.
|
||||
TF::PopulateLoweringTFPatterns(context, patterns);
|
||||
patterns->insert<
|
||||
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
||||
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
|
||||
@ -5838,8 +5869,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
bool allow_partial_conversion, bool legalize_chlo) {
|
||||
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
|
||||
bool allow_partial_conversion, bool legalize_chlo,
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
|
||||
tf2xla_fallback_device_type);
|
||||
}
|
||||
|
||||
} // end namespace mhlo
|
||||
|
@ -528,8 +528,7 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
// global device type for all TensorFlow ops.
|
||||
Option<std::string> device_type_{
|
||||
*this, "device-type",
|
||||
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
|
||||
"Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")};
|
||||
llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
|
||||
};
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
|
@ -36,8 +36,13 @@ namespace mhlo {
|
||||
|
||||
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
|
||||
/// false, emits an error if there is any operation that can't be legalized.
|
||||
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
|
||||
/// patterns from TF2XLA fallback for provided device type (see
|
||||
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
|
||||
/// used.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
bool allow_partial_conversion = false, bool legalize_chlo = true);
|
||||
bool allow_partial_conversion = false, bool legalize_chlo = true,
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
|
||||
|
||||
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
|
||||
/// specified device type.
|
||||
@ -63,8 +68,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
|
||||
/// dialect using the conversion patterns registered by the HLO dialect. When
|
||||
/// allow_partial_conversion is false, emits an error if there is any operation
|
||||
/// that can't be legalized.
|
||||
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false,
|
||||
bool legalize_chlo = true);
|
||||
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
|
||||
/// patterns from TF2XLA fallback for provided device type (see
|
||||
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
|
||||
/// used.
|
||||
LogicalResult legalizeTF(
|
||||
Operation* op, bool allow_partial_conversion = false,
|
||||
bool legalize_chlo = true,
|
||||
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
|
||||
|
||||
// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication
|
||||
// ops.
|
||||
|
Loading…
x
Reference in New Issue
Block a user