Merge pull request #26134 from ANSHUMAN87:concat-refactor
PiperOrigin-RevId: 246571812
This commit is contained in:
commit
23f0cac181
@ -111,72 +111,64 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// allocate and populate these during Prepare().
|
||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
||||
// a model with a Concatenation with fused activation function.
|
||||
#define TF_LITE_CONCATENATION(type, scalar) \
|
||||
{ \
|
||||
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
||||
tflite::ConcatenationParams op_params; \
|
||||
op_params.axis = axis; \
|
||||
op_params.inputs_count = node->inputs->size; \
|
||||
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), \
|
||||
GetTensorData<scalar>(output)); \
|
||||
}
|
||||
|
||||
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
|
||||
#define TF_LITE_CONCATENATION(scalar) \
|
||||
{ \
|
||||
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
||||
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
||||
tflite::ConcatenationParams op_params; \
|
||||
op_params.axis = axis; \
|
||||
op_params.input_zeropoint = all_inputs.zero_point(); \
|
||||
op_params.input_scale = all_inputs.scale(); \
|
||||
op_params.inputs_count = node->inputs->size; \
|
||||
op_params.output_zeropoint = output->params.zero_point; \
|
||||
op_params.output_scale = output->params.scale; \
|
||||
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
|
||||
if (kernel_type == kReference) { \
|
||||
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||
all_inputs.data(), GetTensorShape(output), \
|
||||
GetTensorData<uint8>(output)); \
|
||||
GetTensorData<scalar>(output)); \
|
||||
} else { \
|
||||
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||
all_inputs.data(), GetTensorShape(output), \
|
||||
GetTensorData<scalar>(output)); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define TF_LITE_CONCATENATION_QUANTIZED() \
|
||||
{ \
|
||||
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
||||
tflite::ConcatenationParams op_params; \
|
||||
op_params.axis = axis; \
|
||||
op_params.input_zeropoint = all_inputs.zero_point(); \
|
||||
op_params.input_scale = all_inputs.scale(); \
|
||||
op_params.inputs_count = node->inputs->size; \
|
||||
op_params.output_zeropoint = output->params.zero_point; \
|
||||
op_params.output_scale = output->params.scale; \
|
||||
if (kernel_type == kReference) { \
|
||||
reference_ops::ConcatenationWithScaling( \
|
||||
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||
} else { \
|
||||
optimized_ops::ConcatenationWithScaling( \
|
||||
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||
} \
|
||||
}
|
||||
|
||||
switch (output->type) { // Already know in/outtypes are same.
|
||||
case kTfLiteFloat32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_CONCATENATION(reference_ops, float);
|
||||
} else {
|
||||
TF_LITE_CONCATENATION(optimized_ops, float);
|
||||
}
|
||||
TF_LITE_CONCATENATION(float);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_CONCATENATION(reference_ops, int32);
|
||||
} else {
|
||||
TF_LITE_CONCATENATION(optimized_ops, int32);
|
||||
}
|
||||
TF_LITE_CONCATENATION(int32);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
|
||||
} else {
|
||||
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
|
||||
}
|
||||
TF_LITE_CONCATENATION_QUANTIZED();
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_CONCATENATION(int8_t);
|
||||
break;
|
||||
case kTfLiteInt8: {
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_CONCATENATION(reference_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_CONCATENATION(optimized_ops, int8_t);
|
||||
}
|
||||
} break;
|
||||
case kTfLiteInt64:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_CONCATENATION(reference_ops, int64_t);
|
||||
} else {
|
||||
TF_LITE_CONCATENATION(optimized_ops, int64_t);
|
||||
}
|
||||
TF_LITE_CONCATENATION(int64_t);
|
||||
break;
|
||||
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32 and uint8 are currently supported.");
|
||||
context->ReportError(context, "Type '%s' is not supported currently.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
|
@ -327,6 +327,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) {
|
||||
QuantizedConcatenationOpModel m0(
|
||||
{TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/1,
|
||||
/*num_inputs=*/1);
|
||||
m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(m0.GetOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f})));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) {
|
||||
QuantizedConcatenationOpModel m0(
|
||||
{TensorType_UINT8, {1}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/0,
|
||||
/*num_inputs=*/1);
|
||||
m0.SetInput<uint8_t>(0, {5.0f});
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(m0.GetOutput<uint8_t>(), ::testing::ElementsAre(5));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) {
|
||||
QuantizedConcatenationOpModel m0(
|
||||
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/0,
|
||||
/*num_inputs=*/1);
|
||||
m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(m0.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
|
||||
// We will concatenate two tensors along different dimensions.
|
||||
auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
|
||||
|
||||
QuantizedConcatenationOpModel m0(
|
||||
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/0,
|
||||
/*num_inputs=*/2);
|
||||
m0.SetInput<uint8_t>(0, tensor0);
|
||||
m0.SetInput<uint8_t>(1, tensor1);
|
||||
m0.Invoke();
|
||||
EXPECT_THAT(m0.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||
|
||||
QuantizedConcatenationOpModel m0_negative(
|
||||
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/-2,
|
||||
/*num_inputs=*/2);
|
||||
m0_negative.SetInput<uint8_t>(0, tensor0);
|
||||
m0_negative.SetInput<uint8_t>(1, tensor1);
|
||||
m0_negative.Invoke();
|
||||
EXPECT_THAT(m0_negative.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||
|
||||
QuantizedConcatenationOpModel m1(
|
||||
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/1,
|
||||
/*num_inputs=*/2);
|
||||
m1.SetInput<uint8_t>(0, tensor0);
|
||||
m1.SetInput<uint8_t>(1, tensor1);
|
||||
m1.Invoke();
|
||||
EXPECT_THAT(m1.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
|
||||
|
||||
QuantizedConcatenationOpModel m1_negative(
|
||||
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||
/*axis=*/-1,
|
||||
/*num_inputs=*/2);
|
||||
m1_negative.SetInput<uint8_t>(0, tensor0);
|
||||
m1_negative.SetInput<uint8_t>(1, tensor1);
|
||||
m1_negative.Invoke();
|
||||
EXPECT_THAT(m1_negative.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user