diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 830dce468f5..51eca51a4fe 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -449,8 +449,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 4; } } - if (op_sig.options.broadcast.need_broadcast && - op_sig.options.broadcast.num_dims > 4) { + if (op_sig.options.addsub.need_broadcast && + op_sig.options.addsub.num_dims > 4) { return 3; } if (op_sig.input_types.at(0) == TensorType_INT8) { @@ -648,14 +648,13 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, case BuiltinOperator_SUB: { auto sub_option = op->builtin_options_as_SubOptions(); op_sig.options.addsub.pot_scale_int16 = false; + op_sig.options.addsub.need_broadcast = false; + op_sig.options.addsub.num_dims = 1; if (sub_option) { op_sig.options.addsub.pot_scale_int16 = sub_option->pot_scale_int16(); - } - - if (op_code->builtin_code() == BuiltinOperator_SUB) { - op_sig.options.broadcast.need_broadcast = + op_sig.options.addsub.need_broadcast = !HaveSameShapes(subgraph, op, 0, 1); - op_sig.options.broadcast.num_dims = + op_sig.options.addsub.num_dims = std::max(GetNumDims(subgraph, op, 0), GetNumDims(subgraph, op, 1)); } } break; diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h index 4582d08a879..0e8c086eb14 100644 --- a/tensorflow/lite/tools/versioning/op_version.h +++ b/tensorflow/lite/tools/versioning/op_version.h @@ -62,6 +62,8 @@ typedef struct { } broadcast; struct { bool pot_scale_int16; + int32_t num_dims; + bool need_broadcast; } addsub; } options; } OpSignature;