diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc index 2e2efec127e..115a4fe208c 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -111,12 +110,59 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, } void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer( + context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + OpData data; + int32_t buf_size; + + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + RuntimeShape input_shape = GetTensorShape(input); + + const int input_depth = input_shape.Dims(3); + const int input_width = input->dims->data[2]; + const int input_height = input->dims->data[1]; + const int filter_width = filter->dims->data[2]; + const int filter_height = filter->dims->data[1]; + const int output_width = output->dims->data[2]; + const int output_height = output->dims->data[1]; + + int* buffer_idx = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_STATUS(CalculateOpData( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, &data)); + + if (data.padding.width == 0 && + data.padding.height == 0 && (input_depth % 4 == 0) && + params->stride_width == 1 && + params->stride_height == 1 && filter_width == 1 && filter_height == 1) { + buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(input_depth); + } + else + { + buf_size = arm_convolve_s8_get_buffer_size(input_depth, filter_width, filter_height); + } + + node->user_data = buffer_idx; + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } @@ -200,15 +246,16 @@ TfLiteStatus EvalQuantizedPerChannel( const int output_width = output_shape.Dims(2); int16_t* buf = nullptr; + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void *raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + if (op_params.padding_values.width == 0 && op_params.padding_values.height == 0 && (input_depth % 4 == 0) && (output_depth % 2 == 0) && op_params.stride_width == 1 && op_params.stride_height == 1 && filter_width == 1 && filter_height == 1) { - const int32_t buf_size = - arm_convolve_1x1_s8_fast_get_buffer_size(input_depth); - if (get_cmsis_scratch_buffer(context, &buf, buf_size) != kTfLiteOk) { - return kTfLiteError; - } if (arm_convolve_1x1_s8_fast( GetTensorData(input), input_width, input_height, input_depth, batches, GetTensorData(filter), output_depth, @@ -222,11 +269,6 @@ TfLiteStatus EvalQuantizedPerChannel( return kTfLiteError; } } else { - const int32_t buf_size = arm_convolve_s8_get_buffer_size( - input_depth, filter_width, filter_height); - if (get_cmsis_scratch_buffer(context, &buf, buf_size) != kTfLiteOk) { - return kTfLiteError; - } if (arm_convolve_s8( GetTensorData(input), input_width, input_height, input_depth, batches, GetTensorData(filter), output_depth, diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc index f5543b85cb9..c1563b235ea 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -99,12 +98,40 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer( + context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + + const int filter_width = SizeOfDimension(filter, 2); + const int filter_height = SizeOfDimension(filter, 1); + + RuntimeShape input_shape = GetTensorShape(input); + const int input_depth = input_shape.Dims(3); + + int* buffer_idx = reinterpret_cast(node->user_data); + + *buffer_idx = -1; + node->user_data = buffer_idx; + + if (params->depth_multiplier == 1) { + const int32_t buf_size = arm_depthwise_conv_s8_opt_get_buffer_size(input_depth, filter_width, filter_height); + + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } + } +#endif return kTfLiteOk; } @@ -174,10 +201,12 @@ TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, if (op_params.depth_multiplier == 1) { int16_t* buf = nullptr; - const int32_t buf_size = arm_depthwise_conv_s8_opt_get_buffer_size( - input_depth, filter_width, filter_height); - TF_LITE_ENSURE_OK(context, - get_cmsis_scratch_buffer(context, &buf, buf_size)); + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void *raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + TF_LITE_ENSURE_EQ( context, arm_depthwise_conv_s8_opt( diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index 20980d726c6..b7d3e542afa 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h" namespace tflite { namespace ops { @@ -73,14 +72,33 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer( + context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // todo: call AllocateTemporaryTensor() instead of using - // get_cmsis_scratch_buffer() +#if defined(__ARM_FEATURE_DSP) + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + + RuntimeShape filter_shape = GetTensorShape(filter); + const int filter_dim_count = filter_shape.DimensionsCount(); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + + const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(accum_depth); + + int* buffer_idx = reinterpret_cast(node->user_data); + + node->user_data = buffer_idx; + if (buf_size > 0) { + context->RequestScratchBufferInArena(context, buf_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } @@ -97,9 +115,14 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const int accum_depth = filter_shape.Dims(filter_dim_count - 1); #if defined(__ARM_FEATURE_DSP) - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(accum_depth); int16_t* buf = nullptr; - TF_LITE_ENSURE_OK(context, get_cmsis_scratch_buffer(context, &buf, buf_size)); + + auto* buffer_idx = reinterpret_cast(node->user_data); + if (*buffer_idx > -1) { + void *raw = context->GetScratchBuffer(context, *buffer_idx); + buf = reinterpret_cast(raw); + } + TF_LITE_ENSURE_EQ( context, arm_fully_connected_s8( diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc index 54dcf64118e..d66423ab3e0 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/pooling.cc @@ -16,7 +16,6 @@ limitations under the License. // These are headers from the ARM CMSIS-NN library. #include "arm_nnfunctions.h" // NOLINT -#include "scratch_buffer.h" // NOLINT #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -128,10 +127,13 @@ TfLiteStatus AverageEvalInt8(TfLiteContext* context, const TfLiteNode* node, const int padding_width = data->padding.width; int16_t* scratch_buffer = nullptr; - int32_t buffer_size = arm_avgpool_s8_get_buffer_size(output_width, depth); - TF_LITE_ENSURE_OK( - context, get_cmsis_scratch_buffer(context, &scratch_buffer, buffer_size)); + auto* buffer_idx = reinterpret_cast(node->user_data); + + if (*buffer_idx > -1) { + void *raw = context->GetScratchBuffer(context, *buffer_idx); + scratch_buffer = reinterpret_cast(raw); + } TF_LITE_ENSURE_EQ( context, @@ -207,12 +209,39 @@ void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + void* raw; + context->AllocatePersistentBuffer( + context, sizeof(int), &raw); + return raw; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +#if defined(__ARM_FEATURE_DSP) + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + RuntimeShape input_shape = GetTensorShape(input); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + + RuntimeShape output_shape = GetTensorShape(output); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int output_width = output_shape.Dims(2); + + const int32_t buffer_size = arm_avgpool_s8_get_buffer_size(output_width, depth); + + int* buffer_idx = reinterpret_cast(node->user_data); + + node->user_data = buffer_idx; + if (buffer_size > 0) { + context->RequestScratchBufferInArena(context, buffer_size, buffer_idx); + } else { + *buffer_idx = -1; + } +#endif return kTfLiteOk; } diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc b/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc deleted file mode 100644 index e15a1416aeb..00000000000 --- a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "scratch_buffer.h" - -// todo: remove this function once context->AllocateTemporaryTensor() is -// implemented. - -// This buffer is used by CMSIS-NN optimized operator implementations. -// SCRATCH_BUFFER_BYTES bytes is chosen empirically. It needs to be large -// enough to hold the biggest buffer needed by all CMSIS-NN operators in the -// network. -// note: buffer must be 32-bit aligned for SIMD -#define SCRATCH_BUFFER_BYTES 13000 - -TfLiteStatus get_cmsis_scratch_buffer(TfLiteContext* context, int16_t** buf, - int32_t buf_size_bytes) { - __attribute__((aligned( - 4))) static int16_t cmsis_scratch_buffer[SCRATCH_BUFFER_BYTES / 2] = {0}; - - TF_LITE_ENSURE(context, buf_size_bytes <= SCRATCH_BUFFER_BYTES); - *buf = cmsis_scratch_buffer; - return kTfLiteOk; -} diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h b/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h deleted file mode 100644 index ba63cdfe90b..00000000000 --- a/tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ -#define TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ - -#include "tensorflow/lite/c/common.h" - -// todo: remove this function once context->AllocateTemporaryTensor() is -// implemented. -TfLiteStatus get_cmsis_scratch_buffer(TfLiteContext* context, int16_t** buf, - int32_t buf_size); - -#endif // TENSORFLOW_LITE_MICRO_KERNELS_CMSIS_NN_SCRATCH_BUFFER_H_ diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc index b4d6e505650..cfd87089a84 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc @@ -21,13 +21,6 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) THIRD_PARTY_CC_HDRS += \ $(call recursive_find,$(CMSIS_PATH)/CMSIS/Core/Include,*.h) - # todo: remove the two lines below once context->AllocateTemporaryTensor() - # is implemented. - MICROLITE_CC_HDRS += \ - tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.h - MICROLITE_CC_SRCS += \ - tensorflow/lite/micro/kernels/cmsis-nn/scratch_buffer.cc - INCLUDES += -I$(CMSIS_PATH)/CMSIS/Core/Include \ -I$(CMSIS_PATH)/CMSIS/NN/Include \ -I$(CMSIS_PATH)/CMSIS/DSP/Include diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc index 539f4528d06..29e2143286c 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc @@ -76,6 +76,8 @@ ifeq ($(TARGET), stm32f4) tensorflow/lite/micro/kernels/dequantize_test.cc \ tensorflow/lite/micro/kernels/unpack_test.cc \ tensorflow/lite/micro/kernels/split_test.cc \ + tensorflow/lite/micro/kernels/conv_test.cc \ + tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ tensorflow/lite/micro/simple_tensor_allocator_test.cc MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) EXCLUDED_EXAMPLE_TESTS := \