Update some tflite conversions to tracked transformations
PiperOrigin-RevId: 347694794 Change-Id: I4ea06a87ffd180ede4d5e94ea6e68eac0189ada7
This commit is contained in:
parent
2cc7706804
commit
5ef17067d1
@ -466,12 +466,11 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
|
||||
|
||||
auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
|
||||
Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
|
||||
op->getLoc(), result_types, inputs, attributes);
|
||||
|
||||
// Rewire the output.
|
||||
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -525,12 +524,11 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
|
||||
|
||||
auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
|
||||
Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
|
||||
op->getLoc(), result_types, inputs, attributes);
|
||||
|
||||
// Rewire the output.
|
||||
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOp(op, {nullptr, rnn_result});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user