Refactor test to enable support of other variants of LSTM.

PiperOrigin-RevId: 282605030
Change-Id: I407fce1f3a419d6180cf494c7387dc6cbf8389dd
This commit is contained in:
Jian Li 2019-11-26 11:25:58 -08:00 committed by TensorFlower Gardener
parent c4e8dc612f
commit d29ae17e5a

View File

@ -30,23 +30,30 @@ TEST(LstmPreprocess, Add2Tensors) {
// Create a model with 1 lstm layer.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
auto tensor = absl::make_unique<TensorT>();
auto buffer = absl::make_unique<tflite::BufferT>();
auto lstm_op_code = absl::make_unique<OperatorCodeT>();
auto lstm_op = absl::make_unique<OperatorT>();
tensor->name = "lstm_tensor0";
tensor->shape = {2, 3, 4};
tensor->type = TensorType_FLOAT32;
lstm_op_code->builtin_code = BuiltinOperator_LSTM;
lstm_op_code->version = 2;
lstm_op->opcode_index = 0;
lstm_op->inputs = {0};
lstm_op->outputs = {0};
lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
lstm_op->outputs = {24};
model->subgraphs.push_back(std::move(subgraph));
for (int i = 0; i < lstm_op->inputs.size(); ++i) {
const int index = lstm_op->inputs[i];
if (index == -1) {
continue;
}
auto tensor = absl::make_unique<TensorT>();
tensor->name = "lstm_tensor" + std::to_string(index);
tensor->shape = {2, 3, 4};
tensor->type = TensorType_FLOAT32;
model->subgraphs[0]->tensors.push_back(std::move(tensor));
}
model->subgraphs[0]->operators.push_back(std::move(lstm_op));
model->subgraphs[0]->tensors.push_back(std::move(tensor));
model->operator_codes.push_back(std::move(lstm_op_code));
model->buffers.push_back(std::move(buffer));
@ -58,21 +65,24 @@ TEST(LstmPreprocess, Add2Tensors) {
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
EXPECT_THAT(
model->subgraphs[0]->operators[0]->inputs,
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
ElementsAreArray({0}));
ElementsAreArray({24}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
ElementsAreArray({1, 2, 3, 4, 5}));
ElementsAreArray({21, 22, 23, 24, 25}));
// Call AddIntemediateTensorsToFusedOp again and expect no change in model.
tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get());
@ -81,21 +91,24 @@ TEST(LstmPreprocess, Add2Tensors) {
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
EXPECT_THAT(
model->subgraphs[0]->operators[0]->inputs,
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
ElementsAreArray({0}));
ElementsAreArray({24}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
ElementsAreArray({1, 2, 3, 4, 5}));
ElementsAreArray({21, 22, 23, 24, 25}));
}
} // namespace