Register Unidirectional_sequence_lstm logging op in calibrator.

PiperOrigin-RevId: 337529861
Change-Id: Ie4d13f8066cf6a75f2baab66e016336a95302c93
This commit is contained in:
A. Unique TensorFlower 2020-10-16 10:19:35 -07:00 committed by TensorFlower Gardener
parent 0c4416e3c2
commit 54081bae79
7 changed files with 128 additions and 26 deletions

Binary file not shown.

View File

@ -65,6 +65,7 @@ tf_cc_test(
data = [
"//tensorflow/lite:testdata/lstm.bin",
"//tensorflow/lite:testdata/multi_add.bin",
"//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin",
],
tags = [
"tflite_not_portable_android",

View File

@ -461,10 +461,9 @@ struct OpData {
// Resize the output, state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
LSTMType lstm_type, Logger* logger,
ErrorReporter* error_reporter) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(
context, GetInputSafe(context, node,
@ -578,6 +577,31 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
intermediate_tensor_indexes[i] = node->intermediates->data[i];
}
TfLiteLSTMParams lstm_params;
bool time_major = true;
switch (lstm_type) {
case LSTMType::kLSTM: {
lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
time_major = true;
break;
}
case LSTMType::kUnidirectionalSequenceLSTM: {
const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
node->builtin_data);
// Copy out the LSTM specific params so they can be passed in the
// function.
lstm_params.activation = params->activation;
lstm_params.cell_clip = params->cell_clip;
lstm_params.proj_clip = params->proj_clip;
lstm_params.asymmetric_quantize_inputs =
params->asymmetric_quantize_inputs;
time_major = params->time_major;
break;
}
default:
return kTfLiteError;
}
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return EvalCalibration(
@ -594,9 +618,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, params,
projection_weights, projection_bias, &lstm_params,
/*forward_sequence=*/true,
/*time_major=*/true,
/*time_major=*/time_major,
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
logger, intermediate_tensor_indexes, error_reporter);
}
@ -613,7 +637,14 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger,
ErrorReporter* error_reporter) {
return lstm_eval(context, node, logger, error_reporter);
return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
}
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
TfLiteContext* context, TfLiteNode* node, Logger* logger,
ErrorReporter* error_reporter) {
return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
error_reporter);
}
} // namespace builtin

View File

@ -23,9 +23,18 @@ namespace optimize {
namespace calibration {
namespace builtin {
enum class LSTMType {
kLSTM,
kUnidirectionalSequenceLSTM,
};
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
Logger* logger, ErrorReporter* error_reporter);
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
TfLiteContext* context, TfLiteNode* node, Logger* logger,
ErrorReporter* error_reporter);
} // namespace builtin
} // namespace calibration
} // namespace optimize

View File

