diff --git a/tensorflow/lite/kernels/concatenation.cc b/tensorflow/lite/kernels/concatenation.cc index 3ef4404a8c6..870a10abfc4 100644 --- a/tensorflow/lite/kernels/concatenation.cc +++ b/tensorflow/lite/kernels/concatenation.cc @@ -111,72 +111,64 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // allocate and populate these during Prepare(). // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. -#define TF_LITE_CONCATENATION(type, scalar) \ - { \ - VectorOfTensors all_inputs(*context, *node->inputs); \ - tflite::ConcatenationParams op_params; \ - op_params.axis = axis; \ - op_params.inputs_count = node->inputs->size; \ - type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \ - GetTensorShape(output), \ - GetTensorData(output)); \ - } - -#define TF_LITE_CONCATENATION_QUANTIZED(type) \ +#define TF_LITE_CONCATENATION(scalar) \ { \ - VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ tflite::ConcatenationParams op_params; \ op_params.axis = axis; \ - op_params.input_zeropoint = all_inputs.zero_point(); \ - op_params.input_scale = all_inputs.scale(); \ op_params.inputs_count = node->inputs->size; \ - op_params.output_zeropoint = output->params.zero_point; \ - op_params.output_scale = output->params.scale; \ - type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \ + if (kernel_type == kReference) { \ + reference_ops::Concatenation(op_params, all_inputs.shapes(), \ all_inputs.data(), GetTensorShape(output), \ - GetTensorData(output)); \ + GetTensorData(output)); \ + } else { \ + optimized_ops::Concatenation(op_params, all_inputs.shapes(), \ + all_inputs.data(), GetTensorShape(output), \ + GetTensorData(output)); \ + } \ + } + +#define TF_LITE_CONCATENATION_QUANTIZED() \ + { \ + VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ + tflite::ConcatenationParams op_params; \ + op_params.axis = axis; \ + op_params.input_zeropoint = all_inputs.zero_point(); \ + op_params.input_scale = all_inputs.scale(); \ + op_params.inputs_count = node->inputs->size; \ + op_params.output_zeropoint = output->params.zero_point; \ + op_params.output_scale = output->params.scale; \ + if (kernel_type == kReference) { \ + reference_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } else { \ + optimized_ops::ConcatenationWithScaling( \ + op_params, all_inputs.shapes(), all_inputs.data(), \ + GetTensorShape(output), GetTensorData(output)); \ + } \ } switch (output->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, float); - } else { - TF_LITE_CONCATENATION(optimized_ops, float); - } + TF_LITE_CONCATENATION(float); break; case kTfLiteInt32: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int32); - } else { - TF_LITE_CONCATENATION(optimized_ops, int32); - } + TF_LITE_CONCATENATION(int32); break; case kTfLiteUInt8: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION_QUANTIZED(reference_ops); - } else { - TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); - } + TF_LITE_CONCATENATION_QUANTIZED(); + break; + case kTfLiteInt8: + TF_LITE_CONCATENATION(int8_t); break; - case kTfLiteInt8: { - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int8_t); - } else { - TF_LITE_CONCATENATION(optimized_ops, int8_t); - } - } break; case kTfLiteInt64: - if (kernel_type == kReference) { - TF_LITE_CONCATENATION(reference_ops, int64_t); - } else { - TF_LITE_CONCATENATION(optimized_ops, int64_t); - } + TF_LITE_CONCATENATION(int64_t); break; default: - context->ReportError(context, - "Only float32 and uint8 are currently supported."); + context->ReportError(context, "Type '%s' is not supported currently.", + TfLiteTypeGetName(output->type)); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/concatenation_test.cc b/tensorflow/lite/kernels/concatenation_test.cc index f419a28aacb..f3eb4ab995c 100644 --- a/tensorflow/lite/kernels/concatenation_test.cc +++ b/tensorflow/lite/kernels/concatenation_test.cc @@ -327,6 +327,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) { })); } +TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits::max()}, + /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f}))); +} + +TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {1}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) { + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + QuantizedConcatenationOpModel m0( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + QuantizedConcatenationOpModel m0_negative( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + QuantizedConcatenationOpModel m1( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + QuantizedConcatenationOpModel m1_negative( + {TensorType_UINT8, {2, 3}, 0, std::numeric_limits::max()}, + /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + } // namespace } // namespace tflite