From 50ea65ffdab19e4ffaf90b3b762bb5756b5e5809 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Thu, 18 Feb 2021 16:34:35 -0800 Subject: [PATCH] In lower tensorlist pass, if `allow_tensorlist_pass_through` is true, then emit debug logs (since this is not a failure), otherwise emit standard logs (which is considered as a conversion failure) so that end users will see the error message from the converter stack trace. PiperOrigin-RevId: 358292147 Change-Id: I3c4aa367d09e6fbe4543e66121371a4a4e6d311f --- .../transforms/lower_static_tensor_list.cc | 54 +++++++++++++------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index de9d88495fb..23682ec33ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -168,6 +168,22 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list, start_position, slice_size); } +template +class TensorListOpConverterBase : public OpConversionPattern { + public: + explicit TensorListOpConverterBase(MLIRContext *context, + bool allow_tensorlist_pass_through) + : OpConversionPattern::OpConversionPattern(context), + allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {} + + protected: + // This flag will control the behavior of error emitting during rewrite: + // 1) If it's true, then patterns will only emit errors during debug or + // tracing mode. 2) If it's false, then patterns will emit standard errors + // when there is a rewrite failure. + bool allow_tensorlist_pass_through_; +}; + // Converts tf.Const containing variant of type TensorList to a tensor of // primitive element types. Each of the individual tensor in the list is // converted to an ElementsAttr and then those are packed together using @@ -313,8 +329,9 @@ struct ConvertTensorListSetItem // to generate an equivalent raw tensor. Derived classes are required to // override GetNumElements method. template -struct ConvertTensorListInitOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertTensorListInitOp : public TensorListOpConverterBase { + using TensorListOpConverterBase::TensorListOpConverterBase; + using TensorListOpConverterBase::allow_tensorlist_pass_through_; // Create and return a 1-d tensor with exactly one element equal to the number // of list elements to initialize the output tensor list with. @@ -331,11 +348,13 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(32) || dtype.isInteger(64))) { - return rewriter.notifyMatchFailure( - op, + llvm::Twine error_info = "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " - "transformation pass"); + "transformation pass"; + return allow_tensorlist_pass_through_ + ? rewriter.notifyMatchFailure(op, error_info) + : op.emitOpError(error_info); } Value element_shape = operands[0]; @@ -389,10 +408,12 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (element_shape_acquired) break; } if (!element_shape_acquired) { - return rewriter.notifyMatchFailure( - op, + llvm::Twine error_info = "requires element_shape to be 1D tensor during TF Lite " - "transformation pass"); + "transformation pass"; + return allow_tensorlist_pass_through_ + ? rewriter.notifyMatchFailure(op, error_info) + : op.emitOpError(error_info); } } } @@ -480,8 +501,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern { struct ConvertTensorListReserve : public ConvertTensorListInitOp { - explicit ConvertTensorListReserve(MLIRContext *context) - : ConvertTensorListInitOp(context) {} + explicit ConvertTensorListReserve(MLIRContext *context, + bool allow_tensorlist_pass_through) + : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {} Value GetNumElements(TF::TensorListReserveOp op, ArrayRef operands, PatternRewriter *rewriter) const override { @@ -503,8 +525,9 @@ struct ConvertTensorListReserve // have a different behavior compared to TensorFlow in case of errors. struct ConvertEmptyTensorList : public ConvertTensorListInitOp { - explicit ConvertEmptyTensorList(MLIRContext *context) - : ConvertTensorListInitOp(context) {} + explicit ConvertEmptyTensorList(MLIRContext *context, + bool allow_tensorlist_pass_through) + : ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {} Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef operands, PatternRewriter *rewriter) const override { @@ -1005,12 +1028,13 @@ void LowerStaticTensorListPass::runOnOperation() { OwningRewritePatternList patterns; populateWithGenerated(context, patterns); - patterns.insert( context); + patterns.insert( + context, allow_tensorlist_pass_through); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { if (!allow_tensorlist_pass_through) {