Support calibration of models with 8bit matmul output.

PiperOrigin-RevId: 339873117
Change-Id: I40734689dac29e3466240f9f22d11d84171d9eae
This commit is contained in:
Jian Li 2020-10-30 08:15:29 -07:00 committed by TensorFlower Gardener
parent dfa308ca5d
commit 35b14d52e7
9 changed files with 960 additions and 2 deletions

BIN
tensorflow/lite/testdata/custom_lstm.bin vendored Normal file

Binary file not shown.

View File

@ -88,6 +88,33 @@ tf_cc_test(
],
)
tf_cc_test(
name = "quantization_wrapper_utils_custom_test",
srcs = [
"quantization_wrapper_utils.cc",
"quantization_wrapper_utils.h",
"quantization_wrapper_utils_custom_test.cc",
],
defines = [
"TFLITE_CUSTOM_LSTM",
],
tags = [
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
deps = [
":operator_property",
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
cc_library(
name = "quantization_wrapper",
srcs = ["quantization_wrapper.cc"],

View File

@ -29,6 +29,26 @@ cc_library(
],
)
cc_library(
name = "custom_logging_op",
srcs = ["custom_logging_ops/lstm.cc"],
hdrs = ["custom_logging_ops/lstm.h"],
copts = tflite_copts(),
deps = [
":calibration_logger",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:lstm_shared",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:reference",
"//tensorflow/lite/kernels/internal:tensor_utils",
],
)
cc_library(
name = "calibrator_lib",
srcs = ["calibrator.cc"],
@ -39,6 +59,7 @@ cc_library(
":calibration_common",
":calibration_logger",
":calibration_reader",
":custom_logging_op",
":logging_op",
":logging_op_resolver",
"//tensorflow/lite:framework",
@ -63,6 +84,7 @@ tf_cc_test(
"--test_model_file=$(location //tensorflow/lite:testdata/multi_add.bin)",
],
data = [
"//tensorflow/lite:testdata/custom_lstm.bin",
"//tensorflow/lite:testdata/lstm.bin",
"//tensorflow/lite:testdata/multi_add.bin",
"//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin",

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
#include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
#include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
@ -177,8 +178,12 @@ logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
TfLiteNode* node,
int builtin_op_code) {
switch (builtin_op_code) {
case BuiltinOperator_LSTM:
case BuiltinOperator_LSTM: {
if (node->intermediates->size == 12) {
return tflite::optimize::calibration::custom::lstm_logging_kernel;
}
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
return tflite::optimize::calibration::builtin::
unidirectional_sequence_lstm_logging_kernel;

View File

@ -397,6 +397,80 @@ TEST(CalibratorTest, UnidirectionalSequenceLSTM) {
}
}
TEST(CalibratorTest, CustomLSTM) {
auto flatbuffer_model = ReadModel("custom_lstm.bin");
ASSERT_TRUE(flatbuffer_model);
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<CalibrationReader> reader;
auto status = BuildLoggingInterpreter(*flatbuffer_model,
ops::builtin::BuiltinOpResolver{},
&interpreter, &reader);
EXPECT_EQ(kTfLiteOk, status);
auto readonly_model = flatbuffer_model->GetModel();
tflite::ModelT model;
readonly_model->UnPackTo(&model);
ASSERT_TRUE(interpreter);
ASSERT_TRUE(reader);
EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
const std::vector<float> lstm_input = {0.3, 0.2, 0.9, 0.8};
int input_tensor_idx = interpreter->inputs()[0];
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
for (size_t j = 0; j < lstm_input.size(); j++) {
tensor->data.f[j] = lstm_input[j];
}
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
// Check the results.
const float eps = 1e-6f;
const std::unordered_map<int, CalibrationReader::CalibrationStats>
expected_calibration_result = {
// input.
{0, {0.200000, 0.300000}},
// state.
{18, {0.000000, 0.468415}},
// state.
{19, {0.000000, 0.424349}},
// output.
{24, {0.265968, 0.468415}},
// intermediate 0.
{25, {0.080045, 0.170588}},
// intermediate 1.
{26, {0.080045, 0.170588}},
// intermediate 2.
{27, {0.000000, 0.000000}},
// intermediate 3.
{28, {0.080045, 0.170588}},
// intermediate 4.
{29, {0.080045, 0.170588}},
// intermediate 5.
{30, {0.000000, 0.000000}},
// intermediate 6.
{31, {0.080045, 0.170588}},
// intermediate 7.
{32, {0.080045, 0.170588}},
// intermediate 8.
{33, {0.000000, 0.000000}},
// intermediate 9.
{34, {0.080045, 0.170588}},
// intermediate 10.
{35, {0.080045, 0.170588}},
// intermediate 11.
{36, {0.000000, 0.000000}},
};
EXPECT_EQ(expected_calibration_result.size(), stats.size());
for (const auto& e : stats) {
auto expected_result = expected_calibration_result.at(e.first);
EXPECT_NEAR(e.second.min, expected_result.min, eps);
EXPECT_NEAR(e.second.max, expected_result.max, eps);
}
}
} // namespace
} // namespace calibration
} // namespace optimize

View File

@ -0,0 +1,649 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
#include <algorithm>
#include <cstdio>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/lstm_shared.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
namespace tflite {
namespace optimize {
namespace calibration {
namespace custom {
namespace {
inline void LstmStepWithAuxInput(
const float* input_ptr, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
const float* input_to_cell_weights_ptr,
const float* input_to_output_weights_ptr, const float* aux_input_ptr,
const float* aux_input_to_input_weights_ptr,
const float* aux_input_to_forget_weights_ptr,
const float* aux_input_to_cell_weights_ptr,
const float* aux_input_to_output_weights_ptr,
const float* recurrent_to_input_weights_ptr,
const float* recurrent_to_forget_weights_ptr,
const float* recurrent_to_cell_weights_ptr,
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch,
input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch,
forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
n_cell, n_input, input_ptr,
n_batch, cell_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
output_gate_scratch);
{
// calibration.
if (!use_cifg) {
logger->LogTensorValue(intemediate_tensor_indexes[1], input_gate_scratch,
n_cell * n_batch, error_reporter);
}
logger->LogTensorValue(intemediate_tensor_indexes[4], forget_gate_scratch,
n_cell * n_batch, error_reporter);
logger->LogTensorValue(intemediate_tensor_indexes[7], cell_scratch,
n_cell * n_batch, error_reporter);
logger->LogTensorValue(intemediate_tensor_indexes[10], output_gate_scratch,
n_cell * n_batch, error_reporter);
}
// If auxiliary input is available then compute aux_input_weight * aux_input
if (aux_input_ptr != nullptr) {
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, cell_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, output_gate_scratch);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch);
{
// calibrition.
if (!use_cifg) {
std::vector<float> temp_input(n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, temp_input.data());
logger->LogTensorValue(intemediate_tensor_indexes[2], temp_input.data(),
n_cell * n_batch, error_reporter);
}
std::vector<float> temp_forget(n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, temp_forget.data());
logger->LogTensorValue(intemediate_tensor_indexes[5], temp_forget.data(),
n_cell * n_batch, error_reporter);
std::vector<float> temp_cell(n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, temp_cell.data());
logger->LogTensorValue(intemediate_tensor_indexes[8], temp_cell.data(),
n_cell * n_batch, error_reporter);
std::vector<float> temp_output(n_batch * n_cell);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, temp_output.data());
logger->LogTensorValue(intemediate_tensor_indexes[11], temp_output.data(),
n_cell * n_batch, error_reporter);
}
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[3], forget_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[6], cell_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::CwiseClipping(cell_state_ptr, n_batch * n_cell,
params->cell_clip);
}
// For each batch and cell: update the output gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[9], output_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll batched operations.
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + k * output_batch_leading_dim);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
}
}
for (int k = 0; k < n_batch; k++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
output_gate_scratch + k * n_cell,
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim);
if (params->proj_clip > 0.0) {
tensor_utils::CwiseClipping(output_ptr + k * output_batch_leading_dim,
n_output, params->proj_clip);
}
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr + k * output_batch_leading_dim);
}
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
}
}
TfLiteStatus EvalFloat(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, Logger* logger,
const std::vector<int>& intemediate_tensor_indexes,
ErrorReporter* error_reporter) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
int max_time, n_batch;
if (input->dims->size == 3) {
max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
} else {
max_time = 1;
n_batch = input->dims->data[0];
}
const int n_input = input->dims->data[input->dims->size - 1];
const int aux_input_size =
(aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer_ptr;
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer_ptr;
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
}
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
if (time_major) {
// Loop through the sequence.
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
const float* aux_input_ptr = nullptr;
if (aux_input) {
aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
}
float* output_ptr_time =
GetTensorData<float>(output) + t_rel * output_step + output_offset;
LstmStepWithAuxInput(
input_ptr, GetTensorData<float>(input_to_input_weights),
GetTensorData<float>(input_to_forget_weights),
GetTensorData<float>(input_to_cell_weights),
GetTensorData<float>(input_to_output_weights), aux_input_ptr,
GetTensorData<float>(aux_input_to_input_weights),
GetTensorData<float>(aux_input_to_forget_weights),
GetTensorData<float>(aux_input_to_cell_weights),
GetTensorData<float>(aux_input_to_output_weights),
GetTensorData<float>(recurrent_to_input_weights),
GetTensorData<float>(recurrent_to_forget_weights),
GetTensorData<float>(recurrent_to_cell_weights),
GetTensorData<float>(recurrent_to_output_weights),
GetTensorData<float>(cell_to_input_weights),
GetTensorData<float>(cell_to_forget_weights),
GetTensorData<float>(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch,
output_ptr_time, logger, intemediate_tensor_indexes, error_reporter);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const int time_offset = b * max_time + t_rel;
const float* input_ptr =
GetTensorData<float>(input) + time_offset * input_step;
const float* aux_input_ptr = nullptr;
if (aux_input) {
aux_input_ptr =
GetTensorData<float>(aux_input) + time_offset * input_step;
}
float* output_ptr = GetTensorData<float>(output) +
time_offset * output_step + output_offset;
// Offset the {activation,cell}_state pointers to the right batch.
float* activation_state_ptr = GetTensorData<float>(activation_state) +
b * output_batch_leading_dim;
float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
// Offset the scratch pointers to the right batch.
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepWithAuxInput(
input_ptr, GetTensorData<float>(input_to_input_weights),
GetTensorData<float>(input_to_forget_weights),
GetTensorData<float>(input_to_cell_weights),
GetTensorData<float>(input_to_output_weights), aux_input_ptr,
GetTensorData<float>(aux_input_to_input_weights),
GetTensorData<float>(aux_input_to_forget_weights),
GetTensorData<float>(aux_input_to_cell_weights),
GetTensorData<float>(aux_input_to_output_weights),
GetTensorData<float>(recurrent_to_input_weights),
GetTensorData<float>(recurrent_to_forget_weights),
GetTensorData<float>(recurrent_to_cell_weights),
GetTensorData<float>(recurrent_to_output_weights),
GetTensorData<float>(cell_to_input_weights),
GetTensorData<float>(cell_to_forget_weights),
GetTensorData<float>(cell_to_output_weights),
GetTensorData<float>(input_layer_norm_coefficients),
GetTensorData<float>(forget_layer_norm_coefficients),
GetTensorData<float>(cell_layer_norm_coefficients),
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr, logger, intemediate_tensor_indexes, error_reporter);
}
}
}
return kTfLiteOk;
}
struct OpData {
// Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
// inputs).
// Please note the 20-input full kernel is deprecated and only kept
// here for backward compatibility.
TfLiteLSTMKernelType kernel_type;
// If the lstm is layer norm.
bool use_layer_norm;
// These fields are only used by full kernel.
int scratch_tensor_index;
};
// Resize the output, state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
ErrorReporter* error_reporter) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input =
GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights = GetInput(
context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights = GetInput(
context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node,
ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node,
ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
context, node,
ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
context, node,
ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_bias =
GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
const TfLiteTensor* projection_bias = GetOptionalInputTensor(
context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* activation_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TfLiteTensor* output =
GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
std::vector<int> intemediate_tensor_indexes(node->intermediates->size);
for (int i = 0; i < node->intermediates->size; ++i) {
intemediate_tensor_indexes[i] = node->intermediates->data[i];
}
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return EvalFloat(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output, logger, intemediate_tensor_indexes, error_reporter);
}
case kTfLiteUInt8:
case kTfLiteInt8:
default:
printf("Error. Only float model can be calibrated\n");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger,
ErrorReporter* error_reporter) {
return lstm_eval(context, node, logger, error_reporter);
}
} // namespace custom
} // namespace calibration
} // namespace optimize
} // namespace tflite

