Merge pull request #26134 from ANSHUMAN87:concat-refactor

PiperOrigin-RevId: 246571812
This commit is contained in:
TensorFlower Gardener 2019-05-03 14:56:39 -07:00
commit 23f0cac181
2 changed files with 117 additions and 48 deletions

View File

@ -111,18 +111,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(type, scalar) \
#define TF_LITE_CONCATENATION(scalar) \
{ \
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.inputs_count = node->inputs->size; \
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), \
if (kernel_type == kReference) { \
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} else { \
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} \
}
#define TF_LITE_CONCATENATION_QUANTIZED(type) \
#define TF_LITE_CONCATENATION_QUANTIZED() \
{ \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
@ -132,51 +138,37 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \
op_params.output_scale = output->params.scale; \
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<uint8>(output)); \
if (kernel_type == kReference) { \
reference_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} else { \
optimized_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} \
}
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, float);
} else {
TF_LITE_CONCATENATION(optimized_ops, float);
}
TF_LITE_CONCATENATION(float);
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int32);
} else {
TF_LITE_CONCATENATION(optimized_ops, int32);
}
TF_LITE_CONCATENATION(int32);
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
} else {
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
}
TF_LITE_CONCATENATION_QUANTIZED();
break;
case kTfLiteInt8:
TF_LITE_CONCATENATION(int8_t);
break;
case kTfLiteInt8: {
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int8_t);
} else {
TF_LITE_CONCATENATION(optimized_ops, int8_t);
}
} break;
case kTfLiteInt64:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int64_t);
} else {
TF_LITE_CONCATENATION(optimized_ops, int64_t);
}
TF_LITE_CONCATENATION(int64_t);
break;
default:
context->ReportError(context,
"Only float32 and uint8 are currently supported.");
context->ReportError(context, "Type '%s' is not supported currently.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

View File

@ -327,6 +327,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
}));
}
TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/1,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f})));
}
TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {1}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {5.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(), ::testing::ElementsAre(5));
}
TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
// We will concatenate two tensors along different dimensions.
auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/2);
m0.SetInput<uint8_t>(0, tensor0);
m0.SetInput<uint8_t>(1, tensor1);
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
QuantizedConcatenationOpModel m0_negative(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/-2,
/*num_inputs=*/2);
m0_negative.SetInput<uint8_t>(0, tensor0);
m0_negative.SetInput<uint8_t>(1, tensor1);
m0_negative.Invoke();
EXPECT_THAT(m0_negative.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
QuantizedConcatenationOpModel m1(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/1,
/*num_inputs=*/2);
m1.SetInput<uint8_t>(0, tensor0);
m1.SetInput<uint8_t>(1, tensor1);
m1.Invoke();
EXPECT_THAT(m1.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
QuantizedConcatenationOpModel m1_negative(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/-1,
/*num_inputs=*/2);
m1_negative.SetInput<uint8_t>(0, tensor0);
m1_negative.SetInput<uint8_t>(1, tensor1);
m1_negative.Invoke();
EXPECT_THAT(m1_negative.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
} // namespace
} // namespace tflite