Register Unidirectional_sequence_lstm logging op in calibrator.
PiperOrigin-RevId: 337529861 Change-Id: Ie4d13f8066cf6a75f2baab66e016336a95302c93
This commit is contained in:
parent
0c4416e3c2
commit
54081bae79
tensorflow/lite
testdata
tools/optimize
BIN
tensorflow/lite/testdata/unidirectional_sequence_lstm.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/unidirectional_sequence_lstm.bin
vendored
Normal file
Binary file not shown.
@ -65,6 +65,7 @@ tf_cc_test(
|
|||||||
data = [
|
data = [
|
||||||
"//tensorflow/lite:testdata/lstm.bin",
|
"//tensorflow/lite:testdata/lstm.bin",
|
||||||
"//tensorflow/lite:testdata/multi_add.bin",
|
"//tensorflow/lite:testdata/multi_add.bin",
|
||||||
|
"//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"tflite_not_portable_android",
|
"tflite_not_portable_android",
|
||||||
|
@ -461,10 +461,9 @@ struct OpData {
|
|||||||
// Resize the output, state tensors based on the sizes of the input tensors.
|
// Resize the output, state tensors based on the sizes of the input tensors.
|
||||||
// Allocate a temporary scratch tensor. Also check that the sizes of the input
|
// Allocate a temporary scratch tensor. Also check that the sizes of the input
|
||||||
// tensors match each other.
|
// tensors match each other.
|
||||||
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
LSTMType lstm_type, Logger* logger,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
|
||||||
|
|
||||||
const TfLiteTensor* input;
|
const TfLiteTensor* input;
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, GetInputSafe(context, node,
|
context, GetInputSafe(context, node,
|
||||||
@ -578,6 +577,31 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
|||||||
intermediate_tensor_indexes[i] = node->intermediates->data[i];
|
intermediate_tensor_indexes[i] = node->intermediates->data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteLSTMParams lstm_params;
|
||||||
|
bool time_major = true;
|
||||||
|
switch (lstm_type) {
|
||||||
|
case LSTMType::kLSTM: {
|
||||||
|
lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
|
||||||
|
time_major = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LSTMType::kUnidirectionalSequenceLSTM: {
|
||||||
|
const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
||||||
|
node->builtin_data);
|
||||||
|
// Copy out the LSTM specific params so they can be passed in the
|
||||||
|
// function.
|
||||||
|
lstm_params.activation = params->activation;
|
||||||
|
lstm_params.cell_clip = params->cell_clip;
|
||||||
|
lstm_params.proj_clip = params->proj_clip;
|
||||||
|
lstm_params.asymmetric_quantize_inputs =
|
||||||
|
params->asymmetric_quantize_inputs;
|
||||||
|
time_major = params->time_major;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
switch (input_to_output_weights->type) {
|
switch (input_to_output_weights->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
return EvalCalibration(
|
return EvalCalibration(
|
||||||
@ -594,9 +618,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
|||||||
/*aux_input_to_cell_weights=*/nullptr,
|
/*aux_input_to_cell_weights=*/nullptr,
|
||||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||||
projection_weights, projection_bias, params,
|
projection_weights, projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/true,
|
/*forward_sequence=*/true,
|
||||||
/*time_major=*/true,
|
/*time_major=*/time_major,
|
||||||
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
|
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
|
||||||
logger, intermediate_tensor_indexes, error_reporter);
|
logger, intermediate_tensor_indexes, error_reporter);
|
||||||
}
|
}
|
||||||
@ -613,7 +637,14 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
|||||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||||
Logger* logger,
|
Logger* logger,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
return lstm_eval(context, node, logger, error_reporter);
|
return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
|
||||||
|
TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||||
|
ErrorReporter* error_reporter) {
|
||||||
|
return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
|
||||||
|
error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
@ -23,9 +23,18 @@ namespace optimize {
|
|||||||
namespace calibration {
|
namespace calibration {
|
||||||
namespace builtin {
|
namespace builtin {
|
||||||
|
|
||||||
|
enum class LSTMType {
|
||||||
|
kLSTM,
|
||||||
|
kUnidirectionalSequenceLSTM,
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||||
Logger* logger, ErrorReporter* error_reporter);
|
Logger* logger, ErrorReporter* error_reporter);
|
||||||
|
|
||||||
|
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
|
||||||
|
TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||||
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace calibration
|
} // namespace calibration
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
|
@ -174,13 +174,17 @@ GlobalCalibratorRegistry* GetCalibratorRegistry() {
|
|||||||
// TODO(jianlijianli): extend this to support multiple recipe for the same
|
// TODO(jianlijianli): extend this to support multiple recipe for the same
|
||||||
// model.
|
// model.
|
||||||
logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
|
logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
|
||||||
TfLiteNode* node) {
|
TfLiteNode* node,
|
||||||
const int lstm_number_input = 24;
|
int builtin_op_code) {
|
||||||
if (node->inputs->size == lstm_number_input) {
|
switch (builtin_op_code) {
|
||||||
// LSTM Op.
|
case BuiltinOperator_LSTM:
|
||||||
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
|
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
|
||||||
|
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
|
||||||
|
return tflite::optimize::calibration::builtin::
|
||||||
|
unidirectional_sequence_lstm_logging_kernel;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
|
// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
|
||||||
@ -203,7 +207,9 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
||||||
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
||||||
}
|
}
|
||||||
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
|
auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
|
||||||
|
auto kernel_invoke_intermediate =
|
||||||
|
GetLoggingEvalFunc(context, node, builtin_op_code);
|
||||||
TfLiteStatus status;
|
TfLiteStatus status;
|
||||||
if (kernel_invoke_intermediate == nullptr) {
|
if (kernel_invoke_intermediate == nullptr) {
|
||||||
status = kernel_invoke(context, node);
|
status = kernel_invoke(context, node);
|
||||||
|
@ -283,7 +283,7 @@ TEST(CalibratorTest, LSTM) {
|
|||||||
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
||||||
ops::builtin::BuiltinOpResolver{},
|
ops::builtin::BuiltinOpResolver{},
|
||||||
&interpreter, &reader);
|
&interpreter, &reader);
|
||||||
EXPECT_EQ(kTfLiteOk, status);
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
auto readonly_model = flatbuffer_model->GetModel();
|
auto readonly_model = flatbuffer_model->GetModel();
|
||||||
tflite::ModelT model;
|
tflite::ModelT model;
|
||||||
@ -294,24 +294,17 @@ TEST(CalibratorTest, LSTM) {
|
|||||||
status = interpreter->AllocateTensors();
|
status = interpreter->AllocateTensors();
|
||||||
|
|
||||||
EXPECT_EQ(kTfLiteOk, status);
|
EXPECT_EQ(kTfLiteOk, status);
|
||||||
const std::vector<float> lstm_input = {
|
const std::vector<float> lstm_input = {0.3, 0.2};
|
||||||
0.3, 0.2, 0.9, 0.8, 0.1, //
|
|
||||||
0.1, 0.5, 0.2, 0.4, 0.2, //
|
|
||||||
0.6, 0.9, 0.2, 0.5, 0.7, //
|
|
||||||
};
|
|
||||||
int input_tensor_idx = interpreter->inputs()[0];
|
int input_tensor_idx = interpreter->inputs()[0];
|
||||||
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
||||||
for (size_t j = 0; j < lstm_input.size(); j++) {
|
for (size_t j = 0; j < lstm_input.size(); j++) {
|
||||||
tensor->data.f[j] = lstm_input[j];
|
tensor->data.f[j] = lstm_input[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke with update == true.
|
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||||
status = interpreter->Invoke();
|
|
||||||
ASSERT_EQ(kTfLiteOk, status);
|
|
||||||
|
|
||||||
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
|
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
|
||||||
status = reader->GetTensorStatsAsMap(&stats);
|
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
|
||||||
EXPECT_EQ(kTfLiteOk, status);
|
|
||||||
|
|
||||||
// Check the results.
|
// Check the results.
|
||||||
const float eps = 1e-6f;
|
const float eps = 1e-6f;
|
||||||
@ -344,6 +337,66 @@ TEST(CalibratorTest, LSTM) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CalibratorTest, UnidirectionalSequenceLSTM) {
|
||||||
|
auto flatbuffer_model = ReadModel("unidirectional_sequence_lstm.bin");
|
||||||
|
ASSERT_TRUE(flatbuffer_model);
|
||||||
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
|
std::unique_ptr<CalibrationReader> reader;
|
||||||
|
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
||||||
|
ops::builtin::BuiltinOpResolver{},
|
||||||
|
&interpreter, &reader);
|
||||||
|
EXPECT_EQ(kTfLiteOk, status);
|
||||||
|
|
||||||
|
auto readonly_model = flatbuffer_model->GetModel();
|
||||||
|
tflite::ModelT model;
|
||||||
|
readonly_model->UnPackTo(&model);
|
||||||
|
|
||||||
|
ASSERT_TRUE(interpreter);
|
||||||
|
ASSERT_TRUE(reader);
|
||||||
|
EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||||
|
const std::vector<float> lstm_input = {0.3, 0.2, 0.9, 0.8};
|
||||||
|
int input_tensor_idx = interpreter->inputs()[0];
|
||||||
|
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
||||||
|
for (size_t j = 0; j < lstm_input.size(); j++) {
|
||||||
|
tensor->data.f[j] = lstm_input[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||||
|
|
||||||
|
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
|
||||||
|
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
|
||||||
|
|
||||||
|
// Check the results.
|
||||||
|
const float eps = 1e-6f;
|
||||||
|
const std::unordered_map<int, CalibrationReader::CalibrationStats>
|
||||||
|
expected_calibration_result = {
|
||||||
|
// Input.
|
||||||
|
{0, {0.200000, 0.900000}},
|
||||||
|
// State.
|
||||||
|
{18, {0.000000, 0.520999}},
|
||||||
|
// State.
|
||||||
|
{19, {0.000000, 0.711364}},
|
||||||
|
// Output.
|
||||||
|
{24, {0.247992, 0.520999}},
|
||||||
|
// Intemediate_0.
|
||||||
|
{25, {0.080045, 0.824241}},
|
||||||
|
// Intemediate_1.
|
||||||
|
{26, {0.080045, 0.824241}},
|
||||||
|
// Intemediate_2.
|
||||||
|
{27, {0.080045, 0.824241}},
|
||||||
|
// Intemediate_3.
|
||||||
|
{28, {0.080045, 0.824241}},
|
||||||
|
// Intemediate_4.
|
||||||
|
{29, {0.000000, 0.413618}},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(expected_calibration_result.size(), stats.size());
|
||||||
|
for (const auto& e : stats) {
|
||||||
|
auto expected_result = expected_calibration_result.at(e.first);
|
||||||
|
EXPECT_NEAR(e.second.min, expected_result.min, eps);
|
||||||
|
EXPECT_NEAR(e.second.max, expected_result.max, eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace calibration
|
} // namespace calibration
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
|
@ -44,7 +44,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
|
|||||||
model->subgraphs.at(subgraph_index)->operators[op_index].get();
|
model->subgraphs.at(subgraph_index)->operators[op_index].get();
|
||||||
op_variant.op_code =
|
op_variant.op_code =
|
||||||
GetBuiltinCode(model->operator_codes[op->opcode_index].get());
|
GetBuiltinCode(model->operator_codes[op->opcode_index].get());
|
||||||
if (op_variant.op_code == BuiltinOperator_LSTM) {
|
if (op_variant.op_code == BuiltinOperator_LSTM ||
|
||||||
|
op_variant.op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
|
||||||
if (op->inputs.size() == 5) {
|
if (op->inputs.size() == 5) {
|
||||||
// The 5 input ("basic") LSTM is not supported in this tooling (yet).
|
// The 5 input ("basic") LSTM is not supported in this tooling (yet).
|
||||||
op_variant.is_quantizable = false;
|
op_variant.is_quantizable = false;
|
||||||
@ -230,7 +231,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_LSTM: {
|
case BuiltinOperator_LSTM:
|
||||||
|
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
|
||||||
if (!op_variant.is_quantizable) {
|
if (!op_variant.is_quantizable) {
|
||||||
// Early exist for 5 input LSTM.
|
// Early exist for 5 input LSTM.
|
||||||
// It is not supported in this tooling yet.
|
// It is not supported in this tooling yet.
|
||||||
|
Loading…
Reference in New Issue
Block a user