[TFLite] LSTM projection bias should be F32 tensor
PiperOrigin-RevId: 268960320
This commit is contained in:
parent
3b5b75e304
commit
9dc630468f
@ -236,7 +236,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
|
||||
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
|
||||
proj_bias_ = !projection_type_
|
||||
? none_
|
||||
: CreateI32SplatConst(&builder_, {n_output_}, 0,
|
||||
: CreateF32SplatConst(&builder_, {n_output_}, 0,
|
||||
fused_func_op_.getLoc());
|
||||
}
|
||||
|
||||
|
@ -135,6 +135,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
||||
EXPECT_FALSE(return_op->getOperand(1)->getType().isa<NoneType>());
|
||||
// input layer norm is None
|
||||
EXPECT_TRUE(return_op->getOperand(20)->getType().isa<NoneType>());
|
||||
// proj_bias is F32
|
||||
EXPECT_TRUE(return_op->getOperand(17)
|
||||
->getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.isF32());
|
||||
|
||||
EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1);
|
||||
auto output_types = fused_lstm_func_.getType().getResults();
|
||||
|
Loading…
x
Reference in New Issue
Block a user