TFLM: Update MicroOptionalDebuger to use standard NodeAndRegistration struct.

Also
a) hook this method into a unit test just to make sure the code works.
b) tighten the interface to readonly.

PiperOrigin-RevId: 288912437
Change-Id: I3aee7c638322983d69f8ac23be7c7c1f7ab5ddcd
This commit is contained in:
Tiezhen WANG 2020-01-09 09:31:36 -08:00 committed by TensorFlower Gardener
parent 3f25c2ed0f
commit 42e2f63a12
5 changed files with 10 additions and 54 deletions

View File

@ -233,43 +233,4 @@ TfLiteTensor* MicroInterpreter::tensor(size_t index) {
return &context_.tensors[index];
}
struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration(
int node_index) {
TfLiteStatus status = kTfLiteOk;
struct pairTfLiteNodeAndRegistration tfNodeRegiPair;
auto opcodes = model_->operator_codes();
{
const auto* op = operators_->Get(node_index);
size_t index = op->opcode_index();
if (index < 0 || index >= opcodes->size()) {
error_reporter_->Report("Missing registration for opcode_index %d\n",
index);
}
auto opcode = (*opcodes)[index];
const TfLiteRegistration* registration = nullptr;
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
&registration);
if (status != kTfLiteOk) {
error_reporter_->Report("Missing registration for opcode_index %d\n",
index);
}
if (registration == nullptr) {
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
}
// Disregard const qualifier to workaround with existing API.
TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>(
reinterpret_cast<const TfLiteIntArray*>(op->inputs()));
TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>(
reinterpret_cast<const TfLiteIntArray*>(op->outputs()));
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
tfNodeRegiPair.node = node;
tfNodeRegiPair.registration = registration;
}
return tfNodeRegiPair;
}
} // namespace tflite

View File

@ -94,7 +94,11 @@ class MicroInterpreter {
ErrorReporter* error_reporter() { return error_reporter_; }
size_t operators_size() const { return operators_->size(); }
struct pairTfLiteNodeAndRegistration node_and_registration(int node_index);
// For debugging only.
const NodeAndRegistration node_and_registration(int node_index) const {
return node_and_registrations_[node_index];
}
private:
void CorrectTensorEndianness(TfLiteTensor* tensorCorr);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
@ -102,6 +103,9 @@ TF_LITE_MICRO_TEST(TestInterpreter) {
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
// Just to make sure that this method works.
tflite::PrintInterpreterState(&interpreter);
}
TF_LITE_MICRO_TESTS_END

View File

@ -121,7 +121,7 @@ void PrintInterpreterState(MicroInterpreter* interpreter) {
for (size_t node_index = 0; node_index < interpreter->operators_size();
node_index++) {
struct pairTfLiteNodeAndRegistration node_and_reg =
const NodeAndRegistration node_and_reg =
interpreter->node_and_registration(static_cast<int>(node_index));
const TfLiteNode& node = node_and_reg.node;
const TfLiteRegistration* reg = node_and_reg.registration;

View File

@ -21,20 +21,7 @@ limitations under the License.
namespace tflite {
// Prints a dump of what tensors and what nodes are in the interpreter.
class MicroInterpreter;
void PrintInterpreterState(MicroInterpreter* interpreter);
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
struct pairTfLiteNodeAndRegistration {
TfLiteNode node;
const TfLiteRegistration* registration;
};
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_OPTIONAL_DEBUG_TOOLS_H_