Add 5D support for TFLite Maximum Minimum

PiperOrigin-RevId: 302632680
Change-Id: I1d14cc1afe01888d731e3e68a398f2907e18f174
This commit is contained in:
Thai Nguyen 2020-03-24 03:34:58 -07:00 committed by TensorFlower Gardener
parent 99e754b3a1
commit ee940c2bea
14 changed files with 156 additions and 62 deletions

View File

@ -1637,8 +1637,7 @@ def TFL_MaxUnpooling2DOp :
} }
def TFL_MaximumOp : TFL_Op<"maximum", [ def TFL_MaximumOp : TFL_Op<"maximum", [
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale, ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Max operator"; let summary = "Max operator";
let description = [{ let description = [{
Element-wise max operation. Element-wise max operation.
@ -1836,8 +1835,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
} }
def TFL_MinimumOp : TFL_Op<"minimum", [ def TFL_MinimumOp : TFL_Op<"minimum", [
ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale, ResultsBroadcastableShape, NoSideEffect, Commutative, SameOperandsAndResultsScale]> {
TFL_OperandHasRankLessThan<0, 4>, TFL_OperandHasRankLessThan<1, 4>]> {
let summary = "Min operator"; let summary = "Min operator";
let description = [{ let description = [{
Element-wise min operation. Element-wise min operation.

View File

@ -2061,7 +2061,7 @@ bool NNAPIDelegateKernel::Validate(
} break; } break;
case kTfLiteBuiltinMaximum: case kTfLiteBuiltinMaximum:
case kTfLiteBuiltinMinimum: { case kTfLiteBuiltinMinimum: {
ExpectMaxOpVersion(version, 2, &val_ctx); ExpectMaxOpVersion(version, 3, &val_ctx);
ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI12, ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI12,
&val_ctx); &val_ctx);
const auto input_type = context->tensors[node->inputs->data[0]].type; const auto input_type = context->tensors[node->inputs->data[0]].type;

View File

@ -2147,9 +2147,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims, const T* input2_data, const Dims<4>& input2_dims,
T* output_data, const Dims<4>& output_dims, T* output_data, const Dims<4>& output_dims,
Op op) { Op op) {
MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data, MaximumMinimumBroadcastSlow(DimsToShape(input1_dims), input1_data,
DimsToShape(input2_dims), input2_data, DimsToShape(input2_dims), input2_data,
DimsToShape(output_dims), output_data, op); DimsToShape(output_dims), output_data, op);
} }
template <typename T1, typename T2, typename T3> template <typename T1, typename T2, typename T3>

View File

@ -21,37 +21,40 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace reference_ops { namespace reference_ops {
template <typename T, typename Op> template <typename T, typename Op, int N = 5>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape, void MaximumMinimumBroadcastSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data, const T* input1_data,
const RuntimeShape& unextended_input2_shape, const RuntimeShape& unextended_input2_shape,
const T* input2_data, const T* input2_data,
const RuntimeShape& unextended_output_shape, const RuntimeShape& unextended_output_shape,
T* output_data, Op op) { T* output_data, Op op) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); // Uses element-wise calculation if broadcast is not required.
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); if (unextended_input1_shape == unextended_input2_shape) {
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); const int flat_size =
const RuntimeShape output_shape = MatchingElementsSize(unextended_input1_shape, unextended_input2_shape,
RuntimeShape::ExtendedShape(4, unextended_output_shape); unextended_output_shape);
for (int i = 0; i < flat_size; ++i) {
NdArrayDesc<4> desc1; output_data[i] = op(input1_data[i], input2_data[i]);
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
auto out_idx = Offset(output_shape, b, y, x, c);
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = op(in1_val, in2_val);
}
}
} }
} else {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
NdArrayDesc<N> desc1;
NdArrayDesc<N> desc2;
NdArrayDesc<N> output_desc;
NdArrayDescsForElementwiseBroadcast(
unextended_input1_shape, unextended_input2_shape, &desc1, &desc2);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
auto maxmin_func = [&](int indexes[N]) {
output_data[SubscriptToIndex(output_desc, indexes)] =
op(input1_data[SubscriptToIndex(desc1, indexes)],
input2_data[SubscriptToIndex(desc2, indexes)]);
};
NDOpsHelper<N>(output_desc, maxmin_func);
} }
} }

View File

