Generalize circular buffer operator.
PiperOrigin-RevId: 343587123 Change-Id: Ifaea8ffe1e067b4119e529053aa13874c87c984b
This commit is contained in:
parent
0957e53d4a
commit
05ae49f2f6
@ -65,45 +65,49 @@ struct OpData {
|
||||
int cycles_max;
|
||||
};
|
||||
|
||||
// These constants represent constants specific to the music detect model.
|
||||
// They exist until (b/132070898) is fixed.
|
||||
constexpr int kMaxOpDataSize = 7;
|
||||
int op_data_counter = 0;
|
||||
OpData op_data_array[kMaxOpDataSize];
|
||||
|
||||
} // namespace
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) { op_data_counter = 0; }
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[0]);
|
||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[0]);
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->data[0], output->dims->data[0]);
|
||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]);
|
||||
TF_LITE_ENSURE_EQ(context, 1, output->dims->data[2]);
|
||||
TF_LITE_ENSURE_EQ(context, 1, input->dims->data[2]);
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]);
|
||||
TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]);
|
||||
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
// The circular buffer custom operator currently only supports int8_t.
|
||||
// The circular buffer custom operator currently only supports int8.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
|
||||
|
||||
// TODO(b/132070898): Use statically slotted OpData structures until a
|
||||
// scratch memory API is ready.
|
||||
TFLITE_DCHECK_LE(op_data_counter, kMaxOpDataSize);
|
||||
OpData* op_data = &op_data_array[op_data_counter++];
|
||||
// The last circular buffer layer (length 5) simply accumulates outputs, and
|
||||
// does not run periodically.
|
||||
// The last circular buffer layer simply accumulates outputs, and does not run
|
||||
// periodically.
|
||||
// TODO(b/150001379): Move this special case logic to the tflite flatbuffer.
|
||||
if (output->dims->data[1] == 5) {
|
||||
static int cb_prepare_count = 0;
|
||||
cb_prepare_count++;
|
||||
// These checks specifically work for the only two streaming models supported
|
||||
// on TFLM. They use the shape of the output tensor along with the layer
|
||||
// number to determine if the circular buffer period should be 1 or 2.
|
||||
|
||||
// These models are outlined int the following documents:
|
||||
// https://docs.google.com/document/d/1lc_G2ZFhjiKFo02UHjBaljye1xsL0EkfybkaVELEE3Q/edit?usp=sharing
|
||||
// https://docs.google.com/document/d/1pGc42PuWyrk-Jy1-9qeqtggvsmHr1ifz8Lmqfpr2rKA/edit?usp=sharing
|
||||
if (output->dims->data[1] == 5 || output->dims->data[1] == 13 ||
|
||||
(cb_prepare_count == 5 && output->dims->data[2] == 2 &&
|
||||
output->dims->data[3] == 96)) {
|
||||
op_data->cycles_max = 1;
|
||||
cb_prepare_count = 0;
|
||||
} else {
|
||||
op_data->cycles_max = 2;
|
||||
}
|
||||
@ -127,10 +131,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
int num_slots = output->dims->data[1];
|
||||
int depth = output->dims->data[3];
|
||||
int depth = output->dims->data[2] * output->dims->data[3];
|
||||
|
||||
if (input->type == kTfLiteInt8) {
|
||||
EvalInt8(tflite::micro::GetTensorData<int8_t>(input), num_slots, depth,
|
||||
@ -148,12 +153,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return static_cast<TfLiteStatus>(kTfLiteAbort);
|
||||
}
|
||||
|
||||
// If prepare is ever called more than one time (for example, when testing the
|
||||
// ambient model, the interpreter is created a few times), this op data
|
||||
// counter needs to be reset so that future instances do not overrun this op
|
||||
// data array.
|
||||
op_data_counter = 0;
|
||||
|
||||
data->cycles_until_run = data->cycles_max;
|
||||
|
||||
return kTfLiteOk;
|
||||
@ -162,8 +161,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace circular_buffer
|
||||
|
||||
TfLiteRegistration* Register_CIRCULAR_BUFFER() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/circular_buffer::Free,
|
||||
static TfLiteRegistration r = {/*init=*/circular_buffer::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/circular_buffer::Prepare,
|
||||
/*invoke=*/circular_buffer::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user