diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 3f52c373f42..d98101bd4cb 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -236,7 +236,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() { void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() { proj_bias_ = !projection_type_ ? none_ - : CreateI32SplatConst(&builder_, {n_output_}, 0, + : CreateF32SplatConst(&builder_, {n_output_}, 0, fused_func_op_.getLoc()); } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 072a057556e..56d6ab1f8ab 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -135,6 +135,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { EXPECT_FALSE(return_op->getOperand(1)->getType().isa()); // input layer norm is None EXPECT_TRUE(return_op->getOperand(20)->getType().isa()); + // proj_bias is F32 + EXPECT_TRUE(return_op->getOperand(17) + ->getType() + .cast() + .getElementType() + .isF32()); EXPECT_EQ(fused_lstm_func_.getType().getNumResults(), 1); auto output_types = fused_lstm_func_.getType().getResults();