In lower tensorlist pass, if allow_tensorlist_pass_through
is true, then emit debug logs (since this is not a failure), otherwise emit standard logs (which is considered as a conversion failure) so that end users will see the error message from the converter stack trace.
PiperOrigin-RevId: 358292147 Change-Id: I3c4aa367d09e6fbe4543e66121371a4a4e6d311f
This commit is contained in:
parent
351fd5e844
commit
50ea65ffda
@ -168,6 +168,22 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
start_position, slice_size);
|
||||
}
|
||||
|
||||
template <typename OpT>
|
||||
class TensorListOpConverterBase : public OpConversionPattern<OpT> {
|
||||
public:
|
||||
explicit TensorListOpConverterBase<OpT>(MLIRContext *context,
|
||||
bool allow_tensorlist_pass_through)
|
||||
: OpConversionPattern<OpT>::OpConversionPattern(context),
|
||||
allow_tensorlist_pass_through_(allow_tensorlist_pass_through) {}
|
||||
|
||||
protected:
|
||||
// This flag will control the behavior of error emitting during rewrite:
|
||||
// 1) If it's true, then patterns will only emit errors during debug or
|
||||
// tracing mode. 2) If it's false, then patterns will emit standard errors
|
||||
// when there is a rewrite failure.
|
||||
bool allow_tensorlist_pass_through_;
|
||||
};
|
||||
|
||||
// Converts tf.Const containing variant of type TensorList to a tensor of
|
||||
// primitive element types. Each of the individual tensor in the list is
|
||||
// converted to an ElementsAttr and then those are packed together using
|
||||
@ -313,8 +329,9 @@ struct ConvertTensorListSetItem
|
||||
// to generate an equivalent raw tensor. Derived classes are required to
|
||||
// override GetNumElements method.
|
||||
template <typename OpT>
|
||||
struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
using OpConversionPattern<OpT>::OpConversionPattern;
|
||||
struct ConvertTensorListInitOp : public TensorListOpConverterBase<OpT> {
|
||||
using TensorListOpConverterBase<OpT>::TensorListOpConverterBase;
|
||||
using TensorListOpConverterBase<OpT>::allow_tensorlist_pass_through_;
|
||||
|
||||
// Create and return a 1-d tensor with exactly one element equal to the number
|
||||
// of list elements to initialize the output tensor list with.
|
||||
@ -331,11 +348,13 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
llvm::Twine error_info =
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
"transformation pass";
|
||||
return allow_tensorlist_pass_through_
|
||||
? rewriter.notifyMatchFailure(op, error_info)
|
||||
: op.emitOpError(error_info);
|
||||
}
|
||||
|
||||
Value element_shape = operands[0];
|
||||
@ -389,10 +408,12 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (element_shape_acquired) break;
|
||||
}
|
||||
if (!element_shape_acquired) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
llvm::Twine error_info =
|
||||
"requires element_shape to be 1D tensor during TF Lite "
|
||||
"transformation pass");
|
||||
"transformation pass";
|
||||
return allow_tensorlist_pass_through_
|
||||
? rewriter.notifyMatchFailure(op, error_info)
|
||||
: op.emitOpError(error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -480,8 +501,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
|
||||
struct ConvertTensorListReserve
|
||||
: public ConvertTensorListInitOp<TF::TensorListReserveOp> {
|
||||
explicit ConvertTensorListReserve(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
explicit ConvertTensorListReserve(MLIRContext *context,
|
||||
bool allow_tensorlist_pass_through)
|
||||
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
|
||||
|
||||
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
@ -503,8 +525,9 @@ struct ConvertTensorListReserve
|
||||
// have a different behavior compared to TensorFlow in case of errors.
|
||||
struct ConvertEmptyTensorList
|
||||
: public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
|
||||
explicit ConvertEmptyTensorList(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
explicit ConvertEmptyTensorList(MLIRContext *context,
|
||||
bool allow_tensorlist_pass_through)
|
||||
: ConvertTensorListInitOp(context, allow_tensorlist_pass_through) {}
|
||||
|
||||
Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
@ -1005,12 +1028,13 @@ void LowerStaticTensorListPass::runOnOperation() {
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(context, patterns);
|
||||
patterns.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
|
||||
ConvertTensorListGetItem, ConvertTensorListLength,
|
||||
ConvertTensorListPushBack, ConvertTensorListReserve,
|
||||
patterns.insert<ConvertConst, ConvertIdentity, ConvertTensorListGetItem,
|
||||
ConvertTensorListLength, ConvertTensorListPushBack,
|
||||
ConvertTensorListSetItem, ConvertTensorListStack,
|
||||
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
|
||||
context);
|
||||
patterns.insert<ConvertEmptyTensorList, ConvertTensorListReserve>(
|
||||
context, allow_tensorlist_pass_through);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
if (!allow_tensorlist_pass_through) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user