Replace the MACRO in sub with a templated function.

PiperOrigin-RevId: 308368792
Change-Id: Ic14ace5afa30a5d32d99ec72144995666a6a8f6f
This commit is contained in:
Mirko Visontai 2020-04-24 19:31:51 -07:00 committed by TensorFlower Gardener
parent 4440c68d02
commit 3b9d8dd789

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/sub.h"
#include <limits> #include <limits>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
@ -219,51 +221,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size); return context->ResizeTensor(context, output, output_size);
} }
template <KernelType kernel_type, typename data_type>
void EvalSubImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data,
const TfLiteTensor* input1, const TfLiteTensor* input2,
bool requires_broadcast, TfLiteTensor* output) {
data_type output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params);
switch (kernel_type) {
case kReference:
if (requires_broadcast) {
reference_ops::BroadcastSubSlow(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
} else {
reference_ops::SubWithActivation(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
}
break;
case kGenericOptimized:
case kNeonOptimized:
if (requires_broadcast) {
optimized_ops::BroadcastSubSlow(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
} else {
optimized_ops::SubWithActivation(
op_params, GetTensorShape(input1), GetTensorData<data_type>(input1),
GetTensorShape(input2), GetTensorData<data_type>(input2),
GetTensorShape(output), GetTensorData<data_type>(output));
}
break;
}
}
template <KernelType kernel_type> template <KernelType kernel_type>
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpData* data, const TfLiteTensor* input1, const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) { const TfLiteTensor* input2, TfLiteTensor* output) {
#define TF_LITE_SUB(type, opname, data_type) \ const bool requires_broadcast = data->requires_broadcast;
data_type output_activation_min, output_activation_max; \ switch (output->type) {
CalculateActivationRange(params->activation, &output_activation_min, \ case kTfLiteInt32:
&output_activation_max); \ EvalSubImpl<kernel_type, int32_t>(context, node, params, data, input1,
tflite::ArithmeticParams op_params; \ input2, requires_broadcast, output);
SetActivationParams(output_activation_min, output_activation_max, \ break;
&op_params); \ case kTfLiteFloat32:
type::opname(op_params, GetTensorShape(input1), \ EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
GetTensorData<data_type>(input1), GetTensorShape(input2), \ input2, requires_broadcast, output);
GetTensorData<data_type>(input2), GetTensorShape(output), \ break;
GetTensorData<data_type>(output)) default:
if (output->type == kTfLiteInt32) { TF_LITE_KERNEL_LOG(context, "output type %d is not supported.",
if (kernel_type == kReference) { output->type);
if (data->requires_broadcast) {
TF_LITE_SUB(reference_ops, BroadcastSubSlow, int32_t);
} else {
TF_LITE_SUB(reference_ops, SubWithActivation, int32_t);
} }
} else {
if (data->requires_broadcast) {
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, int32_t);
} else {
TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t);
}
}
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
TF_LITE_SUB(reference_ops, BroadcastSubSlow, float);
} else {
TF_LITE_SUB(reference_ops, SubWithActivation, float);
}
} else {
if (data->requires_broadcast) {
TF_LITE_SUB(optimized_ops, BroadcastSubSlow, float);
} else {
TF_LITE_SUB(optimized_ops, SubWithActivation, float);
}
}
}
#undef TF_LITE_SUB
} }
template <KernelType kernel_type> template <KernelType kernel_type>