Merge pull request #26134 from ANSHUMAN87:concat-refactor
PiperOrigin-RevId: 246571812
This commit is contained in:
commit
23f0cac181
@ -111,18 +111,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// allocate and populate these during Prepare().
|
// allocate and populate these during Prepare().
|
||||||
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
// TODO(ycling): Activation function parameter is ignored. For now we dont have
|
||||||
// a model with a Concatenation with fused activation function.
|
// a model with a Concatenation with fused activation function.
|
||||||
#define TF_LITE_CONCATENATION(type, scalar) \
|
#define TF_LITE_CONCATENATION(scalar) \
|
||||||
{ \
|
{ \
|
||||||
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
|
||||||
tflite::ConcatenationParams op_params; \
|
tflite::ConcatenationParams op_params; \
|
||||||
op_params.axis = axis; \
|
op_params.axis = axis; \
|
||||||
op_params.inputs_count = node->inputs->size; \
|
op_params.inputs_count = node->inputs->size; \
|
||||||
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
|
if (kernel_type == kReference) { \
|
||||||
GetTensorShape(output), \
|
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||||
|
all_inputs.data(), GetTensorShape(output), \
|
||||||
GetTensorData<scalar>(output)); \
|
GetTensorData<scalar>(output)); \
|
||||||
|
} else { \
|
||||||
|
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
|
||||||
|
all_inputs.data(), GetTensorShape(output), \
|
||||||
|
GetTensorData<scalar>(output)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
|
#define TF_LITE_CONCATENATION_QUANTIZED() \
|
||||||
{ \
|
{ \
|
||||||
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
|
||||||
tflite::ConcatenationParams op_params; \
|
tflite::ConcatenationParams op_params; \
|
||||||
@ -132,51 +138,37 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
op_params.inputs_count = node->inputs->size; \
|
op_params.inputs_count = node->inputs->size; \
|
||||||
op_params.output_zeropoint = output->params.zero_point; \
|
op_params.output_zeropoint = output->params.zero_point; \
|
||||||
op_params.output_scale = output->params.scale; \
|
op_params.output_scale = output->params.scale; \
|
||||||
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
|
if (kernel_type == kReference) { \
|
||||||
all_inputs.data(), GetTensorShape(output), \
|
reference_ops::ConcatenationWithScaling( \
|
||||||
GetTensorData<uint8>(output)); \
|
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||||
|
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||||
|
} else { \
|
||||||
|
optimized_ops::ConcatenationWithScaling( \
|
||||||
|
op_params, all_inputs.shapes(), all_inputs.data(), \
|
||||||
|
GetTensorShape(output), GetTensorData<uint8>(output)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (output->type) { // Already know in/outtypes are same.
|
switch (output->type) { // Already know in/outtypes are same.
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_CONCATENATION(float);
|
||||||
TF_LITE_CONCATENATION(reference_ops, float);
|
|
||||||
} else {
|
|
||||||
TF_LITE_CONCATENATION(optimized_ops, float);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_CONCATENATION(int32);
|
||||||
TF_LITE_CONCATENATION(reference_ops, int32);
|
|
||||||
} else {
|
|
||||||
TF_LITE_CONCATENATION(optimized_ops, int32);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_CONCATENATION_QUANTIZED();
|
||||||
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
|
break;
|
||||||
} else {
|
case kTfLiteInt8:
|
||||||
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
|
TF_LITE_CONCATENATION(int8_t);
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8: {
|
|
||||||
if (kernel_type == kReference) {
|
|
||||||
TF_LITE_CONCATENATION(reference_ops, int8_t);
|
|
||||||
} else {
|
|
||||||
TF_LITE_CONCATENATION(optimized_ops, int8_t);
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
if (kernel_type == kReference) {
|
TF_LITE_CONCATENATION(int64_t);
|
||||||
TF_LITE_CONCATENATION(reference_ops, int64_t);
|
|
||||||
} else {
|
|
||||||
TF_LITE_CONCATENATION(optimized_ops, int64_t);
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
context->ReportError(context,
|
context->ReportError(context, "Type '%s' is not supported currently.",
|
||||||
"Only float32 and uint8 are currently supported.");
|
TfLiteTypeGetName(output->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,6 +327,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) {
|
||||||
|
QuantizedConcatenationOpModel m0(
|
||||||
|
{TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/1,
|
||||||
|
/*num_inputs=*/1);
|
||||||
|
m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
|
||||||
|
m0.Invoke();
|
||||||
|
EXPECT_THAT(m0.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) {
|
||||||
|
QuantizedConcatenationOpModel m0(
|
||||||
|
{TensorType_UINT8, {1}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/0,
|
||||||
|
/*num_inputs=*/1);
|
||||||
|
m0.SetInput<uint8_t>(0, {5.0f});
|
||||||
|
m0.Invoke();
|
||||||
|
EXPECT_THAT(m0.GetOutput<uint8_t>(), ::testing::ElementsAre(5));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) {
|
||||||
|
QuantizedConcatenationOpModel m0(
|
||||||
|
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/0,
|
||||||
|
/*num_inputs=*/1);
|
||||||
|
m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
|
||||||
|
m0.Invoke();
|
||||||
|
EXPECT_THAT(m0.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
|
||||||
|
// We will concatenate two tensors along different dimensions.
|
||||||
|
auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||||
|
auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
|
||||||
|
|
||||||
|
QuantizedConcatenationOpModel m0(
|
||||||
|
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/0,
|
||||||
|
/*num_inputs=*/2);
|
||||||
|
m0.SetInput<uint8_t>(0, tensor0);
|
||||||
|
m0.SetInput<uint8_t>(1, tensor1);
|
||||||
|
m0.Invoke();
|
||||||
|
EXPECT_THAT(m0.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||||
|
|
||||||
|
QuantizedConcatenationOpModel m0_negative(
|
||||||
|
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/-2,
|
||||||
|
/*num_inputs=*/2);
|
||||||
|
m0_negative.SetInput<uint8_t>(0, tensor0);
|
||||||
|
m0_negative.SetInput<uint8_t>(1, tensor1);
|
||||||
|
m0_negative.Invoke();
|
||||||
|
EXPECT_THAT(m0_negative.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||||
|
|
||||||
|
QuantizedConcatenationOpModel m1(
|
||||||
|
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/1,
|
||||||
|
/*num_inputs=*/2);
|
||||||
|
m1.SetInput<uint8_t>(0, tensor0);
|
||||||
|
m1.SetInput<uint8_t>(1, tensor1);
|
||||||
|
m1.Invoke();
|
||||||
|
EXPECT_THAT(m1.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
|
||||||
|
|
||||||
|
QuantizedConcatenationOpModel m1_negative(
|
||||||
|
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
|
||||||
|
/*axis=*/-1,
|
||||||
|
/*num_inputs=*/2);
|
||||||
|
m1_negative.SetInput<uint8_t>(0, tensor0);
|
||||||
|
m1_negative.SetInput<uint8_t>(1, tensor1);
|
||||||
|
m1_negative.Invoke();
|
||||||
|
EXPECT_THAT(m1_negative.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user