@ -174,13 +174,17 @@ GlobalCalibratorRegistry* GetCalibratorRegistry() {
// TODO(jianlijianli): extend this to support multiple recipe for the same
// model.
logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
TfLiteNode* node) {
const int lstm_number_input = 24;
if (node->inputs->size == lstm_number_input) {
// LSTM Op.
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
TfLiteNode* node,
int builtin_op_code) {
switch (builtin_op_code) {
case BuiltinOperator_LSTM:
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
return tflite::optimize::calibration::builtin::
unidirectional_sequence_lstm_logging_kernel;
default:
return nullptr;
}
return nullptr;
}
// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
@ -203,7 +207,9 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
}
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
auto kernel_invoke_intermediate =
GetLoggingEvalFunc(context, node, builtin_op_code);
TfLiteStatus status;
if (kernel_invoke_intermediate == nullptr) {
status = kernel_invoke(context, node);

View File

@ -283,7 +283,7 @@ TEST(CalibratorTest, LSTM) {
auto status = BuildLoggingInterpreter(*flatbuffer_model,
ops::builtin::BuiltinOpResolver{},
&interpreter, &reader);
EXPECT_EQ(kTfLiteOk, status);
EXPECT_EQ(status, kTfLiteOk);
auto readonly_model = flatbuffer_model->GetModel();
tflite::ModelT model;
@ -294,24 +294,17 @@ TEST(CalibratorTest, LSTM) {
status = interpreter->AllocateTensors();
EXPECT_EQ(kTfLiteOk, status);
const std::vector<float> lstm_input = {
0.3, 0.2, 0.9, 0.8, 0.1, //
0.1, 0.5, 0.2, 0.4, 0.2, //
0.6, 0.9, 0.2, 0.5, 0.7, //
};
const std::vector<float> lstm_input = {0.3, 0.2};
int input_tensor_idx = interpreter->inputs()[0];
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
for (size_t j = 0; j < lstm_input.size(); j++) {
tensor->data.f[j] = lstm_input[j];
}
// Invoke with update == true.
status = interpreter->Invoke();
ASSERT_EQ(kTfLiteOk, status);
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
status = reader->GetTensorStatsAsMap(&stats);
EXPECT_EQ(kTfLiteOk, status);
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
// Check the results.
const float eps = 1e-6f;
@ -344,6 +337,66 @@ TEST(CalibratorTest, LSTM) {
}
}
TEST(CalibratorTest, UnidirectionalSequenceLSTM) {
auto flatbuffer_model = ReadModel("unidirectional_sequence_lstm.bin");
ASSERT_TRUE(flatbuffer_model);
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<CalibrationReader> reader;
auto status = BuildLoggingInterpreter(*flatbuffer_model,
ops::builtin::BuiltinOpResolver{},
&interpreter, &reader);
EXPECT_EQ(kTfLiteOk, status);
auto readonly_model = flatbuffer_model->GetModel();
tflite::ModelT model;
readonly_model->UnPackTo(&model);
ASSERT_TRUE(interpreter);
ASSERT_TRUE(reader);
EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
const std::vector<float> lstm_input = {0.3, 0.2, 0.9, 0.8};
int input_tensor_idx = interpreter->inputs()[0];
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
for (size_t j = 0; j < lstm_input.size(); j++) {
tensor->data.f[j] = lstm_input[j];
}
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
// Check the results.
const float eps = 1e-6f;
const std::unordered_map<int, CalibrationReader::CalibrationStats>
expected_calibration_result = {
// Input.
{0, {0.200000, 0.900000}},
// State.
{18, {0.000000, 0.520999}},
// State.
{19, {0.000000, 0.711364}},
// Output.
{24, {0.247992, 0.520999}},
// Intemediate_0.
{25, {0.080045, 0.824241}},
// Intemediate_1.
{26, {0.080045, 0.824241}},
// Intemediate_2.
{27, {0.080045, 0.824241}},
// Intemediate_3.
{28, {0.080045, 0.824241}},
// Intemediate_4.
{29, {0.000000, 0.413618}},
};
EXPECT_EQ(expected_calibration_result.size(), stats.size());
for (const auto& e : stats) {
auto expected_result = expected_calibration_result.at(e.first);
EXPECT_NEAR(e.second.min, expected_result.min, eps);
EXPECT_NEAR(e.second.max, expected_result.max, eps);
}
}
} // namespace
} // namespace calibration
} // namespace optimize

View File

@ -44,7 +44,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
model->subgraphs.at(subgraph_index)->operators[op_index].get();
op_variant.op_code =
GetBuiltinCode(model->operator_codes[op->opcode_index].get());
if (op_variant.op_code == BuiltinOperator_LSTM) {
if (op_variant.op_code == BuiltinOperator_LSTM ||
op_variant.op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
if (op->inputs.size() == 5) {
// The 5 input ("basic") LSTM is not supported in this tooling (yet).
op_variant.is_quantizable = false;
@ -230,7 +231,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.version = 2;
break;
}
case BuiltinOperator_LSTM: {
case BuiltinOperator_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
if (!op_variant.is_quantizable) {
// Early exist for 5 input LSTM.
// It is not supported in this tooling yet.