The interpreter doesn't call it if a nullptr is presented in the registration. This can slightly reduce code size. PiperOrigin-RevId: 304284594 Change-Id: I7fa092aa9cf5142ee30982d07655971fd1626630
127 lines
5.0 KiB
C++
127 lines
5.0 KiB
C++
/******************************************************************************
|
|
* Copyright (C) 2019 Cadence Design Systems, Inc.
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining
|
|
* a copy of this software and associated documentation files (the
|
|
* "Software"), to use this Software with Cadence processor cores only and
|
|
* not with any other processors and platforms, subject to
|
|
* the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included
|
|
* in all copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
******************************************************************************/
|
|
|
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/lite/kernels/internal/reference/logistic.h"
|
|
|
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
|
#include "tensorflow/lite/c/common.h"
|
|
#include "tensorflow/lite/kernels/internal/common.h"
|
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
|
#include "tensorflow/lite/kernels/op_macros.h"
|
|
#include "xtensa_tf_micro_common.h"
|
|
|
|
namespace tflite {
|
|
namespace ops {
|
|
namespace micro {
|
|
namespace activations {
|
|
|
|
constexpr int kInputTensor = 0;
|
|
constexpr int kOutputTensor = 0;
|
|
|
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
|
|
if (input->type == kTfLiteFloat32) {
|
|
switch (output->type) {
|
|
case kTfLiteFloat32: {
|
|
int err;
|
|
const float* inp_data_ptr;
|
|
float* out_data_ptr;
|
|
const RuntimeShape& input_shape = GetTensorShape(input);
|
|
const RuntimeShape& output_shape = GetTensorShape(output);
|
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
|
|
|
inp_data_ptr = GetTensorData<float>(input);
|
|
out_data_ptr = GetTensorData<float>(output);
|
|
|
|
err = xa_nn_vec_sigmoid_f32_f32(out_data_ptr, inp_data_ptr, flat_size);
|
|
|
|
CHECK_ERR_HIFI_NNLIB_KER(err, "xa_nn_vec_sigmoid_f32_f32 failed");
|
|
return kTfLiteOk;
|
|
}
|
|
default:
|
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
|
TfLiteTypeGetName(input->type),
|
|
TfLiteTypeGetName(output->type));
|
|
return kTfLiteError;
|
|
}
|
|
} else if (input->type == kTfLiteInt8) {
|
|
switch (output->type) {
|
|
case kTfLiteInt8: {
|
|
reference_ops::Logistic(
|
|
GetTensorShape(input), GetTensorData<int8_t>(input),
|
|
input->params.scale, input->params.zero_point,
|
|
GetTensorShape(output), GetTensorData<int8_t>(output),
|
|
output->params.scale, output->params.zero_point);
|
|
return kTfLiteOk;
|
|
}
|
|
default:
|
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
|
TfLiteTypeGetName(input->type),
|
|
TfLiteTypeGetName(output->type));
|
|
return kTfLiteError;
|
|
}
|
|
} else {
|
|
// (b/141211002): Also support other data types once we have supported
|
|
// temporary tensors in TFLM.
|
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
|
TfLiteTypeGetName(input->type),
|
|
TfLiteTypeGetName(output->type));
|
|
return kTfLiteError;
|
|
}
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
} // namespace activations
|
|
|
|
TfLiteRegistration* Register_LOGISTIC() {
|
|
static TfLiteRegistration r = {/*init=*/nullptr,
|
|
/*free=*/nullptr,
|
|
/*prepare=*/nullptr,
|
|
/*invoke=*/activations::Eval,
|
|
/*profiling_string=*/nullptr,
|
|
/*builtin_code=*/0,
|
|
/*custom_name=*/nullptr,
|
|
/*version=*/0};
|
|
return &r;
|
|
}
|
|
} // namespace micro
|
|
} // namespace ops
|
|
} // namespace tflite
|