Perform MEAN axis check while choosing nodes to delegate

PiperOrigin-RevId: 323448703
Change-Id: If1282ca7bb3faf66769b3f4e8324082a540ffa43
This commit is contained in:
Sachin Joglekar 2020-07-27 15:02:02 -07:00 committed by TensorFlower Gardener
parent dd3ce26d7b
commit 247db5c30e

View File

@ -2562,8 +2562,23 @@ class MeanOperationParser : public TFLiteOperationParser {
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
// Simple mechanism to check if MEAN is to be performed only on HW plane.
auto* axes = &context->tensors[tflite_node->inputs->data[1]];
if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
return absl::UnimplementedError("Mean has unsupported tensor for axes");
}
auto* axes_data = axes->data.i32;
const bool is_hw_mean = tflite::NumElements(axes) == 2 &&
((axes_data[0] == 1 && axes_data[1] == 2) ||
(axes_data[0] == 2 && axes_data[1] == 1));
if (!is_hw_mean) {
return absl::UnimplementedError("Mean operation supports only HW plane");
}
return absl::OkStatus();
}
absl::Status Parse(const TfLiteNode* tflite_node,