Register Unidirectional_sequence_lstm logging op in calibrator.
PiperOrigin-RevId: 337529861 Change-Id: Ie4d13f8066cf6a75f2baab66e016336a95302c93
This commit is contained in:
parent
0c4416e3c2
commit
54081bae79
tensorflow/lite
testdata
tools/optimize
BIN
tensorflow/lite/testdata/unidirectional_sequence_lstm.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/unidirectional_sequence_lstm.bin
vendored
Normal file
Binary file not shown.
@ -65,6 +65,7 @@ tf_cc_test(
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/lstm.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
"//tensorflow/lite:testdata/unidirectional_sequence_lstm.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
|
@ -461,10 +461,9 @@ struct OpData {
|
||||
// Resize the output, state tensors based on the sizes of the input tensors.
|
||||
// Allocate a temporary scratch tensor. Also check that the sizes of the input
|
||||
// tensors match each other.
|
||||
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node,
|
||||
LSTMType lstm_type, Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetInputSafe(context, node,
|
||||
@ -578,6 +577,31 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
intermediate_tensor_indexes[i] = node->intermediates->data[i];
|
||||
}
|
||||
|
||||
TfLiteLSTMParams lstm_params;
|
||||
bool time_major = true;
|
||||
switch (lstm_type) {
|
||||
case LSTMType::kLSTM: {
|
||||
lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
|
||||
time_major = true;
|
||||
break;
|
||||
}
|
||||
case LSTMType::kUnidirectionalSequenceLSTM: {
|
||||
const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
|
||||
node->builtin_data);
|
||||
// Copy out the LSTM specific params so they can be passed in the
|
||||
// function.
|
||||
lstm_params.activation = params->activation;
|
||||
lstm_params.cell_clip = params->cell_clip;
|
||||
lstm_params.proj_clip = params->proj_clip;
|
||||
lstm_params.asymmetric_quantize_inputs =
|
||||
params->asymmetric_quantize_inputs;
|
||||
time_major = params->time_major;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
switch (input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
return EvalCalibration(
|
||||
@ -594,9 +618,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_gate_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params,
|
||||
projection_weights, projection_bias, &lstm_params,
|
||||
/*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
/*time_major=*/time_major,
|
||||
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
|
||||
logger, intermediate_tensor_indexes, error_reporter);
|
||||
}
|
||||
@ -613,7 +637,14 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||
Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
return lstm_eval(context, node, logger, error_reporter);
|
||||
return lstm_eval(context, node, LSTMType::kLSTM, logger, error_reporter);
|
||||
}
|
||||
|
||||
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
|
||||
TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
ErrorReporter* error_reporter) {
|
||||
return lstm_eval(context, node, LSTMType::kUnidirectionalSequenceLSTM, logger,
|
||||
error_reporter);
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
|
@ -23,9 +23,18 @@ namespace optimize {
|
||||
namespace calibration {
|
||||
namespace builtin {
|
||||
|
||||
enum class LSTMType {
|
||||
kLSTM,
|
||||
kUnidirectionalSequenceLSTM,
|
||||
};
|
||||
|
||||
TfLiteStatus lstm_logging_kernel(TfLiteContext* context, TfLiteNode* node,
|
||||
Logger* logger, ErrorReporter* error_reporter);
|
||||
|
||||
TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
|
||||
TfLiteContext* context, TfLiteNode* node, Logger* logger,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
@ -174,13 +174,17 @@ GlobalCalibratorRegistry* GetCalibratorRegistry() {
|
||||
// TODO(jianlijianli): extend this to support multiple recipe for the same
|
||||
// model.
|
||||
logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
const int lstm_number_input = 24;
|
||||
if (node->inputs->size == lstm_number_input) {
|
||||
// LSTM Op.
|
||||
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
|
||||
TfLiteNode* node,
|
||||
int builtin_op_code) {
|
||||
switch (builtin_op_code) {
|
||||
case BuiltinOperator_LSTM:
|
||||
return tflite::optimize::calibration::builtin::lstm_logging_kernel;
|
||||
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
|
||||
return tflite::optimize::calibration::builtin::
|
||||
unidirectional_sequence_lstm_logging_kernel;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
|
||||
@ -203,7 +207,9 @@ TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(logger->LogTensorValue(
|
||||
i, tensor.data.f, tensor.bytes / sizeof(float), error_reporter));
|
||||
}
|
||||
auto kernel_invoke_intermediate = GetLoggingEvalFunc(context, node);
|
||||
auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
|
||||
auto kernel_invoke_intermediate =
|
||||
GetLoggingEvalFunc(context, node, builtin_op_code);
|
||||
TfLiteStatus status;
|
||||
if (kernel_invoke_intermediate == nullptr) {
|
||||
status = kernel_invoke(context, node);
|
||||
|
@ -283,7 +283,7 @@ TEST(CalibratorTest, LSTM) {
|
||||
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
||||
ops::builtin::BuiltinOpResolver{},
|
||||
&interpreter, &reader);
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
auto readonly_model = flatbuffer_model->GetModel();
|
||||
tflite::ModelT model;
|
||||
@ -294,24 +294,17 @@ TEST(CalibratorTest, LSTM) {
|
||||
status = interpreter->AllocateTensors();
|
||||
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
const std::vector<float> lstm_input = {
|
||||
0.3, 0.2, 0.9, 0.8, 0.1, //
|
||||
0.1, 0.5, 0.2, 0.4, 0.2, //
|
||||
0.6, 0.9, 0.2, 0.5, 0.7, //
|
||||
};
|
||||
const std::vector<float> lstm_input = {0.3, 0.2};
|
||||
int input_tensor_idx = interpreter->inputs()[0];
|
||||
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
||||
for (size_t j = 0; j < lstm_input.size(); j++) {
|
||||
tensor->data.f[j] = lstm_input[j];
|
||||
}
|
||||
|
||||
// Invoke with update == true.
|
||||
status = interpreter->Invoke();
|
||||
ASSERT_EQ(kTfLiteOk, status);
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
|
||||
status = reader->GetTensorStatsAsMap(&stats);
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
|
||||
|
||||
// Check the results.
|
||||
const float eps = 1e-6f;
|
||||
@ -344,6 +337,66 @@ TEST(CalibratorTest, LSTM) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CalibratorTest, UnidirectionalSequenceLSTM) {
|
||||
auto flatbuffer_model = ReadModel("unidirectional_sequence_lstm.bin");
|
||||
ASSERT_TRUE(flatbuffer_model);
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
std::unique_ptr<CalibrationReader> reader;
|
||||
auto status = BuildLoggingInterpreter(*flatbuffer_model,
|
||||
ops::builtin::BuiltinOpResolver{},
|
||||
&interpreter, &reader);
|
||||
EXPECT_EQ(kTfLiteOk, status);
|
||||
|
||||
auto readonly_model = flatbuffer_model->GetModel();
|
||||
tflite::ModelT model;
|
||||
readonly_model->UnPackTo(&model);
|
||||
|
||||
ASSERT_TRUE(interpreter);
|
||||
ASSERT_TRUE(reader);
|
||||
EXPECT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
const std::vector<float> lstm_input = {0.3, 0.2, 0.9, 0.8};
|
||||
int input_tensor_idx = interpreter->inputs()[0];
|
||||
TfLiteTensor* tensor = interpreter->tensor(input_tensor_idx);
|
||||
for (size_t j = 0; j < lstm_input.size(); j++) {
|
||||
tensor->data.f[j] = lstm_input[j];
|
||||
}
|
||||
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
absl::flat_hash_map<int, CalibrationReader::CalibrationStats> stats;
|
||||
EXPECT_EQ(reader->GetTensorStatsAsMap(&stats), kTfLiteOk);
|
||||
|
||||
// Check the results.
|
||||
const float eps = 1e-6f;
|
||||
const std::unordered_map<int, CalibrationReader::CalibrationStats>
|
||||
expected_calibration_result = {
|
||||
// Input.
|
||||
{0, {0.200000, 0.900000}},
|
||||
// State.
|
||||
{18, {0.000000, 0.520999}},
|
||||
// State.
|
||||
{19, {0.000000, 0.711364}},
|
||||
// Output.
|
||||
{24, {0.247992, 0.520999}},
|
||||
// Intemediate_0.
|
||||
{25, {0.080045, 0.824241}},
|
||||
// Intemediate_1.
|
||||
{26, {0.080045, 0.824241}},
|
||||
// Intemediate_2.
|
||||
{27, {0.080045, 0.824241}},
|
||||
// Intemediate_3.
|
||||
{28, {0.080045, 0.824241}},
|
||||
// Intemediate_4.
|
||||
{29, {0.000000, 0.413618}},
|
||||
};
|
||||
EXPECT_EQ(expected_calibration_result.size(), stats.size());
|
||||
for (const auto& e : stats) {
|
||||
auto expected_result = expected_calibration_result.at(e.first);
|
||||
EXPECT_NEAR(e.second.min, expected_result.min, eps);
|
||||
EXPECT_NEAR(e.second.max, expected_result.max, eps);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
@ -44,7 +44,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index,
|
||||
model->subgraphs.at(subgraph_index)->operators[op_index].get();
|
||||
op_variant.op_code =
|
||||
GetBuiltinCode(model->operator_codes[op->opcode_index].get());
|
||||
if (op_variant.op_code == BuiltinOperator_LSTM) {
|
||||
if (op_variant.op_code == BuiltinOperator_LSTM ||
|
||||
op_variant.op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
|
||||
if (op->inputs.size() == 5) {
|
||||
// The 5 input ("basic") LSTM is not supported in this tooling (yet).
|
||||
op_variant.is_quantizable = false;
|
||||
@ -230,7 +231,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.version = 2;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_LSTM: {
|
||||
case BuiltinOperator_LSTM:
|
||||
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
|
||||
if (!op_variant.is_quantizable) {
|
||||
// Early exist for 5 input LSTM.
|
||||
// It is not supported in this tooling yet.
|
||||
|
Loading…
Reference in New Issue
Block a user