[TFLite] - Support to stack layers in babelfish rnnt.
Add a module pass to stack the various layers in encoder/decoder in RNN-T. This pass will rely on future function inlining, DCE and canonicalization to ensure the module contains the appropriate IR elements that can then be legalized and exported to TFLite flatbuffer. PiperOrigin-RevId: 269723812
This commit is contained in:
parent
ac7d0914c6
commit
8ab4a0a7bb
tensorflow/compiler/mlir/lite/utils
@ -313,7 +313,9 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
|
||||
fused_func_op_.setAttr("tf._implements",
|
||||
builder_.getStringAttr(GetCompositeOpName()));
|
||||
}
|
||||
SmallVector<int64_t, 2> output_shape{1, n_output_};
|
||||
// SmallVector<int64_t, 2> output_shape{1, n_output_};
|
||||
// TODO(b/141026710):
|
||||
SmallVector<int64_t, 2> output_shape{1, -1};
|
||||
auto input_types = fused_func_op_.getType().getInputs();
|
||||
auto output_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
@ -333,7 +335,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
||||
GenerateFusedOpOperands();
|
||||
|
||||
// Create the fused LSTM op.
|
||||
SmallVector<int64_t, 2> output_shape = {1, n_output_};
|
||||
SmallVector<int64_t, 2> output_shape = {1, -1};
|
||||
auto result_type = builder_.getTensorType(
|
||||
output_shape,
|
||||
input_->getType().cast<RankedTensorType>().getElementType());
|
||||
|
@ -144,7 +144,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
SmallVector<int64_t, 2> output_shape{1, -1};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
@ -209,7 +209,7 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
SmallVector<int64_t, 2> output_shape{1, 2};
|
||||
SmallVector<int64_t, 2> output_shape{1, -1};
|
||||
EXPECT_EQ(output_types[0].cast<RankedTensorType>().getShape().size(),
|
||||
output_shape.size());
|
||||
for (int i = 0; i < output_shape.size(); i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user