TFLM: Update MicroOptionalDebuger to use standard NodeAndRegistration struct.
Also a) hook this method into a unit test just to make sure the code works. b) tighten the interface to readonly. PiperOrigin-RevId: 288912437 Change-Id: I3aee7c638322983d69f8ac23be7c7c1f7ab5ddcd
This commit is contained in:
parent
3f25c2ed0f
commit
42e2f63a12
@ -233,43 +233,4 @@ TfLiteTensor* MicroInterpreter::tensor(size_t index) {
|
|||||||
return &context_.tensors[index];
|
return &context_.tensors[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration(
|
|
||||||
int node_index) {
|
|
||||||
TfLiteStatus status = kTfLiteOk;
|
|
||||||
struct pairTfLiteNodeAndRegistration tfNodeRegiPair;
|
|
||||||
auto opcodes = model_->operator_codes();
|
|
||||||
{
|
|
||||||
const auto* op = operators_->Get(node_index);
|
|
||||||
size_t index = op->opcode_index();
|
|
||||||
if (index < 0 || index >= opcodes->size()) {
|
|
||||||
error_reporter_->Report("Missing registration for opcode_index %d\n",
|
|
||||||
index);
|
|
||||||
}
|
|
||||||
auto opcode = (*opcodes)[index];
|
|
||||||
const TfLiteRegistration* registration = nullptr;
|
|
||||||
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
|
|
||||||
®istration);
|
|
||||||
if (status != kTfLiteOk) {
|
|
||||||
error_reporter_->Report("Missing registration for opcode_index %d\n",
|
|
||||||
index);
|
|
||||||
}
|
|
||||||
if (registration == nullptr) {
|
|
||||||
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disregard const qualifier to workaround with existing API.
|
|
||||||
TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>(
|
|
||||||
reinterpret_cast<const TfLiteIntArray*>(op->inputs()));
|
|
||||||
TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>(
|
|
||||||
reinterpret_cast<const TfLiteIntArray*>(op->outputs()));
|
|
||||||
|
|
||||||
TfLiteNode node;
|
|
||||||
node.inputs = inputs_array;
|
|
||||||
node.outputs = outputs_array;
|
|
||||||
tfNodeRegiPair.node = node;
|
|
||||||
tfNodeRegiPair.registration = registration;
|
|
||||||
}
|
|
||||||
return tfNodeRegiPair;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -94,7 +94,11 @@ class MicroInterpreter {
|
|||||||
ErrorReporter* error_reporter() { return error_reporter_; }
|
ErrorReporter* error_reporter() { return error_reporter_; }
|
||||||
|
|
||||||
size_t operators_size() const { return operators_->size(); }
|
size_t operators_size() const { return operators_->size(); }
|
||||||
struct pairTfLiteNodeAndRegistration node_and_registration(int node_index);
|
|
||||||
|
// For debugging only.
|
||||||
|
const NodeAndRegistration node_and_registration(int node_index) const {
|
||||||
|
return node_and_registrations_[node_index];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void CorrectTensorEndianness(TfLiteTensor* tensorCorr);
|
void CorrectTensorEndianness(TfLiteTensor* tensorCorr);
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/micro/micro_interpreter.h"
|
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
|
||||||
#include "tensorflow/lite/micro/test_helpers.h"
|
#include "tensorflow/lite/micro/test_helpers.h"
|
||||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||||
|
|
||||||
@ -102,6 +103,9 @@ TF_LITE_MICRO_TEST(TestInterpreter) {
|
|||||||
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
|
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
|
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
|
||||||
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
|
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
|
||||||
|
|
||||||
|
// Just to make sure that this method works.
|
||||||
|
tflite::PrintInterpreterState(&interpreter);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TESTS_END
|
TF_LITE_MICRO_TESTS_END
|
||||||
|
@ -121,7 +121,7 @@ void PrintInterpreterState(MicroInterpreter* interpreter) {
|
|||||||
|
|
||||||
for (size_t node_index = 0; node_index < interpreter->operators_size();
|
for (size_t node_index = 0; node_index < interpreter->operators_size();
|
||||||
node_index++) {
|
node_index++) {
|
||||||
struct pairTfLiteNodeAndRegistration node_and_reg =
|
const NodeAndRegistration node_and_reg =
|
||||||
interpreter->node_and_registration(static_cast<int>(node_index));
|
interpreter->node_and_registration(static_cast<int>(node_index));
|
||||||
const TfLiteNode& node = node_and_reg.node;
|
const TfLiteNode& node = node_and_reg.node;
|
||||||
const TfLiteRegistration* reg = node_and_reg.registration;
|
const TfLiteRegistration* reg = node_and_reg.registration;
|
||||||
|
@ -21,20 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
// Prints a dump of what tensors and what nodes are in the interpreter.
|
// Prints a dump of what tensors and what nodes are in the interpreter.
|
||||||
class MicroInterpreter;
|
|
||||||
void PrintInterpreterState(MicroInterpreter* interpreter);
|
void PrintInterpreterState(MicroInterpreter* interpreter);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif // __cplusplus
|
|
||||||
struct pairTfLiteNodeAndRegistration {
|
|
||||||
TfLiteNode node;
|
|
||||||
const TfLiteRegistration* registration;
|
|
||||||
};
|
|
||||||
#ifdef __cplusplus
|
|
||||||
} // extern "C"
|
|
||||||
#endif // __cplusplus
|
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MICRO_MICRO_OPTIONAL_DEBUG_TOOLS_H_
|
#endif // TENSORFLOW_LITE_MICRO_MICRO_OPTIONAL_DEBUG_TOOLS_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user