Integrate speech commands results over time to give more accurate predictions

PiperOrigin-RevId: 226948751
This commit is contained in:
Pete Warden 2018-12-26 13:28:29 -08:00 committed by TensorFlower Gardener
parent 83cb1f1c5e
commit 6d92ee85a8
15 changed files with 612 additions and 52 deletions

View File

@ -176,7 +176,6 @@ cc_library(
":audio_provider",
":model_settings",
":preprocessor_reference",
":timer",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
],
@ -191,7 +190,6 @@ tflite_micro_cc_test(
":audio_provider",
":feature_provider",
":model_settings",
":timer",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
@ -221,6 +219,34 @@ tflite_micro_cc_test(
],
)
cc_library(
name = "recognize_commands",
srcs = [
"recognize_commands.cc",
],
hdrs = [
"recognize_commands.h",
],
deps = [
":model_settings",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
],
)
tflite_micro_cc_test(
name = "recognize_commands_test",
srcs = [
"recognize_commands_test.cc",
],
deps = [
":recognize_commands",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)
cc_binary(
name = "micro_speech",
srcs = [
@ -232,6 +258,7 @@ cc_binary(
":features_test_data",
":model_settings",
":preprocessor_reference",
":recognize_commands",
":timer",
":tiny_conv_model_data",
"//tensorflow/lite:schema_fbs_version",

View File

@ -91,7 +91,6 @@ FEATURE_PROVIDER_TEST_SRCS := \
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
ALL_SRCS += $(FEATURE_PROVIDER_TEST_SRCS)
@ -128,6 +127,26 @@ timer_test_bin: $(TIMER_TEST_BINARY).bin
test_timer: $(TIMER_TEST_BINARY)
$(TEST_SCRIPT) $(TIMER_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
# Tests the feature provider module.
RECOGNIZE_COMMANDS_TEST_SRCS := \
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands_test.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
ALL_SRCS += $(RECOGNIZE_COMMANDS_TEST_SRCS)
RECOGNIZE_COMMANDS_TEST_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(RECOGNIZE_COMMANDS_TEST_SRCS))))
RECOGNIZE_COMMANDS_TEST_BINARY := $(BINDIR)recognize_commands_test
ALL_BINARIES += $(RECOGNIZE_COMMANDS_TEST_BINARY)
$(RECOGNIZE_COMMANDS_TEST_BINARY): $(RECOGNIZE_COMMANDS_TEST_OBJS) $(MICROLITE_LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
-o $(RECOGNIZE_COMMANDS_TEST_BINARY) $(RECOGNIZE_COMMANDS_TEST_OBJS) \
$(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
recognize_commands_test: $(RECOGNIZE_COMMANDS_TEST_BINARY)
recognize_commands_test_bin: $(RECOGNIZE_COMMANDS_TEST_BINARY).bin
test_recognize_commands: $(RECOGNIZE_COMMANDS_TEST_BINARY)
$(TEST_SCRIPT) $(RECOGNIZE_COMMANDS_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
# Builds a standalone speech command recognizer binary.
MICRO_SPEECH_SRCS := \
tensorflow/lite/experimental/micro/examples/micro_speech/main.cc \
@ -138,7 +157,8 @@ tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/no_features_data.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/yes_features_data.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.cc
ALL_SRCS += $(MICRO_SPEECH_SRCS)
MICRO_SPEECH_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_SRCS))))

View File

@ -18,20 +18,11 @@ limitations under the License.
#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
namespace {
// Stores the timestamp for the previous fetch of audio data, so that we can
// avoid recalculating all the features from scratch if some earlier timeslices
// are still present.
int32_t g_last_time_in_ms = 0;
// Make sure we don't try to use cached information if this is the first call
// into the provider.
bool g_is_first_run = true;
} // namespace
FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
: feature_size_(feature_size), feature_data_(feature_data) {
: feature_size_(feature_size),
feature_data_(feature_data),
is_first_run_(true) {
// Initialize the feature data to default values.
for (int n = 0; n < feature_size_; ++n) {
feature_data_[n] = 0;
@ -41,24 +32,23 @@ FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
FeatureProvider::~FeatureProvider() {}
TfLiteStatus FeatureProvider::PopulateFeatureData(
tflite::ErrorReporter* error_reporter, int* how_many_new_slices) {
tflite::ErrorReporter* error_reporter, int32_t last_time_in_ms,
int32_t time_in_ms, int* how_many_new_slices) {
if (feature_size_ != kFeatureElementCount) {
error_reporter->Report("Requested feature_data_ size %d doesn't match %d",
feature_size_, kFeatureElementCount);
return kTfLiteError;
}
const int32_t time_in_ms = TimeInMilliseconds();
// Quantize the time into steps as long as each window stride, so we can
// figure out which audio data we need to fetch.
const int last_step = (g_last_time_in_ms / kFeatureSliceStrideMs);
const int last_step = (last_time_in_ms / kFeatureSliceStrideMs);
const int current_step = (time_in_ms / kFeatureSliceStrideMs);
g_last_time_in_ms = time_in_ms;
int slices_needed = current_step - last_step;
// If this is the first call, make sure we don't use any cached information.
if (g_is_first_run) {
g_is_first_run = false;
if (is_first_run_) {
is_first_run_ = false;
slices_needed = kFeatureSliceCount;
}
if (slices_needed > kFeatureSliceCount) {

View File

@ -38,11 +38,15 @@ class FeatureProvider {
// Fills the feature data with information from audio inputs, and returns how
// many feature slices were updated.
TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
int32_t last_time_in_ms, int32_t time_in_ms,
int* how_many_new_slices);
private:
int feature_size_;
uint8_t* feature_data_;
// Make sure we don't try to use cached information if this is the first call
// into the provider.
bool is_first_run_;
};
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_

View File

@ -30,7 +30,8 @@ TF_LITE_MICRO_TEST(TestFeatureProvider) {
int how_many_new_slices = 0;
TfLiteStatus populate_status = feature_provider.PopulateFeatureData(
error_reporter, &how_many_new_slices);
error_reporter, /* last_time_in_ms= */ 0, /* time_in_ms= */ 10000,
&how_many_new_slices);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status);
TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices);
}

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
@ -68,16 +70,21 @@ int main(int argc, char* argv[]) {
FeatureProvider feature_provider(kFeatureElementCount,
model_input->data.uint8);
RecognizeCommands recognizer(error_reporter);
int32_t previous_time = 0;
// Keep reading and analysing audio data in an infinite loop.
while (true) {
// Fetch the spectrogram for the current time.
const int32_t current_time = TimeInMilliseconds();
int how_many_new_slices = 0;
TfLiteStatus feature_status = feature_provider.PopulateFeatureData(
error_reporter, &how_many_new_slices);
error_reporter, previous_time, current_time, &how_many_new_slices);
if (feature_status != kTfLiteOk) {
error_reporter->Report("Feature generation failed");
return 1;
}
previous_time = current_time;
// If no new audio samples have been received since last time, don't bother
// running the network model.
if (how_many_new_slices == 0) {
@ -105,7 +112,19 @@ int main(int argc, char* argv[]) {
}
}
error_reporter->Report("Heard %s", kCategoryLabels[top_category_index]);
const char* found_command = nullptr;
uint8_t score = 0;
bool is_new_command = false;
TfLiteStatus process_status = recognizer.ProcessLatestResults(
output, current_time, &found_command, &score, &is_new_command);
if (process_status != kTfLiteOk) {
error_reporter->Report(
"RecognizeCommands::ProcessLatestResults() failed");
return 1;
}
if (is_new_command) {
error_reporter->Report("Heard %s (%d)", found_command, score);
}
}
return 0;

View File

@ -0,0 +1,139 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
#include <limits>
RecognizeCommands::RecognizeCommands(tflite::ErrorReporter* error_reporter,
int32_t average_window_duration_ms,
uint8_t detection_threshold,
int32_t suppression_ms,
int32_t minimum_count)
: error_reporter_(error_reporter),
average_window_duration_ms_(average_window_duration_ms),
detection_threshold_(detection_threshold),
suppression_ms_(suppression_ms),
minimum_count_(minimum_count),
previous_results_(error_reporter) {
previous_top_label_ = "_silence_";
previous_top_label_time_ = 0;
}
TfLiteStatus RecognizeCommands::ProcessLatestResults(
const TfLiteTensor* latest_results, const int32_t current_time_ms,
const char** found_command, uint8_t* score, bool* is_new_command) {
if ((latest_results->dims->size != 2) ||
(latest_results->dims->data[0] != 1) ||
(latest_results->dims->data[1] != kCategoryCount)) {
error_reporter_->Report(
"The results for recognition should contain %d elements, but there are "
"%d in an %d-dimensional shape",
kCategoryCount, latest_results->dims->data[1],
latest_results->dims->size);
return kTfLiteError;
}
if (latest_results->type != kTfLiteUInt8) {
error_reporter_->Report(
"The results for recognition should be uint8 elements, but are %d",
latest_results->type);
return kTfLiteError;
}
if ((!previous_results_.empty()) &&
(current_time_ms < previous_results_.front().time_)) {
error_reporter_->Report(
"Results must be fed in increasing time order, but received a "
"timestamp of %d that was earlier than the previous one of %d",
current_time_ms, previous_results_.front().time_);
return kTfLiteError;
}
// Add the latest results to the head of the queue.
previous_results_.push_back({current_time_ms, latest_results->data.uint8});
// Prune any earlier results that are too old for the averaging window.
const int64_t time_limit = current_time_ms - average_window_duration_ms_;
while ((!previous_results_.empty()) &&
previous_results_.front().time_ < time_limit) {
previous_results_.pop_front();
}
// If there are too few results, assume the result will be unreliable and
// bail.
const int64_t how_many_results = previous_results_.size();
const int64_t earliest_time = previous_results_.front().time_;
const int64_t samples_duration = current_time_ms - earliest_time;
if ((how_many_results < minimum_count_) ||
(samples_duration < (average_window_duration_ms_ / 4))) {
*found_command = previous_top_label_;
*score = 0;
*is_new_command = false;
return kTfLiteOk;
}
// Calculate the average score across all the results in the window.
int32_t average_scores[kCategoryCount];
for (int offset = 0; offset < previous_results_.size(); ++offset) {
PreviousResultsQueue::Result previous_result =
previous_results_.from_front(offset);
const uint8_t* scores = previous_result.scores_;
for (int i = 0; i < kCategoryCount; ++i) {
if (offset == 0) {
average_scores[i] = scores[i];
} else {
average_scores[i] += scores[i];
}
}
}
for (int i = 0; i < kCategoryCount; ++i) {
average_scores[i] /= how_many_results;
}
// Find the current highest scoring category.
int current_top_index = 0;
int32_t current_top_score = 0;
for (int i = 0; i < kCategoryCount; ++i) {
if (average_scores[i] > current_top_score) {
current_top_score = average_scores[i];
current_top_index = i;
}
}
const char* current_top_label = kCategoryLabels[current_top_index];
// If we've recently had another label trigger, assume one that occurs too
// soon afterwards is a bad result.
int64_t time_since_last_top;
if ((previous_top_label_ == kCategoryLabels[0]) ||
(previous_top_label_time_ == std::numeric_limits<int32_t>::min())) {
time_since_last_top = std::numeric_limits<int32_t>::max();
} else {
time_since_last_top = current_time_ms - previous_top_label_time_;
}
if ((current_top_score > detection_threshold_) &&
(current_top_label != previous_top_label_) &&
(time_since_last_top > suppression_ms_)) {
previous_top_label_ = current_top_label;
previous_top_label_time_ = current_time_ms;
*is_new_command = true;
} else {
*is_new_command = false;
}
*found_command = current_top_label;
*score = current_top_score;
return kTfLiteOk;
}

View File

@ -0,0 +1,158 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_
#include <cstdint>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
// Partial implementation of std::dequeue, just providing the functionality
// that's needed to keep a record of previous neural network results over a
// short time period, so they can be averaged together to produce a more
// accurate overall prediction. This doesn't use any dynamic memory allocation
// so it's a better fit for microcontroller applications, but this does mean
// there are hard limits on the number of results it can store.
class PreviousResultsQueue {
public:
PreviousResultsQueue(tflite::ErrorReporter* error_reporter)
: error_reporter_(error_reporter), front_index_(0), size_(0) {}
// Data structure that holds an inference result, and the time when it
// was recorded.
struct Result {
Result() : time_(0), scores_() {}
Result(int32_t time, uint8_t* scores) : time_(time) {
for (int i = 0; i < kCategoryCount; ++i) {
scores_[i] = scores[i];
}
}
int32_t time_;
uint8_t scores_[kCategoryCount];
};
int size() { return size_; }
bool empty() { return size_ == 0; }
Result& front() { return results_[front_index_]; }
Result& back() {
int back_index = front_index_ + (size_ - 1);
if (back_index >= kMaxResults) {
back_index -= kMaxResults;
}
return results_[back_index];
}
void push_back(const Result& entry) {
if (size() >= kMaxResults) {
error_reporter_->Report(
"Couldn't push_back latest result, too many already!");
return;
}
size_ += 1;
back() = entry;
}
Result pop_front() {
if (size() <= 0) {
error_reporter_->Report("Couldn't pop_front result, none present!");
return Result();
}
Result result = front();
front_index_ += 1;
if (front_index_ >= kMaxResults) {
front_index_ = 0;
}
size_ -= 1;
return result;
}
// Most of the functions are duplicates of dequeue containers, but this
// is a helper that makes it easy to iterate through the contents of the
// queue.
Result& from_front(int offset) {
if ((offset < 0) || (offset >= size_)) {
error_reporter_->Report("Attempt to read beyond the end of the queue!");
offset = size_ - 1;
}
int index = front_index_ + offset;
if (index >= kMaxResults) {
index -= kMaxResults;
}
return results_[index];
}
private:
tflite::ErrorReporter* error_reporter_;
static constexpr int kMaxResults = 50;
Result results_[kMaxResults];
int front_index_;
int size_;
};
// This class is designed to apply a very primitive decoding model on top of the
// instantaneous results from running an audio recognition model on a single
// window of samples. It applies smoothing over time so that noisy individual
// label scores are averaged, increasing the confidence that apparent matches
// are real.
// To use it, you should create a class object with the configuration you
// want, and then feed results from running a TensorFlow model into the
// processing method. The timestamp for each subsequent call should be
// increasing from the previous, since the class is designed to process a stream
// of data over time.
class RecognizeCommands {
public:
// labels should be a list of the strings associated with each one-hot score.
// The window duration controls the smoothing. Longer durations will give a
// higher confidence that the results are correct, but may miss some commands.
// The detection threshold has a similar effect, with high values increasing
// the precision at the cost of recall. The minimum count controls how many
// results need to be in the averaging window before it's seen as a reliable
// average. This prevents erroneous results when the averaging window is
// initially being populated for example. The suppression argument disables
// further recognitions for a set time after one has been triggered, which can
// help reduce spurious recognitions.
explicit RecognizeCommands(tflite::ErrorReporter* error_reporter,
int32_t average_window_duration_ms = 1000,
uint8_t detection_threshold = 51,
int32_t suppression_ms = 500,
int32_t minimum_count = 3);
// Call this with the results of running a model on sample data.
TfLiteStatus ProcessLatestResults(const TfLiteTensor* latest_results,
const int32_t current_time_ms,
const char** found_command, uint8_t* score,
bool* is_new_command);
private:
// Configuration
tflite::ErrorReporter* error_reporter_;
int32_t average_window_duration_ms_;
uint8_t detection_threshold_;
int32_t suppression_ms_;
int32_t minimum_count_;
// Working variables
PreviousResultsQueue previous_results_;
int previous_results_head_;
int previous_results_tail_;
const char* previous_top_label_;
int32_t previous_top_label_time_;
};
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_RECOGNIZE_COMMANDS_H_

View File

@ -0,0 +1,207 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/micro/examples/micro_speech/recognize_commands.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
PreviousResultsQueue queue(error_reporter);
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
uint8_t scores_a[4] = {0, 0, 0, 1};
queue.push_back({0, scores_a});
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
TF_LITE_MICRO_EXPECT_EQ(0, queue.back().time_);
uint8_t scores_b[4] = {0, 0, 1, 0};
queue.push_back({1, scores_b});
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_);
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
PreviousResultsQueue::Result pop_result = queue.pop_front();
TF_LITE_MICRO_EXPECT_EQ(0, pop_result.time_);
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_);
uint8_t scores_c[4] = {0, 1, 0, 0};
queue.push_back({2, scores_c});
TF_LITE_MICRO_EXPECT_EQ(2, queue.size());
TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_);
TF_LITE_MICRO_EXPECT_EQ(2, queue.back().time_);
}
TF_LITE_MICRO_TEST(PreviousResultsQueuePushPop) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
PreviousResultsQueue queue(error_reporter);
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
for (int i = 0; i < 123; ++i) {
uint8_t scores[4] = {0, 0, 0, 1};
queue.push_back({i, scores});
TF_LITE_MICRO_EXPECT_EQ(1, queue.size());
TF_LITE_MICRO_EXPECT_EQ(i, queue.front().time_);
TF_LITE_MICRO_EXPECT_EQ(i, queue.back().time_);
PreviousResultsQueue::Result pop_result = queue.pop_front();
TF_LITE_MICRO_EXPECT_EQ(i, pop_result.time_);
TF_LITE_MICRO_EXPECT_EQ(0, queue.size());
}
}
TF_LITE_MICRO_TEST(RecognizeCommandsTestBasic) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
RecognizeCommands recognize_commands(error_reporter);
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
{255, 0, 0, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
"input_tensor", 0.0f, 128.0f);
const char* found_command;
uint8_t score;
bool is_new_command;
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&results, 0, &found_command, &score, &is_new_command));
}
TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor(
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
"input_tensor", 0.0f, 128.0f);
bool has_found_new_command = false;
const char* new_command;
for (int i = 0; i < 10; ++i) {
const char* found_command;
uint8_t score;
bool is_new_command;
int32_t current_time_ms = 0 + (i * 100);
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&yes_results, current_time_ms, &found_command, &score,
&is_new_command));
if (is_new_command) {
TF_LITE_MICRO_EXPECT(!has_found_new_command);
has_found_new_command = true;
new_command = found_command;
}
}
TF_LITE_MICRO_EXPECT(has_found_new_command);
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command));
TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor(
{0, 0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
"input_tensor", 0.0f, 128.0f);
has_found_new_command = false;
new_command = "";
uint8_t score;
for (int i = 0; i < 10; ++i) {
const char* found_command;
bool is_new_command;
int32_t current_time_ms = 1000 + (i * 100);
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&no_results, current_time_ms, &found_command, &score,
&is_new_command));
if (is_new_command) {
TF_LITE_MICRO_EXPECT(!has_found_new_command);
has_found_new_command = true;
new_command = found_command;
}
}
TF_LITE_MICRO_EXPECT(has_found_new_command);
TF_LITE_MICRO_EXPECT_EQ(231, score);
TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("no", new_command));
}
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor(
{0, 0, 255}, tflite::testing::IntArrayFromInitializer({2, 1, 3}),
"input_tensor", 0.0f, 128.0f);
const char* found_command;
uint8_t score;
bool is_new_command;
TF_LITE_MICRO_EXPECT_NE(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&bad_results, 0, &found_command, &score, &is_new_command));
}
TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputTimes) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
"input_tensor", 0.0f, 128.0f);
const char* found_command;
uint8_t score;
bool is_new_command;
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&results, 100, &found_command, &score, &is_new_command));
TF_LITE_MICRO_EXPECT_NE(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&results, 0, &found_command, &score, &is_new_command));
}
TF_LITE_MICRO_TEST(RecognizeCommandsTestTooFewInputs) {
tflite::MicroErrorReporter micro_error_reporter;
tflite::ErrorReporter* error_reporter = &micro_error_reporter;
RecognizeCommands recognize_commands(error_reporter, 1000, 51);
TfLiteTensor results = tflite::testing::CreateQuantizedTensor(
{0, 0, 255, 0}, tflite::testing::IntArrayFromInitializer({2, 1, 4}),
"input_tensor", 0.0f, 128.0f);
const char* found_command;
uint8_t score;
bool is_new_command;
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, recognize_commands.ProcessLatestResults(
&results, 100, &found_command, &score, &is_new_command));
TF_LITE_MICRO_EXPECT_EQ(0, score);
TF_LITE_MICRO_EXPECT_EQ(false, is_new_command);
}
TF_LITE_MICRO_TESTS_END

