From 824045be4ead9fc94848549ac2edcd9db0585b2d Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 26 Feb 2019 17:42:13 +0530 Subject: [PATCH 1/5] Lite: Concatenation Op Refactored --- tensorflow/lite/kernels/concatenation.cc | 52 ++--------- tensorflow/lite/kernels/concatenation_test.cc | 89 ++++++++++++++++++- 2 files changed, 97 insertions(+), 44 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 76d906fa6de..b9fbcc81baf 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -32,12 +32,6 @@ namespace ops { namespace builtin { namespace concatenation { -// This file has two implementation of Concatenation. -enum KernelType { - kReference, - kGenericOptimized, -}; - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -54,7 +48,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ahentz): These are limitations of our implementation that could be // removed with a bit of effort. - TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || @@ -100,7 +93,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -140,25 +132,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (output->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, float); - } else { - TF_LITE_CONCATENATION(optimized_ops, float); - } + TF_LITE_CONCATENATION(reference_ops, float); break; case kTfLiteInt32: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int32); - } else { - TF_LITE_CONCATENATION(optimized_ops, int32); - } + TF_LITE_CONCATENATION(reference_ops, int32); break; case kTfLiteUInt8: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION_QUANTIZED(reference_ops); - } else { - TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); - } + TF_LITE_CONCATENATION_QUANTIZED(reference_ops); break; case kTfLiteInt8: { if (kernel_type == kReference) { @@ -168,16 +148,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } break; case kTfLiteInt64: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int64_t); - } else { - TF_LITE_CONCATENATION(optimized_ops, int64_t); - } + TF_LITE_CONCATENATION(reference_ops, int64_t); break; default: - context->ReportError(context, - "Only float32 and uint8 are currently supported."); + context->ReportError(context, "Type '%s' is not supported currently.", + TfLiteTypeGetName(output->type)); return kTfLiteError; } @@ -192,23 +168,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace concatenation TfLiteRegistration* Register_CONCATENATION_REF() { - static TfLiteRegistration r = { - nullptr, nullptr, concatenation::Prepare, - concatenation::Eval}; - return &r; -} - -TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { - static TfLiteRegistration r = { - nullptr, nullptr, concatenation::Prepare, - concatenation::Eval}; + static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; return &r; } TfLiteRegistration* Register_CONCATENATION() { - // TODO(ahentz): It turns out the two versions of Concatenation are almost - // identical, so we should consider removing one. - return Register_CONCATENATION_GENERIC_OPT(); + return Register_CONCATENATION_REF(); } } // namespace builtin diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index dab77d612dc..d753279b9f7 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" @@ -101,6 +101,16 @@ TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); } +TEST(ConcatenationOpTest, FiveDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2, 1, 3}}, /*axis=*/2, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); +} + TEST(ConcatenationOpTest, OneTrivialInput) { ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, /*num_inputs=*/1); @@ -265,6 +275,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) { })); } +TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits::max()}, + /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f}))); +} + +TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {1}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + QuantizedConcatenationOpModel m0_negative( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + QuantizedConcatenationOpModel m1( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + QuantizedConcatenationOpModel m1_negative( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + } // namespace } // namespace tflite From d4d9627956a14dfea0a8c51c49e5ebccd6df0870 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Mon, 11 Mar 2019 10:42:08 +0530 Subject: [PATCH 2/5] [1]Review Comment handled --- tensorflow/lite/kernels/concatenation.cc | 11 ++--- tensorflow/lite/kernels/concatenation_test.cc | 48 ++++++++----------- 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index b9fbcc81baf..5adcdb00ebc 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -48,6 +48,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ahentz): These are limitations of our implementation that could be // removed with a bit of effort. + TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || @@ -140,13 +141,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: TF_LITE_CONCATENATION_QUANTIZED(reference_ops); break; - case kTfLiteInt8: { - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int8_t); - } else { - TF_LITE_CONCATENATION(optimized_ops, int8_t); - } - } break; + case kTfLiteInt8: + TF_LITE_CONCATENATION(reference_ops, int8_t); + break; case kTfLiteInt64: TF_LITE_CONCATENATION(reference_ops, int64_t); break; diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index d753279b9f7..4cdbb56f030 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" @@ -101,16 +101,6 @@ TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); } -TEST(ConcatenationOpTest, FiveDimensionalOneInput) { - ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2, 1, 3}}, /*axis=*/2, - /*num_inputs=*/1); - m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, - 11.0f, 12.0f}); - m0.Invoke(); - EXPECT_THAT(m0.GetOutput(), - ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); -} - TEST(ConcatenationOpTest, OneTrivialInput) { ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, /*num_inputs=*/1); @@ -280,9 +270,9 @@ TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) { {TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits::max()}, /*axis=*/1, /*num_inputs=*/1); - m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); m0.Invoke(); - EXPECT_THAT(m0.GetOutput(), + EXPECT_THAT(m0.GetOutput(), ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f}))); } @@ -291,9 +281,9 @@ TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) { {TensorType_UINT8, {1}, 0, std::numeric_limits::max()}, /*axis=*/0, /*num_inputs=*/1); - m0.SetInput(0, {5.0f}); + m0.SetInput(0, {5.0f}); m0.Invoke(); - EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); } TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) { @@ -301,9 +291,9 @@ TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) { {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, /*axis=*/0, /*num_inputs=*/1); - m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); m0.Invoke(); - EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) { @@ -315,40 +305,40 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) { {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, /*axis=*/0, /*num_inputs=*/2); - m0.SetInput(0, tensor0); - m0.SetInput(1, tensor1); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); m0.Invoke(); - EXPECT_THAT(m0.GetOutput(), + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); QuantizedConcatenationOpModel m0_negative( {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, /*axis=*/-2, /*num_inputs=*/2); - m0_negative.SetInput(0, tensor0); - m0_negative.SetInput(1, tensor1); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); m0_negative.Invoke(); - EXPECT_THAT(m0_negative.GetOutput(), + EXPECT_THAT(m0_negative.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); QuantizedConcatenationOpModel m1( {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, /*axis=*/1, /*num_inputs=*/2); - m1.SetInput(0, tensor0); - m1.SetInput(1, tensor1); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); m1.Invoke(); - EXPECT_THAT(m1.GetOutput(), + EXPECT_THAT(m1.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); QuantizedConcatenationOpModel m1_negative( {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, /*axis=*/-1, /*num_inputs=*/2); - m1_negative.SetInput(0, tensor0); - m1_negative.SetInput(1, tensor1); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); m1_negative.Invoke(); - EXPECT_THAT(m1_negative.GetOutput(), + EXPECT_THAT(m1_negative.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } From 6410f05c96319ef16d1d4a4a6862e89bafec39ec Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 28 Mar 2019 20:49:48 +0530 Subject: [PATCH 3/5] Unused Macro removed --- tensorflow/lite/kernels/concatenation.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 5adcdb00ebc..ac118c7c514 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -160,8 +160,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -#undef TF_LITE_MACRO_DISPATCH - } // namespace concatenation TfLiteRegistration* Register_CONCATENATION_REF() { From 79122e21c808d1afca6be3e6c6e84fa7f2da9820 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Fri, 29 Mar 2019 11:53:22 +0530 Subject: [PATCH 4/5] [3] Review comments handled --- tensorflow/lite/kernels/concatenation.cc | 83 ++++++++++++++++-------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index ac118c7c514..4e5a2d956d1 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -32,6 +32,12 @@ namespace ops { namespace builtin { namespace concatenation { +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -94,6 +100,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -105,47 +112,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // allocate and populate these during Prepare(). // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. -#define TF_LITE_CONCATENATION(type, scalar) \ - { \ - VectorOfTensors all_inputs(*context, *node->inputs); \ - tflite::ConcatenationParams op_params; \ - op_params.axis = axis; \ - op_params.inputs_count = node->inputs->size; \ - type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \ - GetTensorShape(output), \ - GetTensorData(output)); \ - } - -#define TF_LITE_CONCATENATION_QUANTIZED(type) \ +#define TF_LITE_CONCATENATION(scalar) \ { \ - VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ tflite::ConcatenationParams op_params; \ op_params.axis = axis; \ - op_params.input_zeropoint = all_inputs.zero_point(); \ - op_params.input_scale = all_inputs.scale(); \ op_params.inputs_count = node->inputs->size; \ - op_params.output_zeropoint = output->params.zero_point; \ - op_params.output_scale = output->params.scale; \ - type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \ + if (kernel_type == kReference) { \ + reference_ops::Concatenation(op_params, all_inputs.shapes(), \ all_inputs.data(), GetTensorShape(output), \ - GetTensorData(output)); \ + GetTensorData(output)); \ + } else { \ + optimized_ops::Concatenation(op_params, all_inputs.shapes(), \ + all_inputs.data(), GetTensorShape(output), \ + GetTensorData(output)); \ + } \ + } + +#define TF_LITE_CONCATENATION_QUANTIZED \ + { \ + VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + tflite::ConcatenationParams op_params; \ + op_params.axis = axis; \ + op_params.input_zeropoint = all_inputs.zero_point(); \ + op_params.input_scale = all_inputs.scale(); \ + op_params.inputs_count = node->inputs->size; \ + op_params.output_zeropoint = output->params.zero_point; \ + op_params.output_scale = output->params.scale; \ + if (kernel_type == kReference) { \ + reference_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } else { \ + optimized_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } \ } switch (output->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - TF_LITE_CONCATENATION(reference_ops, float); + TF_LITE_CONCATENATION(float); break; case kTfLiteInt32: - TF_LITE_CONCATENATION(reference_ops, int32); + TF_LITE_CONCATENATION(int32); break; case kTfLiteUInt8: - TF_LITE_CONCATENATION_QUANTIZED(reference_ops); + TF_LITE_CONCATENATION_QUANTIZED; break; case kTfLiteInt8: - TF_LITE_CONCATENATION(reference_ops, int8_t); + TF_LITE_CONCATENATION(int8_t); break; case kTfLiteInt64: - TF_LITE_CONCATENATION(reference_ops, int64_t); + TF_LITE_CONCATENATION(int64_t); break; default: @@ -163,13 +182,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace concatenation TfLiteRegistration* Register_CONCATENATION_REF() { - static TfLiteRegistration r = {nullptr, nullptr, concatenation::Prepare, - concatenation::Eval}; + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; return &r; } TfLiteRegistration* Register_CONCATENATION() { - return Register_CONCATENATION_REF(); + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); } } // namespace builtin From c3c83b2d465a91c535bd54269ba33a2442d7d36b Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 25 Apr 2019 15:45:19 +0530 Subject: [PATCH 5/5] [4] Review comments handled --- tensorflow/lite/kernels/concatenation.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 4e5a2d956d1..f17e955ba2e 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -129,7 +129,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } \ } -#define TF_LITE_CONCATENATION_QUANTIZED \ +#define TF_LITE_CONCATENATION_QUANTIZED() \ { \ VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ tflite::ConcatenationParams op_params; \ @@ -158,7 +158,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION(int32); break; case kTfLiteUInt8: - TF_LITE_CONCATENATION_QUANTIZED; + TF_LITE_CONCATENATION_QUANTIZED(); break; case kTfLiteInt8: TF_LITE_CONCATENATION(int8_t); @@ -179,6 +179,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +#undef TF_LITE_MACRO_DISPATCH + } // namespace concatenation TfLiteRegistration* Register_CONCATENATION_REF() {