diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 9002dfcb188..c8941f03cab 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -233,43 +233,4 @@ TfLiteTensor* MicroInterpreter::tensor(size_t index) { return &context_.tensors[index]; } -struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration( - int node_index) { - TfLiteStatus status = kTfLiteOk; - struct pairTfLiteNodeAndRegistration tfNodeRegiPair; - auto opcodes = model_->operator_codes(); - { - const auto* op = operators_->Get(node_index); - size_t index = op->opcode_index(); - if (index < 0 || index >= opcodes->size()) { - error_reporter_->Report("Missing registration for opcode_index %d\n", - index); - } - auto opcode = (*opcodes)[index]; - const TfLiteRegistration* registration = nullptr; - status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, - ®istration); - if (status != kTfLiteOk) { - error_reporter_->Report("Missing registration for opcode_index %d\n", - index); - } - if (registration == nullptr) { - error_reporter_->Report("Skipping op for opcode_index %d\n", index); - } - - // Disregard const qualifier to workaround with existing API. - TfLiteIntArray* inputs_array = const_cast( - reinterpret_cast(op->inputs())); - TfLiteIntArray* outputs_array = const_cast( - reinterpret_cast(op->outputs())); - - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - tfNodeRegiPair.node = node; - tfNodeRegiPair.registration = registration; - } - return tfNodeRegiPair; -} - } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 5f6a2295e9d..4c15853e298 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -94,7 +94,11 @@ class MicroInterpreter { ErrorReporter* error_reporter() { return error_reporter_; } size_t operators_size() const { return operators_->size(); } - struct pairTfLiteNodeAndRegistration node_and_registration(int node_index); + + // For debugging only. + const NodeAndRegistration node_and_registration(int node_index) const { + return node_and_registrations_[node_index]; + } private: void CorrectTensorEndianness(TfLiteTensor* tensorCorr); diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index 338074685e5..f4983b5593b 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/micro/micro_optional_debug_tools.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -102,6 +103,9 @@ TF_LITE_MICRO_TEST(TestInterpreter) { TF_LITE_MICRO_EXPECT_EQ(4, output->bytes); TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32); TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]); + + // Just to make sure that this method works. + tflite::PrintInterpreterState(&interpreter); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc index 31a31ec90b8..bc69eb55315 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.cc +++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc @@ -121,7 +121,7 @@ void PrintInterpreterState(MicroInterpreter* interpreter) { for (size_t node_index = 0; node_index < interpreter->operators_size(); node_index++) { - struct pairTfLiteNodeAndRegistration node_and_reg = + const NodeAndRegistration node_and_reg = interpreter->node_and_registration(static_cast(node_index)); const TfLiteNode& node = node_and_reg.node; const TfLiteRegistration* reg = node_and_reg.registration; diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.h b/tensorflow/lite/micro/micro_optional_debug_tools.h index 70fe6f899da..ae96b62ab3c 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.h +++ b/tensorflow/lite/micro/micro_optional_debug_tools.h @@ -21,20 +21,7 @@ limitations under the License. namespace tflite { // Prints a dump of what tensors and what nodes are in the interpreter. -class MicroInterpreter; void PrintInterpreterState(MicroInterpreter* interpreter); - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus -struct pairTfLiteNodeAndRegistration { - TfLiteNode node; - const TfLiteRegistration* registration; -}; -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_MICRO_OPTIONAL_DEBUG_TOOLS_H_