[TFLite] LSTM projection bias should be F32 tensor

PiperOrigin-RevId: 268960320
This commit is contained in:
Ashwin Murthy 2019-09-13 12:42:05 -07:00 committed by TensorFlower Gardener
parent 3b5b75e304
commit 9dc630468f
2 changed files with 7 additions and 1 deletions

View File

@ -236,7 +236,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
proj_bias_ = !projection_type_
? none_
: CreateI32SplatConst(&builder_, {n_output_}, 0,
: CreateF32SplatConst(&builder_, {n_output_}, 0,
fused_func_op_.getLoc());
}

View File

@ -135,6 +135,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
// input layer norm is None
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
// proj_bias is F32
EXPECT_TRUE(return_op->getOperand(17)
->getType()
.cast<RankedTensorType>()
.getElementType()
.isF32());
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_lstm_func_.getType().getResults();