Perform MEAN axis check while choosing nodes to delegate
PiperOrigin-RevId: 323448703 Change-Id: If1282ca7bb3faf66769b3f4e8324082a540ffa43
This commit is contained in:
parent
dd3ce26d7b
commit
247db5c30e
@ -2562,8 +2562,23 @@ class MeanOperationParser : public TFLiteOperationParser {
|
|||||||
absl::Status IsSupported(const TfLiteContext* context,
|
absl::Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
/*outputs=*/1);
|
/*runtime_inputs=*/1,
|
||||||
|
/*outputs=*/1));
|
||||||
|
|
||||||
|
// Simple mechanism to check if MEAN is to be performed only on HW plane.
|
||||||
|
auto* axes = &context->tensors[tflite_node->inputs->data[1]];
|
||||||
|
if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
|
||||||
|
return absl::UnimplementedError("Mean has unsupported tensor for axes");
|
||||||
|
}
|
||||||
|
auto* axes_data = axes->data.i32;
|
||||||
|
const bool is_hw_mean = tflite::NumElements(axes) == 2 &&
|
||||||
|
((axes_data[0] == 1 && axes_data[1] == 2) ||
|
||||||
|
(axes_data[0] == 2 && axes_data[1] == 1));
|
||||||
|
if (!is_hw_mean) {
|
||||||
|
return absl::UnimplementedError("Mean operation supports only HW plane");
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||||
|
Loading…
Reference in New Issue
Block a user