Added broadcast params to addsub structure.
Change-Id: I61d7d4a94087d052a782890799211031f6ed3015
This commit is contained in:
parent
bec6f3b14f
commit
3c219a46ce
@ -449,8 +449,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
if (op_sig.options.broadcast.need_broadcast &&
|
||||
op_sig.options.broadcast.num_dims > 4) {
|
||||
if (op_sig.options.addsub.need_broadcast &&
|
||||
op_sig.options.addsub.num_dims > 4) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
@ -648,14 +648,13 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
case BuiltinOperator_SUB: {
|
||||
auto sub_option = op->builtin_options_as_SubOptions();
|
||||
op_sig.options.addsub.pot_scale_int16 = false;
|
||||
op_sig.options.addsub.need_broadcast = false;
|
||||
op_sig.options.addsub.num_dims = 1;
|
||||
if (sub_option) {
|
||||
op_sig.options.addsub.pot_scale_int16 = sub_option->pot_scale_int16();
|
||||
}
|
||||
|
||||
if (op_code->builtin_code() == BuiltinOperator_SUB) {
|
||||
op_sig.options.broadcast.need_broadcast =
|
||||
op_sig.options.addsub.need_broadcast =
|
||||
!HaveSameShapes(subgraph, op, 0, 1);
|
||||
op_sig.options.broadcast.num_dims =
|
||||
op_sig.options.addsub.num_dims =
|
||||
std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1));
|
||||
}
|
||||
} break;
|
||||
|
@ -62,6 +62,8 @@ typedef struct {
|
||||
} broadcast;
|
||||
struct {
|
||||
bool pot_scale_int16;
|
||||
int32_t num_dims;
|
||||
bool need_broadcast;
|
||||
} addsub;
|
||||
} options;
|
||||
} OpSignature;
|
||||
|
Loading…
Reference in New Issue
Block a user