View File

@ -0,0 +1,34 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CUSTOM_LOGGING_OPS_LSTM_H_
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CUSTOM_LOGGING_OPS_LSTM_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
namespace tflite {
namespace optimize {
namespace calibration {
namespace custom {
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger, ErrorReporter* error_reporter);
} // namespace custom
} // namespace calibration
} // namespace optimize
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_CALIBRATION_CUSTOM_LOGGING_OPS_LSTM_H_

View File

@ -24,6 +24,12 @@ namespace tflite {
namespace optimize {
namespace {
#ifdef TFLITE_CUSTOM_LSTM
constexpr bool kUseCustomLSTM = true;
#else
constexpr bool kUseCustomLSTM = false;
#endif
void MakeTensor(const string& name, std::unique_ptr<TensorT>* tensor) {
TensorT* tensor_raw = new TensorT;
tensor_raw->name = name;
@ -90,7 +96,10 @@ TfLiteStatus AddIntermediateTensorsToFusedOp(
}
// Add tensors.
const int next_tensor_index = subgraph->tensors.size();
const int num_intermediates = property.intermediates.size();
int num_intermediates = property.intermediates.size();
if (kUseCustomLSTM) {
num_intermediates = 12;
}
for (int i = 0; i < num_intermediates; ++i) {
std::unique_ptr<TensorT> intermediate_tensor;
auto name = CreateTensorName(op_idx, i);

View File

@ -0,0 +1,138 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
#include "tensorflow/lite/tools/optimize/quantization_wrapper_utils.h"
namespace tflite {
namespace optimize {
namespace {
using ::testing::ElementsAreArray;
TEST(LstmPreprocess, Add2Tensors) {
// Create a model with 1 lstm layer.
auto model = absl::make_unique<ModelT>();
auto subgraph = absl::make_unique<tflite::SubGraphT>();
auto buffer = absl::make_unique<tflite::BufferT>();
auto lstm_op_code = absl::make_unique<OperatorCodeT>();
auto lstm_op = absl::make_unique<OperatorT>();
lstm_op_code->builtin_code = BuiltinOperator_LSTM;
lstm_op_code->deprecated_builtin_code =
static_cast<int8_t>(BuiltinOperator_LSTM);
lstm_op_code->version = 2;
lstm_op->opcode_index = 0;
lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
lstm_op->outputs = {24};
model->subgraphs.push_back(std::move(subgraph));
for (int i = 0; i < lstm_op->inputs.size(); ++i) {
const int index = lstm_op->inputs[i];
if (index == -1) {
continue;
}
auto tensor = absl::make_unique<TensorT>();
tensor->name = "lstm_tensor" + std::to_string(index);
tensor->shape = {2, 3, 4};
tensor->type = TensorType_FLOAT32;
model->subgraphs[0]->tensors.push_back(std::move(tensor));
}
model->subgraphs[0]->operators.push_back(std::move(lstm_op));
model->operator_codes.push_back(std::move(lstm_op_code));
model->buffers.push_back(std::move(buffer));
// Add 2 tensors.
flatbuffers::FlatBufferBuilder builder;
tflite::optimize::AddIntermediateTensorsToFusedOp(&builder, model.get());
// Verify results.
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 33);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()),
BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
EXPECT_EQ(model->subgraphs[0]->tensors[26]->name, "intermediate_0_5");
EXPECT_EQ(model->subgraphs[0]->tensors[27]->name, "intermediate_0_6");
EXPECT_EQ(model->subgraphs[0]->tensors[28]->name, "intermediate_0_7");
EXPECT_EQ(model->subgraphs[0]->tensors[29]->name, "intermediate_0_8");
EXPECT_EQ(model->subgraphs[0]->tensors[30]->name, "intermediate_0_9");
EXPECT_EQ(model->subgraphs[0]->tensors[31]->name, "intermediate_0_10");
EXPECT_EQ(model->subgraphs[0]->tensors[32]->name, "intermediate_0_11");
EXPECT_THAT(
model->subgraphs[0]->operators[0]->inputs,
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
ElementsAreArray({24}));
EXPECT_THAT(
model->subgraphs[0]->operators[0]->intermediates,
ElementsAreArray({21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}));
// Call AddIntermediateTensorsToFusedOp again and expect no change in model.
tflite::optimize::AddIntermediateTensorsToFusedOp(&builder, model.get());
// Verify results.
EXPECT_EQ(model->operator_codes.size(), 1);
EXPECT_EQ(model->subgraphs.size(), 1);
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 33);
EXPECT_EQ(model->buffers.size(), 1);
EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()),
BuiltinOperator_LSTM);
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0");
EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1");
EXPECT_EQ(model->subgraphs[0]->tensors[23]->name, "intermediate_0_2");
EXPECT_EQ(model->subgraphs[0]->tensors[24]->name, "intermediate_0_3");
EXPECT_EQ(model->subgraphs[0]->tensors[25]->name, "intermediate_0_4");
EXPECT_EQ(model->subgraphs[0]->tensors[26]->name, "intermediate_0_5");
EXPECT_EQ(model->subgraphs[0]->tensors[27]->name, "intermediate_0_6");
EXPECT_EQ(model->subgraphs[0]->tensors[28]->name, "intermediate_0_7");
EXPECT_EQ(model->subgraphs[0]->tensors[29]->name, "intermediate_0_8");
EXPECT_EQ(model->subgraphs[0]->tensors[30]->name, "intermediate_0_9");
EXPECT_EQ(model->subgraphs[0]->tensors[31]->name, "intermediate_0_10");
EXPECT_EQ(model->subgraphs[0]->tensors[32]->name, "intermediate_0_11");
EXPECT_THAT(
model->subgraphs[0]->operators[0]->inputs,
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1,
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}));
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
ElementsAreArray({24}));
EXPECT_THAT(
model->subgraphs[0]->operators[0]->intermediates,
ElementsAreArray({21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}));
}
} // namespace
} // namespace optimize
} // namespace tflite
int main(int argc, char** argv) { return RUN_ALL_TESTS(); }