In lower tensorlist pass, if allow_tensorlist_pass_through is true, then emit debug logs (since this is not a failure), otherwise emit standard logs (which is considered as a conversion failure) so that end users will see the error message from the converter stack trace.

PiperOrigin-RevId: 358292147
Change-Id: I3c4aa367d09e6fbe4543e66121371a4a4e6d311f
This commit is contained in:
Haoliang Zhang 2021-02-18 16:34:35 -08:00 committed by TensorFlower Gardener
parent 351fd5e844
commit 50ea65ffda

View File

@ -168,6 +168,22 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size); start_position, slice_size);
} }
template <typename OpT>
class TensorListOpConverterBase : public OpConversionPattern<OpT> {
public:
explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
bool allow_tensorlist_pass_through)
: OpConversionPattern<OpT>::OpConversionPattern(context),
allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {}
protected:
// This flag will control the behavior of error emitting during rewrite:
// 1) If it's true, then patterns will only emit errors during debug or
// tracing mode. 2) If it's false, then patterns will emit standard errors
// when there is a rewrite failure.
bool allow_tensorlist_pass_through_;
};
// Converts tf.Const containing variant of type TensorList to a tensor of // Converts tf.Const containing variant of type TensorList to a tensor of
// primitive element types. Each of the individual tensor in the list is // primitive element types. Each of the individual tensor in the list is
// converted to an ElementsAttr and then those are packed together using // converted to an ElementsAttr and then those are packed together using
@ -313,8 +329,9 @@ struct ConvertTensorListSetItem
// to generate an equivalent raw tensor. Derived classes are required to // to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method. // override GetNumElements method.
template <typename OpT> template <typename OpT>
struct ConvertTensorListInitOp : public OpConversionPattern<OpT> { struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
using OpConversionPattern<OpT>::OpConversionPattern; using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
// Create and return a 1-d tensor with exactly one element equal to the number // Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with. // of list elements to initialize the output tensor list with.
@ -331,11 +348,13 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) { dtype.isInteger(32) || dtype.isInteger(64))) {
return rewriter.notifyMatchFailure( llvm::Twine error_info =
op,
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite " "integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass"); "transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
} }
Value element_shape = operands[0]; Value element_shape = operands[0];
@ -389,10 +408,12 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (element_shape_acquired) break; if (element_shape_acquired) break;
} }
if (!element_shape_acquired) { if (!element_shape_acquired) {
return rewriter.notifyMatchFailure( llvm::Twine error_info =
op,
"requires element_shape to be 1D tensor during TF Lite " "requires element_shape to be 1D tensor during TF Lite "
"transformation pass"); "transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
} }
} }
} }
@ -480,8 +501,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
struct ConvertTensorListReserve struct ConvertTensorListReserve
: public ConvertTensorListInitOp<TF::TensorListReserveOp> { : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
explicit ConvertTensorListReserve(MLIRContext *context) explicit ConvertTensorListReserve(MLIRContext *context,
: ConvertTensorListInitOp(context) {} bool allow_tensorlist_pass_through)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands, Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override { PatternRewriter *rewriter) const override {
@ -503,8 +525,9 @@ struct ConvertTensorListReserve
// have a different behavior compared to TensorFlow in case of errors. // have a different behavior compared to TensorFlow in case of errors.
struct ConvertEmptyTensorList struct ConvertEmptyTensorList
: public ConvertTensorListInitOp<TF::EmptyTensorListOp> { : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
explicit ConvertEmptyTensorList(MLIRContext *context) explicit ConvertEmptyTensorList(MLIRContext *context,
: ConvertTensorListInitOp(context) {} bool allow_tensorlist_pass_through)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands, Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override { PatternRewriter *rewriter) const override {
@ -1005,12 +1028,13 @@ void LowerStaticTensorListPass::runOnOperation() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
populateWithGenerated(context, patterns); populateWithGenerated(context, patterns);
patterns.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity, patterns.insert<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
ConvertTensorListGetItem, ConvertTensorListLength, ConvertTensorListLength, ConvertTensorListPushBack,
ConvertTensorListPushBack, ConvertTensorListReserve,
ConvertTensorListSetItem, ConvertTensorListStack, ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>( ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
context); context);
patterns.insert<ConvertEmptyTensorList, ConvertTensorListReserve>(
context, allow_tensorlist_pass_through);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
if (!allow_tensorlist_pass_through) { if (!allow_tensorlist_pass_through) {