Replace the MACRO in sub with a templated function.

PiperOrigin-RevId: 308368792
Change-Id: Ic14ace5afa30a5d32d99ec72144995666a6a8f6f
This commit is contained in:
Mirko Visontai 2020-04-24 19:31:51 -07:00 committed by TensorFlower Gardener
parent 4440c68d02
commit 3b9d8dd789

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/sub.h"
#include <limits>
#include "tensorflow/lite/c/builtin_op_data.h"
@ -219,51 +221,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type, typename data_type>
void EvalSubImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data,
const TfLiteTensor* input1, const TfLiteTensor* input2,
bool requires_broadcast, TfLiteTensor* output) {
data_type output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params);
switch (kernel_type) {
case kReference:
if (requires_broadcast) {
reference_ops::BroadcastSubSlow(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
} else {
reference_ops::SubWithActivation(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
}
break;
case kGenericOptimized:
case kNeonOptimized:
if (requires_broadcast) {
optimized_ops::BroadcastSubSlow(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
} else {
optimized_ops::SubWithActivation(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
}
break;
}
}
template <KernelType kernel_type>
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
#define TF_LITE_SUB(type, opname, data_type) \
data_type output_activation_min, output_activation_max; \
CalculateActivationRange(params->activation, &output_activation_min, \
&output_activation_max); \
tflite::ArithmeticParams op_params; \
SetActivationParams(output_activation_min, output_activation_max, \
&op_params); \
type::opname(op_params, GetTensorShape(input1), \
GetTensorData<data_type>(input1), GetTensorShape(input2), \
GetTensorData<data_type>(input2), GetTensorShape(output), \
GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
TF_LITE_SUB(reference_ops, BroadcastSubSlow, int32_t);
} else {
TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
}
} else {
if (data->requires_broadcast) {
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int32_t);
} else {
TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
}
}
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
TF_LITE_SUB(reference_ops, BroadcastSubSlow, float);
} else {
TF_LITE_SUB(reference_ops, SubWithActivation, float);
}
} else {
if (data->requires_broadcast) {
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, float);
} else {
TF_LITE_SUB(optimized_ops, SubWithActivation, float);
}
}
const bool requires_broadcast = data->requires_broadcast;
switch (output->type) {
case kTfLiteInt32:
EvalSubImpl<kernel_type, int32_t>(context, node, params, data, input1,
input2, requires_broadcast, output);
break;
case kTfLiteFloat32:
EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
input2, requires_broadcast, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "output type %d is not supported.",
output->type);
}
#undef TF_LITE_SUB
}
template <KernelType kernel_type>