diff --git a/tensorflow/lite/kernels/sub.cc b/tensorflow/lite/kernels/sub.cc index 55a91acf1b5..1b04143d222 100644 --- a/tensorflow/lite/kernels/sub.cc +++ b/tensorflow/lite/kernels/sub.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/sub.h" + #include #include "tensorflow/lite/c/builtin_op_data.h" @@ -219,51 +221,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +template +void EvalSubImpl(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, const OpData* data, + const TfLiteTensor* input1, const TfLiteTensor* input2, + bool requires_broadcast, TfLiteTensor* output) { + data_type output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + switch (kernel_type) { + case kReference: + if (requires_broadcast) { + reference_ops::BroadcastSubSlow( + op_params, GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output)); + } else { + reference_ops::SubWithActivation( + op_params, GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output)); + } + break; + case kGenericOptimized: + case kNeonOptimized: + if (requires_broadcast) { + optimized_ops::BroadcastSubSlow( + op_params, GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output)); + } else { + optimized_ops::SubWithActivation( + op_params, GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output)); + } + break; + } +} + template void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { -#define TF_LITE_SUB(type, opname, data_type) \ - data_type output_activation_min, output_activation_max; \ - CalculateActivationRange(params->activation, &output_activation_min, \ - &output_activation_max); \ - tflite::ArithmeticParams op_params; \ - SetActivationParams(output_activation_min, output_activation_max, \ - &op_params); \ - type::opname(op_params, GetTensorShape(input1), \ - GetTensorData(input1), GetTensorShape(input2), \ - GetTensorData(input2), GetTensorShape(output), \ - GetTensorData(output)) - if (output->type == kTfLiteInt32) { - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSubSlow, int32_t); - } else { - TF_LITE_SUB(reference_ops, SubWithActivation, int32_t); - } - } else { - if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int32_t); - } else { - TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t); - } - } - } else if (output->type == kTfLiteFloat32) { - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSubSlow, float); - } else { - TF_LITE_SUB(reference_ops, SubWithActivation, float); - } - } else { - if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSubSlow, float); - } else { - TF_LITE_SUB(optimized_ops, SubWithActivation, float); - } - } + const bool requires_broadcast = data->requires_broadcast; + switch (output->type) { + case kTfLiteInt32: + EvalSubImpl(context, node, params, data, input1, + input2, requires_broadcast, output); + break; + case kTfLiteFloat32: + EvalSubImpl(context, node, params, data, input1, + input2, requires_broadcast, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "output type %d is not supported.", + output->type); } -#undef TF_LITE_SUB } template