@ -88,7 +88,7 @@ struct MinimumOp {
template <typename data_type, typename op_type> template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node, void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) { const OpContext& op_context) {
reference_ops::MaximumMinimumBroadcast4DSlow( reference_ops::MaximumMinimumBroadcastSlow(
GetTensorShape(op_context.input1), GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1), GetTensorData<data_type>(op_context.input1),
GetTensorShape(op_context.input2), GetTensorShape(op_context.input2),

View File

@ -186,5 +186,46 @@ TEST(MaximumOpTest, Int32WithBroadcastTest_ScalarY) {
{TensorType_INT32, {}}, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {}}, {TensorType_INT32, {3, 1, 2}},
data1, data2, {1, 0, -1, -2, 2, 2}, /*is_constant=*/true); data1, data2, {1, 0, -1, -2, 2, 2}, /*is_constant=*/true);
} }
TEST(MaxMinOpTest, Int8Test8D) {
std::initializer_list<int8_t> data1 = {1, 0, 2, 11, 2, 23};
std::initializer_list<int8_t> data2 = {0, 0, 1, 12, 123, 1};
TestModel<int8_t>(BuiltinOperator_MAXIMUM,
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}},
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}},
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}}, data1, data2,
{1, 0, 2, 12, 123, 23});
TestModel<int8_t>(BuiltinOperator_MINIMUM,
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}},
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}},
{TensorType_INT8, {3, 1, 2, 1, 1, 1, 1, 1}}, data1, data2,
{0, 0, 1, 11, 2, 1});
}
TEST(MaximumOpTest, FloatWithBroadcastTest5D) {
std::initializer_list<float> data1 = {1.0, 0.0, -1.0, -2.0, -1.44, 11.0};
std::initializer_list<float> data2 = {0.5, 2.0};
TestModel<float>(
BuiltinOperator_MAXIMUM, {TensorType_FLOAT32, {3, 1, 1, 1, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {3, 1, 1, 1, 2}}, data1,
data2, {1.0, 2.0, 0.5, 2.0, 0.5, 11.0});
TestModel<float>(
BuiltinOperator_MINIMUM, {TensorType_FLOAT32, {3, 1, 1, 1, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {3, 1, 1, 1, 2}}, data1,
data2, {0.5, 0.0, -1.0, -2.0, -1.44, 2.0});
}
TEST(MaximumOpTest, Int32WithBroadcastTest5D) {
std::initializer_list<int32_t> data1 = {1, 0, -1, -2, 3, 11};
std::initializer_list<int32_t> data2 = {2};
TestModel<int32_t>(
BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2, 1, 1}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2, 1, 1}}, data1,
data2, {2, 2, 2, 2, 3, 11});
TestModel<int32_t>(
BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2, 1, 1}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2, 1, 1}}, data1,
data2, {1, 0, -1, -2, 2, 2});
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -164,10 +164,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(),
/* min_version */ 1, /* min_version */ 1,
/* max_version */ 3); /* max_version */ 4);
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(),
/* min_version */ 1, /* min_version */ 1,
/* max_version */ 3); /* max_version */ 4);
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(), AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(),
/* min_version */ 1, /* min_version */ 1,
/* max_version */ 2); /* max_version */ 2);

View File

@ -68,7 +68,7 @@ struct MinimumOp {
template <typename data_type, typename op_type> template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node, void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) { const OpContext& op_context) {
reference_ops::MaximumMinimumBroadcast4DSlow( reference_ops::MaximumMinimumBroadcastSlow(
GetTensorShape(op_context.input1), GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1), GetTensorData<data_type>(op_context.input1),
GetTensorShape(op_context.input2), GetTensorShape(op_context.input2),

View File

@ -29,8 +29,10 @@ def make_maximum_tests(options):
test_parameters = [{ test_parameters = [{
"input_dtype": [tf.float32], "input_dtype": [tf.float32],
"input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3],
"input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], [5, 32, 32, 3, 1], [5, 32, 32, 3, 1]],
"input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3],
[5, 32, 32, 3, 3], [5, 32, 32, 3, 1]],
"fully_quantize": [False, True], "fully_quantize": [False, True],
}] }]
@ -69,4 +71,4 @@ def make_maximum_tests(options):
test_parameters, test_parameters,
build_graph, build_graph,
build_inputs, build_inputs,
expected_tf_failures=16) expected_tf_failures=44)

View File

@ -29,8 +29,10 @@ def make_minimum_tests(options):
test_parameters = [{ test_parameters = [{
"input_dtype": [tf.float32], "input_dtype": [tf.float32],
"input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3],
"input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], [5, 32, 32, 1, 1], [5, 32, 32, 1, 1]],
"input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3],
[5, 32, 32, 1, 1], [5, 32, 32, 1, 3]],
"fully_quantize": [False, True], "fully_quantize": [False, True],
}] }]
@ -69,4 +71,4 @@ def make_minimum_tests(options):
test_parameters, test_parameters,
build_graph, build_graph,
build_inputs, build_inputs,
expected_tf_failures=16) expected_tf_failures=44)

