Generalize circular buffer operator.
PiperOrigin-RevId: 343587123 Change-Id: Ifaea8ffe1e067b4119e529053aa13874c87c984b
This commit is contained in:
parent
0957e53d4a
commit
05ae49f2f6
@ -65,45 +65,49 @@ struct OpData {
|
|||||||
int cycles_max;
|
int cycles_max;
|
||||||
};
|
};
|
||||||
|
|
||||||
// These constants represent constants specific to the music detect model.
|
|
||||||
// They exist until (b/132070898) is fixed.
|
|
||||||
constexpr int kMaxOpDataSize = 7;
|
|
||||||
int op_data_counter = 0;
|
|
||||||
OpData op_data_array[kMaxOpDataSize];
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, input != nullptr);
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
TF_LITE_ENSURE(context, output != nullptr);
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[0]);
|
TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[0]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
|
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[2]);
|
TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
|
||||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
|
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
|
||||||
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
// The circular buffer custom operator currently only supports int8_t.
|
// The circular buffer custom operator currently only supports int8.
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||||
|
|
||||||
// TODO(b/132070898): Use statically slotted OpData structures until a
|
// The last circular buffer layer simply accumulates outputs, and does not run
|
||||||
// scratch memory API is ready.
|
// periodically.
|
||||||
TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize);
|
|
||||||
OpData* op_data = &op_data_array[op_data_counter++];
|
|
||||||
// The last circular buffer layer (length 5) simply accumulates outputs, and
|
|
||||||
// does not run periodically.
|
|
||||||
// TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
|
// TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
|
||||||
if (output->dims->data[1] == 5) {
|
static int cb_prepare_count = 0;
|
||||||
|
cb_prepare_count++;
|
||||||
|
// These checks specifically work for the only two streaming models supported
|
||||||
|
// on TFLM. They use the shape of the output tensor along with the layer
|
||||||
|
// number to determine if the circular buffer period should be 1 or 2.
|
||||||
|
|
||||||
|
// These models are outlined int the following documents:
|
||||||
|
// https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
|
||||||
|
// https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
|
||||||
|
if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
|
||||||
|
(cb_prepare_count == 5 && output->dims->data[2] == 2 &&
|
||||||
|
output->dims->data[3] == 96)) {
|
||||||
op_data->cycles_max = 1;
|
op_data->cycles_max = 1;
|
||||||
|
cb_prepare_count = 0;
|
||||||
} else {
|
} else {
|
||||||
op_data->cycles_max = 2;
|
op_data->cycles_max = 2;
|
||||||
}
|
}
|
||||||
@ -127,10 +131,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteEvalTensor* output =
|
TfLiteEvalTensor* output =
|
||||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
int num_slots = output->dims->data[1];
|
int num_slots = output->dims->data[1];
|
||||||
int depth = output->dims->data[3];
|
int depth = output->dims->data[2] * output->dims->data[3];
|
||||||
|
|
||||||
if (input->type == kTfLiteInt8) {
|
if (input->type == kTfLiteInt8) {
|
||||||
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
||||||
@ -148,12 +153,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return static_cast<TfLiteStatus>(kTfLiteAbort);
|
return static_cast<TfLiteStatus>(kTfLiteAbort);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If prepare is ever called more than one time (for example, when testing the
|
|
||||||
// ambient model, the interpreter is created a few times), this op data
|
|
||||||
// counter needs to be reset so that future instances do not overrun this op
|
|
||||||
// data array.
|
|
||||||
op_data_counter = 0;
|
|
||||||
|
|
||||||
data->cycles_until_run = data->cycles_max;
|
data->cycles_until_run = data->cycles_max;
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@ -162,8 +161,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace circular_buffer
|
} // namespace circular_buffer
|
||||||
|
|
||||||
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
||||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
|
||||||
/*free=*/circular_buffer::Free,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/circular_buffer::Prepare,
|
/*prepare=*/circular_buffer::Prepare,
|
||||||
/*invoke=*/circular_buffer::Eval,
|
/*invoke=*/circular_buffer::Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user