[TFLite] - Support to stack layers in babelfish rnnt.

Add a module pass to stack the various layers in encoder/decoder in RNN-T. This pass will rely on future function inlining, DCE and canonicalization to ensure the module contains the appropriate IR elements that can then be legalized and exported to TFLite flatbuffer.

PiperOrigin-RevId: 269723812
This commit is contained in:
Ashwin Murthy 2019-09-17 21:57:37 -07:00 committed by TensorFlower Gardener
parent ac7d0914c6
commit 8ab4a0a7bb
2 changed files with 6 additions and 4 deletions

View File

@ -313,7 +313,9 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
fused_func_op_.setAttr("tf._implements",
builder_.getStringAttr(GetCompositeOpName()));
}
SmallVector<int64_t, 2> output_shape{1, n_output_};
// SmallVector<int64_t, 2> output_shape{1, n_output_};
// TODO(b/141026710):
SmallVector<int64_t, 2> output_shape{1, -1};
auto input_types = fused_func_op_.getType().getInputs();
auto output_type = builder_.getTensorType(
output_shape,
@ -333,7 +335,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
GenerateFusedOpOperands();
// Create the fused LSTM op.
SmallVector<int64_t, 2> output_shape = {1, n_output_};
SmallVector<int64_t, 2> output_shape = {1, -1};
auto result_type = builder_.getTensorType(
output_shape,
input_->getType().cast<RankedTensorType>().getElementType());

View File

@ -144,7 +144,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
SmallVector<int64_t, 2> output_shape{1, -1};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {
@ -209,7 +209,7 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();
SmallVector<int64_t, 2> output_shape{1, 2};
SmallVector<int64_t, 2> output_shape{1, -1};
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
output_shape.size());
for (int i = 0; i < output_shape.size(); i++) {