Add TF2XLA fallback patterns to LegalizeTF pass
Now LegalizeTF pass can optionally apply TF2XLA fallback patterns. ConvertMLIRToXlaComputation now uses this instead of a separate TF2XLA fallback pass which has following advantages: - ops which need TF -> TF lowering before a fallback pattern can be applied can now be legalized, previously they couldn't - saves intermediate canonicalization and shape inference passes - more flexible control over order in which patterns should be applied PiperOrigin-RevId: 325536262 Change-Id: I6f5d42fe889e5d9404ac558b23ca2e4d6277226f
This commit is contained in:
parent
6b65afa420
commit
00dbf072db
@ -312,29 +312,25 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
// inside PromoteResourcesToArgs.
|
// inside PromoteResourcesToArgs.
|
||||||
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
||||||
|
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true));
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||||
|
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
|
||||||
|
/*tf2xla_fallback_device_type=*/device_type));
|
||||||
for (auto& target_pass : custom_legalization_passes) {
|
for (auto& target_pass : custom_legalization_passes) {
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
|
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
|
||||||
}
|
}
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
|
||||||
|
|
||||||
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
|
|
||||||
// legalization pass.
|
|
||||||
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
|
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
|
||||||
|
|
||||||
// Run shape inference pass to propagate shapes through tensor_cast operations
|
// Run shape inference pass to propagate shapes through tensor_cast operations
|
||||||
// from static to dynamic shapes. This could be generated if the shape
|
// from static to dynamic shapes. This could be generated if the shape
|
||||||
// inference was originally missing in a TF op but the corresponding HLO op
|
// inference was originally missing in a TF op but the corresponding HLO op
|
||||||
// had static shape after lowering.
|
// had static shape after lowering.
|
||||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
|
||||||
// Run LegalizeTFPass again because the previous legalization passes can
|
// Run LegalizeTFPass again because the previous legalization passes can
|
||||||
// expose more graph pruning and canonicalization opportunities that are
|
// expose more graph pruning and canonicalization opportunities that are
|
||||||
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
||||||
// invocation.
|
// invocation.
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||||
|
/*allow_partial_conversion=*/false, /*legalize_chlo=*/true,
|
||||||
|
/*tf2xla_fallback_device_type=*/device_type));
|
||||||
// In order to export to XLA, we must sink constants to control flow regions,
|
// In order to export to XLA, we must sink constants to control flow regions,
|
||||||
// since XLA uses functional control flow.
|
// since XLA uses functional control flow.
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||||
|
@ -56,6 +56,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":type_to_shape",
|
":type_to_shape",
|
||||||
|
":xla_legalize_tf_with_tf2xla",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
|
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
|
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s
|
||||||
|
// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s
|
||||||
|
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s
|
||||||
|
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s
|
||||||
|
|
||||||
|
// We run this test four times:
|
||||||
|
// 1) Legalize without using TF2XLA fallback (ops cannot be legalized).
|
||||||
|
// 2) Use fallback with a device that supports all ops (ops can be legalized).
|
||||||
|
// 3) Use fallback with unspecified device (ops cannot be legalized).
|
||||||
|
// 4) Use fallback with specified but unsupported device (ops cannot be legalized).
|
||||||
|
//
|
||||||
|
// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases
|
||||||
|
// produce remarks that don't occur for 1) and 2) and there is no way to check
|
||||||
|
// the remarks only for 3) and 4) (except using two files).
|
||||||
|
|
||||||
|
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||||
|
|
||||||
|
// CHECK-LABEL: non_max_suppression_v4
|
||||||
|
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
|
||||||
|
%max_size = mhlo.constant dense<2> : tensor<i32>
|
||||||
|
// NO_FALLBACK: tf.NonMaxSuppressionV4
|
||||||
|
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4
|
||||||
|
// UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
|
||||||
|
// UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4
|
||||||
|
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||||
|
return %0#0 : tensor<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: mirror_pad
|
||||||
|
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||||
|
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||||
|
// NO_FALLBACK: tf.MirrorPad
|
||||||
|
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad
|
||||||
|
// UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad
|
||||||
|
// UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad
|
||||||
|
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
|
||||||
|
return %1 : tensor<4x7xcomplex<f64>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: atan2
|
||||||
|
func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
|
||||||
|
// NO_FALLBACK: tf.Atan2
|
||||||
|
// SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2
|
||||||
|
// UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2
|
||||||
|
// UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2
|
||||||
|
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
|
||||||
|
return %0: tensor<4x4x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -71,9 +71,14 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
|||||||
public:
|
public:
|
||||||
LegalizeTF() = default;
|
LegalizeTF() = default;
|
||||||
LegalizeTF(const LegalizeTF &) {}
|
LegalizeTF(const LegalizeTF &) {}
|
||||||
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) {
|
explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
|
||||||
|
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||||
allow_partial_conversion_ = allow_partial_conversion;
|
allow_partial_conversion_ = allow_partial_conversion;
|
||||||
legalize_chlo_ = legalize_chlo;
|
legalize_chlo_ = legalize_chlo;
|
||||||
|
use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
|
||||||
|
if (tf2xla_fallback_device_type.hasValue()) {
|
||||||
|
device_type_ = tf2xla_fallback_device_type.getValue().str();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Performs the lowering to XLA dialect.
|
/// Performs the lowering to XLA dialect.
|
||||||
@ -89,6 +94,17 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
|||||||
llvm::cl::desc(
|
llvm::cl::desc(
|
||||||
"Also legalizes intermediate chlo ops to hlo (default true)"),
|
"Also legalizes intermediate chlo ops to hlo (default true)"),
|
||||||
llvm::cl::init(true)};
|
llvm::cl::init(true)};
|
||||||
|
Option<bool> use_tf2xla_fallback_{
|
||||||
|
*this, "use-tf2xla-fallback",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Also use TF2XLA fallback for legalization (default false)"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
Option<std::string> device_type_{
|
||||||
|
*this, "device-type",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"The device type used by TF2XLA fallback. Must be specified if "
|
||||||
|
"use-tf2xla-fallback is true, otherwise not used."),
|
||||||
|
llvm::cl::init("INVALID_DEVICE_TYPE")};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Returns if the given TF data format string is the default format.
|
/// Returns if the given TF data format string is the default format.
|
||||||
@ -5746,9 +5762,14 @@ void EmitLegalizationErrors(Operation *op,
|
|||||||
|
|
||||||
// Performs the lowering to XLA dialect.
|
// Performs the lowering to XLA dialect.
|
||||||
void LegalizeTF::runOnFunction() {
|
void LegalizeTF::runOnFunction() {
|
||||||
if (failed(
|
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
|
||||||
legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_)))
|
if (use_tf2xla_fallback_) {
|
||||||
|
tf2xla_fallback_device_type = device_type_;
|
||||||
|
}
|
||||||
|
if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
|
||||||
|
legalize_chlo_, tf2xla_fallback_device_type))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LegalizeTF> pass(
|
static PassRegistration<LegalizeTF> pass(
|
||||||
@ -5758,14 +5779,29 @@ static PassRegistration<LegalizeTF> pass(
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||||
|
|
||||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
LogicalResult legalizeTF(
|
||||||
bool legalize_chlo) {
|
Operation *op, bool allow_partial_conversion, bool legalize_chlo,
|
||||||
|
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
|
|
||||||
// Add lowering patterns to the list.
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
// Note that the `OperationConverter` orders patterns lexicographically by:
|
||||||
|
// 1) Ascending legalization depth (i.e., minimum number of patterns necessary
|
||||||
|
// to arrive at conversion target).
|
||||||
|
// 2) Descending pattern benefit.
|
||||||
|
// 3) Order of patterns in `OwningRewritePatternList`.
|
||||||
|
|
||||||
|
// Add TF->HLO legalization patterns.
|
||||||
PopulateLegalizeTfPatterns(context, &patterns);
|
PopulateLegalizeTfPatterns(context, &patterns);
|
||||||
|
|
||||||
|
// Add TF->HLO legalization patterns via TF2XLA fallback.
|
||||||
|
if (tf2xla_fallback_device_type.hasValue()) {
|
||||||
|
PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
|
||||||
|
patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add TF->TF lowering patterns.
|
||||||
|
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||||
|
|
||||||
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
||||||
// CHLO first.
|
// CHLO first.
|
||||||
if (legalize_chlo) {
|
if (legalize_chlo) {
|
||||||
@ -5805,11 +5841,6 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
|||||||
void PopulateLegalizeTfPatterns(MLIRContext *context,
|
void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(context, patterns);
|
populateWithGenerated(context, patterns);
|
||||||
|
|
||||||
// Add patterns that lower some of the high level TensorFlow ops to lower
|
|
||||||
// level TensorFlow ops. So, we don't have to target all the TensorFlow ops
|
|
||||||
// here for lowering to HLO.
|
|
||||||
TF::PopulateLoweringTFPatterns(context, patterns);
|
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op,
|
||||||
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
|
ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp,
|
||||||
@ -5838,8 +5869,10 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||||
bool allow_partial_conversion, bool legalize_chlo) {
|
bool allow_partial_conversion, bool legalize_chlo,
|
||||||
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
|
llvm::Optional<StringRef> tf2xla_fallback_device_type) {
|
||||||
|
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
|
||||||
|
tf2xla_fallback_device_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace mhlo
|
} // end namespace mhlo
|
||||||
|
@ -528,8 +528,7 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
|||||||
// global device type for all TensorFlow ops.
|
// global device type for all TensorFlow ops.
|
||||||
Option<std::string> device_type_{
|
Option<std::string> device_type_{
|
||||||
*this, "device-type",
|
*this, "device-type",
|
||||||
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
|
llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
|
||||||
"Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static PassRegistration<LegalizeTF> pass(
|
static PassRegistration<LegalizeTF> pass(
|
||||||
|
@ -36,8 +36,13 @@ namespace mhlo {
|
|||||||
|
|
||||||
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
|
/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is
|
||||||
/// false, emits an error if there is any operation that can't be legalized.
|
/// false, emits an error if there is any operation that can't be legalized.
|
||||||
|
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
|
||||||
|
/// patterns from TF2XLA fallback for provided device type (see
|
||||||
|
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
|
||||||
|
/// used.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||||
bool allow_partial_conversion = false, bool legalize_chlo = true);
|
bool allow_partial_conversion = false, bool legalize_chlo = true,
|
||||||
|
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
|
||||||
|
|
||||||
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
|
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
|
||||||
/// specified device type.
|
/// specified device type.
|
||||||
@ -63,8 +68,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
|
|||||||
/// dialect using the conversion patterns registered by the HLO dialect. When
|
/// dialect using the conversion patterns registered by the HLO dialect. When
|
||||||
/// allow_partial_conversion is false, emits an error if there is any operation
|
/// allow_partial_conversion is false, emits an error if there is any operation
|
||||||
/// that can't be legalized.
|
/// that can't be legalized.
|
||||||
LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false,
|
/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization
|
||||||
bool legalize_chlo = true);
|
/// patterns from TF2XLA fallback for provided device type (see
|
||||||
|
/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not
|
||||||
|
/// used.
|
||||||
|
LogicalResult legalizeTF(
|
||||||
|
Operation* op, bool allow_partial_conversion = false,
|
||||||
|
bool legalize_chlo = true,
|
||||||
|
llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None);
|
||||||
|
|
||||||
// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication
|
// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication
|
||||||
// ops.
|
// ops.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user