Added broadcast params to addsub structure.

Change-Id: I61d7d4a94087d052a782890799211031f6ed3015
This commit is contained in:
Elena Zhelezina 2020-06-16 18:16:05 +01:00
parent bec6f3b14f
commit 3c219a46ce
2 changed files with 8 additions and 7 deletions

View File

@ -449,8 +449,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 4;
}
}
if (op_sig.options.broadcast.need_broadcast &&
op_sig.options.broadcast.num_dims > 4) {
if (op_sig.options.addsub.need_broadcast &&
op_sig.options.addsub.num_dims > 4) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
@ -648,14 +648,13 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
case BuiltinOperator_SUB: {
auto sub_option = op->builtin_options_as_SubOptions();
op_sig.options.addsub.pot_scale_int16 = false;
op_sig.options.addsub.need_broadcast = false;
op_sig.options.addsub.num_dims = 1;
if (sub_option) {
op_sig.options.addsub.pot_scale_int16 = sub_option->pot_scale_int16();
}
if (op_code->builtin_code() == BuiltinOperator_SUB) {
op_sig.options.broadcast.need_broadcast =
op_sig.options.addsub.need_broadcast =
!HaveSameShapes(subgraph, op, 0, 1);
op_sig.options.broadcast.num_dims =
op_sig.options.addsub.num_dims =
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
}
} break;

View File

@ -62,6 +62,8 @@ typedef struct {
} broadcast;
struct {
bool pot_scale_int16;
int32_t num_dims;
bool need_broadcast;
} addsub;
} options;
} OpSignature;