In lower tensorlist pass, if allow_tensorlist_pass_through is true, then emit debug logs (since this is not a failure), otherwise emit standard logs (which is considered as a conversion failure) so that end users will see the error message from the converter stack trace.

PiperOrigin-RevId: 358292147
Change-Id: I3c4aa367d09e6fbe4543e66121371a4a4e6d311f
This commit is contained in:
Haoliang Zhang 2021-02-18 16:34:35 -08:00 committed by TensorFlower Gardener
parent 351fd5e844
commit 50ea65ffda

View File

@ -168,6 +168,22 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size);
}
template <typename OpT>
class TensorListOpConverterBase : public OpConversionPattern<OpT> {
public:
explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
bool allow_tensorlist_pass_through)
: OpConversionPattern<OpT>::OpConversionPattern(context),
allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {}
protected:
// This flag will control the behavior of error emitting during rewrite:
// 1) If it's true, then patterns will only emit errors during debug or
// tracing mode. 2) If it's false, then patterns will emit standard errors
// when there is a rewrite failure.
bool allow_tensorlist_pass_through_;
};
// Converts tf.Const containing variant of type TensorList to a tensor of
// primitive element types. Each of the individual tensor in the list is
// converted to an ElementsAttr and then those are packed together using
@ -313,8 +329,9 @@ struct ConvertTensorListSetItem
// to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method.
template <typename OpT>
struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
using OpConversionPattern<OpT>::OpConversionPattern;
struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
@ -331,11 +348,13 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
return rewriter.notifyMatchFailure(
op,
llvm::Twine error_info =
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
"transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
Value element_shape = operands[0];
@ -389,10 +408,12 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (element_shape_acquired) break;
}
if (!element_shape_acquired) {
return rewriter.notifyMatchFailure(
op,
llvm::Twine error_info =
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
"transformation pass";
return allow_tensorlist_pass_through_
? rewriter.notifyMatchFailure(op, error_info)
: op.emitOpError(error_info);
}
}
}
@ -480,8 +501,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
struct ConvertTensorListReserve
: public ConvertTensorListInitOp<TF::TensorListReserveOp> {
explicit ConvertTensorListReserve(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
explicit ConvertTensorListReserve(MLIRContext *context,
bool allow_tensorlist_pass_through)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override {
@ -503,8 +525,9 @@ struct ConvertTensorListReserve
// have a different behavior compared to TensorFlow in case of errors.
struct ConvertEmptyTensorList
: public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
explicit ConvertEmptyTensorList(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
explicit ConvertEmptyTensorList(MLIRContext *context,
bool allow_tensorlist_pass_through)
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override {
@ -1005,12 +1028,13 @@ void LowerStaticTensorListPass::runOnOperation() {
OwningRewritePatternList patterns;
populateWithGenerated(context, patterns);
patterns.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
ConvertTensorListGetItem, ConvertTensorListLength,
ConvertTensorListPushBack, ConvertTensorListReserve,
patterns.insert<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
ConvertTensorListLength, ConvertTensorListPushBack,
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
context);
patterns.insert<ConvertEmptyTensorList, ConvertTensorListReserve>(
context, allow_tensorlist_pass_through);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
if (!allow_tensorlist_pass_through) {