Added check for unsupported Mul case.

PiperOrigin-RevId: 314569439
Change-Id: I788a87b6ac20cbd3f01d522029d330f9b6057e20
This commit is contained in:
Raman Sarokin 2020-06-03 11:27:03 -07:00 committed by TensorFlower Gardener
parent 442f5e2b80
commit e2d7d94549

View File

@ -1143,6 +1143,30 @@ class MulOperationParser : public TFLiteOperationParser {
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("MUL requires two input tensors.");
}
auto input0 = tflite::GetInput(context, tflite_node, 0);
auto input1 = tflite::GetInput(context, tflite_node, 1);
if (input0->dims->size == input1->dims->size) {
// this code checks that at least one input of Mul not smaller in all
// dimensions. Sometimes Mul used for matrix-vector multiplication that we
// currently don't support. For example input0 HWC(1, 256, 1), input1
// HWC(1, 1, 256) -> output HWC (1, 256, 256). In this case it can be
// replaced with Convolution operation.
bool first_has_smaller_dim = false;
bool second_has_smaller_dim = false;
for (int i = 0; i < input0->dims->size; ++i) {
if (input0->dims->data[i] < input1->dims->data[i]) {
first_has_smaller_dim = true;
}
if (input1->dims->data[i] < input0->dims->data[i]) {
second_has_smaller_dim = true;
}
}
if (first_has_smaller_dim && second_has_smaller_dim) {
return absl::UnimplementedError(
"MUL requires one tensor that not less than second in all "
"dimensions.");
}
}
const TfLiteMulParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
return IsActivationSupported(tf_options->activation);