Update op_version for split, it should be checking for input[1] which is the input. Input[0] is the axis
PiperOrigin-RevId: 303214325 Change-Id: Id93c8a66d173fd6af460a0db8263268d50aa5f6b
This commit is contained in:
parent
dce27acd28
commit
3009664be0
@ -216,10 +216,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
|
|
||||||
case BuiltinOperator_SPLIT:
|
case BuiltinOperator_SPLIT:
|
||||||
// If the op take int8 input, it is version 2, for int32 it's version 3.
|
// If the op take int8 input, it is version 2, for int32 it's version 3.
|
||||||
if (op_sig.input_types.at(0) == TensorType_INT32) {
|
// The input tensor is at index 1 not 0, 0 is the axis.
|
||||||
|
if (op_sig.input_types.at(1) == TensorType_INT32) {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
if (op_sig.input_types.at(1) == TensorType_INT8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user