From 469de83a9c485563bfda0006f0dbc0673ee75f14 Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 6 May 2020 17:10:15 -0700 Subject: [PATCH] Emit error messages for all missing legalizations in TF to XLA full legalization pass. A full legalization conversion stops after the first failed conversion encountered. For building the TF to XLA bridge, it is useful for this pass to continue through and emit information about all of the missing ops. Instead, use the Partial conversion mode to get the full set of operations that are not legalizable. The "full" conversion succeeds if this set is empty. This does not change the behavior when the full legalization pass succeeds. However, if the conversion fails, the outputted error message is now much more useful. For the sake of demonstrating what this might look like with a large model, I've run this on Transformer with the Unary op lowerings removed. Resulting error message output: Before this change: ``` Compilation failure: MLIR TF to XLA legalization failed-:64:11: error: failed to legalize operation 'tf.Rsqrt' -:64:11: note: see current operation: %37 = "tf.Rsqrt"(%33) : (tensor) -> tensor ``` After this change (default case): ``` Compilation failure: MLIR TF to XLA legalization failed-:4:3: error: The following operations cannot be legalized: tf.Rsqrt (count: 217); tf.SoftmaxCrossEntropyWithLogits (count: 1); tf.Sqrt (count: 370). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc. -:4:3: error: Emitting more detail about one op that failed to legalize... -:251:12: error: 'tf.Rsqrt' op is not legalizable -:251:12: note: see current operation: %224 = "tf.Rsqrt"(%220) : (tensor) -> tensor ``` After this change (verbose case, with logging set to 1): ``` Compilation failure: MLIR TF to XLA legalization failed-:4:3: error: The following operations cannot be legalized: tf.Rsqrt (count: 217); tf.SoftmaxCrossEntropyWithLogits (count: 1); tf.Sqrt (count: 370). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc. -:4:3: error: Emitting more detail about one of each type of op that failed to legalize... -:1769:13: error: 'tf.Rsqrt' op is not legalizable -:1769:13: note: see current operation: %1742 = "tf.Rsqrt"(%1738) : (tensor) -> tensor -:3308:24: error: 'tf.SoftmaxCrossEntropyWithLogits' op is not legalizable -:3308:24: note: see current operation: %loss, %backprop = "tf.SoftmaxCrossEntropyWithLogits"(%3495, %3503) : (tensor<768x33708xf32>, tensor<768x33708xf32>) -> (tensor<768xf32>, tensor<768x33708xf32>) -:6944:13: error: 'tf.Sqrt' op is not legalizable -:6944:13: note: see current operation: %7319 = "tf.Sqrt"(%7318) : (tensor) -> tensor ``` PiperOrigin-RevId: 310258485 Change-Id: Id6f8709c2548e7ded9fb6fe690c9d17e6c6d394f --- .../tests/legalize-tf-full-conversion.mlir | 21 ++++++- .../mlir/xla/transforms/legalize_tf.cc | 57 ++++++++++++++++++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir index d2b4d269fef..0660af4ed1c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir @@ -1,22 +1,24 @@ // RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics +// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} func @tf_executor_graph_op() { - // expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}} tf_executor.graph { %0 = tf_executor.island { + // expected-error@+1 {{'tf.NoOp' op is not legalizable}} "tf.NoOp"() {} : () -> () tf_executor.yield } tf_executor.fetch } return - } // ----- +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // expected-error@+1 {{failed to legalize operation 'tf.OpA'}} + // expected-error@+1 {{'tf.OpA' op is not legalizable}} %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } + +// ----- + +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} +func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // expected-error@+1 {{'tf.OpA' op is not legalizable}} + %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %2: tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index fb03c9b82e5..de808bc8ad2 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project @@ -4785,6 +4786,51 @@ class ConvertQrOp : public OpRewritePattern { } }; +// Emits debug information which includes the number of ops of each type which +// failed to legalize. +void EmitLegalizationErrors(Operation *op, + const DenseSet &nonlegalized_ops) { + // Track the legalization failures by mapping op name to information about + // that failure: the number of unlegalized occurances of the op, and one + // example operation that failed. + std::map> op_name_to_error_info; + DenseSet error_ops; + for (Operation *nonlegalized_op : nonlegalized_ops) { + // Increment count of this legalization failure. + StringRef op_name = nonlegalized_op->getName().getStringRef(); + // If this emplace is successful, it's the first time we've encountered + // this op type. Initialize count to 0 so that after increment, it is 1. + auto insertion_result = op_name_to_error_info.emplace( + op_name, std::make_pair(0, nonlegalized_op)); + ++insertion_result.first->second.first; + } + std::vector error_messages; + error_messages.reserve(op_name_to_error_info.size()); + for (const auto &op_info : op_name_to_error_info) { + error_messages.push_back( + llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first)); + } + Location loc = op->getLoc(); + emitError(loc) << "The following operations cannot be legalized: " + << llvm::join(error_messages, "; ") + << ". These legalization failure(s) may be due to missing TF " + "to HLO lowerings and/or unsupported attributes, etc."; + // Emit more information about the missing ops. This error message + // contains useful details beyond the op name (input and output shapes, + // attributes, etc.). + if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) { + emitError(loc) + << "Emitting more detail about one op that failed to legalize..."; + } else if (VLOG_IS_ON(1)) { + emitError(loc) << "Emitting more detail about one of each type of op " + "that failed to legalize..."; + } + for (const auto &op_info : op_name_to_error_info) { + op_info.second.second->emitOpError() << "is not legalizable"; + if (!VLOG_IS_ON(1)) break; + } +} + // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) @@ -4841,7 +4887,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { if (!allow_partial_conversion) { // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. target.addLegalOp(); - return applyFullConversion(op, target, patterns); + DenseSet nonlegalized_ops; + LogicalResult result = applyPartialConversion( + op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); + // In order to enforce that the conversion result is fully converted, + // fail if there are any nonlegalized ops in the set. + if (failed(result) || !nonlegalized_ops.empty()) { + EmitLegalizationErrors(op, nonlegalized_ops); + return failure(); + } + return result; } return applyPartialConversion(op, target, patterns);