View File

@ -274,10 +274,10 @@ class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
::tflite::OpSignature op_sig = ::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature); GetVersioningOpSig(builtin_op(), op_signature);
if (input1_array.has_shape() && input2_array.has_shape()) { if (input1_array.has_shape() && input2_array.has_shape()) {
op_sig.options.sub.num_dims = op_sig.options.broadcast.num_dims =
std::max(input1_array.shape().dimensions_count(), std::max(input1_array.shape().dimensions_count(),
input2_array.shape().dimensions_count()); input2_array.shape().dimensions_count());
op_sig.options.sub.need_broadcast = op_sig.options.broadcast.need_broadcast =
(input1_array.shape() != input2_array.shape()); (input1_array.shape() != input2_array.shape());
} }
return ::tflite::GetBuiltinOperatorVersion(op_sig); return ::tflite::GetBuiltinOperatorVersion(op_sig);

View File

@ -318,14 +318,17 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_MAXIMUM: case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM: case BuiltinOperator_MINIMUM:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
if (op_sig.input_types.at(0) == TensorType_INT16 && if (op_sig.input_types.at(0) == TensorType_INT16 &&
op_sig.output_types.at(0) == TensorType_INT16) { op_sig.output_types.at(0) == TensorType_INT16) {
return 4;
}
if (op_sig.options.broadcast.need_broadcast &&
op_sig.options.broadcast.num_dims > 4) {
return 3; return 3;
} }
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1; return 1;
case BuiltinOperator_PACK: case BuiltinOperator_PACK:
@ -357,8 +360,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1; return 1;
case BuiltinOperator_SUB: case BuiltinOperator_SUB:
if (op_sig.options.sub.need_broadcast && if (op_sig.options.broadcast.need_broadcast &&
op_sig.options.sub.num_dims > 4) { op_sig.options.broadcast.num_dims > 4) {
return 3; return 3;
} }
if (op_sig.input_types.at(0) == TensorType_INT8) { if (op_sig.input_types.at(0) == TensorType_INT8) {
@ -509,9 +512,12 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
op_sig.options.space_batch.num_dims = GetNumDims(subgraph, op, 0); op_sig.options.space_batch.num_dims = GetNumDims(subgraph, op, 0);
} break; } break;
case BuiltinOperator_SUB: { case BuiltinOperator_SUB:
op_sig.options.sub.need_broadcast = !HaveSameShapes(subgraph, op, 0, 1); case BuiltinOperator_MAXIMUM:
op_sig.options.sub.num_dims = case BuiltinOperator_MINIMUM: {
op_sig.options.broadcast.need_broadcast =
!HaveSameShapes(subgraph, op, 0, 1);
op_sig.options.broadcast.num_dims =
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1)); std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
} break; } break;

View File

@ -58,7 +58,7 @@ typedef struct {
struct { struct {
int32_t num_dims; int32_t num_dims;
bool need_broadcast; bool need_broadcast;
} sub; } broadcast;
} options; } options;
} OpSignature; } OpSignature;

View File

@ -221,11 +221,53 @@ TEST(OpVersionTest, VersioningL2NormTest) {
} }
TEST(OpVersionTest, VersioningMaxTest) { TEST(OpVersionTest, VersioningMaxTest) {
SimpleVersioningTest(BuiltinOperator_MAXIMUM); OpSignature fake_op_sig = {
.op = BuiltinOperator_MAXIMUM,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.broadcast.need_broadcast = true;
fake_op_sig.options.broadcast.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.broadcast.need_broadcast = false;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.options.broadcast.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_MAXIMUM,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
fake_op_sig.options.broadcast.need_broadcast = true;
fake_op_sig.options.broadcast.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.broadcast.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
} }
TEST(OpVersionTest, VersioningMinTest) { TEST(OpVersionTest, VersioningMinTest) {
SimpleVersioningTest(BuiltinOperator_MINIMUM); OpSignature fake_op_sig = {
.op = BuiltinOperator_MINIMUM,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.broadcast.need_broadcast = true;
fake_op_sig.options.broadcast.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.broadcast.need_broadcast = false;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig.options.broadcast.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_MINIMUM,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
fake_op_sig.options.broadcast.need_broadcast = true;
fake_op_sig.options.broadcast.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.broadcast.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
} }
TEST(OpVersionTest, VersioningMeanTest) { TEST(OpVersionTest, VersioningMeanTest) {