Update some tflite conversions to tracked transformations

PiperOrigin-RevId: 347694794
Change-Id: I4ea06a87ffd180ede4d5e94ea6e68eac0189ada7
This commit is contained in:
Tres Popp 2020-12-15 14:35:34 -08:00 committed by TensorFlower Gardener
parent 2cc7706804
commit 5ef17067d1

View File

@ -466,12 +466,11 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
attributes.push_back(
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
auto lstm_op = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
Value lstm_result = rewriter.create<TFL::UnidirectionalSequenceLSTMOp>(
op->getLoc(), result_types, inputs, attributes);
// Rewire the output.
op->getResult(2).replaceAllUsesWith(lstm_op.getResult());
rewriter.eraseOp(op);
rewriter.replaceOp(op, {nullptr, nullptr, lstm_result});
return success();
}
};
@ -525,12 +524,11 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
attributes.push_back(
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
Value rnn_result = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
op->getLoc(), result_types, inputs, attributes);
// Rewire the output.
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
rewriter.eraseOp(op);
rewriter.replaceOp(op, {nullptr, rnn_result});
return success();
}