From 934df8dcea0c176314a52d0062dcca08638bb52d Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Mon, 22 Jun 2020 22:29:15 -0700 Subject: [PATCH] Remove portable_optimized depthwise_conv implementation. PR #34999 surfaced inconsistencies between the reference and portable_optimized implementations and we are also slowly transitioning to CMSIS-NN for the optimized implementations for ARM Cortex-M targets. There may still be an issue with the reference depthwise conv + dilation > 1 but that would be a separate bug. PiperOrigin-RevId: 317802639 Change-Id: Ic8c9acffb060bb3ec5f802a2045ac6ac8b9f2233 --- tensorflow/lite/micro/BUILD | 24 - tensorflow/lite/micro/kernels/BUILD | 81 --- .../portable_optimized/depthwise_conv.cc | 515 ------------------ .../lite/micro/tools/ci_build/test_mbed.sh | 2 +- .../make/targets/apollo3evb_makefile.inc | 3 - 5 files changed, 1 insertion(+), 624 deletions(-) delete mode 100644 tensorflow/lite/micro/kernels/portable_optimized/depthwise_conv.cc diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index f63d9778634..bdfa0c909db 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -102,30 +102,6 @@ cc_library( ], ) -# TODO(b/144176795): This target should really be handled differently so that we -# do not have a fork in the build graph. The bug has some initial ideas. -cc_library( - name = "portable_optimized_op_resolver", - srcs = [ - "all_ops_resolver.cc", - "micro_mutable_op_resolver.h", - "micro_op_resolver.h", - ], - hdrs = [ - "all_ops_resolver.h", - ], - copts = micro_copts(), - deps = [ - ":micro_compatibility", - "//tensorflow/lite/c:common", - "//tensorflow/lite/core/api", - "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/kernels/internal:compatibility", - "//tensorflow/lite/micro/kernels:portable_optimized_micro_ops", - "//tensorflow/lite/schema:schema_fbs", - ], -) - cc_library( name = "debug_log", srcs = [ diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index c7fa19b8cea..0fd0be4e3a4 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -20,7 +20,6 @@ package_group( packages = ["//tensorflow/lite/micro"], ) -# LINT.IfChange(micro_ops) cc_library( name = "micro_ops", srcs = [ @@ -106,73 +105,6 @@ cc_library( ], }), ) -# LINT.ThenChange(//tensorflow/lite/micro/kernels/BUILD:portable_optimized_micro_ops) - -# LINT.IfChange(portable_optimized_micro_ops) -cc_library( - name = "portable_optimized_micro_ops", - srcs = [ - "activations.cc", - "add.cc", - "arg_min_max.cc", - "ceil.cc", - "circular_buffer.cc", - "comparisons.cc", - "concatenation.cc", - "conv.cc", - "dequantize.cc", - "elementwise.cc", - "ethosu.cc", - "floor.cc", - "fully_connected.cc", - "l2norm.cc", - "logical.cc", - "logistic.cc", - "maximum_minimum.cc", - "mul.cc", - "neg.cc", - "pack.cc", - "pad.cc", - "pooling.cc", - "portable_optimized/depthwise_conv.cc", - "prelu.cc", - "quantize.cc", - "reduce.cc", - "reshape.cc", - "resize_nearest_neighbor.cc", - "round.cc", - "softmax.cc", - "split.cc", - "strided_slice.cc", - "sub.cc", - "svdf.cc", - "tanh.cc", - "unpack.cc", - ], - hdrs = ["micro_ops.h"], - copts = micro_copts(), - visibility = [ - # Needed for micro:portable_optimized_ops_resolver but visibility can not be - # finer-grained than a package. - ":micro_top_level", - ], - deps = [ - ":activation_utils", - ":micro_utils", - "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/kernels:op_macros", - "//tensorflow/lite/kernels:padding", - "//tensorflow/lite/kernels/internal:common", - "//tensorflow/lite/kernels/internal:compatibility", - "//tensorflow/lite/kernels/internal:quantization_util", - "//tensorflow/lite/kernels/internal:reference_base", - "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:types", - "//tensorflow/lite/micro:micro_utils", - ], -) -# LINT.ThenChange(//tensorflow/lite/micro/kernels/BUILD:micro_ops) test_suite( name = "all_tests", @@ -214,19 +146,6 @@ tflite_micro_cc_test( ], ) -tflite_micro_cc_test( - name = "portable_optimized_depthwise_conv_test", - srcs = [ - "depthwise_conv_test.cc", - ], - deps = [ - "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/micro:portable_optimized_op_resolver", - "//tensorflow/lite/micro/testing:micro_test", - ], -) - tflite_micro_cc_test( name = "fully_connected_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/portable_optimized/depthwise_conv.cc b/tensorflow/lite/micro/kernels/portable_optimized/depthwise_conv.cc deleted file mode 100644 index 9fb8f2e32cc..00000000000 --- a/tensorflow/lite/micro/kernels/portable_optimized/depthwise_conv.cc +++ /dev/null @@ -1,515 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" - -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { -namespace { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; -constexpr int kMaxChannels = 256; - -// Depthwise conv is quantized along dimension 3: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kDepthwiseConvQuantizedDimension = 3; - -// Size of the cached buffer we'll be using to hold reordered weights. -constexpr int kReshapedFilterDataSize = 1 * 1024; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Per channel output multiplier and shift. - int32_t per_channel_output_multiplier[kMaxChannels]; - int32_t per_channel_output_shift[kMaxChannels]; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, int width, - int height, int filter_width, int filter_height, - int out_width, int out_height, - const TfLiteType data_type, OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; - - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast(data->per_channel_output_shift), num_channels)); - } - return kTfLiteOk; -} - -// Specialized implementation of the depthwise convolution operation designed to -// work with the particular filter width of eight used by the default micro -// speech sample code. It uses 1KB of RAM to hold reordered weight parameters, -// converted from TFLite's NHWC format to NCHW format, and expressed as signed -// eight bit integers, rather than unsigned. Care must be taken when calling -// this not to use it for more than one node since there's only a single static -// buffer holding the weights. You should use this implementation if depthwise -// convolutions are a performance bottleneck, you have a layer that meets the -// parameter requirements, and the extra RAM usage and additional code size are -// not an issue. -static inline void DepthwiseConvOptimizedForFilterWidthEight( - TfLiteContext* context, const DepthwiseParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& filter_shape, const uint8* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, - const RuntimeShape& output_shape, uint8* output_data) { - const int stride_width = params.stride_width; - const int stride_height = params.stride_height; - const int pad_width = params.padding_values.width; - const int pad_height = params.padding_values.height; - const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; - const int output_shift = params.output_shift; - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingDim(input_shape, 0, output_shape, 0); - const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); - const int input_height = input_shape.Dims(1); - const int input_width = input_shape.Dims(2); - const int input_depth = input_shape.Dims(3); - const int filter_height = filter_shape.Dims(1); - const int filter_width = filter_shape.Dims(2); - const int output_height = output_shape.Dims(1); - const int output_width = output_shape.Dims(2); - TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); - TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); - - static int16_t reshaped_filter_data[kReshapedFilterDataSize]; - const int needed_size = - output_depth * filter_width * filter_height * input_depth; - if (needed_size > kReshapedFilterDataSize) { - TF_LITE_KERNEL_LOG( - context, - "Size too large for reshaped weight buffer (%d needed, %d available)", - needed_size, kReshapedFilterDataSize); - return; - } - - RuntimeShape reshaped_filter_shape; - reshaped_filter_shape.BuildFrom( - {1, output_depth, filter_height, filter_width}); - - // If this is the first time through, repack the weights into a cached buffer - // so that they can be accessed sequentially. - static bool is_reshaped_filter_initialized = false; - if (!is_reshaped_filter_initialized) { - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int oc = 0; oc < output_depth; ++oc) { - const uint8* current_filter = - filter_data + Offset(filter_shape, 0, filter_y, filter_x, oc); - int16_t* reshaped_filter = - reshaped_filter_data + - Offset(reshaped_filter_shape, 0, oc, filter_y, filter_x); - *reshaped_filter = - static_cast(*current_filter) + filter_offset; - } - } - } - is_reshaped_filter_initialized = true; - } - - for (int b = 0; b < batches; ++b) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int ic = 0; ic < input_depth; ++ic) { - for (int m = 0; m < depth_multiplier; m++) { - const int oc = m + ic * depth_multiplier; - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - int32 acc = 0; - int in_y_start = in_y_origin; - int filter_y_start = 0; - if (in_y_origin < 0) { - in_y_start = 0; - filter_y_start = 0 - in_y_origin; - } - int filter_y_end = filter_height; - if ((in_y_origin + filter_height) >= input_height) { - filter_y_end -= (in_y_origin + filter_height) - input_height; - } - int in_y = in_y_start; - int in_x_start = in_x_origin; - int filter_x_start = 0; - bool is_out_of_x_bounds = false; - if (in_x_origin < 0) { - in_x_start = 0; - filter_x_start = 0 - in_x_origin; - is_out_of_x_bounds = true; - } - int filter_x_end = filter_width; - if ((in_x_origin + filter_width) >= input_width) { - filter_x_end -= (in_x_origin + filter_width) - input_width; - is_out_of_x_bounds = true; - } - for (int filter_y = filter_y_start; filter_y < filter_y_end; - ++filter_y, ++in_y) { - const uint8* current_input = - input_data + Offset(input_shape, b, in_y, in_x_start, ic); - if ((filter_width == 8) && !is_out_of_x_bounds) { - int16* current_filter = - reshaped_filter_data + Offset(reshaped_filter_shape, 0, oc, - filter_y, filter_x_start); - const uint32_t input_vals0 = - *reinterpret_cast(current_input); - current_input += 4; - const int32_t filter_vals0 = - *reinterpret_cast(current_filter); - current_filter += 2; - const uint8 input_val0 = input_vals0 & 0xff; - const int16 filter_val0 = filter_vals0 & 0xffff; - acc += filter_val0 * input_val0; - const uint8 input_val1 = (input_vals0 >> 8) & 0xff; - const int16 filter_val1 = (filter_vals0 >> 16) & 0xffff; - acc += filter_val1 * input_val1; - - const int32_t filter_vals1 = - *reinterpret_cast(current_filter); - current_filter += 2; - const uint8 input_val2 = (input_vals0 >> 16) & 0xff; - const int16 filter_val2 = filter_vals1 & 0xffff; - acc += filter_val2 * input_val2; - const uint8 input_val3 = (input_vals0 >> 24) & 0xff; - const int16 filter_val3 = (filter_vals1 >> 16) & 0xffff; - acc += filter_val3 * input_val3; - - const uint32_t input_vals1 = - *reinterpret_cast(current_input); - const int32_t filter_vals2 = - *reinterpret_cast(current_filter); - current_filter += 2; - const uint8 input_val4 = input_vals1 & 0xff; - const int16 filter_val4 = filter_vals2 & 0xffff; - acc += filter_val4 * input_val4; - const uint8 input_val5 = (input_vals1 >> 8) & 0xff; - const int16 filter_val5 = (filter_vals2 >> 16) & 0xffff; - acc += filter_val5 * input_val5; - - const int32_t filter_vals3 = - *reinterpret_cast(current_filter); - const uint8 input_val6 = (input_vals1 >> 16) & 0xff; - const int16 filter_val6 = filter_vals3 & 0xffff; - acc += filter_val6 * input_val6; - const uint8 input_val7 = (input_vals1 >> 24) & 0xff; - const int16 filter_val7 = (filter_vals3 >> 16) & 0xffff; - acc += filter_val7 * input_val7; - } else { - const uint8* current_filter = - filter_data + - Offset(filter_shape, 0, filter_y, filter_x_start, oc); - for (int filter_x = filter_x_start; filter_x < filter_x_end; - ++filter_x) { - int32 input_val = *current_input; - current_input += input_depth; - int32 filter_val = *current_filter; - current_filter += output_depth; - acc += - (filter_val + filter_offset) * (input_val + input_offset); - } - } - } - if (bias_data) { - acc += bias_data[oc]; - } - acc = reference_ops::depthwise_conv::DepthwiseConvRound< - DepthwiseConvOutputRounding::kAwayFromZero>( - acc, output_multiplier, output_shift); - acc += output_offset; - acc = std::max(acc, output_activation_min); - acc = std::min(acc, output_activation_max); - output_data[Offset(output_shape, b, out_y, out_x, oc)] = - static_cast(acc); - } - } - } - } - } -} // namespace - -} // namespace - -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - - tflite::DepthwiseParams op_params; - // Padding type is ignored, but still set. - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = 1; - op_params.dilation_height_factor = 1; - op_params.depth_multiplier = params->depth_multiplier; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - - tflite::reference_ops::DepthwiseConv( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); -} - -// TODO(njeff): Optimize for int8 like we do for uint8. - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - DepthwiseParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.depth_multiplier = params->depth_multiplier; - op_params.input_offset = -input->params.zero_point; - op_params.weights_offset = 0; - op_params.output_offset = output->params.zero_point; - // TODO(b/130439627): Use calculated value for clamping. - op_params.quantized_activation_min = std::numeric_limits::min(); - op_params.quantized_activation_max = std::numeric_limits::max(); - - reference_integer_ops::DepthwiseConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); -} - -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - const int32_t input_offset = -input->params.zero_point; - const int32_t filter_offset = -filter->params.zero_point; - const int32_t output_offset = output->params.zero_point; - - tflite::DepthwiseParams op_params; - // Padding type is ignored, but still set. - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = 1; - op_params.dilation_height_factor = 1; - op_params.depth_multiplier = params->depth_multiplier; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data->output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data->output_shift; - - // Figure out if we can use the optimized path for this set of parameters. - const int filter_width = GetTensorShape(filter).Dims(2); - const int input_depth = GetTensorShape(input).Dims(3); - const int output_depth = GetTensorShape(filter).Dims(3); - const int filter_height = GetTensorShape(filter).Dims(1); - const int needed_size = - output_depth * filter_width * filter_height * input_depth; - bool use_optimized_path = false; - if ((filter_width == 8) && (input_offset == 0) && (input_depth == 1) && - (needed_size <= kReshapedFilterDataSize)) { - // FIXME(petewarden) - We need a more robust way of handling this, ideally - // with an allocation mechanism available through the context API. - // Use the address of the node as a proxy for its identity, since we need - // to ensure the weight values are consistent between calls, and there's - // no easy way to do that quickly other than relying on the identity of - // the owning node. - static TfLiteNode* initialized_node_address = node; - if (initialized_node_address == node) { - use_optimized_path = true; - } else { - static bool has_warned = false; - if (!has_warned) { - TF_LITE_KERNEL_LOG( - context, - "Multiple depthwise conv ops match optimization parameters, but " - "only the first will use the fast path, because there's only one " - "RAM cache available"); - has_warned = true; - } - } - } - if (use_optimized_path) { - DepthwiseConvOptimizedForFilterWidthEight( - context, op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); - } else { - tflite::reference_ops::DepthwiseConv( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); - } -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; - - const TfLiteType data_type = input->type; - int width = SizeOfDimension(input, 2); - int height = SizeOfDimension(input, 1); - int filter_width = SizeOfDimension(filter, 2); - int filter_height = SizeOfDimension(filter, 1); - int out_width = ComputeOutSize(params->padding, width, filter_width, - params->stride_width); - int out_height = ComputeOutSize(params->padding, height, filter_height, - params->stride_height); - OpData data; - - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - TF_LITE_ENSURE( - context, affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kDepthwiseConvQuantizedDimension]); - TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, - affine_quantization->zero_point->size); - } - - TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, - filter_width, filter_height, out_width, - out_height, data_type, &data)); - - // TODO(aselle): Consider whether float conv and quantized conv should be - // separate ops to avoid dispatch overhead here. - switch (input->type) { // Already know in/out types are same. - case kTfLiteFloat32: - EvalFloat(context, node, params, &data, input, filter, bias, output); - break; - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, - output); - break; - case kTfLiteUInt8: - EvalQuantized(context, node, params, &data, input, filter, bias, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace depthwise_conv - -TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/depthwise_conv::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/tools/ci_build/test_mbed.sh b/tensorflow/lite/micro/tools/ci_build/test_mbed.sh index a4d47009c93..fa4506fa6b8 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_mbed.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_mbed.sh @@ -49,7 +49,7 @@ fi make -f tensorflow/lite/micro/tools/make/Makefile \ TARGET=${TARGET} \ - TAGS="portable_optimized disco_f746ng" \ + TAGS="disco_f746ng" \ ${PROJECTS} readable_run tensorflow/lite/micro/tools/ci_build/install_mbed_cli.sh diff --git a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc index 7d2a0e65b97..dc7a689daed 100644 --- a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc @@ -24,9 +24,6 @@ ifeq ($(TARGET),$(filter $(TARGET),\ $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST)/$(SF_BSPS_DEST): $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST) endif - # Use the faster depthwise conv implementation. - ALL_TAGS += portable_optimized - PLATFORM_FLAGS = \ -DPART_apollo3 \ -DAM_PACKAGE_BGA \