Replace the MACRO in sub with a templated function.
PiperOrigin-RevId: 308368792 Change-Id: Ic14ace5afa30a5d32d99ec72144995666a6a8f6f
This commit is contained in:
parent
4440c68d02
commit
3b9d8dd789
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/sub.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
@ -219,51 +221,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type, typename data_type>
|
||||
void EvalSubImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteSubParams* params, const OpData* data,
|
||||
const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
bool requires_broadcast, TfLiteTensor* output) {
|
||||
data_type output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
tflite::ArithmeticParams op_params;
|
||||
SetActivationParams(output_activation_min, output_activation_max, &op_params);
|
||||
|
||||
switch (kernel_type) {
|
||||
case kReference:
|
||||
if (requires_broadcast) {
|
||||
reference_ops::BroadcastSubSlow(
|
||||
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
|
||||
GetTensorShape(input2), GetTensorData<data_type>(input2),
|
||||
GetTensorShape(output), GetTensorData<data_type>(output));
|
||||
} else {
|
||||
reference_ops::SubWithActivation(
|
||||
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
|
||||
GetTensorShape(input2), GetTensorData<data_type>(input2),
|
||||
GetTensorShape(output), GetTensorData<data_type>(output));
|
||||
}
|
||||
break;
|
||||
case kGenericOptimized:
|
||||
case kNeonOptimized:
|
||||
if (requires_broadcast) {
|
||||
optimized_ops::BroadcastSubSlow(
|
||||
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
|
||||
GetTensorShape(input2), GetTensorData<data_type>(input2),
|
||||
GetTensorShape(output), GetTensorData<data_type>(output));
|
||||
} else {
|
||||
optimized_ops::SubWithActivation(
|
||||
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
|
||||
GetTensorShape(input2), GetTensorData<data_type>(input2),
|
||||
GetTensorShape(output), GetTensorData<data_type>(output));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
|
||||
const OpData* data, const TfLiteTensor* input1,
|
||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||
#define TF_LITE_SUB(type, opname, data_type) \
|
||||
data_type output_activation_min, output_activation_max; \
|
||||
CalculateActivationRange(params->activation, &output_activation_min, \
|
||||
&output_activation_max); \
|
||||
tflite::ArithmeticParams op_params; \
|
||||
SetActivationParams(output_activation_min, output_activation_max, \
|
||||
&op_params); \
|
||||
type::opname(op_params, GetTensorShape(input1), \
|
||||
GetTensorData<data_type>(input1), GetTensorShape(input2), \
|
||||
GetTensorData<data_type>(input2), GetTensorShape(output), \
|
||||
GetTensorData<data_type>(output))
|
||||
if (output->type == kTfLiteInt32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (data->requires_broadcast) {
|
||||
TF_LITE_SUB(reference_ops, BroadcastSubSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
|
||||
}
|
||||
} else {
|
||||
if (data->requires_broadcast) {
|
||||
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int32_t);
|
||||
} else {
|
||||
TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
|
||||
}
|
||||
}
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
if (kernel_type == kReference) {
|
||||
if (data->requires_broadcast) {
|
||||
TF_LITE_SUB(reference_ops, BroadcastSubSlow, float);
|
||||
} else {
|
||||
TF_LITE_SUB(reference_ops, SubWithActivation, float);
|
||||
}
|
||||
} else {
|
||||
if (data->requires_broadcast) {
|
||||
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, float);
|
||||
} else {
|
||||
TF_LITE_SUB(optimized_ops, SubWithActivation, float);
|
||||
}
|
||||
}
|
||||
const bool requires_broadcast = data->requires_broadcast;
|
||||
switch (output->type) {
|
||||
case kTfLiteInt32:
|
||||
EvalSubImpl<kernel_type, int32_t>(context, node, params, data, input1,
|
||||
input2, requires_broadcast, output);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
|
||||
input2, requires_broadcast, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "output type %d is not supported.",
|
||||
output->type);
|
||||
}
|
||||
#undef TF_LITE_SUB
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
|
Loading…
x
Reference in New Issue
Block a user