Emit error messages for all missing legalizations in TF to XLA full legalization pass.

A full legalization conversion stops after the first failed conversion encountered. For building the TF to XLA bridge, it is useful for this pass to continue through and emit information about all of the missing ops. Instead, use the Partial conversion mode to get the full set of operations that are not legalizable. The "full" conversion succeeds if this set is empty.

This does not change the behavior when the full legalization pass succeeds. However, if the conversion fails, the outputted error message is now much more useful.

For the sake of demonstrating what this might look like with a large model, I've run this on Transformer with the Unary op lowerings removed. Resulting error message output:

Before this change:
```
Compilation failure: MLIR TF to XLA legalization failed-:64:11: error: failed to legalize operation 'tf.Rsqrt'
-:64:11: note: see current operation: %37 = "tf.Rsqrt"(%33) : (tensor<f32>) -> tensor<f32>
```

After this change (default case):
```
Compilation failure: MLIR TF to XLA legalization failed-:4:3: error: The following operations cannot be legalized: tf.Rsqrt (count: 217); tf.SoftmaxCrossEntropyWithLogits (count: 1); tf.Sqrt (count: 370). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.
-:4:3: error: Emitting more detail about one op that failed to legalize...
-:251:12: error: 'tf.Rsqrt' op is not legalizable
-:251:12: note: see current operation: %224 = "tf.Rsqrt"(%220) : (tensor<f32>) -> tensor<f32>
```

After this change (verbose case, with logging set to 1):
```
Compilation failure: MLIR TF to XLA legalization failed-:4:3: error: The following operations cannot be legalized: tf.Rsqrt (count: 217); tf.SoftmaxCrossEntropyWithLogits (count: 1); tf.Sqrt (count: 370). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.
-:4:3: error: Emitting more detail about one of each type of op that failed to legalize...
-:1769:13: error: 'tf.Rsqrt' op is not legalizable
-:1769:13: note: see current operation: %1742 = "tf.Rsqrt"(%1738) : (tensor<f32>) -> tensor<f32>
-:3308:24: error: 'tf.SoftmaxCrossEntropyWithLogits' op is not legalizable
-:3308:24: note: see current operation: %loss, %backprop = "tf.SoftmaxCrossEntropyWithLogits"(%3495, %3503) : (tensor<768x33708xf32>, tensor<768x33708xf32>) -> (tensor<768xf32>, tensor<768x33708xf32>)
-:6944:13: error: 'tf.Sqrt' op is not legalizable
-:6944:13: note: see current operation: %7319 = "tf.Sqrt"(%7318) : (tensor<f32>) -> tensor<f32>
```
PiperOrigin-RevId: 310258485
Change-Id: Id6f8709c2548e7ded9fb6fe690c9d17e6c6d394f
This commit is contained in:
Lucy Fox 2020-05-06 17:10:15 -07:00 committed by TensorFlower Gardener
parent cff8cf4fa1
commit 469de83a9c
2 changed files with 74 additions and 4 deletions

View File

@ -1,22 +1,24 @@
// RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics
// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
// expected-error@below{{Emitting more detail about one op that failed to legalize...}}
func @tf_executor_graph_op() {
// expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}}
tf_executor.graph {
%0 = tf_executor.island {
// expected-error@+1 {{'tf.NoOp' op is not legalizable}}
"tf.NoOp"() {} : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}
// -----
// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// expected-error@+1 {{failed to legalize operation 'tf.OpA'}}
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// -----
// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}}
// expected-error@below{{Emitting more detail about one op that failed to legalize...}}
func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// expected-error@+1 {{'tf.OpA' op is not legalizable}}
%0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %2: tensor<2xi32>
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
@ -4785,6 +4786,51 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
const DenseSet<Operation *> &nonlegalized_ops) {
// Track the legalization failures by mapping op name to information about
// that failure: the number of unlegalized occurances of the op, and one
// example operation that failed.
std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
DenseSet<Operation *> error_ops;
for (Operation *nonlegalized_op : nonlegalized_ops) {
// Increment count of this legalization failure.
StringRef op_name = nonlegalized_op->getName().getStringRef();
// If this emplace is successful, it's the first time we've encountered
// this op type. Initialize count to 0 so that after increment, it is 1.
auto insertion_result = op_name_to_error_info.emplace(
op_name, std::make_pair(0, nonlegalized_op));
++insertion_result.first->second.first;
}
std::vector<std::string> error_messages;
error_messages.reserve(op_name_to_error_info.size());
for (const auto &op_info : op_name_to_error_info) {
error_messages.push_back(
llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
}
Location loc = op->getLoc();
emitError(loc) << "The following operations cannot be legalized: "
<< llvm::join(error_messages, "; ")
<< ". These legalization failure(s) may be due to missing TF "
"to HLO lowerings and/or unsupported attributes, etc.";
// Emit more information about the missing ops. This error message
// contains useful details beyond the op name (input and output shapes,
// attributes, etc.).
if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
emitError(loc)
<< "Emitting more detail about one op that failed to legalize...";
} else if (VLOG_IS_ON(1)) {
emitError(loc) << "Emitting more detail about one of each type of op "
"that failed to legalize...";
}
for (const auto &op_info : op_name_to_error_info) {
op_info.second.second->emitOpError() << "is not legalizable";
if (!VLOG_IS_ON(1)) break;
}
}
// Performs the lowering to XLA dialect.
void LegalizeTF::runOnFunction() {
if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
@ -4841,7 +4887,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
if (!allow_partial_conversion) {
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
return applyFullConversion(op, target, patterns);
DenseSet<Operation *> nonlegalized_ops;
LogicalResult result = applyPartialConversion(
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops);
// In order to enforce that the conversion result is fully converted,
// fail if there are any nonlegalized ops in the set.
if (failed(result) || !nonlegalized_ops.empty()) {
EmitLegalizationErrors(op, nonlegalized_ops);
return failure();
}
return result;
}
return applyPartialConversion(op, target, patterns);