TFLM: Update MicroOptionalDebuger to use standard NodeAndRegistration struct.
Also a) hook this method into a unit test just to make sure the code works. b) tighten the interface to readonly. PiperOrigin-RevId: 288912437 Change-Id: I3aee7c638322983d69f8ac23be7c7c1f7ab5ddcd
This commit is contained in:
parent
3f25c2ed0f
commit
42e2f63a12
@ -233,43 +233,4 @@ TfLiteTensor* MicroInterpreter::tensor(size_t index) {
|
||||
return &context_.tensors[index];
|
||||
}
|
||||
|
||||
struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration(
|
||||
int node_index) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
struct pairTfLiteNodeAndRegistration tfNodeRegiPair;
|
||||
auto opcodes = model_->operator_codes();
|
||||
{
|
||||
const auto* op = operators_->Get(node_index);
|
||||
size_t index = op->opcode_index();
|
||||
if (index < 0 || index >= opcodes->size()) {
|
||||
error_reporter_->Report("Missing registration for opcode_index %d\n",
|
||||
index);
|
||||
}
|
||||
auto opcode = (*opcodes)[index];
|
||||
const TfLiteRegistration* registration = nullptr;
|
||||
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
|
||||
®istration);
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter_->Report("Missing registration for opcode_index %d\n",
|
||||
index);
|
||||
}
|
||||
if (registration == nullptr) {
|
||||
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
|
||||
}
|
||||
|
||||
// Disregard const qualifier to workaround with existing API.
|
||||
TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>(
|
||||
reinterpret_cast<const TfLiteIntArray*>(op->inputs()));
|
||||
TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>(
|
||||
reinterpret_cast<const TfLiteIntArray*>(op->outputs()));
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
tfNodeRegiPair.node = node;
|
||||
tfNodeRegiPair.registration = registration;
|
||||
}
|
||||
return tfNodeRegiPair;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -94,7 +94,11 @@ class MicroInterpreter {
|
||||
ErrorReporter* error_reporter() { return error_reporter_; }
|
||||
|
||||
size_t operators_size() const { return operators_->size(); }
|
||||
struct pairTfLiteNodeAndRegistration node_and_registration(int node_index);
|
||||
|
||||
// For debugging only.
|
||||
const NodeAndRegistration node_and_registration(int node_index) const {
|
||||
return node_and_registrations_[node_index];
|
||||
}
|
||||
|
||||
private:
|
||||
void CorrectTensorEndianness(TfLiteTensor* tensorCorr);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||
|
||||
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
@ -102,6 +103,9 @@ TF_LITE_MICRO_TEST(TestInterpreter) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
|
||||
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
|
||||
|
||||
// Just to make sure that this method works.
|
||||
tflite::PrintInterpreterState(&interpreter);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
||||
|
@ -121,7 +121,7 @@ void PrintInterpreterState(MicroInterpreter* interpreter) {
|
||||
|
||||
for (size_t node_index = 0; node_index < interpreter->operators_size();
|
||||
node_index++) {
|
||||
struct pairTfLiteNodeAndRegistration node_and_reg =
|
||||
const NodeAndRegistration node_and_reg =
|
||||
interpreter->node_and_registration(static_cast<int>(node_index));
|
||||
const TfLiteNode& node = node_and_reg.node;
|
||||
const TfLiteRegistration* reg = node_and_reg.registration;
|
||||
|
@ -21,20 +21,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
// Prints a dump of what tensors and what nodes are in the interpreter.
|
||||
class MicroInterpreter;
|
||||
void PrintInterpreterState(MicroInterpreter* interpreter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
struct pairTfLiteNodeAndRegistration {
|
||||
TfLiteNode node;
|
||||
const TfLiteRegistration* registration;
|
||||
};
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_MICRO_OPTIONAL_DEBUG_TOOLS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user