Perform MEAN axis check while choosing nodes to delegate
PiperOrigin-RevId: 323448703 Change-Id: If1282ca7bb3faf66769b3f4e8324082a540ffa43
This commit is contained in:
parent
dd3ce26d7b
commit
247db5c30e
@ -2562,8 +2562,23 @@ class MeanOperationParser : public TFLiteOperationParser {
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
||||
/*outputs=*/1);
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/1,
|
||||
/*outputs=*/1));
|
||||
|
||||
// Simple mechanism to check if MEAN is to be performed only on HW plane.
|
||||
auto* axes = &context->tensors[tflite_node->inputs->data[1]];
|
||||
if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
|
||||
return absl::UnimplementedError("Mean has unsupported tensor for axes");
|
||||
}
|
||||
auto* axes_data = axes->data.i32;
|
||||
const bool is_hw_mean = tflite::NumElements(axes) == 2 &&
|
||||
((axes_data[0] == 1 && axes_data[1] == 2) ||
|
||||
(axes_data[0] == 2 && axes_data[1] == 1));
|
||||
if (!is_hw_mean) {
|
||||
return absl::UnimplementedError("Mean operation supports only HW plane");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
|
Loading…
Reference in New Issue
Block a user