View File

@ -48,22 +48,6 @@ cc_library(
],
)
cc_library(
name = "test_utils",
srcs = [
],
hdrs = [
"test_utils.h",
],
copts = tflite_copts(),
deps = [
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)
tflite_micro_cc_test(
name = "depthwise_conv_test",
srcs = [
@ -71,7 +55,6 @@ tflite_micro_cc_test(
],
deps = [
":all_ops_resolver",
":test_utils",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
@ -85,7 +68,6 @@ tflite_micro_cc_test(
],
deps = [
":all_ops_resolver",
":test_utils",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
@ -99,7 +81,6 @@ tflite_micro_cc_test(
],
deps = [
":all_ops_resolver",
":test_utils",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
namespace tflite {
namespace testing {

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
namespace tflite {
namespace testing {

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
namespace tflite {
namespace testing {

View File

@ -10,8 +10,10 @@ cc_library(
name = "micro_test",
hdrs = [
"micro_test.h",
"test_utils.h",
],
deps = [
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
],
)

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
#include <cstdarg>
#include <initializer_list>
@ -21,8 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/experimental/micro/kernels/test_utils.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
namespace tflite {
@ -164,7 +163,20 @@ inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
}
// Do a simple string comparison for testing purposes, without requiring the
// standard C library.
inline int TestStrcmp(const char* a, const char* b) {
if ((a == nullptr) || (b == nullptr)) {
return -1;
}
while ((*a != 0) && (*a == *b)) {
a++;
b++;
}
return *(const unsigned char*)a - *(const unsigned char*)b;
}
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_