diff --git a/tensorflow/lite/experimental/micro/kernels/BUILD b/tensorflow/lite/experimental/micro/kernels/BUILD index d9d271e922a..ae71b1f3e69 100644 --- a/tensorflow/lite/experimental/micro/kernels/BUILD +++ b/tensorflow/lite/experimental/micro/kernels/BUILD @@ -28,6 +28,7 @@ cc_library( "logical.cc", "logistic.cc", "maximum_minimum.cc", + "mul.cc", "neg.cc", "pack.cc", "pooling.cc", @@ -89,6 +90,7 @@ cc_library( "logical.cc", "logistic.cc", "maximum_minimum.cc", + "mul.cc", "neg.cc", "pack.cc", "pooling.cc", @@ -321,6 +323,19 @@ tflite_micro_cc_test( ], ) +tflite_micro_cc_test( + name = "mul_test", + srcs = [ + "mul_test.cc", + ], + deps = [ + ":all_ops_resolver", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/experimental/micro:micro_framework", + "//tensorflow/lite/experimental/micro/testing:micro_test", + ], +) + tflite_micro_cc_test( name = "arg_min_max_test", srcs = [ diff --git a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc index be38e8fe21c..7a0f77908a0 100644 --- a/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc +++ b/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc @@ -62,6 +62,7 @@ AllOpsResolver::AllOpsResolver() { AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL()); AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(), 1, 4); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(), 1, 4); AddBuiltin(BuiltinOperator_RELU, Register_RELU()); diff --git a/tensorflow/lite/experimental/micro/kernels/micro_ops.h b/tensorflow/lite/experimental/micro/kernels/micro_ops.h index a4518db6cd3..1fc45171461 100644 --- a/tensorflow/lite/experimental/micro/kernels/micro_ops.h +++ b/tensorflow/lite/experimental/micro/kernels/micro_ops.h @@ -54,6 +54,7 @@ TfLiteRegistration* Register_LOGISTIC(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MAX_POOL_2D(); TfLiteRegistration* Register_MINIMUM(); +TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_NOT_EQUAL(); TfLiteRegistration* Register_PACK(); diff --git a/tensorflow/lite/experimental/micro/kernels/mul.cc b/tensorflow/lite/experimental/micro/kernels/mul.cc new file mode 100644 index 00000000000..cbd6251b0d1 --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/mul.cc @@ -0,0 +1,181 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/mul.h" + +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" +#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace mul { + +constexpr int kInput1Tensor = 0; +constexpr int kInput2Tensor = 1; +constexpr int kOutputTensor = 0; + +struct OpData { + int32_t output_activation_min; + int32_t output_activation_max; + + int32_t output_multiplier; + int output_shift; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, OpData* data) { + const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor); + const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } else if (output->type == kTfLiteInt8) { + CalculateActivationRangeInt8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, + &data->output_shift); + + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, OpData* data, + const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output) { + if (output->type == kTfLiteInt8 || output->type == kTfLiteUInt8) { + tflite::ArithmeticParams op_params; + SetActivationParams(data->output_activation_min, + data->output_activation_max, &op_params); + op_params.input1_offset = -input1->params.zero_point; + op_params.input2_offset = -input2->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + bool need_broadcast = reference_ops::ProcessBroadcastShapes( + GetTensorShape(input1), GetTensorShape(input2), &op_params); + +#define TF_LITE_MUL(type, opname, dtype) \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)); + + if (output->type == kTfLiteInt8) { + if (need_broadcast) { + TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t); + } else { + TF_LITE_MUL(reference_integer_ops, Mul, int8_t); + } + } else if (output->type == kTfLiteUInt8) { + if (need_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t); + } else { + TF_LITE_MUL(reference_ops, Mul, uint8_t); + } + } +#undef TF_LITE_MUL + } +} + +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, OpData* data, + const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + bool need_broadcast = reference_ops::ProcessBroadcastShapes( + GetTensorShape(input1), GetTensorShape(input2), &op_params); +#define TF_LITE_MUL(opname) \ + reference_ops::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)); + + if (need_broadcast) { + TF_LITE_MUL(BroadcastMul4DSlow); + } else { + TF_LITE_MUL(Mul); + } +#undef TF_LITE_MUL +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData data; + + const TfLiteTensor* input1 = GetInput(context, node, kInput1Tensor); + const TfLiteTensor* input2 = GetInput(context, node, kInput2Tensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + CalculateOpData(context, node, params, &data); + + switch (input1->type) { + case kTfLiteUInt8: + case kTfLiteInt8: + EvalQuantized(context, node, params, &data, input1, input2, output); + break; + case kTfLiteFloat32: + EvalFloat(context, node, params, &data, input1, input2, output); + break; + default: + context->ReportError(context, "Type %d not currently supported.", + input1->type); + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace mul + +TfLiteRegistration* Register_MUL() { + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/experimental/micro/kernels/mul_test.cc b/tensorflow/lite/experimental/micro/kernels/mul_test.cc new file mode 100644 index 00000000000..2498bc4801b --- /dev/null +++ b/tensorflow/lite/experimental/micro/kernels/mul_test.cc @@ -0,0 +1,423 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/lite/experimental/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +void TestMulFloat(std::initializer_list input1_dims_data, + std::initializer_list input1_data, + std::initializer_list input2_dims_data, + std::initializer_list input2_data, + std::initializer_list output_dims_data, + std::initializer_list expected_output_data, + float* output_data, TfLiteFusedActivation activation) { + TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data); + TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + ::tflite::ops::micro::AllOpsResolver resolver; + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input1_data, input1_dims, "input1_tensor"), + CreateFloatTensor(input2_data, input2_dims, "input2_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_MUL, 1); + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteMulParams builtin_data = { + .activation = activation, + }; + + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + + for (int i = 0; i < output_dims_count; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +template +void TestMulQuantized(std::initializer_list input1_dims_data, + std::initializer_list input1_data, + std::initializer_list input2_dims_data, + std::initializer_list input2_data, + const float input_min, const float input_max, + std::initializer_list output_dims_data, + const float output_min, const float output_max, + std::initializer_list expected_output_data, + T* output_data, TfLiteFusedActivation activation, + int error_tolerance) { + TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data); + TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + ::tflite::ops::micro::AllOpsResolver resolver; + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input1_data, input1_dims, "input1_tensor", + input_min, input_max), + CreateQuantizedTensor(input2_data, input2_dims, "input2_tensor", + input_min, input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_MUL, 1); + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteMulParams builtin_data = { + .activation = activation, + }; + + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + + for (int i = 0; i < output_dims_count; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + error_tolerance); + } +} + +} // namespace + +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(Int8NoActivation) { + using tflite::testing::F2QS; + const float input_min = -1; + const float input_max = 1; + const float output_min = -1; + const float output_max = 1; + + int8_t output_data[4]; + tflite::testing::TestMulQuantized({4, 1, 2, 2, 1}, // input1 dims + { + F2QS(-0.8, input_min, input_max), + F2QS(0.2, input_min, input_max), + F2QS(0.9, input_min, input_max), + F2QS(0.7, input_min, input_max), + }, // input1 data + {4, 1, 2, 2, 1}, // input2 dims + { + F2QS(0.6, input_min, input_max), + F2QS(0.4, input_min, input_max), + F2QS(0.9, input_min, input_max), + F2QS(0.8, input_min, input_max), + }, // input2 data + input_min, input_max, + {4, 1, 2, 2, 1}, // output dims + output_min, output_max, + { + F2QS(-0.48, output_min, output_max), + F2QS(0.08, output_min, output_max), + F2QS(0.81, output_min, output_max), + F2QS(0.56, output_min, output_max), + }, // expected output data + output_data, kTfLiteActNone, 1); +} + +TF_LITE_MICRO_TEST(Int8NoActivationLargeMultiplier) { + using tflite::testing::F2QS; + const float input_min = -100; + const float input_max = 100; + const float output_min = -10; + const float output_max = 10; + + int8_t output_data[4]; + tflite::testing::TestMulQuantized( + {4, 1, 2, 2, 1}, + { + F2QS(-4, input_min, input_max), + F2QS(2, input_min, input_max), + F2QS(3, input_min, input_max), + F2QS(1, input_min, input_max), + }, + {4, 1, 2, 2, 1}, + { + /* F2QS(-1, input_min, input_max), F2QS(-3, input_min, input_max), */ + F2QS(-1, input_min, input_max), + F2QS(-3, input_min, input_max), + F2QS(4, input_min, input_max), + F2QS(2, input_min, input_max), + }, + input_min, input_max, {4, 1, 2, 2, 1}, output_min, output_max, + { + F2QS(4, output_min, output_max), + F2QS(-6, output_min, output_max), + F2QS(12, output_min, output_max), + F2QS(2, output_min, output_max), + }, + // In Tensorflow Lite, this test have a max allowed error of 1.4f. + // A difference of 1.4 in floating points corresponds to 18 quantized + // for the output min/max [-10, 10]. + output_data, kTfLiteActNone, 18); +} + +TF_LITE_MICRO_TEST(Int8NoActivationBroadcast) { + using tflite::testing::F2QS; + const float input_min = -3.0; + const float input_max = 3.0; + const float output_min = -3.0; + const float output_max = 3.0; + + int8_t output_data[6]; + tflite::testing::TestMulQuantized({4, 1, 3, 1, 2}, // input1 shape + { + F2QS(-2.0, input_min, input_max), + F2QS(0.2, input_min, input_max), + F2QS(0.7, input_min, input_max), + F2QS(0.8, input_min, input_max), + F2QS(1.1, input_min, input_max), + F2QS(2.0, input_min, input_max), + }, // input1 data + {1, 1}, // input2 shape + { + F2QS(0.1, input_min, input_max), + }, // input2 data + input_min, input_max, + {4, 1, 3, 1, 2}, // output shape + output_min, output_max, + { + F2QS(-0.2, output_min, output_max), + F2QS(0.02, output_min, output_max), + F2QS(0.07, output_min, output_max), + F2QS(0.08, output_min, output_max), + F2QS(0.11, output_min, output_max), + F2QS(0.2, output_min, output_max), + }, // expected output data + output_data, kTfLiteActNone, 1); +} + +TF_LITE_MICRO_TEST(UInt8NoActivation) { + using tflite::testing::F2Q; + const float input_min = -1; + const float input_max = 1; + const float output_min = -1; + const float output_max = 1; + + uint8_t output_data[4]; + tflite::testing::TestMulQuantized({4, 1, 2, 2, 1}, // input1 dims + { + F2Q(-0.8, input_min, input_max), + F2Q(0.2, input_min, input_max), + F2Q(0.9, input_min, input_max), + F2Q(0.7, input_min, input_max), + }, // input1 data + {4, 1, 2, 2, 1}, // input2 dims + { + F2Q(0.6, input_min, input_max), + F2Q(0.4, input_min, input_max), + F2Q(0.9, input_min, input_max), + F2Q(0.8, input_min, input_max), + }, // input2 data + input_min, input_max, + {4, 1, 2, 2, 1}, // output dims + output_min, output_max, + { + F2Q(-0.48, output_min, output_max), + F2Q(0.08, output_min, output_max), + F2Q(0.81, output_min, output_max), + F2Q(0.56, output_min, output_max), + }, // expected output data + output_data, kTfLiteActNone, 1); +} + +TF_LITE_MICRO_TEST(UInt8NoActivationLargeMultiplier) { + using tflite::testing::F2Q; + const float input_min = -100; + const float input_max = 100; + const float output_min = -10; + const float output_max = 10; + + uint8_t output_data[4]; + tflite::testing::TestMulQuantized( + {4, 1, 2, 2, 1}, + { + F2Q(-4, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(1, input_min, input_max), + }, + {4, 1, 2, 2, 1}, + { + F2Q(-1, input_min, input_max), + F2Q(-3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(2, input_min, input_max), + }, + input_min, input_max, {4, 1, 2, 2, 1}, output_min, output_max, + { + F2Q(4, output_min, output_max), + F2Q(-6, output_min, output_max), + F2Q(12, output_min, output_max), + F2Q(2, output_min, output_max), + }, + // In Tensorflow Lite, this test have a max allowed error of 1.4f. + // A difference of 1.4 in floating points corresponds to 18 quantized + // for the output min/max [-10, 10]. + output_data, kTfLiteActNone, 18); +} + +TF_LITE_MICRO_TEST(UInt8NoActivationBroadcast) { + using tflite::testing::F2Q; + const float input_min = -3.0; + const float input_max = 3.0; + const float output_min = -3.0; + const float output_max = 3.0; + + uint8_t output_data[6]; + tflite::testing::TestMulQuantized({4, 1, 3, 1, 2}, // input1 shape + { + F2Q(-2.0, input_min, input_max), + F2Q(0.2, input_min, input_max), + F2Q(0.7, input_min, input_max), + F2Q(0.8, input_min, input_max), + F2Q(1.1, input_min, input_max), + F2Q(2.0, input_min, input_max), + }, // input1 data + {1, 1}, // input2 shape + { + F2Q(0.1, input_min, input_max), + }, // input2 data + input_min, input_max, + {4, 1, 3, 1, 2}, // output shape + output_min, output_max, + { + F2Q(-0.2, output_min, output_max), + F2Q(0.02, output_min, output_max), + F2Q(0.07, output_min, output_max), + F2Q(0.08, output_min, output_max), + F2Q(0.11, output_min, output_max), + F2Q(0.2, output_min, output_max), + }, // expected output data + output_data, kTfLiteActNone, 1); +} + +TF_LITE_MICRO_TEST(FloatNoActivation) { + float output_data[4]; + tflite::testing::TestMulFloat( + {4, 1, 2, 2, 1}, // input1 shape + {-2.0, 0.2, 0.7, 0.8}, // input1 data + {4, 1, 2, 2, 1}, // input2 shape + {0.1, 0.2, 0.3, 0.5}, // input2 data + {4, 1, 2, 2, 1}, // output shape + {-0.2, 0.04, 0.21, 0.4}, // expected output data + output_data, kTfLiteActNone); +} + +TF_LITE_MICRO_TEST(FloatRelu) { + float output_data[4]; + tflite::testing::TestMulFloat( + {4, 1, 2, 2, 1}, // input1 shape + {-2.0, 0.2, 0.7, 0.8}, // input1 data + {4, 1, 2, 2, 1}, // input2 shape + {0.1, 0.2, 0.3, 0.5}, // input2 data + {4, 1, 2, 2, 1}, // output shape + {-0.2, 0.04, 0.21, 0.4}, // expected output data + output_data, kTfLiteActRelu1); +} + +TF_LITE_MICRO_TEST(FloatBroadcast) { + float output_data[6]; + tflite::testing::TestMulFloat( + {4, 1, 3, 1, 2}, // input1 shape + {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}, // input1 data + {1, 1}, // input2 shape + {0.1}, // input2 data + {4, 1, 3, 1, 2}, // output shape + {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, // expected output data + output_data, kTfLiteActNone); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile index 9fbbaf507d7..0822b29b6b8 100644 --- a/tensorflow/lite/experimental/micro/tools/make/Makefile +++ b/tensorflow/lite/experimental/micro/tools/make/Makefile @@ -133,7 +133,9 @@ tensorflow/lite/kernels/internal/reference/integer_ops/add.h \ tensorflow/lite/kernels/internal/reference/integer_ops/conv.h \ tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \ tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \ +tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \ tensorflow/lite/kernels/internal/reference/maximum_minimum.h \ +tensorflow/lite/kernels/internal/reference/mul.h \ tensorflow/lite/kernels/internal/reference/neg.h \ tensorflow/lite/kernels/internal/reference/pooling.h \ tensorflow/lite/kernels/internal/reference/prelu.h \ diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 50b93666809..57cec39c2ac 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -415,6 +415,7 @@ cc_library( "reference/integer_ops/tanh.h", "reference/logistic.h", "reference/maximum_minimum.h", + "reference/mul.h", "reference/neg.h", "reference/non_max_suppression.h", "reference/pooling.h", @@ -472,6 +473,7 @@ cc_library( "reference/legacy_reference_ops.h", "reference/logistic.h", "reference/maximum_minimum.h", + "reference/mul.h", "reference/neg.h", "reference/pooling.h", "reference/prelu.h", diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h index f054d07f9c6..3eece043998 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mul.h @@ -16,7 +16,6 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_ #include "fixedpoint/fixedpoint.h" -#include "profiling/instrumentation.h" #include "tensorflow/lite/kernels/internal/common.h" namespace tflite { @@ -46,7 +45,6 @@ inline void Mul(const ArithmeticParams& params, const RuntimeShape& output_shape, int8_t* output_data) { TFLITE_DCHECK_LE(params.quantized_activation_min, params.quantized_activation_max); - gemmlowp::ScopedProfilingLabel label("Mul/8bit"); const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape); @@ -58,7 +56,6 @@ inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data, const RuntimeShape& input2_shape, const int16* input2_data, const RuntimeShape& output_shape, int8_t* output_data) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16Int8"); int32 output_offset = params.output_offset; int32 output_activation_min = params.quantized_activation_min; int32 output_activation_max = params.quantized_activation_max; @@ -90,8 +87,6 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params, const int8_t* input2_data, const RuntimeShape& output_shape, int8_t* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit"); - NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; // The input shapes are extended as part of NdArrayDesc initialization. diff --git a/tensorflow/lite/kernels/internal/reference/mul.h b/tensorflow/lite/kernels/internal/reference/mul.h new file mode 100644 index 00000000000..54e947db9ca --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/mul.h @@ -0,0 +1,166 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { + +namespace reference_ops { + +// Element-wise mul that can often be used for inner loop of broadcast Mul as +// well as the non-broadcast Mul. +inline void MulElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplier(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} + +template +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + T output_activation_min; + T output_activation_max; + GetActivationParams(params, &output_activation_min, &output_activation_max); + + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] * input2_data[i], output_activation_min, + output_activation_max); + } +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + const RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + const int32 input1_val = + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32 input2_val = + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplier(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = std::min( + params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[Offset(extended_output_shape, b, y, x, c)] = + static_cast(clamped_output); + } + } + } + } +} + +template +void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const T* input1_data, + const RuntimeShape& unextended_input2_shape, + const T* input2_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + T output_activation_min; + T output_activation_max; + GetActivationParams(params, &output_activation_min, &output_activation_max); + + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + output_data[Offset(output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] * + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + output_activation_min, output_activation_max); + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_MUL_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 8a0ab56a689..f67e167ecb7 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/logistic.h" #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h" +#include "tensorflow/lite/kernels/internal/reference/mul.h" #include "tensorflow/lite/kernels/internal/reference/neg.h" #include "tensorflow/lite/kernels/internal/reference/pooling.h" #include "tensorflow/lite/kernels/internal/reference/prelu.h" @@ -358,24 +359,6 @@ inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs, } } -template -inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, T* output_data) { - T output_activation_min; - T output_activation_max; - GetActivationParams(params, &output_activation_min, &output_activation_max); - - const int flat_size = - MatchingElementsSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] * input2_data[i], output_activation_min, - output_activation_max); - } -} - // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -384,89 +367,6 @@ inline void Mul(const ArithmeticParams& params, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template -void BroadcastMul4DSlow(const ArithmeticParams& params, - const RuntimeShape& unextended_input1_shape, - const T* input1_data, - const RuntimeShape& unextended_input2_shape, - const T* input2_data, - const RuntimeShape& unextended_output_shape, - T* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow"); - T output_activation_min; - T output_activation_max; - GetActivationParams(params, &output_activation_min, &output_activation_max); - - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, b, y, x, c)] * - input2_data[SubscriptToIndex(desc2, b, y, x, c)], - output_activation_min, output_activation_max); - } - } - } - } -} - -// Element-wise mul that can often be used for inner loop of broadcast Mul as -// well as the non-broadcast Mul. -inline void MulElementwise(int size, const ArithmeticParams& params, - const uint8* input1_data, const uint8* input2_data, - uint8* output_data) { - for (int i = 0; i < size; ++i) { - const int32 input1_val = params.input1_offset + input1_data[i]; - const int32 input2_val = params.input2_offset + input2_data[i]; - const int32 unclamped_result = - params.output_offset + - MultiplyByQuantizedMultiplier(input1_val * input2_val, - params.output_multiplier, - params.output_shift); - const int32 clamped_output = - std::min(params.quantized_activation_max, - std::max(params.quantized_activation_min, unclamped_result)); - output_data[i] = static_cast(clamped_output); - } -} - -inline void Mul(const ArithmeticParams& params, - const RuntimeShape& input1_shape, const uint8* input1_data, - const RuntimeShape& input2_shape, const uint8* input2_data, - const RuntimeShape& output_shape, uint8* output_data) { - TFLITE_DCHECK_LE(params.quantized_activation_min, - params.quantized_activation_max); - gemmlowp::ScopedProfilingLabel label("Mul/8bit"); - const int flat_size = - MatchingElementsSize(input1_shape, input2_shape, output_shape); - - MulElementwise(flat_size, params, input1_data, input2_data, output_data); -} - inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, const RuntimeShape& unswitched_input1_shape, const uint8* unswitched_input1_data, @@ -519,48 +419,6 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, } } -inline void BroadcastMul4DSlow(const ArithmeticParams& params, - const RuntimeShape& input1_shape, - const uint8* input1_data, - const RuntimeShape& input2_shape, - const uint8* input2_data, - const RuntimeShape& output_shape, - uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - // The input shapes are extended as part of NdArrayDesc initialization. - NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, - &desc2); - const RuntimeShape extended_output_shape = - RuntimeShape::ExtendedShape(4, output_shape); - - for (int b = 0; b < extended_output_shape.Dims(0); ++b) { - for (int y = 0; y < extended_output_shape.Dims(1); ++y) { - for (int x = 0; x < extended_output_shape.Dims(2); ++x) { - for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - const int32 input1_val = - params.input1_offset + - input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32 input2_val = - params.input2_offset + - input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - const int32 unclamped_result = - params.output_offset + - MultiplyByQuantizedMultiplier(input1_val * input2_val, - params.output_multiplier, - params.output_shift); - const int32 clamped_output = std::min( - params.quantized_activation_max, - std::max(params.quantized_activation_min, unclamped_result)); - output_data[Offset(extended_output_shape, b, y, x, c)] = - static_cast(clamped_output); - } - } - } - } -} inline void Mul(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int16* input1_data,