Update op_version for split, it should be checking for input[1] which is the input. Input[0] is the axis
PiperOrigin-RevId: 303214325 Change-Id: Id93c8a66d173fd6af460a0db8263268d50aa5f6b
This commit is contained in:
parent
dce27acd28
commit
3009664be0
@ -216,10 +216,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
|
||||
case BuiltinOperator_SPLIT:
|
||||
// If the op take int8 input, it is version 2, for int32 it's version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT32) {
|
||||
// The input tensor is at index 1 not 0, 0 is the axis.
|
||||
if (op_sig.input_types.at(1) == TensorType_INT32) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
if (op_sig.input_types.at(1) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user